RNN
目次
はじめに
RNN(Recurrent Neural Network)は、時間に依存したデータのパターンを学習してくれるネットワークです。
今回はRNNについて計算グラフを用いて行列で順伝播と逆伝播の式を導出してみたいと思います。
計算グラフについては以下過去記事を参照ください。
RNNで扱われるデータは、サンプル数とユニット数のインデックスの他に、時間インデックスが増えるのため、3階のテンソルになります。
例えば、ネットワークの
層目の中間層ユニットのデータ形式は
となります。(
がサンプル数、
が時間、
がユニット数を表すインデックスです)
もちろん、特徴量データや目標値データ(教師データ)も同じように3階のテンソル
、
になります。
行列で表現するときは、3階テンソル
のサンプル数とユニット数のインデックスを行列形式にまとめ、時間インデックスだけで
と表してやります。
RNNブロックの順伝播
- 以下の様にRNNブロック(RNNの計算処理をブロックにまとめて表記してます。)に時間依存データ
を入力し、
が出力される状態を考えていきます。
RNNでは出力を再帰的に入力に戻してやります。RNNブロックには、
と前の時間の出力
が入力される形になります。
紙面上方向が時間の正方向として、全てのデータの入出力をちゃんと書くと以下の様になります。
- 各時刻
におけるRNNブロックの中身、つまり
層から
層への順伝播式は以下のようになっています。
前の時間
からの出力
に重み
をかけて足しこむところ以外は通常の全結合と同じです。
重み
、
と閾値
は、時間に依存しない値であることに注意してください。
時間インデックスのMaxを
とすると
から
まで順番に、全ての時間
において
を求めていきます。(ただし
にしておきます。)
RNNブロックの中身を計算グラフで表現すると以下の様になります。
RNNブロックの逆伝播
- 逆伝播時は、順伝播の出力
が時間
のRNNブロックへ分岐しているので、時間
からの逆伝播も考慮に入れる必要があります。以下計算グラフの赤線で示しています。(以下の逆伝播の計算グラフは、以前の記事行列演算と計算グラフでまとめた計算グラフの逆伝播のルールに従い機械的に書いたものです。)
- 上記の計算グラフより、
から
までの逆伝播式は、
(3)式の
や(4)式の
の右上についている
は転置を表すので気をつけてください。(以降、行列の右上に付いている
は全て転置を表します。時間のMax値の記号
と被っているので不本意なのですが他に記号が思いつかず...)
逆伝播時は
から
まで時間をさかのぼって、全ての時間
において(3)、(4)式を計算していきます。(ただし、
にしておきます。)
RNNにおけるバックプロパゲーションは、レイヤー間だけではなく、時間軸に沿っても行う為、BPTT(Back Propagation Through Time)と呼ばれています。
また計算グラフより、重み・閾値の勾配は以下の様に求まりますが、
- 順伝播での分岐は逆伝播では和になるので、(5)、(6)、(7)式を時間で和をとったものが、真の重み・閾値の勾配になります。
- 計算グラフを使うとRNNの複雑な逆伝播もすんなり求めることができますね。
損失関数(2乗誤差)
RNNの例題として回帰を扱おうと思うので、回帰における一般的な損失関数である2乗誤差について、順伝播と逆伝播をまとめておこうと思います。
RNNで損失関数を計算するときは、時間のインデックスを考慮して時間に関して和を取る必要があります。(分類問題の場合も、各時間における交差エントロピーを算出後、時間について和を取ってやります。)
ネットワークの最終出力を
、 目標値を
とした時、2乗誤差
は以下の式で計算されます。
- 順伝播時の演算を計算グラフで表すと、
- 逆伝播時は、
分類でソフトマックス&交差エントロピーを使った時と同じく、誤差は
になります。うまくできてます。
以下、逆伝播時の計算グラフです。
特徴量データと目標値データの時間の個数が合わない場合
例えば10個の時間インデックスを持つ特徴量データから1個の時間インデックスを持つ目標値データを学習させたい場合などもあると思います。
ネットワークの入力である特徴量データと、目標値データの時間の個数を必ずしも合わせる必要はないです。
今ネットワークに
個の特徴量データ
(
)を入力して、
個の出力
が得られたとします。
それに対し学習させたい目標値データ
が、
個(
)しかない場合を考えます。
この場合の損失関数は、目標値データの個数分だけ時間の和を取る形になります。
- 逆伝播時の
の計算では、
の出力
が損失値
に寄与していないため、その微分
は全部0になります。