「PyTorchのネステロフモーメンタムの実装は間違っていますか?」
PyTorchのネステロフモーメンタムの実装は間違っていますか?
はじめに
PyTorchのSGDのドキュメンテーションを注意深く見ると、Nesterov運動量の実装には元の論文といくつかの違いがあることがわかります。特に、PyTorchの実装では現在のパラメータで勾配を評価しますが、Nesterov運動量の本来の目的はシフトされたパラメータで勾配を評価することです。残念ながら、これらの違いについての議論はインターネット上ではほとんど見当たりません。この記事では、PyTorchの実装と元のNesterov運動量の公式の違いを調べ、説明します。最終的に、PyTorchの実装は間違っているのではなく、近似値であり、その実装の利点について推測します。
公式
元の論文では、Nesterov運動量は以下の更新ルールを使用して説明されています:
ここで、v_{t+1}とθ_{t+1}は時間tでの速度ベクトルとモデルパラメータです。μは運動量係数、εは学習率です。PyTorchのSGDドキュメンテーションのノートには、以下の更新ルールが使用されていると記載されています:
ここで、g_{t+1}はv_{t+1}を計算するために使用される勾配を表します。θ_{t+1}の更新ルールを展開すると以下のようになります:
- 「ニューラルネットワークの多様性の力を解き放つ:適応ニューロンが画像分類と非線形回帰で均一性を上回る方法」
- 「本番環境での機械学習モデルのモニタリング:なぜ必要であり、どのように行うか?」
- 「機械学習の公衆の認識に関する問題」
これから私たちは以下のことが推察できます:
そして、更新ルールは以下のようになります:
これらはPyTorchが理論上使用する更新ルールです。先ほど述べたように、PyTorchは実際にはシフトされたパラメータではなく現在のパラメータで勾配を評価します。これはPyTorchのSGDドキュメンテーションのアルゴリズムの説明を見ることで確認できます。後ほどこれについて詳しく調査します。
元の(1, 2)とPyTorchの(3, 4)の公式の場合、v_0 = 0であれば、θへの最初の更新は次のようになります:
PyTorchのSGDドキュメンテーションのノートには、アルゴリズムが最初のステップで勾配をモーメンタムバッファに初期化すると記述されていますが、後ほどv_0 = 0を意味することを示します。
初期の違い
元の(1, 2)からPyTorchの(3, 4)の公式に移る際には、2つの即座の違いがあります:
- 学習率がv_{t+1}の外に移動されます。
- v_{t+1}の更新ルールでは、勾配を引く代わりに足し、θ_{t+1}の更新ルールでは速度ベクトルを足す代わりに引きます。勾配項内の符号の違いは、前の節で示した通り、この結果として生じるものです。
これらの違いを理解するために、まず更新ルールを展開しましょう。ここで示唆されているように、学習率スケジュールを考慮した場合、最初の違いの効果はより明確になります。したがって、εが固定されていないが時間によって変化すると考えられる更新ルールの一般化を考えます。tステップでの学習率をε_tとしましょう。簡単のために、以下のようにします:
v_0 = 0と仮定すると、元の公式は以下のようになります:
そして、PyTorchの公式は以下のようになります:
元の公式(6)では、学習率が時間tで変化する場合、和のi = tの項の大きさのみが影響を受け、他の項の大きさは変わらないことになります。その結果、学習率の変化の即時の影響は非常に限定的であり、学習率の変化が後続の時間ステップに「徐々に」影響を及ぼすのを待たなければ、全体的なステップサイズに強い影響を与えるには時間がかかります。対照的に、PyTorchの公式(7)では、学習率が時間tで変化する場合、全体のステップの大きさが即座に影響を受けます。
v_0 = 0の場合、展開されたルールからは、2つ目の違いは最終的には影響を与えないことが明らかです。どちらの公式でも、ステップは現在のパラメータから引かれる勾配の割引された合計となります。
主な違い
重み減衰と減衰を無視して、PyTorchのドキュメントでSGDアルゴリズムを分析することにより、実装された更新ルールは次のようになります:
ここで、θ’_{t+1}は時刻tのモデルパラメータを表します。
式3と式4をPyTorchの「ノート」式と呼び、式8と式9をPyTorchの「実装」式と呼びます。θとθ’の区別をする理由は、すぐに明らかになります。ノート式との最も顕著な違いは、勾配がシフトされたパラメータではなく、現在のパラメータで評価されることです。これだけから、アルゴリズムがNesterovの運動量の適切な実装ではないように見えるかもしれません。
次に、PyTorchアルゴリズムが最終的にNesterovの運動量を近似する方法を調べてみましょう。以前のバージョンのPyTorchの導出は、このGitHubの問題で参照されたIvo Danihelkaによってここで見つけることができます。現在のバージョンのPyTorchの導出は、以前の導出から比較的簡単な調整を行ったもので、ここで見つけることができます。これらの導出のLaTeXのレンダリングをここで提供します。実装された式は、変数の単純な変更によって導出されます。具体的には、次のようにします:
変数の変更後、ノートの更新ルールv_{t+1}(3)は、変数の変更後の実装された更新ルールv_{t+1}(8)と等価になることがすぐに明らかになります。次に、θ’_{t+1}の更新ルールをθ’_tの関数として導出したいと思います:
これは、PyTorchで実装された更新ルール(9)と同じです。高レベルでは、PyTorchの実装では、現在のパラメータθ’_tがすでに「実際の」パラメータθ_tのシフトバージョンであると仮定しています。したがって、各時刻で、「実際の」パラメータθ_tは現在のパラメータθ’_tと次のように関連しています:
ただし、ソースコードからは、PyTorchのSGD実装ではアルゴリズムの最後にいかなる修正も行われていないように見えるため、最終出力は技術的には「実際の」パラメータの近似です。
最後に、v_0は0でなければならないことを示します:
さらに、v_0 = 0の場合に元の形式で行われる最初の更新と同じ最初の更新が「実際の」パラメータに対して行われることを確認できます:
これは、式5と等価であることがわかります。
実装された式の利点
もちろん、最も重要な質問は次のとおりです:PyTorchはなぜ元のNesterovの運動量の更新ルール(1、2)からドキュメントのノート(3、4)への式の再定式化に取り組むのでしょうか?ひとつの可能性は、再定式化が必要な算術演算の数を削減するかもしれないということです。この可能性を評価するために、算術演算の数を数えてみましょう。ノートの式(3、4)では、次のようになります:
ここでは、合計で7つの演算があります。実装された式(8、9)では、次のようになります:
ここでは、合計で6つの演算があります。PyTorchの実装では、2番目の勾配は最初の勾配計算から保存された結果を使用するだけであり、各時刻で1つの勾配計算が実行されます。したがって、明らかな利点は、PyTorchの実装が各ステップで追加の乗算演算を削減することです。
結論
まとめると:
- PyTorchのSGDドキュメントのノート(3、4)に記載された更新ルールは、元のNesterovの運動量の更新ルール(1、2)と比較して学習率の場所が異なるため、学習率のスケジュールが全体のステップサイズに直ちに影響を与えることができます。一方、元の形式では学習率の変更の影響が後続の時間ステップに徐々に「伝播」する効果があります。
- PyTorchのSGDアルゴリズムで実装された更新ルール(8、9)は、ドキュメントのノート(3、4)に記載された更新ルールの単純な変数の変更後の近似です。各時刻で、「実際の」パラメータは現在のパラメータから簡単に復元できますが、PyTorchの実装ではアルゴリズムの最後にそのような修正は行われず、したがって最終パラメータは技術的には「実際の」最終パラメータの近似のままです。
- PyTorchの実装の明らかな利点は、各時刻で追加の乗算演算を削減することです。
参考文献
- 「SGD.」SGD — PyTorch 2.0 ドキュメンテーション, pytorch.org/docs/stable/generated/torch.optim.SGD.html. 2023年9月2日にアクセス。
- Sutskever, Ilya、他。「深層学習における初期化とモメンタムの重要性について」。国際機械学習会議。PMLR、2013年。
- Danihelka, Ivo。「Nesterovのモメンタムを簡単にする」。2012年8月25日。
- Chintala, Soumith。「nesterov momentum is wrong in sgd · Issue #27 · torch/optim」。GitHub、2014年10月13日、github.com/torch/optim/issues/27。
- Gross, Sam。「optimのモメンタムの定式化に関するドキュメントに注意を追加する · Issue #1099 · pytorch/pytorch」。GitHub、2017年3月25日、github.com/pytorch/pytorch/issues/1099#issuecomment-289190614。
- Zhao, Yilong。「Nesterov Momentum Bugを修正する · Issue #5920 · pytorch/pytorch」。GitHub、2018年3月21日、https://github.com/pytorch/pytorch/pull/5920#issuecomment-375181908。
We will continue to update VoAGI; if you have any questions or suggestions, please contact us!
Was this article helpful?
93 out of 132 found this helpful
Related articles