yiskw note

機械学習やプログラミングについて気まぐれで書きます

【Gradio / PyTorch】YOLOXで物体検出を行うデモアプリを作る


概要

今回はGradioでYOLOXを用いた物体検出を行うデモを作成してみます.
今回使用したコードはこちらで公開しています.

Gradioを用いた他のデモについてもメモを残しておりますので,
よければ以下の記事をご覧いただけたらと思います.

yiskw713.hatenablog.com

yiskw713.hatenablog.com

yiskw713.hatenablog.com

物体検出を行うデモの内容

2021年に発表されたアンカーフリーな物体検出器であるYOLOXを使用して,物体検出を行うデモアプリを作成します.
PyTorchのモデルを使用しても良いのですが,今回は推論の速度の観点から公式レポジトリで提供されている学習済みONNXモデルを使用しました.
今回のデモでは,確信度とIoUの閾値を変更できるようスライダーも追加してみます.
基本的には,画像を入力とし,画像を出力する処理なので,画像の前処理の可視化と同じような形で実装できます.

必要なパッケージ

今回使用するパッケージは以下の通りです.必要に応じてインストールしてください.
また,こちらに今回使用したDockerfileも追加しておきますので,よければ参考にしてみてください.

重みファイルのダウンロード

コードの解説に入る前に,公式レポジトリから重みファイルをダウンロードします.
以下のコマンドを実行し,weights/yolox_s.onnxという名前で保存します.

wget https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_s.onnx \
    -O weights/yolox_s.onnx

今回作成したコード

それでは実際のコードを解説していきたいと思います.

物体検出の推論では,後処理としてNMSを行ったり,バウンディングボックスを可視化する処理を記述する必要があります.
が,これらの処理は公式レポジトリで実装してくれていますので,
必要なものだけをyolox_utils.pyファイルに書き出して使用します.
こちらに記述すると長くなってしまうので,詳しくはこちらをご参照ください.

github.com

次にメインの処理の実装に移ります.
まず必要なパッケージのインポートします.

import gradio as gr
import numpy as np
import onnxruntime

from yolox_utils import COCO_CLASSES, demo_postprocess, multiclass_nms
from yolox_utils import preproc as preprocess
from yolox_utils import vis

次に物体検出を行う関数を書いていきます.
gradio.Interfaceに渡す関数は,入出力の画像はnp.ndarrayとなる点に注意します.

def main():
    # モデルの重みファイルのパス,入力サイズなどを設定
    MODEL = "./weights/yolox_s.onnx"
    INPUT_SHAPE = (640, 640)
    WITH_P6 = False

    # onnxruntimeのセッションを立ち上げる
    session = onnxruntime.InferenceSession(MODEL)

    def inference(
        gr_input: np.ndarray, score_thr: float, nms_iou_thr: float
    ) -> np.ndarray:
        """Inference with onnx model
        Reference:
        https://github.com/Megvii-BaseDetection/YOLOX/blob/main/demo/ONNXRuntime/onnx_inference.py
        """
        # 画像の前処理を行う
        img, ratio = preprocess(gr_input, INPUT_SHAPE)

        # 推論
        ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
        output = session.run(None, ort_inputs)

        # 後処理
        predictions = demo_postprocess(output[0], INPUT_SHAPE, p6=WITH_P6)[0]

        boxes = predictions[:, :4]
        scores = predictions[:, 4:5] * predictions[:, 5:]

        boxes_xyxy = np.ones_like(boxes)
        boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.0
        boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.0
        boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.0
        boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.0
        boxes_xyxy /= ratio
        dets = multiclass_nms(boxes_xyxy, scores, nms_thr=nms_iou_thr, score_thr=0.01)
        if dets is not None:
            final_boxes, final_scores, final_cls_inds = (
                dets[:, :4],
                dets[:, 4],
                dets[:, 5],
            )
            gr_input = vis(
                gr_input,
                final_boxes,
                final_scores,
                final_cls_inds,
                conf=score_thr,
                class_names=COCO_CLASSES,
            )

        return np.asarray(gr_input)

    # 入力を定義 (画像とスライダーx2)
    img_input = gr.inputs.Image()
    score_thr_input = gr.inputs.Slider(
        minimum=0,
        maximum=1.0,
        step=0.05,
        default=0.3,
        label="score threshold",
    )
    nms_iou_thr_input = gr.inputs.Slider(
        minimum=0,
        maximum=1.0,
        step=0.05,
        default=0.45,
        label="nms iou threshold",
    )

    # サーバーの立ち上げ
    interface = gr.Interface(
        fn=inference,
        inputs=[
            img_input,
            score_thr_input,
            nms_iou_thr_input,
        ],
        outputs="image",
    )

    interface.launch()


if __name__ == "__main__":
    main()

デモの実行結果

以上の内容を実行すると,以下のような画面が立ち上がります.

f:id:yiskw713:20211225195634p:plain:w800

こちらに画像を追加し,Submitしてみます.

f:id:yiskw713:20211225195420p:plain:w800
PexelsによるPixabayからの画像を使用

かなり正確に,かつ高速に物体検出ができていることが確認できました!

まとめ

GradioとYOLOXを用いて,物体検出のデモを作成しました.
非常に簡単にYOLOXの検証ができて,便利だと感じました.
今度は画像以外のデモアプリの作成も作ってみようと思います.

参考リンク