"機械学習","信号解析","ディープラーニング"の勉強

読者です 読者をやめる 読者になる 読者になる

HELLO CYBERNETICS

深層学習、機械学習、強化学習、信号処理、制御工学などをテーマに扱っていきます

今更聞けないLSTMの基本

 

 

f:id:s0sem0y:20170506164912p:plain

ディープラーニングで畳込みニューラルネットに並ぶ重要な要素のであるLong Short-Term Memoryについて、その基本を解説します。

LSTMとは

LSTMとはLong Short-Term Memoryの略です。

short-term memoryとは短期記憶のことであり、短期記憶を長期に渡って活用することを可能にしたのが、LSTMの重大な成果です。

 

 

 

リカレントニューラルネットワーク

従来はニューラルネットに入力データの記憶を保持させるために、中間層にループをもたせていました。これによりニューラルネットは1個前のデータを判断材料に使うことができました。

 

f:id:s0sem0y:20170114233157p:plain

 

中間層が1つ前のデータの依存性を扱うことができ、当然1つ前のデータを扱うときにはそのまた1つ前のデータとの依存性を扱うことができているはずであり、理屈の上では、データの系列の長期的な依存性にも対応できるはずでした。

 

このようなタイプのニューラルネットワークをリカレントニューラルネットワーク(RNN)と呼びます。リカレントニューラルネットはディープラーニングが流行する以前より考案されていましたが、学習が上手くできないなどの問題により、陽の目を浴びることはありませんでした。

 

LSTMの役割

以下のようにLSTMが1つの中間層に相当すると思って構いません。層の中で複雑な処理を行い、普通に中間層のような役割を担ってくれます。

 

 

f:id:s0sem0y:20170506170354p:plain

 

LSTMはRNNを実現するために考案され、前の情報を上手く扱うことに特化した層を提供してくれると考えればいいでしょう。LSTMもいろいろな改良がなされて、中身は変わっていっていますが、LSTMの目指す姿とはいつでも、系列データを上手く扱うことです。

 

LSTMの計算

LSTMの中身を1つ1つ見ていき、どのような計算を担っていくるのかを見てみましょう。以下ボールド体を用いなくとも、小文字は基本的にベクトルであり、大文字は行列を表します。

基本的なLSTMブロックの中身を見ていくと、以下のように図式化できます。

 

f:id:s0sem0y:20170506172239p:plain

 

LSTMはx_tを受け取ると、前回の自身の出力h_{t-1}x_tを使い内部で計算を行い、今回の出力h_tを決めます。またh_tは次回の出力h_{t+1}を使うためにも使います。

 

次はLSTMが持つ計算の役割をそれぞれ見ていきましょう。

Output Gate

以下の赤い矢印が使われている部分が、Output Gataに関わってくる部分です。x_th_{t-1}が割と出力h_tに直結しています。この部分が、過去の出力を自身の入力にフィードバックしていた原始的なリカレントネットワークの発想の部分になります。

 

f:id:s0sem0y:20170506172605p:plain

 

Output GateにはOutput Gateのための線形変換W_oが準備されており、x_tに対して以下のような計算を行います。至って普通のニューラルネットと変わりません。

 

W_ox_t

 

同様に、h_{t-1}に対しては線形変換R_oが準備されており、

 

R_oh_{t-1}

 

という計算をします。これらの計算結果と、バイアスb_oを使って、Output Gateは

 

\displaystyle o_t = σ \left( W_ox_t + R_oh_{t-1} + b_o \right)

 

という計算をします。何も難しくありませんね。最終的にOutput Gateの赤い矢印の最後は上記の値となっています。他のゲートの計算と合わせて最終的な出力を決めますが、このゲートの役割はこれだけです。

 

Input GateとForget Gate

こちらのほうがLSTMでのポイントとなってきます。

これらは両方合わせて初めて重要な役割を担うので、まとめて話していきます。

 

Forget Gate

下記の赤色の部分を見ましょう。

 

f:id:s0sem0y:20170506173644p:plain

 

こちらも特に難しい部分は実はありません。このゲートにもW_f,R_f,b_fというパラメータがあり、下記のような計算を行います。

 

\displaystyle f_t = σ \left( W_fx_t + R_fh_{t-1} + b_f \right)

 

基本的には、全部似たような計算をしていることに注意してください。それぞれ別々のパラメータを持っているというだけの話です。

 

Input Gate

Input Gateに関しても同様で、以下の赤い部分は

 

f:id:s0sem0y:20170506174103p:plain

 


\displaystyle i_t = σ \left( W_ix_t + R_ih_{t-1} + b_i \right)

 

となっており、以下の赤い部分は

 

f:id:s0sem0y:20170506174212p:plain

 

\displaystyle z_t = tanh \left( W_zx_t + R_zh_{t-1} + b_z \right)

 

となります。活性化関数がtanh()であること以外に違いはありません。この次の段階が一気に複雑になります。

 

LSTMの肝であるMemory Cell周辺

緑色のメモリーセルを含めると図は以下のようになっています。メモリーセルから出てくる点線はc_{t-1}という1個前の何らかの値を出しているとしましょう。

 

f:id:s0sem0y:20170506174601p:plain

 

Forget Gate側の出来事

f:id:s0sem0y:20170506174836p:plain

Forget Gateでの以下の値

 

