2014-10-21
CaffeでDeep Q-Networkを実装して深層強化学習してみた
概要
深層学習フレームワークCaffeを使って,Deep Q-Networkという深層強化学習アルゴリズムをC++で実装して,Atari 2600のゲームをプレイさせてみました.
Deep Q-Network
Deep Q-Network(以下DQN)は,2013年のNIPSのDeep Learning Workshopの”Playing Atari with Deep Reinforcement Learning”という論文で提案されたアルゴリズムで,行動価値関数Q(s,a)を深層ニューラルネットワークにより近似するという,近年の深層学習の研究成果を強化学習に活かしたものです.Atari 2600のゲームに適用され,既存手法を圧倒するとともに一部のゲームでは人間のエキスパートを上回るスコアを達成しています.論文の著者らは今年Googleに買収されたDeepMindの研究者です.
NIPS2013読み会で自分が紹介した際のスライドがこちらになります.
他の方が作成したスライドもあります.
必要なもの
- Caffe
- まだ本家にマージされていないpull requestを修正した上で使用しています.とりあえず動かしてみたい方は場合は自分のforkレポジトリのdqnブランチを使えば動くと思います.
- https://github.com/BVLC/caffe/pull/1228 (ソルバーのステップ実行に必要)
- https://github.com/BVLC/caffe/pull/1122 (AdaDelta)
- まだ本家にマージされていないpull requestを修正した上で使用しています.とりあえず動かしてみたい方は場合は自分のforkレポジトリのdqnブランチを使えば動くと思います.
- Arcade Learning Environment
- http://www.arcadelearningenvironment.org/ からダウンロードしてビルドします.ゲームスクリーンを表示するためにMakefileのUSE_SDLを1にセットします.libsdl,libsdl-gfx,libsdl-imageが必要になります.
ソースコード
GitHubで公開しています.DQN-in-the-Caffe
ネットワークの構成
- 入力層:84x84x4(ラスト4フレームのダウンサンプリング&グレイスケール化)
- 隠れ層1:8x8のフィルタx8(ストライド4)による畳込み+ReLU
- 隠れ層2:4x4のフィルタx16(ストライド2)による畳込み+ReLU
- 隠れ層3:fully-connectedなノードx256+ReLU
- 出力層:fully-connectedなノードx18(18種類のアクションそれぞれの行動価値)
としました.このネットワークを逆伝播により学習するためには,複数ある出力のうち1つの出力のみに対して誤差を計算する必要があるのですが,それを可能にするためにCaffeのELTWISEレイヤーを使い,1つの要素のみ1で残りは0であるようなベクトルをネットワークの出力に掛け合わせることで望みの出力だけを取り出せるようにしています.Caffeのネットワーク表記でネットワーク全体を書くと下のようになりました.
パラメータの学習
パラメータの学習のためには,「状態で行動
を選択したところ,報酬
を獲得し,次の状態が
であった」という状態遷移
の経験をreplay memoryというメモリに保管していき,パラメータ更新の際にはそこからランダムサンプリングした一定数の遷移それぞれについて
となるように勾配を計算した上で,まとめて更新を行うミニバッチ学習を行います.
元論文ではここでRMSPropというパラメータ更新量の自動調節アルゴリズムを用いていますが,Caffeには今のところRMSPropは実装されておらず,その代わりAdaDeltaというRMSPropによく似たアルゴリズムをすでに実装してpull requestを投げている人がいたので,それを使いました.ただし,AdaDeltaをそのまま使用するとパラメータが発散してしまうことが多かったため,AdaDeltaによる更新量にさらに一定の係数(最初の100万イテレーションでは0.2,次の100万イテレーションでは0.02)を掛けて用いました(同じようなことをやっている?人).ミニバッチの大きさは元論文と同じ32,割引率は元論文では示されていませんが0.95としました.
元論文ではreplay memoryの容量は100万フレームでしたが,メモリの都合上,半分の50万で実験しました.
学習時間
実行環境は
です.CaffeはGPUモードで,さらにcuDNNを使いました.この構成でミニバッチ5万個の学習に45分ほどかかりましたが,元論文では5万個分を30分ほどで学習しているので,1.5倍ほど遅い結果となりました.
結果
上の動画はPongというゲームを200万イテレーション(およそ30時間)学習させた後のプレイ動画です.右の緑がDQNで,元論文と同じく各フレームごとに5%の確率で完全にランダムにアクションを選ぶようにしています.3回の試行のスコアがそれぞれ16,13,19と,元論文の20という平均スコアには達していませんが,元論文では1000万イテレーションの学習を行っているので,より学習が進めば同等のスコアが出せるかもしれません.
元論文ではHuman Expertのスコアは-3とされていますが,現時点でもそれよりは大幅に上回っているので,DQNは人間より強いという結果が再現出来て何よりです.
- 22 http://t.co/XeWcEFLiQq
- 3 http://b.hatena.ne.jp/
- 2 http://www.google.com/url?source=web&url=http://d.hatena.ne.jp/muupan/20141021/1413850461
- 1 http://api.twitter.com/1/statuses/show/524377978680000513.json
- 1 http://pipes.yahoo.com/pipes/pipe.info?_id=e4c70514b5136c08ae93591f390be2e2
- 1 http://www.google.co.jp/url?sa=t&rct=j&q=&esrc=s&source=web&cd=4&ved=0CCMQFjAB&url=http://d.hatena.ne.jp/muupan/20141021/1413850461&ei=mKpFVL2oA7uRcMPCAQ&usg=AFQjCNHUgxXSYaw5AO7gVjZ3viikwNMVoA