VAE: 変分オートエンコーダー

この記事では変分Bayesの考えを応用したオートエンコーダー「変分オートエンコーダー」について解説する。

関連記事

目次

変分オートエンコーダー

Auto-Encoding Variational Bayes(AEVB)は、2013年にKingma, Wellingによって考案された認識モデルの最適化アルゴリズム。AEVBの著者は、その応用例として変分オートエンコーダー(Variational Auto-Encoder, VAE)というニューラルネットワークのモデル・損失設計を提案した。従来のオートエンコーダーと比較すると、VAEは変分Bayesの考えを応用していることが大きな特徴となっている。

VAEではエンコーダー・デコーダーの出力の仕様や、損失については定めているが、ニューラルネットワークの具体的な構造は限定していない。

変分下限

潜在変数\(z\)がパラメーター\(\theta\)の確率分布\(p(z|\theta)\)に従って発生し、\(x\)は\(\theta\)と\(z\)によって特徴付けられた確率分布\(p(x|z,\theta)\)から出力されるとする。
また、\(z\)の事後分布がパラメーター\(\phi\)の確率分布\(q(z|x, \phi)\)として表現されているとする。

前提
  • \(p(z|\theta), p(x|z,\theta)\)が関数として既知。
    (つまりニューラルネットワークが具体的に実装されているということ)
  • \(z\)の事後分布\(q(z|x, \phi)\)もまた関数として既知。
  • 分布\(p\)に従って生成された値\((x, z)\)の内、\(x\)だけが観測されていてデータの集合\(X^D=\{x_n^D\}_{n=1}^N\)として値が得られている。
目的

観測データに最も適合する\(\theta, \phi\)をただ一つ求める。

観測済みのデータ\(X^D\)の対数尤度を調べる。

\begin{align} \ell(\theta|X^D) &= \log p(X^D|\theta) \\ &= \int q(Z|X^D,\phi)\log p(X^D|\theta)dZ \\ &= \int q(Z|X^D,\phi)\log\frac{p(X^D,Z|\theta)}{p(Z|X^D,\theta)}dZ \\ &= \int q(Z|X^D,\phi)\log\frac{q(Z|X^D,\phi)}{p(Z|X^D,\theta)}dZ + \int q(Z|X^D,\phi)\log\frac{p(X^D,Z|\theta)}{q(Z|X^D,\phi)}dZ \\ &= D_{KL}\left(q(\cdot|X^D,\phi)\|p(\cdot|X^D,\theta)\right) + F(\theta,\phi|X^D) \end{align}

オートエンコーダーの目的上第1項は最小化されるべきなので、対数尤度を最大化するためには変分Bayesのときと同様に変分下限\(F(\theta,\phi|X^D)\)を最大化する。
変分下限は次のようにも表すことができる。

\begin{align} F(\theta,\phi|X^D) &= \int q(Z|X^D,\phi)\log\frac{p(X^D,Z|\theta)}{q(Z|X^D,\phi)}dZ \\ &= \int q(Z|X^D,\phi)\log\frac{p(X^D|Z,\theta)p(Z|\theta)}{q(Z|X^D,\phi)}dZ \\ &= \int q(Z|X^D,\phi)\log p(X^D|Z,\theta)dZ - \int q(Z|X^D,\phi)\log\frac{q(Z|X^D,\phi)}{p(Z|\theta)}dZ \\ &= \mathbb{E}_{q(Z|X^D,\phi)}\left(\log p(X^D|Z,\theta)\right) - D_{KL}\left(q(Z|X^D,\phi)\|p(Z|\theta)\right) \\ &= \sum_n \left( \mathbb{E}_{q(z_n|x_n^D,\phi)}\left(\log p(x_n^D|z_n,\theta)\right) - D_{KL}\left(q(z_n|x_n^D,\phi)\|p(z_n|\theta)\right) \right) \\ \end{align}

変分下限の第1項を再構成誤差と言い、第2項を正則化項と言う。

損失関数

変分オートエンコーダーでは、\(p(z|\theta), q(z|x^D,\phi)\)を以下のような多次元正規分布であるとする。

\begin{align} \begin{cases} p(z|\theta) = p_{\mathcal{N}}(z|0,I) \\ q(z|x^D,\phi) = p_{\mathcal{N}}\left(z|\mu(x^D,\phi),\Sigma(x^D,\phi)\right) \\ \end{cases} \end{align}

