Unity ML-Agentsで強化学習してみる

はじめに

Unity環境で強化学習ができる「ML-Agents」が気になっていたので試してみることにしました。
今回は公式の手順に沿って環境構築から学習まで行った内容を記載します。

 

[Unity ML-Agents Toolkit]

github.com


準備

作業は全てWindows10で行いました。

ML-Agentsのバージョンとしては、Release20がちょうど出たばかりのようでしたが、
今回はRelease19(現時点でのlatest sbable版)を使用しました。

 

まず、以下からML-Agents一式を取得して適当な場所に展開。

GitHub - Unity-Technologies/ml-agents at release_19

Unity Hubを起動してプロジェクトの「開く」プルダウンで「ディスクから加える」を選び、上で展開したフォルダ内の"Project"を指定。これをUnityエディタで開けばOK。

 

Python側の準備としては、Anacondaでpython3.7の環境を作成し、以下をインストール。
pip3 install torch~=1.7.1 -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install mlagents==0.28.0


学習環境の作成

以下の手順に沿って簡単な学習環境を作成してみました。ボールを箱にぶつけるだけのシンプルなタスクです。

ml-agents/Learning-Environment-Create-New.md at release_19_docs · Unity-Technologies/ml-agents · GitHub

C#スクリプトにてAgentクラスを継承し、観測や報酬、エピソード終了条件などを記述します。

(Unityエディタ上のインスペクタでも学習に関するパラメータをいくつか設定します)

最後に、作成した環境をPrefab化して複製することで、経験の収集を並行して進められるようになり学習を効率化できます。

 

学習・テスト

学習のハイパーパラメータをyamlファイルに記述し、以下のコマンドを実行します。
mlagents-learn config/rollerball_config.yaml --run-id=RollerBall

 

そして、Unityエディタで再生を押すと学習が開始します。(学習中の様子↓)

 

学習中の報酬の推移等はTensorBoardで表示可能です。
学習終了後、生成されるonnxファイルをAgentにアタッチすると学習後の動きを確認できます。


Environment Executable

学習環境をビルドして実行する方法も試しました。手順は以下の通りです。

ml-agents/Learning-Environment-Executable.md at release_19_docs · Unity-Technologies/ml-agents · GitHub

 

学習中の様子↓

 

今回はWindows向けにビルドして実行しましたが、これをLinux用にビルドしてGPUサーバ等に持って行くことで、学習をより高速化できるものと思います。


おわり

Unity ML-Agentsで環境作成と学習を動かしてみました。
今回のはごく単純なタスクですが、もっと複雑なゲーム環境への適用も時間があれば試してみたいです。

 

 

四元数の勉強

はじめに

四元数(クォータニオン)は3D空間における回転等を表す方式として、航空宇宙業界やCG/ゲーム業界などで広く活用されています。

今回は、この四元数について勉強したことを備忘録としてまとめます。

 

 

基本的な演算や特徴など

以下、四元数について定義されている演算やその特徴をいくつか列挙します。

表記については参考文献[1]をベースとします。

 

・表記方法

 [w, (x, y, z)]や [w, \boldsymbol{v}]などとあらわす。

また、3つの虚数i, j, kを使い  w + xi + yj + zkと書ける。

 

・大きさ

 ||q|| = \sqrt{w^{2} + x^{2} + y^{2} + z^{2}}

※3D回転を扱う際は、大きさが1の四元数(単位四元数)を考える。

 

・共役と逆数

四元数の共役は、虚数部分の符号を反転することで得られる。

 q = [w, \boldsymbol{v}]に対し、 q* = [w, -\boldsymbol{v}]

逆数は共役を大きさで割ることで定義される。

 q^{-1} = \dfrac{q*}{||q||}

単位四元数の場合は、大きさ=1なので共役と逆数は同値。

 

・乗算

四元数同士の乗算は以下の形となる。

 [w1, \boldsymbol{v1}] [w2, \boldsymbol{v2}] = [w1w2 - \boldsymbol{v1}・\boldsymbol{v2},   w1\boldsymbol{v2} + w2\boldsymbol{v1} + \boldsymbol{v1} \times \boldsymbol{v2}]

※乗算について結合法則は成り立つが、交換法則は成り立たない。

 

