【PyTorch】モデルの可視化を行うtorchinfoを使ってみた
PyTorchでモデルを可視化する方法はいくつかありますが,今回はその中でtorchinfoというものを見つけました.
実際にtorchinfoを使用してみたので,その使い方についてこちらにメモを残しておきます.
そのほかの可視化ライブラリについてもまとめておりますので,良ければご参照ください.
torchinfo
とは
PyTorchのモデルを可視化してくれるライブラリです.
Tensorflowのmodel.summary()
のようにPyTorchのモデルを可視化してくれます.
実行環境
- python 3.8.2
- pytorch 1.7.0
- torchvision 0.8.1
コードは全てJupyter Notebook上で実行しております.
インストール方法
pip
でインストールできます.
pip install torchinfo
使い方
使い方はとてもシンプルで,pytorchのモデルと入力サイズをtorchinfo.summary
に渡すだけです.
from torchinfo import summary from torchvision.models import resnet18 model = resnet18() batch_size = 2 summary( model, input_size=(batch_size, 3, 224, 224), col_names=["output_size", "num_params"], )
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ResNet -- -- ├─Conv2d: 1-1 [2, 64, 112, 112] 9,408 ├─BatchNorm2d: 1-2 [2, 64, 112, 112] 128 ├─ReLU: 1-3 [2, 64, 112, 112] -- ├─MaxPool2d: 1-4 [2, 64, 56, 56] -- ├─Sequential: 1-5 [2, 64, 56, 56] -- │ └─BasicBlock: 2-1 [2, 64, 56, 56] -- │ │ └─Conv2d: 3-1 [2, 64, 56, 56] 36,864 │ │ └─BatchNorm2d: 3-2 [2, 64, 56, 56] 128 │ │ └─ReLU: 3-3 [2, 64, 56, 56] -- │ │ └─Conv2d: 3-4 [2, 64, 56, 56] 36,864 │ │ └─BatchNorm2d: 3-5 [2, 64, 56, 56] 128 │ │ └─ReLU: 3-6 [2, 64, 56, 56] -- │ └─BasicBlock: 2-2 [2, 64, 56, 56] -- │ │ └─Conv2d: 3-7 [2, 64, 56, 56] 36,864 │ │ └─BatchNorm2d: 3-8 [2, 64, 56, 56] 128 │ │ └─ReLU: 3-9 [2, 64, 56, 56] -- │ │ └─Conv2d: 3-10 [2, 64, 56, 56] 36,864 │ │ └─BatchNorm2d: 3-11 [2, 64, 56, 56] 128 │ │ └─ReLU: 3-12 [2, 64, 56, 56] -- ├─Sequential: 1-6 [2, 128, 28, 28] -- │ └─BasicBlock: 2-3 [2, 128, 28, 28] -- │ │ └─Conv2d: 3-13 [2, 128, 28, 28] 73,728 │ │ └─BatchNorm2d: 3-14 [2, 128, 28, 28] 256 │ │ └─ReLU: 3-15 [2, 128, 28, 28] -- │ │ └─Conv2d: 3-16 [2, 128, 28, 28] 147,456 │ │ └─BatchNorm2d: 3-17 [2, 128, 28, 28] 256 │ │ └─Sequential: 3-18 [2, 128, 28, 28] 8,448 │ │ └─ReLU: 3-19 [2, 128, 28, 28] -- │ └─BasicBlock: 2-4 [2, 128, 28, 28] -- │ │ └─Conv2d: 3-20 [2, 128, 28, 28] 147,456 │ │ └─BatchNorm2d: 3-21 [2, 128, 28, 28] 256 │ │ └─ReLU: 3-22 [2, 128, 28, 28] -- │ │ └─Conv2d: 3-23 [2, 128, 28, 28] 147,456 │ │ └─BatchNorm2d: 3-24 [2, 128, 28, 28] 256 │ │ └─ReLU: 3-25 [2, 128, 28, 28] -- ├─Sequential: 1-7 [2, 256, 14, 14] -- │ └─BasicBlock: 2-5 [2, 256, 14, 14] -- │ │ └─Conv2d: 3-26 [2, 256, 14, 14] 294,912 │ │ └─BatchNorm2d: 3-27 [2, 256, 14, 14] 512 │ │ └─ReLU: 3-28 [2, 256, 14, 14] -- │ │ └─Conv2d: 3-29 [2, 256, 14, 14] 589,824 │ │ └─BatchNorm2d: 3-30 [2, 256, 14, 14] 512 │ │ └─Sequential: 3-31 [2, 256, 14, 14] 33,280 │ │ └─ReLU: 3-32 [2, 256, 14, 14] -- │ └─BasicBlock: 2-6 [2, 256, 14, 14] -- │ │ └─Conv2d: 3-33 [2, 256, 14, 14] 589,824 │ │ └─BatchNorm2d: 3-34 [2, 256, 14, 14] 512 │ │ └─ReLU: 3-35 [2, 256, 14, 14] -- │ │ └─Conv2d: 3-36 [2, 256, 14, 14] 589,824 │ │ └─BatchNorm2d: 3-37 [2, 256, 14, 14] 512 │ │ └─ReLU: 3-38 [2, 256, 14, 14] -- ├─Sequential: 1-8 [2, 512, 7, 7] -- │ └─BasicBlock: 2-7 [2, 512, 7, 7] -- │ │ └─Conv2d: 3-39 [2, 512, 7, 7] 1,179,648 │ │ └─BatchNorm2d: 3-40 [2, 512, 7, 7] 1,024 │ │ └─ReLU: 3-41 [2, 512, 7, 7] -- │ │ └─Conv2d: 3-42 [2, 512, 7, 7] 2,359,296 │ │ └─BatchNorm2d: 3-43 [2, 512, 7, 7] 1,024 │ │ └─Sequential: 3-44 [2, 512, 7, 7] 132,096 │ │ └─ReLU: 3-45 [2, 512, 7, 7] -- │ └─BasicBlock: 2-8 [2, 512, 7, 7] -- │ │ └─Conv2d: 3-46 [2, 512, 7, 7] 2,359,296 │ │ └─BatchNorm2d: 3-47 [2, 512, 7, 7] 1,024 │ │ └─ReLU: 3-48 [2, 512, 7, 7] -- │ │ └─Conv2d: 3-49 [2, 512, 7, 7] 2,359,296 │ │ └─BatchNorm2d: 3-50 [2, 512, 7, 7] 1,024 │ │ └─ReLU: 3-51 [2, 512, 7, 7] -- ├─AdaptiveAvgPool2d: 1-9 [2, 512, 1, 1] -- ├─Linear: 1-10 [2, 1000] 513,000 ========================================================================================== Total params: 11,689,512 Trainable params: 11,689,512 Non-trainable params: 0 Total mult-adds (G): 3.63 ========================================================================================== Input size (MB): 1.20 Forward/backward pass size (MB): 79.49 Params size (MB): 46.76 Estimated Total Size (MB): 127.46 ==========================================================================================
このように出力のテンソルの形状やパラメータの数だけなく,フォーワードやバックワードにかかる容量なども出力してくれます.
torchinfo.summary
の引数
torchinfo.summary
には様々な引数を渡すことができ,それによって出力の結果を変えることができます.
input_size
... バッチの次元を含めて入力のサイズを指定する.複数の入力にも対応.col_names
... 可視化したいパラメータを指定できる.複数指定可能.2021/06/01現在は,"input_size"
,"output_size"
,"num_params"
,"kernel_size"
,"mult_adds"
に対応.depth
... ネストしたレイヤー(nn.Sequential
など)に対して,どの深さまで可視化させるか.(後述の例参照)device
... 使用するデバイス名.dtypes
... 入力のdata typeを指定.複数入力に対して,それぞれ指定することができる.row_settings
... レイヤーの名前を指定できる.複数指定可能.2021/06/01現在は,var_names
(変数名表示)と,depth
(depth-index表示)の二つが使用可能.verbose
...verbose=0
とすると何も表示しない.verbose=1
とすると,上の例のようにモデルのsummaryのみを出力.verbose=2
とすると,各レイヤーが持つ重みパラメータまで出力.
その他の引数や詳細に関しては,こちらを参照してください.
上で紹介した引数を指定しながら,再度モデルの可視化を行ってみます.
from torchinfo import summary from torchvision.models import resnet18 model = resnet18() batch_size = 2 summary( model, input_size=(batch_size, 3, 224, 224), col_names=["input_size", "output_size", "num_params", "mult_adds"], depth=2, row_settings=["var_names"], )
============================================================================================================================================ Layer (type (var_name)) Input Shape Output Shape Param # Mult-Adds ============================================================================================================================================ ResNet -- -- -- -- ├─Conv2d (conv1) [2, 3, 224, 224] [2, 64, 112, 112] 9,408 236,027,904 ├─BatchNorm2d (bn1) [2, 64, 112, 112] [2, 64, 112, 112] 128 256 ├─ReLU (relu) [2, 64, 112, 112] [2, 64, 112, 112] -- -- ├─MaxPool2d (maxpool) [2, 64, 112, 112] [2, 64, 56, 56] -- -- ├─Sequential (layer1) [2, 64, 56, 56] [2, 64, 56, 56] -- -- │ └─BasicBlock (0) [2, 64, 56, 56] [2, 64, 56, 56] 73,984 -- │ └─BasicBlock (1) [2, 64, 56, 56] [2, 64, 56, 56] 73,984 -- ├─Sequential (layer2) [2, 64, 56, 56] [2, 128, 28, 28] -- -- │ └─BasicBlock (0) [2, 64, 56, 56] [2, 128, 28, 28] 230,144 -- │ └─BasicBlock (1) [2, 128, 28, 28] [2, 128, 28, 28] 295,424 -- ├─Sequential (layer3) [2, 128, 28, 28] [2, 256, 14, 14] -- -- │ └─BasicBlock (0) [2, 128, 28, 28] [2, 256, 14, 14] 919,040 -- │ └─BasicBlock (1) [2, 256, 14, 14] [2, 256, 14, 14] 1,180,672 -- ├─Sequential (layer4) [2, 256, 14, 14] [2, 512, 7, 7] -- -- │ └─BasicBlock (0) [2, 256, 14, 14] [2, 512, 7, 7] 3,673,088 -- │ └─BasicBlock (1) [2, 512, 7, 7] [2, 512, 7, 7] 4,720,640 -- ├─AdaptiveAvgPool2d (avgpool) [2, 512, 7, 7] [2, 512, 1, 1] -- -- ├─Linear (fc) [2, 512] [2, 1000] 513,000 1,026,000 ============================================================================================================================================ Total params: 11,689,512 Trainable params: 11,689,512 Non-trainable params: 0 Total mult-adds (G): 3.63 ============================================================================================================================================ Input size (MB): 1.20 Forward/backward pass size (MB): 79.49 Params size (MB): 46.76 Estimated Total Size (MB): 127.46 ============================================================================================================================================
こんな感じで可視化するレイヤーの深さ(depth)を変更できたり,入出力の形状や計算量などをレイヤーごとに可視化できるので非常に便利でした!
参考
リンク
リンク
リンク