ネットワークの重みを表示する

学習の過程で重みが更新されているかを確認したいときに、毎回とまどいながら設定していましたので、備忘的に書いておくことにします。

例として、torchvisionで読み込んだVGG16の重みを確認してみます。

学習済みモデルの読み込み

import numpy as np
import torch
import torchvision.models as models

# VGG16のpretrainモデルを読み込み
model = models.vgg16(pretrained=True)

まずは、読み込んだモデルの形状を確認します。

print(model)

VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)

VGG16は、「features」に13の畳み込み層、「classifier」に3つの全結合層が存在しているのがわかります。13+3=16なので、VGG16と呼ばれているわけですね。

重みの取り出し

では、ここに含まれている重みを見るにはどうすればいいでしょうか。重みは、モデルの「state_dict」メソッドに辞書形式で格納されています。
state_dictを取り出して、表示してみます。

# state_dictの呼び出し
state_dict = model.state_dict()

# 項目表示
print(state_dict.keys())

odict_keys([‘features.0.weight’, ‘features.0.bias’, ‘features.2.weight’, ‘features.2.bias’, ‘features.5.weight’, ‘features.5.bias’, ‘features.7.weight’, ‘features.7.bias’, ‘features.10.weight’, ‘features.10.bias’, ‘features.12.weight’, ‘features.12.bias’, ‘features.14.weight’, ‘features.14.bias’, ‘features.17.weight’, ‘features.17.bias’, ‘features.19.weight’, ‘features.19.bias’, ‘features.21.weight’, ‘features.21.bias’, ‘features.24.weight’, ‘features.24.bias’, ‘features.26.weight’, ‘features.26.bias’, ‘features.28.weight’, ‘features.28.bias’, ‘classifier.0.weight’, ‘classifier.0.bias’, ‘classifier.3.weight’, ‘classifier.3.bias’, ‘classifier.6.weight’, ‘classifier.6.bias’])

これがVGG16が保持・学習する重み群です。名前についている番号は、上で表示したmodelのレイヤ番号(カッコ内の項番)に対応します。

state_dictは辞書形式のため、例えば7番目のfeaturesの重みを確認したいときは、以下のように呼び出せます。

print(state_dict['features.7.weight'])

tensor([[[[ 2.5788e-02, -1.9852e-02, -1.0697e-02], [-1.6114e-02, -4.1759e-03, 8.4582e-03], [ 4.0309e-03, 1.8973e-02, 3.8059e-02]], [[-5.4261e-02, -2.9872e-02, 1.1506e-02], [-1.8266e-02, -1.5708e-02, -1.2726e-02], [ 2.0867e-02, -5.6425e-03, -5.1218e-04]], [[ 1.8961e-02, -2.3766e-03, 6.1427e-03], [-5.1761e-02, -1.7272e-02, 8.7075e-03], [-4.4383e-02, 2.3661e-02, 9.0710e-02]], …, (以下略)

辞書形式でテンソルが格納されているだけなので、ここまでできてしまえば、情報を収集するのは慣れたやり方でできます。

# 重みのサイズを確認
print(state_dict['features.7.weight'].size()) 

# 一部の重みだけを表示
print(state_dict['features.7.weight'][0,0,0,0])

torch.Size([128, 128, 3, 3])
tensor(0.0258)

中身の確認はこれでOKです。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です