ネットワークモデルを設計する際、既存の (実績ある) モデルをベースにすることがほとんどです。torchvisionには多くのモデルがすでに実装されていますので、それをうまく利用することで、ネットワーク構築のコストを下げていけるかと思います。
調べていたら、torchvisionのv0.14でモデルの追加があったり、メソッド、パラメータの変更などがなされたようでしたので、v0.14をベースに動きを調べてみました。
今回使用したバージョンは以下のとおりです。pythonのバージョンが古すぎると、torchvisionのimportでエラーが生じるため、注意が必要です。ちなみに、python3.7ではエラーが出ました。
- python==3.10.0
- torch==1.13.0
- torchvision==0.14
モデルの一覧を表示
torchvision.models.list_models()で定義されているモデルの一覧表示ができます (version0.14以降)。
import torchvision.models as models models.list_models()
出力
['alexnet', 'convnext_base', 'convnext_large', 'convnext_small', 'convnext_tiny', 'deeplabv3_mobilenet_v3_large', 'deeplabv3_resnet101', 'deeplabv3_resnet50', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_v2_l', 'efficientnet_v2_m', 'efficientnet_v2_s', 'fasterrcnn_mobilenet_v3_large_320_fpn', 'fasterrcnn_mobilenet_v3_large_fpn', 'fasterrcnn_resnet50_fpn', 'fasterrcnn_resnet50_fpn_v2', 'fcn_resnet101', 'fcn_resnet50', 'fcos_resnet50_fpn', 'googlenet', 'inception_v3', 'keypointrcnn_resnet50_fpn', 'lraspp_mobilenet_v3_large', 'maskrcnn_resnet50_fpn', 'maskrcnn_resnet50_fpn_v2', 'maxvit_t', 'mc3_18', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3', 'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small', 'mvit_v1_b', 'mvit_v2_s', 'quantized_googlenet', 'quantized_inception_v3', 'quantized_mobilenet_v2', 'quantized_mobilenet_v3_large', 'quantized_resnet18', 'quantized_resnet50', 'quantized_resnext101_32x8d', 'quantized_resnext101_64x4d', 'quantized_shufflenet_v2_x0_5', 'quantized_shufflenet_v2_x1_0', 'quantized_shufflenet_v2_x1_5', 'quantized_shufflenet_v2_x2_0', 'r2plus1d_18', 'r3d_18', 'raft_large', 'raft_small', 'regnet_x_16gf', 'regnet_x_1_6gf', 'regnet_x_32gf', 'regnet_x_3_2gf', 'regnet_x_400mf', 'regnet_x_800mf', 'regnet_x_8gf', 'regnet_y_128gf', 'regnet_y_16gf', 'regnet_y_1_6gf', 'regnet_y_32gf', 'regnet_y_3_2gf', 'regnet_y_400mf', 'regnet_y_800mf', 'regnet_y_8gf', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'resnext101_32x8d', 'resnext101_64x4d', 'resnext50_32x4d', 'retinanet_resnet50_fpn', 'retinanet_resnet50_fpn_v2', 's3d', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0', 'squeezenet1_0', 'squeezenet1_1', 'ssd300_vgg16', 'ssdlite320_mobilenet_v3_large', 'swin_b', 'swin_s', 'swin_t', 'swin_v2_b', 'swin_v2_s', 'swin_v2_t', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'vit_b_16', 'vit_b_32', 'vit_h_14', 'vit_l_16', 'vit_l_32', 'wide_resnet101_2', 'wide_resnet50_2']
定義済みモデルの使用
定義済みモデルを使用する方法は以下の2つです。ただし、get_model()はまだベータ版との記述がありました (https://pytorch.org/vision/stable/models.html)。
- torchvision.modelsで定義された各モデルを呼び出し
- get_model()でモデルの名前を指定して読み込み
VGG16
まずは定番VGG16です。学習済の重みも読み込むようにしています。
# vgg16モデル呼び出し model_vgg16 = models.vgg16(weights=models.VGG16_Weights.DEFAULT) # get_model() model_vgg16 = models.get_model('vgg16', weights='DEFAULT')
結果
VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ...以下略...
ViT
ViTも同じようにモデルを定義できます。
# vit_b_16モデル呼び出し model_vit = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT) # get_model() model_vit = models.get_model('vit_b_16', weights='DEFAULT')
結果
VisionTransformer( (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16)) (encoder): Encoder( (dropout): Dropout(p=0.0, inplace=False) ...以下略...
Swin Transformer
Swin Transformerのモデルもありました。進化していますね。
# swin_bモデル呼び出し model_swin = models.swin_b(weights=models.Swin_B_Weights.DEFAULT) # get_model() model_swin = models.get_model('swin_b', weights='DEFAULT')
結果
SwinTransformer( (features): Sequential( (0): Sequential( (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4)) (1): Permute() (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) ) ...以下略...
学習済の重みの設定
モデルの重みの指定方法も、個々のモデル指定での呼び出しとget_modelの場合とで若干違いがあります。最新版の重みを使う場合、個々のモデル指定での呼び出しの場合はmodels.モデル名_Weights.DEFAULT 、get_modelの場合はweights=’DEFAULT’を指定すればよさそうです。(pretrained=Trueも今は使えますが、v0.15で削除予定とのこと)
また、ランダムに初期化された重みを使用する場合は、weights=Noneにします。
とりあえず今回はここまで。
torchvisionの提供モデルがすごく増えててびっくりしました。