DL4Jをバリバリ使ってる人は日本にあまり居ない?ような気がしてきたので、1か月弱くらい色々使ってみた感触などを述べてみる。
DeepLearning4Jってなんだ
何が出来るんだ
最近の機械学習系の大体の事は出来るような気がする。RNN(LSTM. GRUは実装中)、Word2Vec、FeedFoward、CNN、RBM(Deep Belief Nets)、AutoEncorderなど。
他のライブラリと何が違うんだ
型付き言語(JVM言語)で組める。他の機械学習ライブラリはC++とか特にPythonが多い。C++は単純に書くのがしんどい。Pythonも悪くは無いのだけど、私にとってはもはやScalaの方が使いやすい(使い慣れてる)のでできればScalaで組みたい。あと、コンパイルされる言語なので数値計算以外の部分はJVMの方がPythonよりは早いはず。
あとはAkkaやSparkで分散処理が可能とのこと。Sparkは実際に動かしてみたが、まだまだバグが多いしバッドノウハウを駆使しないと動かなかったりと未熟な感じ。今後に期待。
ちなみに開発陣は次世代のデータ処理に向いた言語としてScalaを意識しているとのこと。Scalaのサポートにも期待したいが、後述する数値計算ライブラリのND4JのScala向けラッパND4Sはリリースが追い付いていないし、開発リソースが足りてないみたい。
ビジュアライズはどうなってんだ
Javaの弱みがここだと思う。matplotlibのような手軽で強力なビジュアライズツールが無い。一応、deepLearning4j-uiというビジュアライズツールがある。dropwizardで作られたWebアプリが立ち上がってブラウザから現在のスコア、各層のウェイトのヒストグラム、パラメータの更新の勾配の平均などを見ることが出来る。必要最小限のものは出来ますよ、って感じ。以下が実際の画面。
CNNなどでよくある、ウェイトを可視化するようなものは自前で実装するしか無い。Javaではmatplotlibのようなライブラリが無いのが痛い。
数値計算はどうなってんだ
ND4Jが担当。CPUネイティブやGPU(CUDA)で動作。NumpyのndarrayみたいなAPI。
ちゃんと動くのか
まだまだバグだらけ。何かやろうとするとバグに出会う。ローカルで最新ソースをビルドできるようにしておいたほうが良いかも。開発を主導しているのはSkymindという会社で、本記事作成時点のGitHubスター数は3305。
でも上手く(?)動かせば普通に使い物になる。バージョン0.4.0の正式版がリリースされてからは、既存のバグはかなり修正された。
ただ、まだ未実装のままになっている機能も多く、まだまだ発展途上といった印象。
コードを見せろ
build.sbtはこんな感じ。これが最小構成に近い。
scalaVersion := "2.11.8" val nd4jVersion = "0.4.0" val dl4jVersion = "0.4.0" classpathTypes += "maven-plugin" libraryDependencies ++= Seq( "org.deeplearning4j" % "deeplearning4j-core" % dl4jVersion, "org.nd4j" %% "nd4s" % "0.4-rc3.8", "org.nd4j" % "nd4j-native" % nd4jVersion classifier "" classifier "linux-x86_64" )
nd4j-nativeのプラットフォーム部分はOSによって適宜書き換える必要あり。
- windows-x86_64
- linux-x86_64
- linux-ppc64
- linux-ppc64le
- macosx-x86_64
があるみたい。
コードは以下。AND演算の学習。
import org.deeplearning4j.eval.Evaluation import org.deeplearning4j.nn.api.OptimizationAlgorithm import org.deeplearning4j.nn.conf.NeuralNetConfiguration.Builder import org.deeplearning4j.nn.conf.distribution.UniformDistribution import org.deeplearning4j.nn.conf.layers.OutputLayer import org.deeplearning4j.nn.multilayer.MultiLayerNetwork import org.deeplearning4j.nn.weights.WeightInit import org.deeplearning4j.optimize.listeners.ScoreIterationListener import org.nd4j.linalg.dataset.DataSet import org.nd4j.linalg.lossfunctions.LossFunctions import org.nd4s.Implicits._ object ANDSample extends App { // Input val input = Array( Array(0, 0), Array(0, 1), Array(1, 0), Array(1, 1) ).toNDArray // Output - T, F: 2 classification problem val labels = Array( Array(0, 1), Array(0, 1), Array(0, 1), Array(1, 0) ).toNDArray val ds = new DataSet(input, labels) val conf = new Builder() .seed(123) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .iterations(50000) .learningRate(0.1) .useDropConnect(false) .list() .layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nIn(2) .nOut(2) .activation("sigmoid") .weightInit(WeightInit.DISTRIBUTION) .dist(new UniformDistribution(0, 1)) .build()) .pretrain(false) .backprop(true) .build() val net = new MultiLayerNetwork(conf) net.init() net.setListeners(new ScoreIterationListener(500)) net.fit(ds) val output = net.output(ds.getFeatureMatrix) System.out.println(output) val eval = new Evaluation(2) eval.eval(ds.getLabels, output) System.out.println(eval.stats()) }
KerasやChainerのように層を強く意識したAPIになっています。Builderパターンをふんだんに使っていて、JavaでありながらPythonの手軽さに近い感じの使い勝手を実現できている…気がする。もっと複雑なネットワークはComputation Graphとして実装できるみたい。