DNNにおける蒸留を提案「Distilling the Knowledge in a Neural Network」を読みました

CTOの幅野です。
今回はDNNにおいてはじめて提案された蒸留手法の論文について解説していきます。

arxiv.org

蒸留のモチベーション

1. 背景

DeepLearningモデルは層が深く、パラメータが多くすることでモデルの表現力を高め精度を向上させやすくなることが知られています。 さらにモデルを複数作成し、複数モデルの出力を利用して予測をすることで精度を向上させるアンサンブル方法も提案されています。 しかし、複数のニューラルネットワークを運用することは処理速度が遅くなったり、マシンリソースが多く必要だったりとあまり、現実的ではないです。 なので、Singleモデルでアンサンブルモデルや巨大なモデルと同等の表現力を持たせることが必要です。

2. 蒸留の方法

蒸留とは大きなモデルやアンサンブルモデルなどの教師モデルの出力結果を元に1つのモデルを学習する方法です。学習させるモデルを生徒モデル、生徒モデルを学習するときに使うモデルを教師モデルと呼びます。

具体的には2つの損失関数で生徒モデルを学習します。

  • Hard Target: 用意した教師データとの損失関数
  • Soft Target: 教師モデルの出力結果との損失関数

蒸留を利用することで、生徒モデルは通常の学習をしたときよりも高い精度を出すことができます。

3. 教師モデルの出力を学習すること

生徒モデルは教師モデルの出力結果を学習することによって教師モデルの出力の確率分布を学習することができます。 クラス分類タスクにおける教師データは該当するクラスの値を1とし、それ以外を0として学習を行います。 しかし教師モデルの全ての出力は教師データのようなスパースなベクトルにはなりません。 MNISTの分類を例にすると、「2」の画像が与えられたときに「2」クラスの確率が高くなることが求められますが、画像として似ている「7」クラスも同時に確率がある程度あるかもしれません。 このように正解クラス以外のクラスの出力確率を利用する、つまり教師モデルの出力確率分布を学習します。

学習方法

本手法ではSoft Targetをクロスエントロピー誤差で学習することを考えます。 Soft Targetを蒸留に利用する理由は教師モデルの確率分布をより学習しやすくするためにHard Targetより多くの情報量を提供することです。
蒸留をする際に、パラメータ$T$をSoftMax関数に導入しています。 $$ q_{i}=\frac{\exp \left(z_{i} / T\right)}{\sum_{j} \exp \left(z_{j} / T\right)} $$ $q$は生徒モデルのクラス$i$の確率を表しており、$z_i$はSoftMax関数を適応する直前のニューラルネットワークの出力です。教師モデルにもこのSoftMax関数を導入します。通常のSoftMax関数では$T=1$ですが、$T$の値を上げることで$q$の確率分布を滑らかにすることができます。逆に$T$の値を下げることによって、クラス間の確率の差が大きくなります。 確率分布を滑らかにできるというのは本来低かった確率のクラスが大きくなっています。
これにより、正解クラス以外のクラスの確率を大きくすることができ、教師モデルに反映しやすくすることができます。
Soft Targetでは教師モデルの確率分布を学習したいのでその確率分布をどこまで学習するかを$T$によって調整することができます。

評価

ARSタスクの精度比較

Automatic Speech Recognition (ASR)における2000hほどの英語のspeaking datasetで検証しています。 タスクは各フレームの音声データを元に事前にHMMでクラスタリングした14000種類のクラスを予測する問題です。 Base Lineはhard targetのみのSingleモデルで、10xEnsembleは10個のアンサンブルモデルです。 Deistilled Single modelはHard Targetと10個のアンサンブルモデルの出力結果を元にしたSoft Targetの誤差をそれぞれ0.5に重み付けした値の和をLossとして学習したものです。$T=20$で学習をしています。

ネットワークは全て同じ構造です。 蒸留手法がBaseLineより高くWord Error Rate(WER)ではアンサンブルとほぼ同等の精度を出しています。

またSoft TargetはHard Targetより学習に役立つ情報を持っていると考えており、それらの情報は小さいデータセットでの学習に対しても寄与すると主張しています。 ASRでデータセットを減らして精度評価をしています。

Soft Targetsを利用することでテストデータでは3%Baselineより高い精度を出しています。 これらのことから小さいデータに対しても寄与する→正則化がかかっていることがわかります。

JFTデータセットでの検証

JFTデータセットGoogleが用意した1億の画像を15000クラス割り当てられているデータセットです。 このようなデータセットで精度を上げるためにアンサンブルを使おうと考えると、全てのデータセットを使って複数モデルを学習することは計算量が膨大になり現実的ではありません。 そこで本手法の蒸留方法を利用したアンサンブルモデルの作成方法を利用して評価しています。

具体的にはデータセットを$k$個に分割して、それぞれのデータセットでSpetialistモデルを学習し、Spetialistモデルでアンサンブルを行います。

Spetialistモデルの作成手順は以下の通りです。

  1. 全てのデータセットで作成したモデルを1つ学習する(Generalist Model)
  2. Generalist Modelの出力結果を元に計算したクラスのConfusion Matrixからクラスタリングしてデータを分割
  3. Generalist Modelと分割したデータを利用して学習

Generalist Model

最初に全てのデータセットでGeneralist Modelを学習します。Generalist ModelがSpetialistモデルを学習するときの教師モデルとなります。

クラスタリングによるデータ分割

クラスのクラスタリングはK-Meansを利用してクラスタリングをしています。

クラスタリング例は以下のようになっています。

このクラスタリングによってデータセットを分割して、分割したデータセットそれぞれでSpetialist Modelmを学習します。分割したデータセットS_mとします。

Generalistモデルと分割したデータによるSpetialist Modelの作成

S_mからGeneralist Modelを利用してSpetialist Modelを作成します。

まず最初にデータS_mに対してGeneralist Modelで予測をし、最も確率の高かったクラスをkとします。 このkS_mクラスのグループの中に入っているデータを抽出します。 抽出したデータをA_kとします。

このA_kを元にGeneralist ModelとSpetialist modelを学習します。 そのときに利用する損失関数は以下です。 $$ K L\left(\mathbf{p}^{g}, \mathbf{q}\right)+\sum_{m \in A_{k}} K L\left(\mathbf{p}^{m}, \mathbf{q}\right) $$

p^gはGeneralist Modelの予測確率でqT=1としたGeneralist Modelの予測確率です。p^mA_kに対して推論したSpetialist Modelmの予測確率を表しています。

これによりSpeticalist Modelを61個作成した結果が以下のようになります。