[python/PyTorch] 乱数を固定する

CNNの既存コードを見ていると、torch.manual_seed()なんていう一文があります。
おまじないみたいなものだろうと全然気にしないでいたのですが、調べてみたら深い意味をもつものでしたので、備忘のために書いておきます。

RNG (乱数ジェネレータ)

PyTorchに限らず、ランダムな処理を行う場合には、RNG (Random Number Generator) と呼ばれるジェネレータが乱数を生成し、その数字に基づきランダム処理を行っています。

例えば、データセットからバッチを生成するときに、データを取ってくるためのランダム性だったり、重みの初期化に使う乱数だったり。

この乱数は、seedと呼ばれる数字から生成しています。
(ちゃんと理解できていませんが、)seedが同じであれば、同じ乱数が生成されるそうなのです。

同じ乱数が生成されて何が嬉しいかというと、結果に再現性をもたせることができるのです。同じバッチを使って、同じ初期化重みを使えるので、学習は同じ経過をたどりますね。

最初、seedの意味が分からなかったので、何度学習しても同じ画像を同じタイミングで読み込んでいるのがずっと不思議でした。shuffleが機能していないのだと勘違いしたり……
ランダムに取り出すデータが固定されていたからなのですね。

RNGの動作

numpyの関数で、RNGの動作を確認してみます。

import numpy as np

for n in range(5):
    np.random.seed(0) # seedの設定
    x = np.random.randn(1) # 乱数の生成
    print(x)

[1.76405235]
[1.76405235]
[1.76405235]
[1.76405235]
[1.76405235]

出力結果を見ると、seedに同じ値が入っているため、生成される乱数が同一であることが分かります。

ちなみに、以下のようにseedの設定ループを前に出すと、繰り返すたびに乱数は変わります。毎回、同じseedの値を使うわけではなく、1度使ったseedの続きの値を使っているからだそうです。

import numpy as np

np.random.seed(0) # seedの設定
for n in range(5):
    x = np.random.randn(1) # 乱数の生成
    print(x)

[1.76405235]
[0.40015721]
[0.97873798]
[2.2408932]
[1.86755799]

RNGの固定

再現性のある学習を行うために、以下をコードのはじめに設定します。引数としてのseedは任意の数字を入れます。
様々な段階で乱数を使っていますので、完全再現のためには4種類、忘れずに設定しましょう。

np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です