\displaystyle f_t = σ \left( W_fx_t + R_fh_{t-1} + b_f \right)

 

と、Cellから出てきているc_{t-1}が合わさって、

 

c_{t-1} \otimes f_t

 

という計算がされます。\otimesは要素ごとの積です。

 

Input Gate側での出来事

f:id:s0sem0y:20170506175353p:plain

 

以下で計算される2つの値

 

\displaystyle i_t = σ \left( W_ix_t + R_ih_{t-1} + b_i \right)

 

\displaystyle z_t = tanh \left( W_zx_t + R_zh_{t-1} + b_z \right)

 

が合わさって

 

i_t \otimes z_t 

 

という計算がされています。ここまで来るとかなり単純です。

 

Cellの手前での出来事

f:id:s0sem0y:20170506175522p:plain

 

上記で見てきた2つの計算結果

 

i_t \otimes z_t

 

c_{t-1} \otimes f_t

 

これらを合わせて

 

c_t = i_t \otimes z_t+c_{t-1} \otimes f_t

 

という値が計算されていることになります。セルは、c_tという値を次の計算に渡したり、c_{t-1}という値を保持するという役割を持っているだけで、特別な計算はしません。従ってこれまでの計算は以下の計算に集約されます。

 

\displaystyle i_t = σ \left( W_ix_t + R_ih_{t-1} + b_i \right)

 

\displaystyle z_t = tanh \left( W_zx_t + R_zh_{t-1} + b_z \right)

 

\displaystyle f_t = σ \left( W_fx_t + R_fh_{t-1} + b_f \right)

 

c_t = i_t \otimes z_t+c_{t-1} \otimes f_t

 

大事なのはこの部分なので、後々振り返っていきます。

 

出力付近の話

以下の青で囲まれた部分について説明します。

ここは特に難しいことはありません。

 

f:id:s0sem0y:20170506180226p:plain

 

メモリーセルで得られた以下の値

 

c_t = i_t \otimes z_t+c_{t-1} \otimes f_t

 

とOutput Gateで得られた以下の値

 

\displaystyle o_t = σ \left( W_ox_t + R_oh_{t-1} + b_o \right)

 

を使って、

 

h_t = o_t \otimes tanh(c_t)

 

という計算を行うだけです。この値は次のデータが入力された時の計算にも用いられます。

 

 

LSTMの役割

これまで計算がどのように行われるのかを見てきましたが、このような計算をさせることによって何が期待できるのか、LSTMの役割を見ていきましょう。

 

セル付近の役割

LSTMではセル付近での以下の計算が重要な役割を担っています。

 

\displaystyle i_t = σ \left( W_ix_t + R_ih_{t-1} + b_i \right)

 

\displaystyle z_t = tanh \left( W_zx_t + R_zh_{t-1} + b_z \right)

 

\displaystyle f_t = σ \left( W_fx_t + R_fh_{t-1} + b_f \right)

 

c_t = i_t \otimes z_t+c_{t-1} \otimes f_t

 

Forget Gateが過去の情報をどれだけ保持するか決める

以下の計算を見てみましょう。

 

c_t = i_t \otimes z_t+c_{t-1} \otimes f_t

 

第二項を見てください。c_{t-1}というのは1個前のデータに対して得られている計算結果です。これに対して、仮にf_tの要素がほとんど0になっていたらどうでしょうか。1個前の計算によって得られた結果はほとんど失われることになります。

 

ちなみに、第一項については、現在得られた入力値i_tをどれだけ反映するかをz_tで調整していると言えます。

 

このようにして、Forget Gateで過去の情報の保持具合を調整し(ある程度は忘れてしまい)、Input Gateで現在の情報の反映具合を調整しておいて、メモリに渡します。

 

セルはこの値をそのまま通過させ、Output Gateとの計算に使うと同時に、次の入力に備えて値を保持します。

 

全体を通しての役割

出力h_tを保存するだけでなく、内部でもセルによってc_tを保存しているということは理解できたでしょう。しかし、系列データを扱う上で、このような処理が本当に必要なのでしょうか。h_tを上手く扱えば、それなりに動作しそうなものではあります。

 

LSTMは学習を長らく困難にしてきた勾配消失問題の緩和に成功したという実績があります。

 

しかし、系列データを扱うことのできるGRU(Gated Recurrent Unit)にはメモリセルに相当するものはありません。GRUよりLSTMの方が歴史が長く、GRUはまさに簡易版でLSTMと同じことができないのかを検討した結果生まれてきたものだと考えればいいですが、まだまだ発展の余地はありそうです。

 

最後に

ここまでを把握できるとLSTMの数式のかなりの部分を理解できるかと思います。実際LSTMをChainerで提供されているLSTMを使わなくとも、自分で実装できたりします。

 

ニューラルネットはそれぞれの基本的な手法を理解できるのが前提で、それらの出力や中間層にどのような意味を込めて学習させるかが肝になってきます。

 

畳み込みやLSTMは早めに理解して、応用的な論文を読めるようにしていきましょう。

 

 

 

s0sem0y.hatenablog.com

s0sem0y.hatenablog.com

s0sem0y.hatenablog.com

s0sem0y.hatenablog.com

 

 

 

s0sem0y.hatenablog.com

 

 

s0sem0y.hatenablog.com