学習を行うときの学習率の設定はなかなか悩ましいものがあります。
収束は早くしてほしいけど、きちんとロスの底にたどり着いてほしい。
そういう場合に、Schedulerの機能を使います。Schedulerは学習率を習の途中で変化させることができます。
下の図は、学習率が学習に与える影響を表しています。グラフの底になるように学習をしたいのですが、学習率が高すぎる(左)と底を飛び越えてしまう可能性があり、低すぎる(中央)と収束に時間がかかったり局所解に陥る可能性もあります。
なので、右のように学習の進行に従って学習率を下げていくのが効率的な学習と考えることができます。
PyTorchにもそんなschedulerがいくつか用意されています。
全部見ようと思ったのですが、理解するのが大変そうなので、考え方が分かりやすかったものを2つだけピックアップすることにします。
torch.optim.lr_scheduler.StepLR
一番シンプルなアップデート方法でしょうか。
【書式】StepLR(optimizer, step_size, gamma)
optimizer:ラップ対象のオプティマイザ。SGDとかAdamやらを定義したものが個々に入ります。
step_size:更新タイミングのエポック数
例えば、step_size=30を設定した場合、30, 60, 90, …のタイミングで学習率にgammaが乗算されます。
gamma:更新率。デフォルトは0.1でこの場合、更新タイミングで学習率が1/10になっていきます。
更新タイミングを2エポックごとに設定して、学習率の推移を見ていきます。
学習率の更新はscheduler.step()で、学習率の取得はscheduler.get_lr()で行えます。
※コードは一部を抜き出しています。
num_epoch = 10 scheduler = optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) for m in range(num_epoch): for i_batch, sample_batched in enumerate(train_data): x = model(sample_batched[0]) y = sample_batched[1] loss = loss_func(x, y) loss.backward() opt.step() opt.zero_grad() print('epoch:{}, lr:{}'.format(epoch, scheduler.get_lr()[0])) scheduler.step()
結果は以下となり、2エポックごとに学習率が1/10になっていることがわかります。
epoch:0, lr:0.001
epoch:1, lr:0.001
epoch:2, lr:0.0001
epoch:3, lr:0.0001
epoch:4, lr:1.0000000000000003e-05
epoch:5, lr:1.0000000000000003e-05
epoch:6, lr:1.0000000000000002e-06
epoch:7, lr:1.0000000000000002e-06
epoch:8, lr:1.0000000000000002e-07
epoch:9, lr:1.0000000000000002e-07
torch.optim.lr_scheduler.MultiStepLR
Multiという単語がついているように、 こちらの関数はより柔軟な設定ができます。
【書式】MultiStepLR(optimizer, milestones, gamma)
違いは、milestonesです。
milestones:学習率の更新エポックのリスト
milestonesのエポックごとに、学習率はgammaを乗算した値になります。
schedulerの設定を、2エポック、6エポックで学習率の更新をするように変更して結果を見てみます。変えるのはこの行だけです。
optim.lr_scheduler.MultiStepLR(opt, milestones=[2, 6], gamma=0.1)
動かした結果です。設定通り、2エポックと6エポックで学習率が1/10になっています。
epoch:0, lr:0.001
epoch:1, lr:0.001
epoch:2, lr:0.0001
epoch:3, lr:0.0001
epoch:4, lr:0.0001
epoch:5, lr:0.0001
epoch:6, lr:1.0000000000000003e-05
epoch:7, lr:1.0000000000000003e-05
epoch:8, lr:1.0000000000000003e-05
epoch:9, lr:1.0000000000000003e-05
他にもいろいろな更新方法があるようなので、かなりフレキシブルな学習が行えるのですね。
追記
スケジューラの動作について、PyTorch1.1にはどうもバグがあるみたいです。
参照:https://github.com/pytorch/pytorch/issues/22107
import torch print(torch.__version__)
1.1.0
PyTorch1.1のバージョンで、StepLR()を動かしてみます。
2エポックごとだと動きが分かりづらいため、step_sizeを4にします。
scheduler = optim.lr_scheduler.StepLR(opt, step_size=4, gamma=0.1)
下に示すように、更新エポックのときだけ学習率がおかしくなっています。gammaが2回かけられているみたいですね。
epoch:0, lr:0.001
epoch:1, lr:0.001
epoch:2, lr:0.001
epoch:3, lr:0.001
epoch:4, lr:1e-05
epoch:5, lr:0.0001
epoch:6, lr:0.0001
epoch:7, lr:0.0001
epoch:8, lr:1.0000000000000002e-06
epoch:9, lr:1e-05
最初に動かした環境がPyTorch1.1だったので、結構はまってしまいました。こういうこともあるんですね。