【PyTorch】テンソルのデータ型や形状をアノテーションするtorchtypingを使ってみた
概要
PyTorchのテンソルのデータ型や形状をアノテーションできるtorchtyping
というライブラリを使ってみたので,
使い方やその使用感について,こちらにメモを残しておきます.
torchtypingとは
その名の通りPyTorchのテンソルのデータ型や形状をアノテーションできるライブラリです.
Pythonの標準ライブラリであるtypingでは,PyTorchのテンソル型の形状を記述することができません.
そのため以下のようにコメントに形状を記述する必要があります.
import torch def batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # x has shape (batch, x_channels) # y has shape (batch, y_channels) # return has shape (batch, x_channels, y_channels) return x.unsqueeze(-1) * y.unsqueeze(-2)
これをtorchtypingを使って書くと,以下のようになります.
import torch from torchtyping import TensorType def batch_outer_product( x: TensorType["batch", "x_channels"], y: TensorType["batch", "y_channels"], ) -> TensorType["batch", "x_channels", "y_channels"]: return x.unsqueeze(-1) * y.unsqueeze(-2)
このように,テンソルの形状をわかりやすく書くことができます.
また,typeguard
と一緒に実行時することで,実行時にテンソルが所望の形状になっているかどうかを確認することができます.
インストール方法
pipでインストールが可能です.
$ pip install torchtyping
使い方
アノテーションの記法
torchtyping
からTensorType
をimportして型アノテーションします.
torchtyping.TensorType[shape, dtype, layout, details]
shape
... テンソルの形状をint
,str
,str: int
,...
などの形で記述します.int
... その次元が固定長であることを意味します.str
... ある次元に対して,この名前が与えられ,実行時に他のテンソルと矛盾がないか確認されます.str: int
... ある次元に対して名前が与えられ,かつ固定長であることを意味を意味します....
... その次元が任意のサイズであることを意味します.
dtype
...int
/float
/bool
のいずれかを指定します.layout
...torch.strided
/torch.sparse_coo
のいずれかを指定します.details
... その他のフラグを追加できます.詳細はこちらを参照してください.
def batch_outer_product( x: TensorType["batch", "x_channels": 4, float], y: TensorType["batch", "y_channels": 8, float], ) -> TensorType["batch", "x_channels":4, "y_channels":8, float]: return x.unsqueeze(-1) * y.unsqueeze(-2)
ただし,このままでは型のヒントを与えているだけで,
実際にその型が想定したものになっているかどうかの確認はしてくれません.
例えば以下のような例でもエラーは吐かれません.
import torch from torchtyping import TensorType def batch_outer_product( x: TensorType["batch", "x_channels":4, float], y: TensorType["batch", "y_channels":8, float], ) -> TensorType["batch", "x_channels":4, "y_channels":8, float]: return x.unsqueeze(-1) * y.unsqueeze(-2) x = torch.randn(4, 5) y = torch.randn(4, 15) print(batch_outer_product(x, y).shape)
typeguard
を使用することで,実行時に型のチェックを行ってくれるようになります.
typeguard
との併用
typeguard
とはPEP 484で定義されている実行時型チェックを提供するライブラリです.
以下のようにtorchtyping.patch_typeguard()
を呼びだし,
型チェックを行いたい関数に@typechecked
デコレータを追加することで,
実行時の型チェックが行われます.
import torch from torchtyping import TensorType, patch_typeguard from typeguard import typechecked patch_typeguard() # use before @typechecked @typechecked def batch_outer_product( x: TensorType["batch", "x_channels":4, float], y: TensorType["batch", "y_channels":8, float], ) -> TensorType["batch", "x_channels":4, "y_channels":8, float]: return x.unsqueeze(-1) * y.unsqueeze(-2) x = torch.randn(4, 5) y = torch.randn(4, 15) # raise error print(batch_outer_product(x, y))
$ python tests_torchtyping.py Traceback (most recent call last): File "tests_torchtyping.py", line 20, in <module> print(batch_outer_product(x, y)) File "/Users/yuchi/.anyenv/envs/pyenv/versions/torch/lib/python3.8/site-packages/typeguard/__init__.py", line 926, in wrapper check_argument_types(memo) File "/Users/yuchi/.anyenv/envs/pyenv/versions/torch/lib/python3.8/site-packages/torchtyping/typechecker.py", line 338, in check_argument_types retval = _check_argument_types(*args, **kwargs) File "/Users/yuchi/.anyenv/envs/pyenv/versions/torch/lib/python3.8/site-packages/typeguard/__init__.py", line 768, in check_argument_types raise TypeError(*exc.args) from None TypeError: argument "x" must be of type TensorType['batch', x_channels: 4, torch.float32], got type TensorType[4, 5, torch.float32] instead.
typeguard
を使用した時の速度について
typeguard
を用いた実行時型チェックを行った時と,そうでない時の速度を比較しました.
実行したスクリプトは以下です.
import time # (中略: 上と同じ関数を使用) start = time.time() for _ in range(100000): x = torch.randn(4, 4) y = torch.randn(4, 8) res = batch_outer_product(x, y) print((time.time() - start), "sec")
結果は,typeguardがある場合は17.7秒,ない場合は1.57秒となりました.かなり速度に差が出ますね.
なので,ユニットテストなどのテスト時にtypeguardを使用すると良いかもしれません.
pytestでテンソルの型チェックを行う
torchtypingではpytestのプラグインも提供しており,
テスト時に型のチェックを行ってくれます.(typeguardを使用する必要あり)
以下のオプションでtorchtyping
による型チェックを有効にします.
$ pytest --torchtyping-patch-typeguard
実際にテストしてみます.
import pytest # (中略: 上と同じ関数を使用) def test_batch_outer_product() -> None: x = torch.randn(4, 4) y = torch.randn(4, 8) res = batch_outer_product(x, y) assert res.shape == (4, 4, 8) x = torch.randn(4, 5) y = torch.randn(4, 15) with pytest.raises(TypeError): res = batch_outer_product(x, y)
$ pytest --torchtyping-patch-typeguard tests_torchtyping.py ============================== test session starts =============================== platform darwin -- Python 3.8.2, pytest-6.1.2, py-1.9.0, pluggy-0.13.1 rootdir: /Users/yuchi/Documents/hoge plugins: anyio-2.2.0, Faker-5.0.1, torchtyping-0.1.3, django-3.10.0, typeguard-2.12.1, mock-3.3.1, cov-2.10.1, dash-1.16.3 collected 1 item tests_torchtyping.py . [100%] =============================== 1 passed in 0.12s ================================
所感
テンソルの形状や型を,型ヒントとしてかけるのが便利で,コードの可読性が上がったと感じました.
また,実行時型チェックによりテストがしやすくなったと感じました.
一方でmypy
やflake8
との相性が悪い点が非常に気になりました.
上で記述した関数に対してmypy
を実行すると,
torchtyping_sample.py:6: error: syntax error in type comment torchtyping_sample.py:7: error: syntax error in type comment torchtyping_sample.py:8: error: syntax error in type comment Found 3 errors in 1 file (checked 1 source file)
このようなエラーが吐かれてしまいます.
これは型ヒントにstr: int
のようにコロンを使った場合に,
mypyがバグを起こしていることが原因だと述べられています.
またflake8とも相性が良くなく,以下のように型ヒントに対して未定義の警告を吐かれてしまいます.
このあたりは今後の改善に期待されます.
【追記】mypyの警告を無視する
mypy
による警告を無視するには,エラーが出ている行に# type: ignore
というコメントを追加すれば良いみたいです.
from torchtyping import TensorType # type: ignore def batch_outer_product( x: TensorType["batch", "x_channels"], # type: ignore y: TensorType["batch", "y_channels"], # type: ignore ) -> TensorType["batch", "x_channels", "y_channels"]: # type: ignore return x.unsqueeze(-1) * y.unsqueeze(-2)
$ mypy torchtyping_sample.py Success: no issues found in 1 source file
これで警告は防げるものの,毎回これをつけるのは少し面倒...