四元数による回転

標準的な3Dの点(x, y, z)に対し、四元数p = [0, (x, y, z)]を定義する。これをある軸 \boldsymbol{n}の周りで \thetaだけ回転させる場合、四元数q = [cos  \theta/2,   \boldsymbol{n} sin \theta/2]を用いて以下の乗算によりpを回転させることができる。

 p' = qpq^{-1}

 

四元数の長所と短所 (角変位の表現形式として)

角変位を表す主要な形式として、「行列・オイラー角・四元数」の3つがあります。

ここでは、四元数の主な長所と短所について、他の二つと比較しつつ記載します。

 

長所

  • 滑らかな補間が可能

    slerpやsquadといった演算により、四元数間の滑らかな補間が可能。

    ※行列やオイラー角の場合、滑らかな補間は不可。

  • 角変位の連結が容易

    四元数の乗算の結合法則により、一連の角変位を連結できる。

    ※行列形式も連結可能だが比較的低速。オイラー角の場合は連結は容易でない。

  • メモリ消費が比較的少ない

    四元数は4つの数で表せるので、9つの数を使う行列よりは経済的。

    オイラー角は3つなので、この点はオイラー角の方が良い。

短所

  • 人間による解釈が難しい
    直感的な解釈のしやすさという点で、四元数オイラー角に大きく劣る。
  • 無効になる場合がありえる

    入力データのミスや浮動小数点の丸め誤差によりエラークリープが起こりえる。(正規化により対処可能)

    ※行列についても無効になる場合はありえる(正規直交性を満たさない場合)。オイラー角については、どんな3つの数値でも有効。

 

実際には、「行列・オイラー角・四元数」は相互に変換可能であるため、状況に応じてそれぞれの長所を組み合わせて使用されます。

(ユーザが直接指定する部分はオイラー角形式として、内部的には四元数形式で回転の補間を計算したり)

 

参考

[1] O'Reilly Japan - 実例で学ぶゲーム3D数学

[2] クォータニオン (Quaternion) を総整理! ~ 三次元物体の回転と姿勢を鮮やかに扱う

 

 

 

 

 

 

 

 

 

強化学習でゾンビに挑む

はじめに

マインクラフト上での強化学習に関する記事を読んでいて、自分でも試してみたくなったので実験してみました。

 

やったこと

概要

イクラ環境で強化学習を行った例は検索するといくつか出てきますが、敵Mobと戦うことをテーマにしている例は調べた限り見当たらなかったので、今回は強化学習で敵Mobを倒すことを目標にしてみました。ターゲットは代表的な敵であるゾンビです。

 

タスク設定としては以下のようなイメージになります。

 

環境準備

 

実装

作成したコードは以下に置いています。(ワールドデータは除く)

GitHub - ramu-igo/Minecraft-RL

ポイントとしては、

  • 行動空間の制限
    エージェントの行動の選択肢は以下の7通りとしました。
    - 移動 (前、後、左、右)
    - 攻撃
    - 視点移動 (左、右)
    行動空間はDiscrete(7)です。*1

  • 報酬設計
    なるべくダメージを受けずにゾンビを倒せるようになってほしかったので、以下のような報酬の与え方にしてみました。*2
    正の報酬
     - ゾンビにダメージを与えたとき (+3)
     - ゾンビを倒したとき (+10)
    負の報酬
     - ダメージを受けたとき (-0.5)
     - やられたとき (-3)

  • 学習部分
    強化学習ライブラリとしてはStable Baselines3を、アルゴリズムはPPOを使いました。

  • 観察者の追加
    評価時にエージェントの挙動を観察しやすくするため、観察者(別クライアント)を追加しました。これは、mission xmlにObserverセクションを追加することで実現しています。

 

学習と評価

train.pyを実行し、エピソード毎の平均報酬の上昇が落ち着くまで学習を行いました。(1~2時間程度)

評価としては、学習後のモデルを読み込んでエージェントを動かした状態で、別クライアントの観察者で以下のコマンドによりゾンビを召喚してエージェントと戦わせました。*3

summon zombie -2358 4 -237 {ArmorItems:[{},{},{},{id:"iron_helmet",Count:1}]}

 

