[PyTorch/torchvision] ネットワークを設定する (1)

ネットワークモデルを設計する際、既存の (実績ある) モデルをベースにすることがほとんどです。torchvisionには多くのモデルがすでに実装されていますので、それをうまく利用することで、ネットワーク構築のコストを下げていけるかと思います。

調べていたら、torchvisionのv0.14でモデルの追加があったり、メソッド、パラメータの変更などがなされたようでしたので、v0.14をベースに動きを調べてみました。

今回使用したバージョンは以下のとおりです。pythonのバージョンが古すぎると、torchvisionのimportでエラーが生じるため、注意が必要です。ちなみに、python3.7ではエラーが出ました。

  • python==3.10.0
  • torch==1.13.0
  • torchvision==0.14

モデルの一覧を表示

torchvision.models.list_models()で定義されているモデルの一覧表示ができます (version0.14以降)。

import torchvision.models as models

models.list_models()

出力

['alexnet', 'convnext_base', 'convnext_large', 'convnext_small', 'convnext_tiny', 'deeplabv3_mobilenet_v3_large', 'deeplabv3_resnet101', 'deeplabv3_resnet50', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_v2_l', 'efficientnet_v2_m', 'efficientnet_v2_s', 'fasterrcnn_mobilenet_v3_large_320_fpn', 'fasterrcnn_mobilenet_v3_large_fpn', 'fasterrcnn_resnet50_fpn', 'fasterrcnn_resnet50_fpn_v2', 'fcn_resnet101', 'fcn_resnet50', 'fcos_resnet50_fpn', 'googlenet', 'inception_v3', 'keypointrcnn_resnet50_fpn', 'lraspp_mobilenet_v3_large', 'maskrcnn_resnet50_fpn', 'maskrcnn_resnet50_fpn_v2', 'maxvit_t', 'mc3_18', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3', 'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small', 'mvit_v1_b', 'mvit_v2_s', 'quantized_googlenet', 'quantized_inception_v3', 'quantized_mobilenet_v2', 'quantized_mobilenet_v3_large', 'quantized_resnet18', 'quantized_resnet50', 'quantized_resnext101_32x8d', 'quantized_resnext101_64x4d', 'quantized_shufflenet_v2_x0_5', 'quantized_shufflenet_v2_x1_0', 'quantized_shufflenet_v2_x1_5', 'quantized_shufflenet_v2_x2_0', 'r2plus1d_18', 'r3d_18', 'raft_large', 'raft_small', 'regnet_x_16gf', 'regnet_x_1_6gf', 'regnet_x_32gf', 'regnet_x_3_2gf', 'regnet_x_400mf', 'regnet_x_800mf', 'regnet_x_8gf', 'regnet_y_128gf', 'regnet_y_16gf', 'regnet_y_1_6gf', 'regnet_y_32gf', 'regnet_y_3_2gf', 'regnet_y_400mf', 'regnet_y_800mf', 'regnet_y_8gf', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'resnext101_32x8d', 'resnext101_64x4d', 'resnext50_32x4d', 'retinanet_resnet50_fpn', 'retinanet_resnet50_fpn_v2', 's3d', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0', 'squeezenet1_0', 'squeezenet1_1', 'ssd300_vgg16', 'ssdlite320_mobilenet_v3_large', 'swin_b', 'swin_s', 'swin_t', 'swin_v2_b', 'swin_v2_s', 'swin_v2_t', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'vit_b_16', 'vit_b_32', 'vit_h_14', 'vit_l_16', 'vit_l_32', 'wide_resnet101_2', 'wide_resnet50_2']

定義済みモデルの使用

定義済みモデルを使用する方法は以下の2つです。ただし、get_model()はまだベータ版との記述がありました (https://pytorch.org/vision/stable/models.html)。

  • torchvision.modelsで定義された各モデルを呼び出し
  • get_model()でモデルの名前を指定して読み込み

VGG16

まずは定番VGG16です。学習済の重みも読み込むようにしています。

# vgg16モデル呼び出し
model_vgg16 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)

# get_model()
model_vgg16 = models.get_model('vgg16', weights='DEFAULT')

結果

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    ...以下略...

ViT

ViTも同じようにモデルを定義できます。

# vit_b_16モデル呼び出し
model_vit = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT)

# get_model()
model_vit = models.get_model('vit_b_16', weights='DEFAULT')

結果

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    ...以下略...

Swin Transformer

Swin Transformerのモデルもありました。進化していますね。

# swin_bモデル呼び出し
model_swin = models.swin_b(weights=models.Swin_B_Weights.DEFAULT)

# get_model()
model_swin = models.get_model('swin_b', weights='DEFAULT')

結果

SwinTransformer(
  (features): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      (1): Permute()
      (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    ...以下略...

学習済の重みの設定

モデルの重みの指定方法も、個々のモデル指定での呼び出しとget_modelの場合とで若干違いがあります。最新版の重みを使う場合、個々のモデル指定での呼び出しの場合はmodels.モデル名_Weights.DEFAULT 、get_modelの場合はweights=’DEFAULT’を指定すればよさそうです。(pretrained=Trueも今は使えますが、v0.15で削除予定とのこと)

また、ランダムに初期化された重みを使用する場合は、weights=Noneにします。

とりあえず今回はここまで。
torchvisionの提供モデルがすごく増えててびっくりしました。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です