ここで、\(\mu(x^D,\phi),\Sigma(x^D,\phi)\)は\(x^D\)を入力とするニューラルネットワークの出力である。
分散共分散行列\(\Sigma(x^D,\phi)\)は対角行列であるものとし、その対角成分の平方根を取り出したベクトルを\(\sigma(x^D,\phi)\)と表す。

ニューラルネットワークの微分が繋がるように、\(q(z|x^D,\phi)\)から\(z\)をサンプルするときはまず\(\epsilon\sim\mathcal{N}(0,I)\)を取り、\(z = \mu(x^D,\phi) + \sigma(x^D,\phi)\odot\epsilon \)とする。このような手法をreparameterization trickと呼ぶ。

正則化項

\(p(z|\theta), q(z|x^D,\phi)\)が共に正規分布なので、正規分布同士のKL情報量の計算結果を用いて正則化項は以下のようになる。\(D\)は\(z\)の次元。

\begin{align} & D_{KL}\left(q(\cdot|X^D,\phi)\|p(\cdot|\theta)\right) \\ &= D_{KL}\left(p_{\mathcal{N}}\left(\cdot|\mu(x^D,\phi),\Sigma(x^D,\phi)\right)\|p_{\mathcal{N}}(\cdot|0,I)\right) \\ &= \frac{1}{2}\left( \log\frac{|I|}{|\Sigma(x^D,\phi)|} - D + {\rm tr}(I^{-1}\Sigma(x^D,\phi)) + (\mu(x^D,\phi)-0)^\top I^{-1}(\mu(x^D,\phi)-0) \right) \\ &= \frac{1}{2}\left( -\log\left(\prod_d\sigma_d^2\right) - D + \sum_d \sigma_d^2 + \mu^\top \mu \right) \\ &= \frac{1}{2}\left( -\sum_d \log\left(\sigma_d^2\right) - D + \sum_d \sigma_d^2 + \sum_d \mu_d^2 \right) \\ &= \frac{1}{2}\sum_d \left( -\log\left(\sigma_d^2\right) - 1 + \sigma_d^2 + \mu_d^2 \right) \\ \end{align}

ここで、\(v_d := \log\left(\sigma_d^2\right)\)とすると最終的に以下のようになる。

\begin{align} D_{KL}\left(q(\cdot|X^D,\phi)\|p(\cdot|\theta)\right) &= \frac{1}{2}\sum_d \left( \mu_d^2 - \log\left(\sigma_d^2\right) + \sigma_d^2 - 1 \right) \\ &= \frac{1}{2}\sum_d \left( \mu_d^2 - v_d + \exp(v_d) - 1 \right) \\ \end{align}

再構成誤差

再構成誤差はデコーダー\(p(x^D|z, \theta)\)の設計次第で具体的な式が変わるが、通常のオートエンコーダーの損失と同じように設計されることが多い。

例えば、\(p(x^D|z, \theta)\)を平均\(y(z,\theta)\), 分散共分散行列\(\sigma^2I\)の多次元正規分布とする。(デコーダーの出力を正規分布の平均\(y(z,\theta)\)とする)
このとき、再構成誤差は平均二乗誤差(MSE)になる。

\begin{align} \mathbb{E}_{q(z_n|x_n^D,\phi)}\left(\log p(x_n^D|z_n,\theta)\right) &= \mathbb{E}_{q(z_n|x_n^D,\phi)}\left(\log p_\mathbb{N}(x_n^D|y(z_n,\theta),\sigma^2I)\right) \\ &= \mathbb{E}_{q(z_n|x_n^D,\phi)}\left(-\frac{1}{2\sigma^2}\|x_n^D-y(z_n,\theta)\|^2+C\right) \\ &= -\frac{1}{2\sigma^2}\mathbb{E}_{q(z_n|x_n^D,\phi)}\left(\|x_n^D-y(z_n,\theta)\|^2)\right)+C \\ \end{align}

確率的勾配降下法を使うので、期待値の部分はミニバッチを用いた平均で近似する。
具体的にはデータセット\(\{x_n^D\}_{n=1}^N\)から入力のミニバッチ\(\{x_b^D\}_{b=1}^B\)をサンプルし、更にニューラルネットワークを通して上記のreparameterization trickによって\(\{z(x_b^D,\phi)\}_{b=1}^B\)をサンプルして計算に使用する。

