yiskw note

機械学習やプログラミングについて気まぐれで書きます

【PyTorch】モデルの可視化を行うtorchinfoを使ってみた


PyTorchでモデルを可視化する方法はいくつかありますが,今回はその中でtorchinfoというものを見つけました.
実際にtorchinfoを使用してみたので,その使い方についてこちらにメモを残しておきます.

そのほかの可視化ライブラリについてもまとめておりますので,良ければご参照ください.

yiskw713.hatenablog.com

torchinfoとは

github.com

PyTorchのモデルを可視化してくれるライブラリです.
Tensorflowのmodel.summary()のようにPyTorchのモデルを可視化してくれます.

実行環境

  • python 3.8.2
  • pytorch 1.7.0
  • torchvision 0.8.1

コードは全てJupyter Notebook上で実行しております.

インストール方法

pipでインストールできます.

pip install torchinfo

使い方

使い方はとてもシンプルで,pytorchのモデルと入力サイズをtorchinfo.summaryに渡すだけです.

from torchinfo import summary
from torchvision.models import resnet18

model = resnet18()
batch_size = 2

summary(
    model,
    input_size=(batch_size, 3, 224, 224),
    col_names=["output_size", "num_params"],
)
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   --                        --
├─Conv2d: 1-1                            [2, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [2, 64, 112, 112]         128
├─ReLU: 1-3                              [2, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [2, 64, 56, 56]           --
├─Sequential: 1-5                        [2, 64, 56, 56]           --
│    └─BasicBlock: 2-1                   [2, 64, 56, 56]           --
│    │    └─Conv2d: 3-1                  [2, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-2             [2, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [2, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [2, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [2, 64, 56, 56]           128
│    │    └─ReLU: 3-6                    [2, 64, 56, 56]           --
│    └─BasicBlock: 2-2                   [2, 64, 56, 56]           --
│    │    └─Conv2d: 3-7                  [2, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-8             [2, 64, 56, 56]           128
│    │    └─ReLU: 3-9                    [2, 64, 56, 56]           --
│    │    └─Conv2d: 3-10                 [2, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-11            [2, 64, 56, 56]           128
│    │    └─ReLU: 3-12                   [2, 64, 56, 56]           --
├─Sequential: 1-6                        [2, 128, 28, 28]          --
│    └─BasicBlock: 2-3                   [2, 128, 28, 28]          --
│    │    └─Conv2d: 3-13                 [2, 128, 28, 28]          73,728
│    │    └─BatchNorm2d: 3-14            [2, 128, 28, 28]          256
│    │    └─ReLU: 3-15                   [2, 128, 28, 28]          --
│    │    └─Conv2d: 3-16                 [2, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-17            [2, 128, 28, 28]          256
│    │    └─Sequential: 3-18             [2, 128, 28, 28]          8,448
│    │    └─ReLU: 3-19                   [2, 128, 28, 28]          --
│    └─BasicBlock: 2-4                   [2, 128, 28, 28]          --
│    │    └─Conv2d: 3-20                 [2, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-21            [2, 128, 28, 28]          256
│    │    └─ReLU: 3-22                   [2, 128, 28, 28]          --
│    │    └─Conv2d: 3-23                 [2, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-24            [2, 128, 28, 28]          256
│    │    └─ReLU: 3-25                   [2, 128, 28, 28]          --
├─Sequential: 1-7                        [2, 256, 14, 14]          --
│    └─BasicBlock: 2-5                   [2, 256, 14, 14]          --
│    │    └─Conv2d: 3-26                 [2, 256, 14, 14]          294,912
│    │    └─BatchNorm2d: 3-27            [2, 256, 14, 14]          512
│    │    └─ReLU: 3-28                   [2, 256, 14, 14]          --
│    │    └─Conv2d: 3-29                 [2, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-30            [2, 256, 14, 14]          512
│    │    └─Sequential: 3-31             [2, 256, 14, 14]          33,280
│    │    └─ReLU: 3-32                   [2, 256, 14, 14]          --
│    └─BasicBlock: 2-6                   [2, 256, 14, 14]          --
│    │    └─Conv2d: 3-33                 [2, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-34            [2, 256, 14, 14]          512
│    │    └─ReLU: 3-35                   [2, 256, 14, 14]          --
│    │    └─Conv2d: 3-36                 [2, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-37            [2, 256, 14, 14]          512
│    │    └─ReLU: 3-38                   [2, 256, 14, 14]          --
├─Sequential: 1-8                        [2, 512, 7, 7]            --
│    └─BasicBlock: 2-7                   [2, 512, 7, 7]            --
│    │    └─Conv2d: 3-39                 [2, 512, 7, 7]            1,179,648
│    │    └─BatchNorm2d: 3-40            [2, 512, 7, 7]            1,024
│    │    └─ReLU: 3-41                   [2, 512, 7, 7]            --
│    │    └─Conv2d: 3-42                 [2, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-43            [2, 512, 7, 7]            1,024
│    │    └─Sequential: 3-44             [2, 512, 7, 7]            132,096
│    │    └─ReLU: 3-45                   [2, 512, 7, 7]            --
│    └─BasicBlock: 2-8                   [2, 512, 7, 7]            --
│    │    └─Conv2d: 3-46                 [2, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-47            [2, 512, 7, 7]            1,024
│    │    └─ReLU: 3-48                   [2, 512, 7, 7]            --
│    │    └─Conv2d: 3-49                 [2, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-50            [2, 512, 7, 7]            1,024
│    │    └─ReLU: 3-51                   [2, 512, 7, 7]            --
├─AdaptiveAvgPool2d: 1-9                 [2, 512, 1, 1]            --
├─Linear: 1-10                           [2, 1000]                 513,000
==========================================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
Total mult-adds (G): 3.63
==========================================================================================
Input size (MB): 1.20
Forward/backward pass size (MB): 79.49
Params size (MB): 46.76
Estimated Total Size (MB): 127.46
==========================================================================================

このように出力のテンソルの形状やパラメータの数だけなく,フォーワードやバックワードにかかる容量なども出力してくれます.

torchinfo.summaryの引数

torchinfo.summaryには様々な引数を渡すことができ,それによって出力の結果を変えることができます.

  • input_size ... バッチの次元を含めて入力のサイズを指定する.複数の入力にも対応.
  • col_names ... 可視化したいパラメータを指定できる.複数指定可能.2021/06/01現在は,"input_size", "output_size", "num_params", "kernel_size", "mult_adds"に対応.
  • depth ... ネストしたレイヤー(nn.Sequentialなど)に対して,どの深さまで可視化させるか.(後述の例参照)
  • device ... 使用するデバイス名.
  • dtypes ... 入力のdata typeを指定.複数入力に対して,それぞれ指定することができる.
  • row_settings ... レイヤーの名前を指定できる.複数指定可能.2021/06/01現在は,var_names(変数名表示)と,depth(depth-index表示)の二つが使用可能.
  • verbose ... verbose=0とすると何も表示しない.verbose=1とすると,上の例のようにモデルのsummaryのみを出力.verbose=2とすると,各レイヤーが持つ重みパラメータまで出力.

その他の引数や詳細に関しては,こちらを参照してください.

上で紹介した引数を指定しながら,再度モデルの可視化を行ってみます.

from torchinfo import summary
from torchvision.models import resnet18

model = resnet18()
batch_size = 2

summary(
    model,
    input_size=(batch_size, 3, 224, 224),
    col_names=["input_size", "output_size", "num_params", "mult_adds"],
    depth=2,
    row_settings=["var_names"],
)
============================================================================================================================================
Layer (type (var_name))                  Input Shape               Output Shape              Param #                   Mult-Adds
============================================================================================================================================
ResNet                                   --                        --                        --                        --
├─Conv2d (conv1)                         [2, 3, 224, 224]          [2, 64, 112, 112]         9,408                     236,027,904
├─BatchNorm2d (bn1)                      [2, 64, 112, 112]         [2, 64, 112, 112]         128                       256
├─ReLU (relu)                            [2, 64, 112, 112]         [2, 64, 112, 112]         --                        --
├─MaxPool2d (maxpool)                    [2, 64, 112, 112]         [2, 64, 56, 56]           --                        --
├─Sequential (layer1)                    [2, 64, 56, 56]           [2, 64, 56, 56]           --                        --
│    └─BasicBlock (0)                    [2, 64, 56, 56]           [2, 64, 56, 56]           73,984                    --
│    └─BasicBlock (1)                    [2, 64, 56, 56]           [2, 64, 56, 56]           73,984                    --
├─Sequential (layer2)                    [2, 64, 56, 56]           [2, 128, 28, 28]          --                        --
│    └─BasicBlock (0)                    [2, 64, 56, 56]           [2, 128, 28, 28]          230,144                   --
│    └─BasicBlock (1)                    [2, 128, 28, 28]          [2, 128, 28, 28]          295,424                   --
├─Sequential (layer3)                    [2, 128, 28, 28]          [2, 256, 14, 14]          --                        --
│    └─BasicBlock (0)                    [2, 128, 28, 28]          [2, 256, 14, 14]          919,040                   --
│    └─BasicBlock (1)                    [2, 256, 14, 14]          [2, 256, 14, 14]          1,180,672                 --
├─Sequential (layer4)                    [2, 256, 14, 14]          [2, 512, 7, 7]            --                        --
│    └─BasicBlock (0)                    [2, 256, 14, 14]          [2, 512, 7, 7]            3,673,088                 --
│    └─BasicBlock (1)                    [2, 512, 7, 7]            [2, 512, 7, 7]            4,720,640                 --
├─AdaptiveAvgPool2d (avgpool)            [2, 512, 7, 7]            [2, 512, 1, 1]            --                        --
├─Linear (fc)                            [2, 512]                  [2, 1000]                 513,000                   1,026,000
============================================================================================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
Total mult-adds (G): 3.63
============================================================================================================================================
Input size (MB): 1.20
Forward/backward pass size (MB): 79.49
Params size (MB): 46.76
Estimated Total Size (MB): 127.46
============================================================================================================================================

こんな感じで可視化するレイヤーの深さ(depth)を変更できたり,入出力の形状や計算量などをレイヤーごとに可視化できるので非常に便利でした!

参考