[PyTorch/torchvision] ネットワークを設定する (2)

既存モデルをカスタマイズ

torchvisionで実装されているモデルをそのまま使うのではなく、一部だけを流用したいという場合もあります。

途中までの構造を使用

VGG16モデルの途中 (ここでは2回目のMaxpool直後) までを使う場合は、順にレイヤをリストに格納していって、最後にnn.Sequential()に変換するという手順を踏みます。

import torchvision.models as models
import torch.nn as nn

model_vgg16 = models.vgg16(pretrained=True)

modules = []
n_maxpool = 0

for layer in model_vgg16.features:
    n_maxpool = n_maxpool + 1 if isinstance(layer, nn.MaxPool2d) else n_maxpool
    modules.append(layer)
    
    if n_maxpool==2:
        break

clipped_model = nn.Sequential(*modules)

出力です。VGG16の一部のみがclipped_modelとして定義されました。

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

ネットワークの構造に追加

途中まで取り出した構造にレイヤを追加する場合は、リストに追加したいレイヤをappendすればOKです。

下では、VGG16の構造にin: 128/out: 512の畳み込み層を加えた例を示しています。

modules.append(nn.Conv2d(128, 512, 3, 1, 1))

clipped_model = nn.Sequential(*modules)

出力です。上記のclipped modelに10番目のレイヤが追加されました。

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

短いですが、今回は以上です。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です