yiskw note

機械学習やプログラミングについて気まぐれで書きます

【Gradio / PyTorch】Gradioでセマンティックセグメンテーションを行うデモアプリを作る


概要

今回はGradioでセマンティックセグメンテーションを行うデモを作成してみました.
こちらのデモを使用することで,簡単にセマンティックセグメンテーションを試すことができます. 今回使用したコードはこちらで公開しています.

Gradioの簡単な使い方や画像分類のデモ,画像の前処理可視化のデモについてもメモを残しておりますので,
よければ以下の記事をご覧いただけたらと思います.

yiskw713.hatenablog.com

yiskw713.hatenablog.com

セマンティックセグメンテーションのデモアプリの内容

今回はセマンティックセグメンテーションを行うデモを作成します.
モデルには,torchvisionで提供されているResNet50をバックボーンとしたDeepLabv3の学習済みモデルを使用します.
画像を入力とし,画像を出力する処理なので,画像の前処理の可視化と同じような形で実装できます.

今回作成したコード

それでは実際のコードを解説していきたいと思います.
コードは非常にシンプルで,入出力の形式と,処理を行う関数を定義し,gradio.Interfaceに渡すだけです.

まず必要なパッケージのインポートします.

import gradio as gr
import numpy as np
import torch
from PIL import Image
from torchvision.models.segmentation import deeplabv3_resnet50
from torchvision.transforms import functional as F

次にメインの処理を書いていきます.
gradio.Interfaceに渡す処理では,入出力の画像はnp.ndarrayとなる点に注意します.
またセグメンテーションの可視化の際には,Pascal VOCのカラーパレットを使用しました.
そのため,Pascal VOCのラベル画像を適当に一つ用意する必要があります.

def main():
    # モデルの定義
    model = deeplabv3_resnet50(pretrained=True)
    model.eval()

    # Pascal VOCのカラーパレットを読み込み
    voc = Image.open("./imgs/voc_sample.png")
    palette = voc.getpalette()

    # gradio.Interfaceで実行して欲しい処理を定義
    @torch.no_grad()
    def inference(gr_input):
        img = Image.fromarray(gr_input.astype("uint8"), "RGB")

        # 前処理
        img = F.to_tensor(img)
        img = img.unsqueeze(0)
        img = F.normalize(
            img,
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        )

        # セグメンテーションを実行
        output = model(img)["out"].squeeze(0)
        _, mask = output.max(dim=0)
        mask = mask.numpy().astype("uint8")

        # マスク画像にPascal VOCのカラーパレットを適用
        mask = Image.fromarray(mask)
        mask = mask.convert("P")
        mask.putpalette(palette)
        mask = mask.convert("RGB")

        return np.asarray(mask)

    # 入力の定義とサーバーの立ち上げ
    inputs = gr.inputs.Image()
    interface = gr.Interface(fn=inference, inputs=inputs, outputs="image")

    interface.launch()


if __name__ == "__main__":
    main()

デモの実行結果

以上の内容を実行すると,以下のような画面が立ち上がります.

こちらに画像を追加し,Submitしてみます.

f:id:yiskw713:20211225154613p:plain:w800
※使用した猫の写真はK LによるPixabayからの画像

おおよそ正しい位置で猫の位置を認識していることが確認できました!
このように数行のコードで,実際のモデルの出力結果を確認できるのは非常に便利です.

まとめ

Gradioを使用して,セマンティックセグメンテーションのデモを作成しました.
次回は物体検出のデモアプリを作成してみようと思います.

参考リンク