3DPose推定モデル「RepNet」を読んでみました
CTOの幅野です。
CVPR2019で発表された3DPose推定の論文RepNetを解説します。
概要
2DPoseから3DPoseを推定するモデルを提案した論文です。
本論文では既存モデルは学習データに類似したシーンは3DPoseをうまく推定できるものの、カメラの位置やPoseが学習データと異なるものに対する3DPose推定がうまくいっていない問題をあげております。
この過学習問題への解決方法として著者らは2つの方法を考案しています。
一つは2DPoseの分布から3DPoseの分布への写像をAdversarial Lossを利用して学習する方法で、もう一つは推定した3DPoseが元の2DPoseへ再構成の誤差を付与して学習する方法です。
この2つの方法を元にRepNetというモデルを提案し、弱教師あり学習において高い精度を出しています。
モデルアーキテクチャ
以下の図がRepNetのアーキテクチャ図です。
RepNetは4つのネットワークで構成されています。
- Pose Generator Network
- Camera Network
- Critic Network
- Reprojection Network
それぞれのモデルについて説明していきます。
Pose Generator Network
Pose Generator Networkは2DPose座標を入力し、3DPose座標を出力するネットワークです。このネットワークを学習するためにその他のネットワークを付与しています。
ネットワーク自体はFull Connected Layerで構成されたネットワークで活性化関数はleaky ReLUを利用しています。
Camera Network
2DPose座標を入力し、2×3の行列であるカメラパラメータKを出力するネットワークです。
このカメラパラメータと推定した3DPoseを利用して元の2DPoseを再構成するために利用します。
Critic Network
Pose Generator Networkで生成された3DPoseが3DPoseの分布に従っているかどうかを判定するDiscriminator Networkです。Wasserstein Lossを利用して学習します。
Reprojection Network
Pose Generator Networkで生成された3DPoseとCamera Networkで出力されてカメラパラメータを利用して元の2DPoseを再構成するネットワークです。
モデルの学習方法
次にそれぞれのネットワークの詳細な内容と学習方法について説明していきます。
2DPoseの再構成によるReprojection Loss
既存の3DPose推定モデルが過学習している一つの要因として、著者らは推定した3DPoseが元の2DPoseへ再構成できるという制約を無視していると主張しています。
RepNetは3DPoseとその3DPoseをどの位置から投影したかを表すカメラパラメータがあれば2DPoseを再構成できるという仮説のもと構成されています。
この再構成の学習方法について説明します。
まずPose Generator Networkで生成した3DPoseを$\boldsymbol{X} \in \mathbb{R}^{3 \times n}$とし、Camera Networkで出力されたカメラパラメータを$\boldsymbol{K} \in \mathbb{R}^{2 \times 3}$)とします。
この2つの行列の積を推定した2DPoseとします。
$$ \boldsymbol{W}^{\prime}=\boldsymbol{K} \boldsymbol{X} $$
この推定した2DPoseと入力した2DPoseの再構成誤差を以下のように定義します。
$$ \mathcal{L}_{r e p}(\boldsymbol{X}, \boldsymbol{K})=|\boldsymbol{W}-\boldsymbol{K} \boldsymbol{X}|_{F} $$
Fはフロベニウスノルムを表しています。
KCS Layerを導入したWasserstein Loss
Pose Generator Networkの学習はReprojection Layerに加えて、Critic Networkを利用してAdversarial Lossで学習を行います。 Adversarial LossはWGANで利用されているWasserstein Lossを利用します。 Critic Networkの図を以下で示します。
そして、このCritic NetworkにはKCS Layerという層を利用しています。
このKCS LayerはHuman Poseの性質を考慮した情報を抽出する層です。
具体的には人間の関節点の角度や長さなどを抽出します。
そしてその抽出を2つの行列のみを利用して計算をすることができます。
$$ B=X C $$
$$ \boldsymbol{B}=\left(\boldsymbol{b}_{1}, \boldsymbol{b}_{2}, \ldots, \boldsymbol{b}_{b}\right) $$
この計算について詳細に説明します。
$$
\begin{array}{l}{\qquad b_{k}=p_{r}-p_{t}=X c} \ {\text { where }} \ {\qquad c=(0, \ldots, 0,1,0, \ldots, 0,-1,0, \ldots, 0)^{T}}\end{array}
$$
pは各関節のPoseの座標を表しており、PrとPtは結合している関節点同士です。
この2つの座標の差分を計算したいのですが、直接計算をするのではなく、cというベクトルを利用してbを計算します。
このcをすべてのBoneに対して用意することでC行列を作成することができ、上記の計算を行列積として計算できます。
これにより計算されたBを追加してWasserstein Lossで学習を行います。
Camera Networkの制約を考慮したCamera Loss
Camera Networkで出力されるカメラパラメータKはコンピュータグラフィックスの透視投影変換という手法で利用される透視投影行列とみなしています。
そして透視投影行列は以下の性質を持っています。
$$
\boldsymbol{K} \boldsymbol{K}^{T}=s^{2} \boldsymbol{I}_{2}
$$
sは投影するPoseのスケールを表しており、この式の制約に従うようにCamera Networkを学習する必要があります。
しかしこのsは学習時には未知な値なのでsを損失関数に入れないように変形をおこないます。
sについて式に変換したものが以下となります。
$$
s=\sqrt{\operatorname{trace}\left(\boldsymbol{K} \boldsymbol{K}^{T}\right) / 2}
$$
この式を元の式に代入して、定義した損失関数が以下になります。
$$
\mathcal{L}_{c a m}=\left|\frac{2}{\operatorname{trace}\left(\boldsymbol{K} \boldsymbol{K}^{T}\right)} \boldsymbol{K} \boldsymbol{K}^{T}-\boldsymbol{I}_{2}\right|_{F}
$$
まとめるとRepNet以下の3つの損失関数を元に学習を行います。
- Wasserstein Loss
- Reprojection Loss
- Camera Loss
検証
3DPoseデータセットであるHuman3.6とMPI-INF-3DHPを利用して検証を行っています。
評価方法はMean Per Joint Positioning Error(MPJPE)を利用して行動ごとに算出しています。
Protocol1
Protocol1はPose Generator Networkで出力された3DPoseをそのままMPJPEで評価したものです。
比較しているモデルは教師ありのタスクで学習してモデルでそれらと比較すると高い精度を出せているとは言えません。
しかし、Ground Truthな2DPoseを利用すれば同等の精度を出せるモデルであることを示しています。
Protocol2
Protocol2は推定した3DPoseに対してrigid alignmentという手法を利用して3Dから2Dに戻したものをMPJPEで評価をしています。
[44], [35]は弱教師あり学習のタスクで提案されたモデルでそれよりも高い精度を出しています。
また、KCS Layerをいれることによって精度が向上しています。
Boneの長さの検証
KCS Layerを導入することによってHuman Poseの左、右のBoneの長さの誤差が低くなっているかどうかを検証しています。
本論文ではSymmetry Errorとして左、右のBoneの長さの差分を計算しています。
結果をみるとKCS LayerをいれることでBoneのSymmetry Errorが下がっていることがわかります。
3DPCK, AUCの検証
MPJPEに加えて3DPCKとAUCの評価を加えて行っています。
3DPCKは関節点が正しい位置を推定している比率を計算するための手法で、関節点が一定の値(引用している論文では150mmとしています)をthresholdとして正しい位置にあるかどうかをずれを利用して判定します。
AUCは3DPCKのaccuracy計算をAUCに置き換えたものです。
MPI-INF-3DHPのテストデータセットに対して、3DHPもしくはHuman3.6Mで学習したモデルを利用して評価しています。