バッチ正規化による学習の安定化

ぐちゃぐちゃなLossの下がり方を直す!!

2020年11月02日

損失が全然下がらない

だいぶ今のモデルにもプログルムにも慣れてきて, 自分でいろいろいじっているうちに, 学習が全然うまくいかない... 具体的には, Lossがもうぐわんぐわんして,「なんやねん!! 全然下がってないやん!!」状態になっていた... 先生にも「う〜〜ん... 全然うまく学習できてないね!! やり直し!!」っと, 中間報告会が終わって一息ついたことろで, その一言... 今思うと, こんなにもうまく学習できていない結果を良くもまあ, 中間報告会で紹介したな...っていう感じ. 中間報告会で出した損失... ぐにゃんぐにゃん...

バッチ正規化

なんかいろいろとモデルをいじりすぎたので, めちゃシンプルな形にして再学習!! でも下がらない... そんなときにロボコンの友達が, 「バッチ正規化はどれくらいでやってる??」の一言!! それだ!それ気にしてなかったわ!! 実際に, これまでは入力時刻数を増やすことばかり考えていて, バッチサイズ1とか2とかにしてた... よし! これを改善しよう!!

>

さっそくバッチサイズを10にした途端に, Memory Exhausted Errorが飛び出た!! しか〜〜し! ぼくらには名大さんのスパコンがついている!! ということで, すぐさまスパコンにプログラムを投げて実行!! Lossの下がり方は以下のよう!! バッチ正規化したらこうなったよ!! Pytorchだと, optimizerの重みを自動的に調整して減衰してくれるWeight Decay機能や, 1行でバッチ正則化をしてくれるnn.BatchNorm1dなどがある!! うれしい!! とてもうれしい!! そして, ロボコンの友達にありがとう!!
cf) Pytorchの様々な最適化手法
cf) Pytorchで関数フィッティング その2:Batch正規化の導入
ちなみに, 実際のGPUメモリ使用の状況は, こんな感じ... やばっ!!(笑) 合わせて36GB以上?まじ??