yiskw note

機械学習やプログラミングについて気まぐれで書きます

【Gradio / PyTorch】Gradioで画像分類を行うデモアプリを簡単に作る


概要

f:id:yiskw713:20211223155339p:plain
Gradioを用いた画像分類アプリ.※猫の写真はK LによるPixabayからの画像

機械学習モデルのデモアプリを作成したいと思っていたところ,
Gradioというライブラリを見つけました.
Gradioを使用することで,たった数行のpythonコードで簡単にデモを作成できるようです.
その使い方を備忘録としてこちらに残しておきます.
今回使用したコードや,環境構築のためのDockerfileなどはこちらで公開しております.

Gradioとは

簡単に機械学習モデルのデモを簡単に作成できるpythonのライブラリです.
特にフロント側のUIを実装する必要がなく,バックエンド側の数行のコードだけで,デモアプリを作成することができます.
jupyter notebookやgoogle colab上での動作にも対応しております.
また,作成したデモアプリを公開することも可能です.

基本的な使い方

まず自分の環境にGradioをインストールします.

pip install gradio

google colabを使用する場合は,以下をセルで実行します.

!pip install gradio

続いて以下のコードをnotebook,もしくはpythonファイルに記述します.

import gradio as gr

def greet(name):
  return "Hello " + name + "!!"

iface = gr.Interface(fn=greet, inputs="text", outputs="text")
iface.launch()

上記を実行し,http://localhost:7860にアクセスすると,以下のように挨拶をする(入力テキストに"Hello"を追加する)アプリで遊ぶことができます.

f:id:yiskw713:20211223162558p:plain

このようにデモで行いたい処理とその入出力の形式をgradio.Interfaceに指定し,launchメソッドを実行するだけで,簡単にデモアプリを作成することができます.

画像分類アプリの作成

それでは本題の画像分類アプリの使い方に入ろうと思います.
今回はPyTorchと,torchvisionで提供されているResNet50を用いた画像分類デモを作成します.

まずライブラリをimportします.必要に応じて各種ライブラリはインストールしてください.

import gradio as gr
import requests
import torch
import torch.nn as nn
from PIL import Image
from torchvision.models import resnet50
from torchvision.transforms import functional as F

続いてメインの処理を記述していきます.

def main():
    # モデルの準備
    model = resnet50(pretrained=True)
    model.eval()

    # ImageNetのラベルの取得
    response = requests.get("https://git.io/JJkYN")
    labels = response.text.split("\n")

    # 画像分類を行う関数を定義
    @torch.no_grad()
    def inference(gr_input):
        img = Image.fromarray(gr_input.astype("uint8"), "RGB")

        # 前処理
        img = F.resize(img, (224, 224))
        img = F.to_tensor(img)
        img = img.unsqueeze(0)
        img = F.normalize(
            img,
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        )

        # 推論
        output = model(img).squeeze(0)
        probs = nn.functional.softmax(output, dim=0).numpy()

        # ラベルごとの確率をdictとして返す
        return {labels[i]: float(probs[i]) for i in range(1000)}

    # 入力の形式を画像とする
    inputs = gr.inputs.Image()
    # 出力はラベル形式で,top5まで表示する
    outputs = gr.outputs.Label(num_top_classes=5)

    # サーバーの立ち上げ
    interface = gr.Interface(fn=inference, inputs=inputs, outputs=outputs)
    interface.launch()


if __name__ == "__main__":
    main()

入力を画像,出力をラベル形式とし,画像に対してラベルごとの予測確率を出力しています.

デモの実行結果

以上の内容を実行した結果が以下です.

f:id:yiskw713:20211223180515p:plain

こちらの画面から画像をアップロードし,submitを押すことで推論可能です.
試しに画像をアップロードしてみます.

f:id:yiskw713:20211228112740p:plain
犬の写真は,Szabolcs MolnarによるPixabayからの画像

無事推論できました!
ウェルシュ・コーギー・ペンブローク(Pembroke Welsh Corgi)と正しく認識できています.

まとめ

Gradioを使用して,画像分類アプリのデモを作成しました.
入出力の形と処理を書くだけで,簡単にデモの実装ができるので,非常に便利だと感じました.
社内プレゼンでの簡単なデモや,研究発表でのデモなんかにも良いと思うので,ぜひ積極的に使っていきたいです.

参考リンク