yiskw note

機械学習やプログラミングについて気まぐれで書きます

【PyTorch】テンソルのデータ型や形状をアノテーションするtorchtypingを使ってみた


概要

github.com

PyTorchのテンソルのデータ型や形状をアノテーションできるtorchtypingというライブラリを使ってみたので,
使い方やその使用感について,こちらにメモを残しておきます.

torchtypingとは

その名の通りPyTorchのテンソルのデータ型や形状をアノテーションできるライブラリです.
Pythonの標準ライブラリであるtypingでは,PyTorchのテンソル型の形状を記述することができません.
そのため以下のようにコメントに形状を記述する必要があります.

import torch


def batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    # x has shape (batch, x_channels)
    # y has shape (batch, y_channels)
    # return has shape (batch, x_channels, y_channels)

    return x.unsqueeze(-1) * y.unsqueeze(-2)

これをtorchtypingを使って書くと,以下のようになります.

import torch
from torchtyping import TensorType


def batch_outer_product(
    x: TensorType["batch", "x_channels"],
    y: TensorType["batch", "y_channels"],
) -> TensorType["batch", "x_channels", "y_channels"]:

    return x.unsqueeze(-1) * y.unsqueeze(-2)

このように,テンソルの形状をわかりやすく書くことができます.
また,typeguardと一緒に実行時することで,実行時にテンソルが所望の形状になっているかどうかを確認することができます.

インストール方法

pipでインストールが可能です.

$ pip install torchtyping

使い方

アノテーションの記法

torchtypingからTensorTypeをimportして型アノテーションします.

torchtyping.TensorType[shape, dtype, layout, details]
  • shape ... テンソルの形状をintstrstr: int, ...などの形で記述します.
    • int ... その次元が固定長であることを意味します.
    • str ... ある次元に対して,この名前が与えられ,実行時に他のテンソルと矛盾がないか確認されます.
    • str: int ... ある次元に対して名前が与えられ,かつ固定長であることを意味を意味します.
    • ... ... その次元が任意のサイズであることを意味します.
  • dtype ... int / float / bool のいずれかを指定します.
  • layout ... torch.strided / torch.sparse_cooのいずれかを指定します.
  • details ... その他のフラグを追加できます.詳細はこちらを参照してください.

例えばこのようにテンソルアノテーションできます.

def batch_outer_product(
    x: TensorType["batch", "x_channels": 4, float],
    y: TensorType["batch", "y_channels": 8, float],
) -> TensorType["batch", "x_channels":4, "y_channels":8, float]:
    return x.unsqueeze(-1) * y.unsqueeze(-2)

ただし,このままでは型のヒントを与えているだけで,
実際にその型が想定したものになっているかどうかの確認はしてくれません.
例えば以下のような例でもエラーは吐かれません.

import torch
from torchtyping import TensorType


def batch_outer_product(
    x: TensorType["batch", "x_channels":4, float],
    y: TensorType["batch", "y_channels":8, float],
) -> TensorType["batch", "x_channels":4, "y_channels":8, float]:
    return x.unsqueeze(-1) * y.unsqueeze(-2)


x = torch.randn(4, 5)
y = torch.randn(4, 15)

print(batch_outer_product(x, y).shape)

typeguardを使用することで,実行時に型のチェックを行ってくれるようになります.

typeguardとの併用

typeguard.readthedocs.io

typeguardとはPEP 484で定義されている実行時型チェックを提供するライブラリです.
以下のようにtorchtyping.patch_typeguard()を呼びだし,
型チェックを行いたい関数に@typecheckedデコレータを追加することで,
実行時の型チェックが行われます.

import torch
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked

patch_typeguard()  # use before @typechecked


@typechecked
def batch_outer_product(
    x: TensorType["batch", "x_channels":4, float],
    y: TensorType["batch", "y_channels":8, float],
) -> TensorType["batch", "x_channels":4, "y_channels":8, float]:
    return x.unsqueeze(-1) * y.unsqueeze(-2)


x = torch.randn(4, 5)
y = torch.randn(4, 15)