\begin{align} \frac{1}{N}\sum_{n=1}^N \mathbb{E}_{q(z_n|x_n^D,\phi)}\left(\|x_n^D-y(z_n,\theta)\|^2)\right) &\sim \frac{1}{B}\sum_{b=1}^B \|x_b^D-y(z_b,\theta)\|^2 \\ &= \frac{1}{B}\sum_{b=1}^B \left\|x_b^D-y\left(z(x_b^D,\phi),\theta\right)\right\|^2 \\ &= \frac{1}{B}\sum_{b=1}^B Loss_{recon}(y_b,x_b^D) \\ \end{align}

\begin{align} Loss_{recon}(y_b,x_b^D) := \left\|x_b^D-y_b\right\|^2 \end{align}

一方、\(p(x^D|z, \theta)\)にBernoulli分布を仮定する場合、再構成誤差はBinary Cross Entropy Lossになる。
\(x^D\in\{0,1\}^I\)とし、\(p(x^D|z, \theta)\)は「各成分が1になる確率」で構成されたベクトル\(y(z,\theta)\in(0,1)^I\)を出力するように設計する。このとき、再構成誤差は以下のように計算される。

\begin{align} \frac{1}{N}\sum_n\mathbb{E}_{q(z_n|x_n^D,\phi)}\left(\log p(x_n^D|z_n,\theta)\right) &= \frac{1}{N}\sum_n\mathbb{E}_{q(z_n|x_n^D,\phi)}\left(\log p_{Bernoulli}(x_n^D|y(z_n,\theta))\right) \\ &= \frac{1}{N}\sum_n\mathbb{E}_{q(z_n|x_n^D,\phi)}\left(\log\left(\prod_{i=1}^I y_{n,i}^{x_{n,i}^D}(1-y_{n,i})^{1-x_{n,i}^D}\right)\right) \\ &= \frac{1}{N}\sum_n\mathbb{E}_{q(z_n|x_n^D,\phi)}\left(\sum_i \left(x_{n,i}^D\log y_{n,i} + (1-x_{n,i}^D)\log(1-y_{n,i})\right)\right) \\ &= -\frac{1}{N}\sum_n\mathbb{E}_{q(z_n|x_n^D,\phi)}\left(\sum_i BCELoss\left(y(z_n,\theta), x_n^D\right)_i\right) \\ &\sim -\frac{1}{B}\sum_{b,i}BCELoss\left(y(z(x_b^D,\phi),\theta), x_b^D\right)_i \\ &= \frac{1}{B}\sum_b Loss_{recon}(y_b,x_b^D) \\ \end{align}

\begin{align} Loss_{recon}(y_b,x_b^D) := \sum_{i}BCELoss\left(y_b,x_b^D\right)_i \end{align}

損失関数

損失関数は最小化する値なので、変分下限の符号を反転させる。

\begin{align} Loss &= -F(\theta,\phi|X^D) \\ &= -\sum_n \left( \mathbb{E}_{q(z_n|x_n^D,\phi)}\left(\log p(x_n^D|z_n,\theta)\right) - D_{KL}\left(q(z_n|x_n^D,\phi)\|p(z_n|\theta)\right) \right) \\ &\sim -\frac{N}{B}\sum_b \left( \log p(x_b^D|z_b,\theta) - D_{KL}\left(q(z_n|x_b^D,\phi)\|p(z_b|\theta)\right) \right) \\ &= \frac{N}{B}\sum_b \left( \frac{1}{2\sigma^2}Loss_{recon}(y_b,x_b^D) + \frac{1}{2}\sum_d \left( \mu_{b,d}^2 - v_{b,d} + \exp(v_{b,d}) \right) \right) \\ &= \frac{N}{2B}\sum_b \left( \frac{1}{\sigma^2}Loss_{recon}(y_b,x_b^D) + \sum_d \left( \mu_{b,d}^2 - v_{b,d} + \exp(v_{b,d}) \right) \right) \\ \end{align}

実験結果

AEVBを使った学習の例として、論文著者はFrey FaceデータセットとMNISTデータセットを学習し、潜在空間からデコードして可視化した結果を掲載している。

参考