計算グラフを用いて誤差逆伝播法を理解してみる

誤差逆伝播法とは


数値微分による勾配の導出は、損失関数を算出するまでの全ての過程をひとつの大きな関数とみなすことで、その入力と出力のみから関数のふるまいを分析するという方法だった。この方法はシンプルで簡単なのが良いが、ループするごとにニューラルネットの全ての計算を行うので、とにかく時間がかかる。ここで、もう少し頑張ってニューラルネットの中身を解析的に解き、効率良く勾配を求めようというのが「誤差逆伝播法(Backpropagation)」である。

 

画像①数値微分は大雑把

 

計算グラフと連鎖律


誤差逆伝播法のデメリットは複雑なことである。「計算グラフ(computational graph)」は、この複雑な手法を理解する助けとなるもの。簡単に言えば計算グラフは「演算をノードにして計算の流れを描くことで、数値の”変化”のみを孤立させてから、これを逆流することで変化率=微分係数を求める」というもの。

 

画像②計算グラフ概略

計算グラフの強みは、①局所的な計算に分割することで問題を単純化できる、②途中の計算結果を保持できる、③逆方向の伝播によって「微分」を効率良く計算できる、というところだろう。③について、なぜ逆流で微分が求められるかという疑問に関しては、微分の性質と「連鎖律(chain rule)」によって説明できる。

 

画像③局所的な関数の微分

まず、微分とは変化率のことだ。つまり、変化=関数の結果がわかっていれば求めることができる。(もちろん可微分な関数であることは条件だが。)誤差逆伝播法では変化が起こる局所局所を分断しているので、その局所ごとで微分が求められるのは当然である。これが逆流の理由である。

 

画像④数式

連鎖律は合成関数の微分の法則のことだ。ニューラルネットを関数の連結としてみると、全体は1つの合成関数になる。つまり、これを逆伝播させて辿っていき特定の入力に到達すれば、それはその入力に関する合成関数の微分値を求めていることと同義であり、その入力の持つ全体への影響を見ることができるわけだ。

 

画像⑤りんご

例えば上の計算グラフでリンゴの個数に辿れば「リンゴが一個増えたら値段がいくら増えるのか=110円/個」という微分値が求まる。辿るルールは、出力を入力で微分するだけだ。ようは「その入力が増幅なり縮小なり、どれだけの変化を受けて出力に影響を及ぼしているか」ということだから、加算ノードではそのまま影響をするから1倍(そのまま)で、乗算ノードは相手の数字が掛かる。

 

画像⑥加算と乗算の逆伝播

 

可算と乗算の逆伝播を実装


誤差逆伝播法の概略を理解をしたところで本題に戻る。求めたいのは、重みパラメータ(入力)を動かした際の損失関数(出力)の変化率である。これを求めるためには、実際に使われる関数を逆伝播させるシステムを作らなくてはいけない。しかし、ここで活性化関数に含まれる加算・乗算を全て分解して1つ1つ実装に持っていく必要はない。活性化関数はそれぞれレイヤー化されているから、逆伝播の仕組みも方程式で簡単な形にまとめ、実装もクラスとしてまとめ、使うときはただ組み合わせれば良い。まさにオブジェクト志向的な解き方だ。ここでは最後に加算と乗算の逆伝播をクラス化しておく。

Leave a Comment