【Gradio / PyTorch】Gradioでセマンティックセグメンテーションを行うデモアプリを作る
概要
今回はGradioでセマンティックセグメンテーションを行うデモを作成してみました.
こちらのデモを使用することで,簡単にセマンティックセグメンテーションを試すことができます.
今回使用したコードはこちらで公開しています.
Gradioの簡単な使い方や画像分類のデモ,画像の前処理可視化のデモについてもメモを残しておりますので,
よければ以下の記事をご覧いただけたらと思います.
セマンティックセグメンテーションのデモアプリの内容
今回はセマンティックセグメンテーションを行うデモを作成します.
モデルには,torchvisionで提供されているResNet50をバックボーンとしたDeepLabv3の学習済みモデルを使用します.
画像を入力とし,画像を出力する処理なので,画像の前処理の可視化と同じような形で実装できます.
今回作成したコード
それでは実際のコードを解説していきたいと思います.
コードは非常にシンプルで,入出力の形式と,処理を行う関数を定義し,gradio.Interface
に渡すだけです.
まず必要なパッケージのインポートします.
import gradio as gr import numpy as np import torch from PIL import Image from torchvision.models.segmentation import deeplabv3_resnet50 from torchvision.transforms import functional as F
次にメインの処理を書いていきます.
gradio.Interface
に渡す処理では,入出力の画像はnp.ndarray
となる点に注意します.
またセグメンテーションの可視化の際には,Pascal VOCのカラーパレットを使用しました.
そのため,Pascal VOCのラベル画像を適当に一つ用意する必要があります.
def main(): # モデルの定義 model = deeplabv3_resnet50(pretrained=True) model.eval() # Pascal VOCのカラーパレットを読み込み voc = Image.open("./imgs/voc_sample.png") palette = voc.getpalette() # gradio.Interfaceで実行して欲しい処理を定義 @torch.no_grad() def inference(gr_input): img = Image.fromarray(gr_input.astype("uint8"), "RGB") # 前処理 img = F.to_tensor(img) img = img.unsqueeze(0) img = F.normalize( img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ) # セグメンテーションを実行 output = model(img)["out"].squeeze(0) _, mask = output.max(dim=0) mask = mask.numpy().astype("uint8") # マスク画像にPascal VOCのカラーパレットを適用 mask = Image.fromarray(mask) mask = mask.convert("P") mask.putpalette(palette) mask = mask.convert("RGB") return np.asarray(mask) # 入力の定義とサーバーの立ち上げ inputs = gr.inputs.Image() interface = gr.Interface(fn=inference, inputs=inputs, outputs="image") interface.launch() if __name__ == "__main__": main()
デモの実行結果
以上の内容を実行すると,以下のような画面が立ち上がります.
こちらに画像を追加し,Submitしてみます.
おおよそ正しい位置で猫の位置を認識していることが確認できました!
このように数行のコードで,実際のモデルの出力結果を確認できるのは非常に便利です.
まとめ
Gradioを使用して,セマンティックセグメンテーションのデモを作成しました.
次回は物体検出のデモアプリを作成してみようと思います.
参考リンク
- Gradio
- GitHub - yiskw713/gradio_sample
- 【Gradio / PyTorch】Gradioで画像分類を行うデモアプリを簡単に作る - yiskw note
- 【Gradio / PyTorch】Gradioで画像の前処理を可視化するデモを作る - yiskw note