結果

比較のため、まずは学習前の様子を以下に載せます。小さい方のウィンドウがエージェント視点、大きい方のウィンドウが観察者視点です。

www.youtube.com

学習前なのでエージェントはランダムに行動しています。たまに偶然攻撃がゾンビに当たることはありますが、倒すまではいかずにすぐやられます。

 

続いて学習後の様子が以下です。

www.youtube.com

余裕でゾンビに勝てるようになりました。

ゾンビが視界に入ると、なるべく視界にとらえつづけるようにうまく視点移動をしているようです。

ゾンビの攻撃を避けるように若干後ずさりしつつ反撃するような挙動も見られました。

(自分でプレイするより上手いかもしれない)

 

おわり

以上、強化学習でゾンビと戦う実験をしてみました。報酬の与え方を変えることで、また違った戦い方をするエージェントを育てることもできそうです。

ゾンビは自らエージェントに接近してくるため、強化学習が割と上手くいきやすかったものと思います。遠距離攻撃してくる敵(スケルトン等)が相手になると、もっと難しい問題になる気がします。

 

*1:malmoのContinuousコマンドのOn/OffをActionMgrで管理することで、シンプルなDiscrete(n)の行動空間として扱えるようにしました。

*2:報酬を与える判断に必要な情報(エージェントのライフやゾンビの残数など)はこの辺りを参考にして取得しました。

*3:ゾンビが日光で燃えないように、ヘルメットを被せた状態で召喚しています。

iPadでLiDAR & フォトグラメトリを試してみる

はじめに

先日iPad proを購入したので、以前から気になっていた3Dスキャンを試してみることにしました。

 

実験

iPadiPhoneで3次元構造を得る方法としては以下の二つがあります。

  • LiDAR (Light Detection and Ranging)
    測定対象にレーザー光を照射し、反射光を観測することで距離を計測する。
    レーザーの発光から受光までの時間を計測して距離を算出するToF (Time of Flight) 方式が一般的。
  • フォトグラメトリ (写真測量法)
    多面的に撮影した複数の画像を解析し、特徴点マッチング等を用いて3Dモデルを構築する方法。
    ※こちらはLiDARスキャナ非搭載のiPad/iPhoneでも実行できる。

今回は「WIDAR」というアプリを使い、iPadでLiDARとフォトグラメトリを両方試してみました。

WIDAR - 3D Scan & Edit

 

LiDAR

WIDARアプリの「LiDARスキャン」モードを使い自宅の一室をスキャンしました。

スキャン作業自体は簡単で、1分程度で終わりました。

 

スキャンした結果はWIDARアプリ上で編集することもできますが、複数の形式でエクスポートすることができるので、UnityやBlender、3D CADソフトなど様々なツールで利用できます。

今回はply形式で点群情報を出力し、CloudCompare(3Dデータ編集ソフト)で読み込んでみました。

読み込んだ様子が下図になります。

 

点群の間引き、ノイズ除去、メッシュ化までやってみた結果が以下です。

(わかりにくいですが、真ん中あたりに洗濯機があります)

 

フォトグラメトリ

続いてフォトグラメトリです。

iPadのLiDARは建造物などの比較的大きな物のスキャンを得意とする一方、小さなフィギュアなどのスキャンには適していません。

これに対し、フォトグラメトリではフィギュアや食品などの細かい3D構造もとらえることができます。

 

今回対象物にしたのは、たまたま手元にあった洗濯ばさみです。

 

WIDARアプリの「Photoスキャン」モードにて、洗濯ばさみの周囲でiPadを少しずつずらし、全方向から60枚ほど撮影しました。(手ぶれに注意しつつ慎重に)

 

結果をfbx形式で3Dモデルとして出力し、CloudCompareで開いてみた様子が以下です。

おそらく撮影技術不足により所々テクスチャに違和感はありますが、全体的には洗濯ばさみの3D構造をいい感じにとらえられたかと思います。

 

まとめ

以上、簡単にですがiPadでLiDARとフォトグラメトリを試してみました。

手軽に3Dモデルを生成できることがわかったので、今度は得られた3DモデルをCGソフト上で動かすような事もやってみたいと思います。

 

 

 

 

 

 