# raise error
print(batch_outer_product(x, y))
$ python tests_torchtyping.py
Traceback (most recent call last):
  File "tests_torchtyping.py", line 20, in <module>
    print(batch_outer_product(x, y))
  File "/Users/yuchi/.anyenv/envs/pyenv/versions/torch/lib/python3.8/site-packages/typeguard/__init__.py", line 926, in wrapper
    check_argument_types(memo)
  File "/Users/yuchi/.anyenv/envs/pyenv/versions/torch/lib/python3.8/site-packages/torchtyping/typechecker.py", line 338, in check_argument_types
    retval = _check_argument_types(*args, **kwargs)
  File "/Users/yuchi/.anyenv/envs/pyenv/versions/torch/lib/python3.8/site-packages/typeguard/__init__.py", line 768, in check_argument_types
    raise TypeError(*exc.args) from None
TypeError: argument "x" must be of type TensorType['batch', x_channels: 4, torch.float32], got type TensorType[4, 5, torch.float32] instead.

typeguardを使用した時の速度について

typeguardを用いた実行時型チェックを行った時と,そうでない時の速度を比較しました.
実行したスクリプトは以下です.

import time

# (中略: 上と同じ関数を使用)

start = time.time()
for _ in range(100000):
    x = torch.randn(4, 4)
    y = torch.randn(4, 8)

    res = batch_outer_product(x, y)

print((time.time() - start), "sec")

結果は,typeguardがある場合は17.7秒,ない場合は1.57秒となりました.かなり速度に差が出ますね.
なので,ユニットテストなどのテスト時にtypeguardを使用すると良いかもしれません.

pytestでテンソルの型チェックを行う

torchtypingではpytestのプラグインも提供しており,
テスト時に型のチェックを行ってくれます.(typeguardを使用する必要あり)
以下のオプションでtorchtypingによる型チェックを有効にします.

$ pytest --torchtyping-patch-typeguard

実際にテストしてみます.

import pytest

# (中略: 上と同じ関数を使用)

def test_batch_outer_product() -> None:
    x = torch.randn(4, 4)
    y = torch.randn(4, 8)
    res = batch_outer_product(x, y)
    assert res.shape == (4, 4, 8)

    x = torch.randn(4, 5)
    y = torch.randn(4, 15)
    with pytest.raises(TypeError):
        res = batch_outer_product(x, y)
$ pytest --torchtyping-patch-typeguard tests_torchtyping.py

============================== test session starts ===============================
platform darwin -- Python 3.8.2, pytest-6.1.2, py-1.9.0, pluggy-0.13.1
rootdir: /Users/yuchi/Documents/hoge
plugins: anyio-2.2.0, Faker-5.0.1, torchtyping-0.1.3, django-3.10.0, typeguard-2.12.1, mock-3.3.1, cov-2.10.1, dash-1.16.3
collected 1 item                                                                 

tests_torchtyping.py .                                                     [100%]

=============================== 1 passed in 0.12s ================================

所感

テンソルの形状や型を,型ヒントとしてかけるのが便利で,コードの可読性が上がったと感じました.
また,実行時型チェックによりテストがしやすくなったと感じました.

一方でmypyflake8との相性が悪い点が非常に気になりました.
上で記述した関数に対してmypyを実行すると,

torchtyping_sample.py:6: error: syntax error in type comment
torchtyping_sample.py:7: error: syntax error in type comment
torchtyping_sample.py:8: error: syntax error in type comment
Found 3 errors in 1 file (checked 1 source file)

このようなエラーが吐かれてしまいます.
これは型ヒントにstr: intのようにコロンを使った場合に,
mypyがバグを起こしていることが原因だと述べられています.

またflake8とも相性が良くなく,以下のように型ヒントに対して未定義の警告を吐かれてしまいます.

f:id:yiskw713:20210608072333p:plain

このあたりは今後の改善に期待されます.

【追記】mypyの警告を無視する

mypyによる警告を無視するには,エラーが出ている行に# type: ignoreというコメントを追加すれば良いみたいです.

from torchtyping import TensorType  # type: ignore


def batch_outer_product(
    x: TensorType["batch", "x_channels"],  # type: ignore
    y: TensorType["batch", "y_channels"],  # type: ignore
) -> TensorType["batch", "x_channels", "y_channels"]:  # type: ignore

    return x.unsqueeze(-1) * y.unsqueeze(-2)
$ mypy torchtyping_sample.py
Success: no issues found in 1 source file

これで警告は防げるものの,毎回これをつけるのは少し面倒...

Reference