본문 바로가기

코틀린

KotlinDl - 코틀린 버전의 딥러닝

반응형
 

GitHub - Kotlin/kotlindl: High-level Deep Learning Framework written in Kotlin and inspired by Keras

High-level Deep Learning Framework written in Kotlin and inspired by Keras - GitHub - Kotlin/kotlindl: High-level Deep Learning Framework written in Kotlin and inspired by Keras

github.com

Kotlin에서 딥러닝을 위한 KotlinDl 라이브러리가 나왔다. 딥러닝은 할 줄 모르니 Quick start guide 나 따라해보자.

1. 라이브러리 적용

// for Kotlin
dependencies {
    implementation("org.jetbrains.kotlinx:kotlin-deeplearning-tensorflow:0.5.0")

    // options
    implementation("org.jetbrains.kotlinx:kotlin-deeplearning-visualization:0.5.0")
    implementation("org.jetbrains.kotlinx:kotlin-deeplearning-onnx:0.5.0")
}

// for Android
dependencies {
    implementation("org.jetbrains.kotlinx:kotlin-deeplearning-onnx:0.5.0")
}

2. 모델 학습시키기

fun main(){
    val (train, test) = fashionMnist()

    ...
}

일단 학습을 시킬 데이터가 필요하다.

KotlinDl 에는 몇 가지 데이터셋이 다운로드할 수 있는 형태로 내장되어 있다.
요청이 발생할 때 알아서 다운로드해서 저장해준다.

  • MNIST : 숫자
  • Fashion-MNIST : 의류
  • MNIST 3D : 3D 숫자
  • Free Spoken Digits : 숫자 음성
  • Cifar'10 : 엄청난 양의 이미지
  • Dogs-vs-Cats : 개와 고양이

커스텀한 데이터가 필요한 경우
OnHeapDataset : 데이터를 Heap Memory에 올림
OnFlyImageDataset : 데이터를 Disk에 올림
를 쓰라고 되어 있다.

val model = Sequential.of(
    Input(28, 28, 1),
    Conv2D(filters = 32, kernelSize = intArrayOf(3, 3)),
    Conv2D(filters = 64, kernelSize = intArrayOf(3, 3)),
    MaxPool2D(poolSize = 2),
    Dropout(0.25f),
    Flatten(),
    Dense(128),
    Dropout(0.5f),
    Dense(10, activation = Activations.Softmax)
)

예제의 학습 모델은 너무 심심해 보여서 버리고, 봐도 모르겠지만 일단 그럴듯한 걸로 가져왔다.

fun main() {
    val (train, test) = fashionMnist()

    model.use {
        it.compile(
            optimizer = Adam(),
            loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
            metric = Metrics.ACCURACY
        )

        it.printSummary()
        ...
    }
}

use 는 기존에 kotlin에서 제공해주는 auto close block

compile 은 모델을 만들기 시작하고, 모델이 잘못되면 여기서 에러를 던져준다.

printSummary 은 모델에 대한 정보를 출력시켜줍니다.

모델에 대한 정보를 출력해준다.

fun main() {
    val (train, test) = fashionMnist()

    model.use {
        ...

          // 학습 시작
        it.fit(
            dataset = train,
            epochs = 4,
            batchSize = 100,
            listOf(object : Callback(){
                override fun onEpochBegin(epoch: Int, logs: TrainingHistory) {
                    println("onEpochBegin $epoch")
                    super.onEpochBegin(epoch, logs)
                }

                override fun onEpochEnd(epoch: Int, event: EpochTrainingEvent, logs: TrainingHistory) {
                    println("onEpochEnd $epoch")
                    super.onEpochEnd(epoch, event, logs)
                }
            })
        )

        val accuracy = it.evaluate(dataset = test, batchSize = 100).metrics[Metrics.ACCURACY]
        println("Accuracy: $accuracy")
        ...
    }
}

fit 을 사용하면 학습을 시작한다.
학습 과정에서 별도의 프로그래스나 출력이 없지만, 직접 추가할 수 있도록 열어준다.

 

evaluate 를 사용하면 학습된 모델을 테스트할 수 있다.

fun main() {
    val (train, test) = fashionMnist()

    model.use {
        ...

        it.save(File("model/my_model"), writingMode = WritingMode.OVERRIDE)
    }
}

save 로 모델을 저장할 수 있다.


제일 기대한 게 Android의 한 구석에서 학습시키고, 그걸 가져다 쓰는 건데,

Exception in thread "main" java.lang.NoClassDefFoundError: org/jetbrains/kotlinx/dl/impl/util/ByteArrayUtilKt
    at org.jetbrains.kotlinx.dl.dataset.embedded.MnistUtilKt.extractImages(MnistUtil.kt:58)

안타깝게도 학습 과정은 Android에서 진행할 수 없다.
Android에서는 기본 제공해주는 DataSet을 추출하는 과정에서 위의 에러와 함께 터져버린다.
( 커스텀한 데이터셋을 쓰거나 다른 방법이 어딘가에는 있을 수 있다. )

3. 모델 가져다 쓰기

val stringLabels = mapOf(
    0 to "T-shirt/top",
    1 to "Trousers",
    2 to "Pullover",
    3 to "Dress",
    4 to "Coat",
    5 to "Sandals",
    6 to "Shirt",
    7 to "Sneakers",
    8 to "Bag",
    9 to "Ankle boots"
)

TensorFlowInferenceModel.load(File("model/my_model")) .use {
    // 데이터셋에서 하나의 데이터의 크기 지정
    it.reshape(28, 28, 1)

    // 하나의 데이터 들고와서 측정 X : 데이터 Y : 레이블
    val prediction = it.predict(test.getX(0))
    val actualLabel = test.getY(0)

    println("Predicted label is: $prediction. This corresponds to class ${stringLabels[prediction]}.")
    println("Actual label is: $actualLabel.")
}

학습한 결과는 TensorFlowInferenceModel 를 통해서 가져와 확인할 수 있습니다.

4. 모델 관련 정보 화면 출력

visualization 라이브러리에는 화면 출력용 함수도 Swing 코드가 포함되어 있다.

fun main() {
    val (train, test) = fashionMnist()

    model.use {
        ...
          val fashionPlots = List(12) { imageIndex ->
            flattenImagePlot(
                imageIndex, test,
                predict = it::predict,
                labelEncoding = stringLabels::get,
                plotFeature = PlotFeature.GRAY
            )
        }
        columnPlot(fashionPlots, 3, 512).show()

        filtersPlot(it.layers[2] as Conv2D).show()
        ...
    }
}

show를 하면 화면을 그려주는 하나의 앱을 열어서 결과를 보여준다.
창이 너무 큰 경우에는 Swing 이 깨지는 경우가 있는데, 창 크기를 줄였다가 키우면 제대로 보여준다.

실제&예측 레이블 / 뭔지 모르겠지만 필터...

5. 마무리

다른 딥러닝 언어나 라이브러리를 찾아봤는데, 거진 다 비슷하게 생긴 거 같아서
Kotlin이라는 것 말고는 크게 메리트를 못 찾았다.

안드로이드 용으로 이미 학습된 모델을 제공해주는 데, 그건 다음 글이다.

반응형