カルマンフィルタの勉強

はじめに

画像認識関連の開発や調査をしているとカルマンフィルタに出くわす事がありますが、よくわかっていない部分があったので改めて勉強することにしました。

前半は勉強したことのまとめ、後半は具体的な数値シミュレーション例を記載します。

 

カルマンフィルタとは

概要

  • 観測値からシステムの状態を推定するアルゴリズム。逐次ベイズフィルタの一種。
  • 1960年代にカルマン(Rudolf Emil Kalman)によって提案され、アポロ計画において人工衛星の軌道推定に使われたことで有名になった。
  • 現在でも、航空・宇宙工学、ロボット工学、画像処理、計量経済学、生物学といった幅広い分野においてカルマンフィルタおよびそれを拡張したアルゴリズムが活用されている。


処理の流れ

以下にカルマンフィルタを利用する流れを簡単にまとめます。

1. 対象とする時系列システムのモデリング
一般にカルマンフィルタでは「状態空間モデル」を用いて状態を推定します。
※状態空間モデル:測定などで直接得られる「観測値」と、直接得られない潜在的な「状態」を設定するモデル

ここでは、以下のような離散時間の状態空間モデルを考えます。

 \boldsymbol{x}(k+1)=\boldsymbol{Ax}(k) + \boldsymbol{b}v(k)

 y(k) = \boldsymbol{c}^{T}\boldsymbol{x}(k) + w(k)

1つ目の式はある時刻の状態から次の状態を求める式であり、「状態方程式」とも呼ばれます。
2つ目の式は状態から観測値を得る式であり、「観測方程式」とも呼ばれます。

ここで \boldsymbol{x}(k)はn次元の状態ベクトル y(k)は観測値(スカラー)です。

 v(k)はシステムノイズと呼ばれ、平均0, 分散 \sigma^{2}_{v}の正規白色ノイズ、 w(k)は観測ノイズと呼ばれ、平均0, 分散 \sigma^{2}_{w}の正規白色ノイズです。

また、 \boldsymbol{A}, \boldsymbol{b}, \boldsymbol{c}は既知の係数行列・ベクトルであり時間的に一定とします。


2. カルマンフィルタによる状態推定

上で設定した状態空間モデルに対し、カルマンフィルタで状態推定を行うステップを以下に記載します。 ※真の状態 xに対し、状態の推定値は \hat{x}で表します。 

 

(1) 予測ステップ

 \boldsymbol{\hat{x}}^{-}(k) = \boldsymbol{A\hat{x}}(k-1)

 \boldsymbol{P}^{-}(k) = \boldsymbol{AP}(k-1)\boldsymbol{A}^{T} + \sigma^{2}_{v}\boldsymbol{bb}^{T}

予測ステップでは事前状態推定値と事前誤差共分散行列を求めます。

※上付きのマイナス記号は"事前"であることを示しています。

 

(2) フィルタリングステップ
 \boldsymbol{g}(k) = \dfrac{ \boldsymbol{P}^{-}(k)\boldsymbol{c} }{ \boldsymbol{c}^{T} \boldsymbol{P}^{-}(k)\boldsymbol{c} + \sigma^{2}_{w} }

 \boldsymbol{\hat{x}}(k) = \boldsymbol{\hat{x}}^{-}(k) + \boldsymbol{g}(k)(y(k) - \boldsymbol{c}^{T}\boldsymbol{\hat{x}}^{-}(k) )
 \boldsymbol{P}(k) = ( \boldsymbol{I} -  \boldsymbol{g}(k)\boldsymbol{c}^{T} )\boldsymbol{P}^{-}(k)

フィルタリングステップではまず「カルマンゲイン」を計算し、これを用いて状態推定値と誤差共分散行列を更新します。

※カルマンゲインは予測ステップで求めた予測値に対し、"観測"による補正をどれだけ入れるかをコントロールする役割。

 

以上のように、「予測」と「フィルタリング」から成る一連の更新式を、観測値を受け取る度に逐次的に実行していくのがカルマンフィルタの特徴です。

 

簡単な具体例

ここでは以下のシステムを考えます。

 x(k+1) = x(k) + v(k)

 y(k) = x(k) + w(k)

