【Gradio / PyTorch】Gradioで画像分類を行うデモアプリを簡単に作る
概要
機械学習モデルのデモアプリを作成したいと思っていたところ,
Gradioというライブラリを見つけました.
Gradioを使用することで,たった数行のpythonコードで簡単にデモを作成できるようです.
その使い方を備忘録としてこちらに残しておきます.
今回使用したコードや,環境構築のためのDockerfileなどはこちらで公開しております.
Gradioとは
簡単に機械学習モデルのデモを簡単に作成できるpythonのライブラリです.
特にフロント側のUIを実装する必要がなく,バックエンド側の数行のコードだけで,デモアプリを作成することができます.
jupyter notebookやgoogle colab上での動作にも対応しております.
また,作成したデモアプリを公開することも可能です.
基本的な使い方
まず自分の環境にGradioをインストールします.
pip install gradio
google colabを使用する場合は,以下をセルで実行します.
!pip install gradio
続いて以下のコードをnotebook,もしくはpythonファイルに記述します.
import gradio as gr def greet(name): return "Hello " + name + "!!" iface = gr.Interface(fn=greet, inputs="text", outputs="text") iface.launch()
上記を実行し,http://localhost:7860にアクセスすると,以下のように挨拶をする(入力テキストに"Hello"を追加する)アプリで遊ぶことができます.
このようにデモで行いたい処理とその入出力の形式をgradio.Interface
に指定し,launch
メソッドを実行するだけで,簡単にデモアプリを作成することができます.
画像分類アプリの作成
それでは本題の画像分類アプリの使い方に入ろうと思います.
今回はPyTorchと,torchvisionで提供されているResNet50を用いた画像分類デモを作成します.
まずライブラリをimportします.必要に応じて各種ライブラリはインストールしてください.
import gradio as gr import requests import torch import torch.nn as nn from PIL import Image from torchvision.models import resnet50 from torchvision.transforms import functional as F
続いてメインの処理を記述していきます.
def main(): # モデルの準備 model = resnet50(pretrained=True) model.eval() # ImageNetのラベルの取得 response = requests.get("https://git.io/JJkYN") labels = response.text.split("\n") # 画像分類を行う関数を定義 @torch.no_grad() def inference(gr_input): img = Image.fromarray(gr_input.astype("uint8"), "RGB") # 前処理 img = F.resize(img, (224, 224)) img = F.to_tensor(img) img = img.unsqueeze(0) img = F.normalize( img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ) # 推論 output = model(img).squeeze(0) probs = nn.functional.softmax(output, dim=0).numpy() # ラベルごとの確率をdictとして返す return {labels[i]: float(probs[i]) for i in range(1000)} # 入力の形式を画像とする inputs = gr.inputs.Image() # 出力はラベル形式で,top5まで表示する outputs = gr.outputs.Label(num_top_classes=5) # サーバーの立ち上げ interface = gr.Interface(fn=inference, inputs=inputs, outputs=outputs) interface.launch() if __name__ == "__main__": main()
入力を画像,出力をラベル形式とし,画像に対してラベルごとの予測確率を出力しています.
デモの実行結果
以上の内容を実行した結果が以下です.
こちらの画面から画像をアップロードし,submitを押すことで推論可能です.
試しに画像をアップロードしてみます.
無事推論できました!
ウェルシュ・コーギー・ペンブローク(Pembroke Welsh Corgi)と正しく認識できています.
まとめ
Gradioを使用して,画像分類アプリのデモを作成しました.
入出力の形と処理を書くだけで,簡単にデモの実装ができるので,非常に便利だと感じました.
社内プレゼンでの簡単なデモや,研究発表でのデモなんかにも良いと思うので,ぜひ積極的に使っていきたいです.