Kotlin에서 딥러닝을 위한 KotlinDl 라이브러리가 나왔다. 딥러닝은 할 줄 모르니 Quick start guide 나 따라해보자.
1. 라이브러리 적용
// for Kotlin
dependencies {
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-tensorflow:0.5.0")
// options
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-visualization:0.5.0")
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-onnx:0.5.0")
}
// for Android
dependencies {
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-onnx:0.5.0")
}
2. 모델 학습시키기
fun main(){
val (train, test) = fashionMnist()
...
}
일단 학습을 시킬 데이터가 필요하다.
KotlinDl 에는 몇 가지 데이터셋이 다운로드할 수 있는 형태로 내장되어 있다.
요청이 발생할 때 알아서 다운로드해서 저장해준다.
- MNIST : 숫자
- Fashion-MNIST : 의류
- MNIST 3D : 3D 숫자
- Free Spoken Digits : 숫자 음성
- Cifar'10 : 엄청난 양의 이미지
- Dogs-vs-Cats : 개와 고양이
커스텀한 데이터가 필요한 경우OnHeapDataset
: 데이터를 Heap Memory에 올림OnFlyImageDataset
: 데이터를 Disk에 올림
를 쓰라고 되어 있다.
val model = Sequential.of(
Input(28, 28, 1),
Conv2D(filters = 32, kernelSize = intArrayOf(3, 3)),
Conv2D(filters = 64, kernelSize = intArrayOf(3, 3)),
MaxPool2D(poolSize = 2),
Dropout(0.25f),
Flatten(),
Dense(128),
Dropout(0.5f),
Dense(10, activation = Activations.Softmax)
)
예제의 학습 모델은 너무 심심해 보여서 버리고, 봐도 모르겠지만 일단 그럴듯한 걸로 가져왔다.
fun main() {
val (train, test) = fashionMnist()
model.use {
it.compile(
optimizer = Adam(),
loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
metric = Metrics.ACCURACY
)
it.printSummary()
...
}
}
use
는 기존에 kotlin에서 제공해주는 auto close block
compile
은 모델을 만들기 시작하고, 모델이 잘못되면 여기서 에러를 던져준다.
printSummary
은 모델에 대한 정보를 출력시켜줍니다.
fun main() {
val (train, test) = fashionMnist()
model.use {
...
// 학습 시작
it.fit(
dataset = train,
epochs = 4,
batchSize = 100,
listOf(object : Callback(){
override fun onEpochBegin(epoch: Int, logs: TrainingHistory) {
println("onEpochBegin $epoch")
super.onEpochBegin(epoch, logs)
}
override fun onEpochEnd(epoch: Int, event: EpochTrainingEvent, logs: TrainingHistory) {
println("onEpochEnd $epoch")
super.onEpochEnd(epoch, event, logs)
}
})
)
val accuracy = it.evaluate(dataset = test, batchSize = 100).metrics[Metrics.ACCURACY]
println("Accuracy: $accuracy")
...
}
}
fit
을 사용하면 학습을 시작한다.
학습 과정에서 별도의 프로그래스나 출력이 없지만, 직접 추가할 수 있도록 열어준다.
evaluate
를 사용하면 학습된 모델을 테스트할 수 있다.
fun main() {
val (train, test) = fashionMnist()
model.use {
...
it.save(File("model/my_model"), writingMode = WritingMode.OVERRIDE)
}
}
save
로 모델을 저장할 수 있다.
제일 기대한 게 Android의 한 구석에서 학습시키고, 그걸 가져다 쓰는 건데,
Exception in thread "main" java.lang.NoClassDefFoundError: org/jetbrains/kotlinx/dl/impl/util/ByteArrayUtilKt
at org.jetbrains.kotlinx.dl.dataset.embedded.MnistUtilKt.extractImages(MnistUtil.kt:58)
안타깝게도 학습 과정은 Android에서 진행할 수 없다.
Android에서는 기본 제공해주는 DataSet을 추출하는 과정에서 위의 에러와 함께 터져버린다.
( 커스텀한 데이터셋을 쓰거나 다른 방법이 어딘가에는 있을 수 있다. )
3. 모델 가져다 쓰기
val stringLabels = mapOf(
0 to "T-shirt/top",
1 to "Trousers",
2 to "Pullover",
3 to "Dress",
4 to "Coat",
5 to "Sandals",
6 to "Shirt",
7 to "Sneakers",
8 to "Bag",
9 to "Ankle boots"
)
TensorFlowInferenceModel.load(File("model/my_model")) .use {
// 데이터셋에서 하나의 데이터의 크기 지정
it.reshape(28, 28, 1)
// 하나의 데이터 들고와서 측정 X : 데이터 Y : 레이블
val prediction = it.predict(test.getX(0))
val actualLabel = test.getY(0)
println("Predicted label is: $prediction. This corresponds to class ${stringLabels[prediction]}.")
println("Actual label is: $actualLabel.")
}
학습한 결과는 TensorFlowInferenceModel
를 통해서 가져와 확인할 수 있습니다.
4. 모델 관련 정보 화면 출력
visualization 라이브러리에는 화면 출력용 함수도 Swing 코드가 포함되어 있다.
fun main() {
val (train, test) = fashionMnist()
model.use {
...
val fashionPlots = List(12) { imageIndex ->
flattenImagePlot(
imageIndex, test,
predict = it::predict,
labelEncoding = stringLabels::get,
plotFeature = PlotFeature.GRAY
)
}
columnPlot(fashionPlots, 3, 512).show()
filtersPlot(it.layers[2] as Conv2D).show()
...
}
}
show를 하면 화면을 그려주는 하나의 앱을 열어서 결과를 보여준다.
창이 너무 큰 경우에는 Swing 이 깨지는 경우가 있는데, 창 크기를 줄였다가 키우면 제대로 보여준다.
5. 마무리
다른 딥러닝 언어나 라이브러리를 찾아봤는데, 거진 다 비슷하게 생긴 거 같아서
Kotlin이라는 것 말고는 크게 메리트를 못 찾았다.
안드로이드 용으로 이미 학습된 모델을 제공해주는 데, 그건 다음 글이다.
'코틀린' 카테고리의 다른 글
이펙티브 코틀린 - 음..? (1) | 2022.03.29 |
---|---|
Kotlin Channel - 코루틴간 데이터 통신 (0) | 2020.07.30 |
Kotlin Flow #4 - onXXX 함수와 예외처리 (0) | 2020.07.11 |
Kotlin Flow #3 - Zip And Combine (0) | 2020.07.10 |
Kotlin Flow #2 - 연산자들 (0) | 2020.07.08 |