これは前述の状態空間モデルにおいて状態xをスカラーとし、A=b=c=1とおいたものです。

 

この場合、カルマンフィルタによる推定のステップは以下となります。

(1) 予測ステップ

 \hat{x}^{-}(k) = \hat{x}(k-1)

 p^{-}(k) = p(k-1) + \sigma^{2}_{v}

 

(2) フィルタリングステップ
 g(k) = \dfrac{ p^{-}(k) }{ p^{-}(k) + \sigma^{2}_{w} }

 \hat{x}(k) = \hat{x}^{-}(k) + {g}(k)(y(k) - \hat{x}^{-}(k) )
 p(k) = ( 1 -  {g}(k) )p^{-}(k)

 

この例で、実際にシミュレーションを行ってみます。 

以下は参考文献のMATLABコードを少し変更してpythonで書いた物です。

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")

# 問題設定
N = 150
var_v = 1 # システムノイズの分散
var_w = 2 # 観測ノイズの分散
# var_w = 20 

np.random.seed(0)
v = np.random.normal(0, np.sqrt(var_v), N) # システムノイズ
w = np.random.normal(0, np.sqrt(var_w), N) # 観測ノイズ

# 真の状態と、観測値を用意する
x = np.zeros(N) # 真の状態
y = np.zeros(N) # 観測値
for k in range(N):
    if k > 0:
        x[k] = x[k - 1] + v[k - 1]
    y[k] = x[k] + w[k]

# ----------------------------------------------

# カルマンフィルタの更新式
def kalman_filter(y, x_est_pre, p_pre, var_v, var_w):
    # 予測ステップ
    x_est = x_est_pre
    p = p_pre + var_v
    
    # フィルタリングステップ
    kalman_gain = p / (p + var_w)
    x_est = x_est + kalman_gain * (y - x_est)
    p = (1 - kalman_gain) * p

    return x_est, p

# カルマンフィルタによる状態推定
x_est = np.zeros(N) # 状態の推定値
p = 0
for k in range(N - 1):
    x_est[k + 1], p = kalman_filter(y[k], x_est[k], p, var_v, var_w)

# 結果の表示
t_ = np.arange(N)
fig, ax = plt.subplots(figsize=(12, 4))
sns.lineplot(t_, x, ax=ax, label="true state")
sns.lineplot(t_, y, color="gray", ax=ax, label="observation")
sns.lineplot(t_, x_est, color="red", ax=ax, label="estimated state")
ax.set_title("results")
ax.legend()
plt.show()

 

実行結果を以下に示します。

青が真の状態、グレーが観測値、赤がカルマンフィルタによる状態の推定値です。

 \sigma^{2}_{w} = 2の場合

 \sigma^{2}_{w} = 20の場合

観測ノイズを大きくした場合でも、カルマンフィルタによって真の状態に近い値が推定できていることがわかります。

 

今回の記事は以上です。

 

参考

 

 

KAPAOによる姿勢推定を試してみる

はじめに

姿勢推定について最近の手法を調べていたところ、KAPAOという物を見つけて気になったので、簡単に調査・動作確認をしてみました。


KAPAOについて

  • KAPAOはKeypoints and Poses as Objectsの略
  • タスク設定としては一般的な2D Pose Estimationであり、画像中の人物の主要な関節点の2次元座標を推定する
  • 分類としては"single-stage"の手法

    「人物検知」→「各人物の姿勢推定」という2段階の推論ではなく、1段階で姿勢推定まで行う

  • モデルの構造は以下(論文中のFig. 2から引用)
    物体検出用のモデルを拡張する形で実現されている

 

KAPAOは姿勢推定にヒートマップを使っていないのが大きな特徴です。これは論文中でも繰り返し主張されています。

(一般的に姿勢推定ではヒートマップを使った手法が多い。有名なOpenPoseもそう。)

 

計算コストの高いヒートマップを使わないことで、処理時間(forward pass + 後処理)の短縮を実現したと述べられています。
実際、論文中のTable 1を見るとヒートマップを用いる既存手法と比べて速度(特に後処理)が大きく改善しています。


動作確認

公式の実装と学習済みモデルを用いて推論処理を試してみました。

 

静止画での推論

サンプル画像を処理するコマンド例

python demos/image.py --pose --bbox

※ --faceをつけると顔パーツのkeypointも描画される
※ --kp-bboxをつけると、検知したkeypoint(膝、ひじ等)ごとにbounding boxが描画される


動画での推論

サンプル動画で推論実行し、結果をgifに保存するコマンド例

python demos/video.py --yt-id 2DiQUX11YaY --tag 136 --imgsz 1280 --color 255 0 255 --start 188 --end 193 --gif

 

実際に処理した結果は以下の通りです。手元のGPU環境(GeForce RTX 3060Ti)で45FPS程度で動作しました。

 


サンプル以外の動画でも試してみましたが、人が2人だけ映っているシーンでは60FPS以上出ていました。姿勢推定の精度についても、人物の服装や背景によらずかなり安定している印象でした。


こちらを見る限り、ONNXやTensorRTへの変換は公式では検討されていないようですが、TensorRTに持っていければさらに高速に動作させられるかもしれません。

 

今回の記事は以上です。


参考

・KAPAO論文

Rethinking Keypoint Representations: Modeling Keypoints and Poses as Objects for Multi-Person Human Pose Estimation

・公式実装

https://github.com/wmcnally/kapao

 

 

 

Swin Transformerに関するまとめ

はじめに

コンピュータビジョン分野においてTransformerを活用した有力なモデルであるSwin Transformerについて勉強しようと思い、今回の記事にまとめました。

まず基本となるTransformerおよびVision Transformerについておさらいし、その後Swin Transformerについて記載します。

 

Transformer

背景

Transformerは2017年に翻訳タスク用に提案されたネットワークです。

Transformer以前は、翻訳などの言語処理を行うモデルとしては再帰的な構造を持つネットワークであるRNN(LSTM, GRU等)がよく使われていました。
しかしこれらのモデルでは単語を逐次的に入力していくため並列処理ができず、学習時間が長いという欠点がありました。そのため、巨大で複雑なモデルを構築して高度な言語処理を行うことは難しかったとされています。

また、RNNではなくCNNベースのモデルを言語処理タスクへ応用する研究もされていました。CNNの場合は逐次処理ではなく文章を一度に処理できるため学習は高速化できましたが、文章が長くなると離れた単語同士の関係性を考慮できないという課題がありました。

 

これらの問題を改善したのがTransformerです。
TransformerではRNNやCNNは使われておらず、"Attention"(注意機構)が主役のネットワークになります。

構造

Transformerの構造は以下の通りです。(論文より引用)

主な特徴としては、

  • エンコーダ-デコーダモデルである (左側がエンコーダ、右側がデコーダ)
  • 入力データはEmbedding層により単語ごとにベクトルに変換され、Positional Encoding*1を経てTransformer blockに入力される
  • "Multi-Head Attention"*2を繰り返し実施
    エンコーダ側ではself attention、デコーダ側ではmasked self attention*3とsource-target attentionを行う

発展

Transformerは翻訳タスクで当時のSoTAを達成しました。
その後もTransformerをベースとした手法(BERT, GPT-nなど)が次々と発表され、自然言語処理分野における主流になりました。

また自然言語処理以外でも、音声認識や音楽生成などへの応用も研究されるようになりました。*4


そしてコンピュータビジョン領域にもTransformerを適用しようという研究が活発化し、その代表的な物がVision Transformerとなります。


Vision Transformer(ViT)

概要

Vision Transformer(通称ViT)はTransformerを画像分類タスクに適用したものです。
それ以前の画像認識モデルにおいてほぼ必須であったCNNを使わずにSoTAを更新したことで大きなインパクトを与えました。

 

ネットワーク構造としてはTransformerのエンコーダをほぼそのまま使っています。

元々のTransformerが単語の列を処理するのに対し、ViTでは入力画像を小さな「パッチ」に分割し、各パッチを単語のように扱うイメージになります。

  • それぞれのパッチをベクトル化(flatten)して埋め込みを行い、位置エンコードを加えてTransformerエンコーダに入力
  • エンコーダの中の処理は本家のTransformerとほぼ同様
    細かい点として、正規化層がMulti-Head Attentionより前に来ていること、MLPの活性化層にReLUではなくGELU*5を用いているといった差はある

性能について

論文のTable 2に示されているように、画像分類用の多くのデータセットにおいて、
当時のSoTAモデル(CNNベース)より高いスコアを記録しました。


またViTの特徴としてJFT-300M*6などの巨大なデータセットで事前学習させているという点があります。
事前学習用のデータセットが小規模の場合はむしろ既存のモデルより性能が低いが、巨大なデータセットを使うことでその真価を発揮することが示されています。

 

Swin Transformer

概要

ViTは画像分類で優れた結果を出しましたが、コンピュータビジョンにおける様々なタスクにTransformerを適用するためにはまだ次のような問題がありました。

  • 画像を固定サイズのパッチに分けて入力するだけだと、物体のスケールの変化に対応するのが難しい (物体検出タスクなど)
  • 全てのパッチ間で関連度を求めるため、サイズの大きい画像では計算量が大きくなりすぎてしまう

これらの問題を解決し、コンピュータビジョンの様々なタスクに使える汎用的なバックボーンとして提案されたのがSwin Transformerです。

Swin TransformerではViTから以下のような改善が行われています。

  • 階層的構造("パッチマージ"を繰り返して特徴マップを小さくしていく)により、マルチスケールの特徴を得る
  • "Window"でパッチをグループ分けし、各Window内でパッチ間の関連度を求めることで計算量を削減

(以下は論文Figure 1から引用)

 

構造と特徴

Swin Transformerのネットワーク構造を以下に示します。

上図の(b)に示されているように、Swin Transformer Blockの中身はViTのアーキテクチャとほぼ変わりません。
異なるのはW-MSA(Window-based Multi-head self-attention)とSW-MSA("Shifted" W-MSA)です。これらの"Window"のイメージは以下になります。

この図のように、パッチをMxMのwindowで分けて各window内でattentionを計算するのがViTとの違いです。

  • ViTの場合
    全パッチ間でattentionを計算するため、この部分の計算量は以下のようにパッチ数の2乗に依存する。(パッチ数[h x w]個とする)
     (hw) \times (hw) = (hw)^{2}
  • Swin Transformerの場合
    MxMのwindow内でattentionを計算するため、以下のようにパッチ数に対し線形となる。
     (M^{2} \times M^{2}) \times (h/M) \times (w/M) = M^{2}(hw)

このようにwindow-basedにすることで計算量は削減されますが、本来は隣接していて関連度の高いパッチ同士がwindowで分割され、これらのパッチ間での関連度が無視される可能性があります。
そこで、W-MSA(上図の左)とSW-MSA(上図の右)を交互に用いることで隣接するwindowに属するパッチ間の関連度も考慮されるようにしています。(この効果はTable 4に示されています)

性能について

論文中では、画像分類に加えて物体検出とセグメンテーションにおいても優れたスコアが示されており、Swin Transformerが様々なCVタスクに適用できる汎用的なバックボーンであることが主張されています。

 

まとめ

今回の記事作成を通して、TransformerからSwin Transformerまでの流れを自分の中でざっと整理することができました。
Swin Transformerの登場でTransformerの利用がより身近になったことで、CNNベースモデルの出番は今後減っていくのでしょうか...。

 

論文リンク

 

*1:Positional Encodingは、単語位置に一意の値を加算することで"単語の順番"が無視されないようにする役目

*2:Multi-Head Attentionは、query、key、valueをそれぞれheadの数に分割し、パラレルにattentionを計算してから結合するモジュール。複数のheadに分けてattentionを実施する方が、分けない場合より性能が上がったと報告されている。

*3:"masked"は、デコーダがある位置の単語を予測する時にそれより未来の単語をマスクしてカンニングを防ぐため。

*4:音声認識分野への応用例としては、「Conformer」や「wav2vec 2.0」などがある。
音楽生成への応用の一例としては「Music Transformer」がある。これは、曲の途中までを入力して続きを予測して生成したり、指定したメロディに対して伴奏を生成したりできる。

*5:GELU(Gaussian Error Linear Units)は標準正規分布を利用した活性化関数。ReLUを滑らかにしたような形。GPTやBERTでも使われている。

*6:Googleがプライベートで持つデータセット。画像3億枚!