DDPM: Denoising Diffusion Probabilistic Models

この記事では近年の画像生成AIに用いられている基礎技術「DDPM」の解説を行う。

拡散モデル(Diffusion Models)は、有限時間で入力データを再現するように学習したMarkov連鎖。2015年にSohl-Dicksteinらによって発表された。
2020年6月にHo, Jain, Abbeelが公開した論文Denoising Diffusion Probabilistic Models(DDPM)では拡散モデルをニューラルネットワークによる画像生成に適用することで、高品質な画像を出力するオートエンコーダーを構築することに成功した。

TensorFlowで実装された公式のソースコードも公開されている。

関連記事

目次

理論

順課程と逆過程

拡散モデルの理論は順過程と逆過程によって構成される。


順過程と逆過程の概要

順過程(forward process)では、教師データ\(x_0\)にランダムな微小ノイズを加える操作を\(T\)回繰り返し、完全なノイズに近い画像\(x_T\sim\mathcal{N}(0,I)\)へ変化させる。
それぞれのタイムステップにおける画像を\(x_t\)とし、その分布を\(q(x_t|x_{t-1})\)と表す。

逆過程(reverse process)では、ニューラルネットワークを用いてステップごとに入力画像のノイズを除去していく。
ニューラルネットワークには入力\(x_t\)に対して、タイムステップ1回分のノイズを除去した\(x_{t-1}\)を出力する確率\(p_\theta(x_{t-1}|x_t)\)を学習させる。

forward process

画像\(x_0\)に対して、以下の分布に従って\(x_{1:T}=\{x_t\}_{t=1}^T\)をサンプルする。

\begin{align} q(x_t|x_{t-1}) = p_\mathcal{N}\left(x_t|\sqrt{1-\beta_t}x_{t-1}, \beta_t I\right) \end{align}

reverse process

各時刻の\(x_t\)に対して、パラメーター\(\theta\)の分布を通して以下のように\(x_{t-1}\)を予測する。

\begin{align} \begin{cases} p_(x_T) = p_\mathcal{N}(x_T|0,I) \\ p_\theta(x_{t-1}|x_t) = p_\mathcal{N}\left(x_{t-1}|\mu_\theta(x_t,t),\Sigma_\theta(x_t,t)\right) \end{cases} \end{align}

ここで、\(\mu_\theta(x_t,t),\Sigma_\theta(x_t,t)\)はニューラルネットワークが出力する値。


Encoder-Decoderモデルとして見たときの概要図

順課程において、平均と分散の係数は\(\mathbb{E}(\|x_t\|^2) = \beta_t + (1-\beta_t)\|x_{t-1}\|^2\)となるように\(\sqrt{1-\beta_t},\beta_t\)が選ばれている。
\(\beta_t\)は学習可能な値にすることも可能だが、本論文では\(t\)のみに依存する学習不可能な固定値としている。

\(\alpha_t := 1-\beta_t, \bar{\alpha}_t := \prod_{s=1}^t \alpha_s\)とすると、\(q(x_1|x_0)\)は次のようにも表すことができる。

\begin{align} q(x_1|x_0) &= p_\mathcal{N}(x_1|\sqrt{1-\beta_1}x_0,\beta_1I) \\ &= p_\mathcal{N}(x_1|\sqrt{\bar{\alpha}_1}x_0,(1-\bar{\alpha}_1)I) \\ \end{align}

正規分布の再生性を用いることで、帰納的に任意の時刻\(t\)に対して\(q(x_t|x_0)\)を次のように求めることができる。

\begin{align} q(x_t|x_0) &= \int q(x_t|x_{t-1})q(x_{t-1}|x_0)dx_{t-1} \\ &= \int p_\mathcal{N}(x_t|\sqrt{1-\beta_t}x_{t-1},\beta_tI)p_\mathcal{N}(x_{t-1}|\sqrt{\bar{\alpha}_{t-1}}x_0,(1-\bar{\alpha}_{t-1})I)dx_{t-1} \\ &= \int p_\mathcal{N}\left(\frac{x_t}{\sqrt{1-\beta_t}}\middle|x_{t-1},\frac{\beta_t}{1-\beta_t}I\right)p_\mathcal{N}(x_{t-1}|\sqrt{\bar{\alpha}_{t-1}}x_0,(1-\bar{\alpha}_{t-1})I)dx_{t-1} \\ &= p_\mathcal{N}\left(\frac{x_t}{\sqrt{1-\beta_t}}\middle|\sqrt{\bar{\alpha}_{t-1}}x_0,(\frac{\beta_t}{1-\beta_t}+1-\bar{\alpha}_{t-1})I\right) \\ &= p_\mathcal{N}(x_t|\sqrt{1-\beta_t}\sqrt{\bar{\alpha}_{t-1}}x_0,(\beta_t + (1-\beta_t)(1-\bar{\alpha}_{t-1}))I) \\ &= p_\mathcal{N}(x_t|\sqrt{\alpha_t}\sqrt{\bar{\alpha}_{t-1}}x_0,(1-\alpha_t + \alpha_t(1-\bar{\alpha}_{t-1}))I) \\ &= p_\mathcal{N}(x_t|\sqrt{\alpha_t\bar{\alpha}_{t-1}}x_0,(1-\alpha_t\bar{\alpha}_{t-1})I) \\ &= p_\mathcal{N}(x_t|\sqrt{\bar{\alpha}_t}x_0,(1-\bar{\alpha}_t)I) \\ \end{align}

これにより、タイムステップを無視していきなり\(x_t\)を生成し、学習に利用することができるようになる。


\(\varepsilon \sim \mathcal{N}(0, I)\)を取り、\(x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\varepsilon\)を生成する

また、\(q(x_T|x_0)=p_\mathcal{N}(x_T|\sqrt{\bar{\alpha}_T}x_0,(1-\bar{\alpha}_T)I)\)が完全なノイズの分布\(p_\mathcal{N}(0,I)\)に近くなる必要があるので、\(\bar{\alpha}_T\)が微小な値となるように\(T\)と\(\beta_t\)を選ばなければならない。

変分下限

VAEの表記と対比すると、\(x_0\)がVAEの\(x^D\)に相当し、\(x_{1:T}=\{x_t\}_{t=1}^T\)全てがVAEの潜在変数\(z\)に相当することになる。

VAEと同様に、負の変分下限を最小化することを目的とする。
VAEの変分下限に\(x^D=x_0, z=x_{1:T}\)を当てはめると以下のようになる。

\begin{align} Loss &= -F(\theta|x_0) \\ &= -\int q(x_{1:T}|x_0) \log \frac{p_\theta(x_{0:T})}{q(x_{1:T}|x_0)} dx_{1:T} \\ &= -\mathbb{E}_{q(x_{1:T}|x_0)}\left(\log \frac{p_\theta(x_{0:T})}{q(x_{1:T}|x_0)}\right) \\ &= -\mathbb{E}_{q(x_{1:T}|x_0)}\left(\log \frac{p(x_T)\prod_{t=1}^T p_\theta(x_{t-1}|x_t)}{\prod_{t=1}^T q(x_t|x_{t-1})}\right) \\ &= -\mathbb{E}_{q(x_{1:T}|x_0)}\left(\log p(x_T) + \sum_{t=1}^T \log \frac{p_\theta(x_{t-1}|x_t)}{q(x_t|x_{t-1})}\right) \\ &= -\mathbb{E}_{q(x_{1:T}|x_0)}\left(\log p(x_T) + \sum_{t=2}^T \log \frac{p_\theta(x_{t-1}|x_t)}{q(x_t|x_{t-1})} + \log \frac{p_\theta(x_0|x_1)}{q(x_1|x_0)}\right) \\ \end{align}

第2項を詳しく見る。
順過程はMarkov連鎖となるので、\(q(x_t|x_{t-1}) = q(x_t|x_{t-1},x_0)\)とできる。また、\(q(x_t|x_{t-1},x_0)q(x_{t-1}|x_0) = q(x_{t-1}|x_t,x_0)q(x_t|x_0)\)となることを利用して以下のように式を変形する。

\begin{align} \sum_{t=2}^T \log \frac{p_\theta(x_{t-1}|x_t)}{q(x_t|x_{t-1})} &= \sum_{t=2}^T \log \frac{p_\theta(x_{t-1}|x_t)}{q(x_t|x_{t-1},x_0)} \\ &= \sum_{t=2}^T \log \frac{p_\theta(x_{t-1}|x_t)}{q(x_{t-1}|x_t,x_0)}\frac{q(x_{t-1}|x_0)}{q(x_t|x_0)} \\ &= \sum_{t=2}^T \log \frac{p_\theta(x_{t-1}|x_t)}{q(x_{t-1}|x_t,x_0)} + \sum_{t=2}^T\left(\log q(x_{t-1}|x_0)- \log q(x_t|x_0)\right) \\ &= \sum_{t=2}^T \log \frac{p_\theta(x_{t-1}|x_t)}{q(x_{t-1}|x_t,x_0)} + \log q(x_1|x_0)- \log q(x_T|x_0) \\ \end{align}

元の式へ反映させる。

\begin{align} Loss &= -\mathbb{E}_{q(x_{1:T}|x_0)}\left(\log p(x_T) + \sum_{t=2}^T \log \frac{p_\theta(x_{t-1}|x_t)}{q(x_t|x_{t-1})} + \log \frac{p_\theta(x_0|x_1)}{q(x_1|x_0)}\right) \\ &= -\mathbb{E}_{q(x_{1:T}|x_0)}\left(\log p(x_T) + \sum_{t=2}^T \log \frac{p_\theta(x_{t-1}|x_t)}{q(x_{t-1}|x_t,x_0)} + \log q(x_1|x_0)- \log q(x_T|x_0) + \log \frac{p_\theta(x_0|x_1)}{q(x_1|x_0)}\right) \\ &= -\mathbb{E}_{q(x_{1:T}|x_0)}\left(\log\frac{p(x_T)}{q(x_T|x_0)} + \sum_{t=2}^T \log \frac{p_\theta(x_{t-1}|x_t)}{q(x_{t-1}|x_t,x_0)} + \log p_\theta(x_0|x_1)\right) \\ &= \mathbb{E}_{q(x_{1:T}|x_0)}\left(\log\frac{q(x_T|x_0)}{p(x_T)}\right) + \sum_{t=2}^T \mathbb{E}_{q(x_{1:T}|x_0)}\left(\log \frac{q(x_{t-1}|x_t,x_0)}{p_\theta(x_{t-1}|x_t)}\right) - \mathbb{E}_{q(x_{1:T}|x_0)}\left(\log p_\theta(x_0|x_1)\right) \\ &=: L_T + \sum_{t=2}^T L_{t-1} +L_0 \\ \end{align}

ここで、\(L_T, L_{t-1}, L_0\)の定義は以下。

\begin{align} \left\{ \begin{array}{ll} L_T &= \mathbb{E}_{q(x_{1:T}|x_0)}\left(\log\frac{q(x_T|x_0)}{p(x_T)}\right), \\ L_{t-1} &= \mathbb{E}_{q(x_{1:T}|x_0)}\left(\log \frac{q(x_{t-1}|x_t,x_0)}{p_\theta(x_{t-1}|x_t)}\right), \\ L_0 &= - \mathbb{E}_{q(x_{1:T}|x_0)}\left(\log p_\theta(x_0|x_1)\right). \\ \end{array} \right. \end{align}

\(L_T\)

\begin{align} L_T &= \mathbb{E}_{q(x_{1:T}|x_0)}\left(\log\frac{q(x_T|x_0)}{p(x_T)}\right) \\ &= \int \prod_{t=1}^T q(x_t|x_{t-1},x_0)\log\frac{q(x_T|x_0)}{p(x_T)}dx_{1:T} \\ &= \int \int \prod_{t=1}^T q(x_t|x_{t-1},x_0)\log\frac{q(x_T|x_0)}{p(x_T)} dx_{1:T-1}dx_T \\ &= \int \left(\int \prod_{t=1}^T q(x_t|x_{t-1},x_0)dx_{1:T-1}\right)\log\frac{q(x_T|x_0)}{p(x_T)} dx_T \\ &= \int q(x_T|x_0)\log\frac{q(x_T|x_0)}{p(x_T)} dx_T \\ &= D_{KL}\left(q(x_T|x_0)\|p(x_T)\right) \end{align}

\(\beta_t\)は学習可能にすることもできると先述したが、DDPMでは\(t\)のみに依存する学習不可能な固定値とする。
こうすることで\(L_T\)には学習可能なパラメーターが残らないようになり、損失関数の中で定数として無視することができるようになる。

ところで、この項は\(q(x_T|x_0)\)と\(p(x_T)\)の近さを表すので、順過程でノイズを加えた最終的な画像が完全なノイズにどのくらい近付いているかを評価する項と見ることもできる。
正規分布同士のKL情報量を用いて更に詳しく調べる。\(x_0\in[-1,1]^D\)とする。

\begin{align} L_T &= D_{KL}\left(q(x_T|x_0)\|p(x_T)\right) \\ &= D_{KL}\left(p_\mathcal{N}(x_T|\sqrt{\bar{\alpha}_T}x_0,(1-\bar{\alpha}_T)I)\|p_\mathcal{N}(x_T|0,I)\right) \\ &= \frac{1}{2}\left( \log \frac{1}{1-\bar{\alpha}_T} - D + (1-\bar{\alpha}_T)D + \|\sqrt{\bar{\alpha}_T}x_0\|^2 \right) \\ &= \frac{1}{2}\left( - \log (1-\bar{\alpha}_T) - \bar{\alpha}_T D + \bar{\alpha}_T\|x_0\|^2 \right) \\ &\leq \frac{1}{2}\left( - \log (1-\bar{\alpha}_T) - \bar{\alpha}_T D + \bar{\alpha}_T D \right) \\ &= -\frac{1}{2}\log (1-\bar{\alpha}_T) \\ \end{align}

したがって、\(\bar{\alpha}_T\)が十分小さければ、順過程の設定が適切であると言える。

\(L_{t-1}\)

\begin{align} L_{t-1} &= \mathbb{E}_{q(x_{1:T}|x_0)}\left(\log \frac{q(x_{t-1}|x_t,x_0)}{p_\theta(x_{t-1}|x_t)}\right) \\ &= \int \prod_{s=1}^T q(x_s|x_{s-1},x_0)\log \frac{q(x_{t-1}|x_t,x_0)}{p_\theta(x_{t-1}|x_t)}dx_{1:T} \\ &= \int \int \int \int \left(\prod_{s=1}^{t-1} q(x_s|x_{s-1},x_0)\right) q(x_t|x_{t-1},x_0) \left(\prod_{s=t+1}^T q(x_s|x_{s-1},x_0)\right) \log \frac{q(x_{t-1}|x_t,x_0)}{p_\theta(x_{t-1}|x_t)}dx_{1:t-2}dx_{t-1}dx_t dx_{t+1:T} \\ &= \int \int \left(\int \prod_{s=1}^{t-1} q(x_s|x_{s-1},x_0)dx_{1:t-2}\right) q(x_t|x_{t-1},x_0) \left(\int \prod_{s=t+1}^T q(x_s|x_{s-1},x_0)dx_{t+1:T}\right) \log \frac{q(x_{t-1}|x_t,x_0)}{p_\theta(x_{t-1}|x_t)}dx_{t-1}dx_t \\ &= \int \int q(x_{t-1}|x_0) q(x_t|x_{t-1},x_0) \cdot 1 \cdot \log \frac{q(x_{t-1}|x_t,x_0)}{p_\theta(x_{t-1}|x_t)}dx_{t-1}dx_t \\ &= \int \int q(x_t,x_{t-1}|x_0) \log \frac{q(x_{t-1}|x_t,x_0)}{p_\theta(x_{t-1}|x_t)}dx_{t-1}dx_t \\ &= \int \int q(x_{t-1}|x_t,x_0)q(x_t|x_0) \log \frac{q(x_{t-1}|x_t,x_0)}{p_\theta(x_{t-1}|x_t)}dx_{t-1}dx_t \\ &= \int q(x_t|x_0) D_{KL}\left(q(x_{t-1}|x_t,x_0)\|p_\theta(x_{t-1}|x_t)\right) dx_t \\ &= \mathbb{E}_{q(x_t|x_0)}\left( D_{KL}\left(q(x_{t-1}|x_t,x_0)\|p_\theta(x_{t-1}|x_t)\right) \right) \\ \end{align}

\(q(x_{t-1}|x_t,x_0)\)を調べる。

\begin{align} q(x_{t-1}|x_t,x_0) &= \frac{q(x_{t-1},x_t|x_0)}{q(x_t|x_0)} \\ &= q(x_t|x_{t-1},x_0)\frac{q(x_{t-1}|x_0)}{q(x_t|x_0)} \\ &\propto q(x_t|x_{t-1},x_0)q(x_{t-1}|x_0) \\ &= p_\mathcal{N}(x_t|\sqrt{\alpha_t}x_{t-1},(1-\alpha_t)I)p_\mathcal{N}(x_{t-1}|\sqrt{\bar{\alpha}_{t-1}}x_0,(1-\bar{\alpha}_{t-1})I) \\ &= p_\mathcal{N}\left(x_{t-1}\middle|\frac{x_t}{\sqrt{\alpha_t}},\frac{1-\alpha_t}{\alpha_t}I\right)p_\mathcal{N}(x_{t-1}|\sqrt{\bar{\alpha}_{t-1}}x_0,(1-\bar{\alpha}_{t-1})I) \\ \end{align}

一般に、2つの正規分布の積は正規分布の定数倍となる。

\begin{align} & p_\mathcal{N}(x|\mu_1,\Sigma_1)p_\mathcal{N}(x|\mu_2,\Sigma_2) \\ &\propto \exp\left( -\frac{1}{2}\left( (x-\mu_1)^T\Sigma_1^{-1}(x-\mu_1) + (x-\mu_2)^T\Sigma_2^{-1}(x-\mu_2) \right) \right) \\ &\propto \exp\left( -\frac{1}{2}\left( x^T(\Sigma_1^{-1}+\Sigma_2^{-1})x - x^T(\Sigma_1^{-1}\mu_1+\Sigma_2^{-1}\mu_2) - (\Sigma_1^{-1}\mu_1+\Sigma_2^{-1}\mu_2)^Tx \right) \right) \\ &\propto \exp\left( -\frac{1}{2}\left( (x-(\Sigma_1^{-1}+\Sigma_2^{-1})^{-1}(\Sigma_1^{-1}\mu_1+\Sigma_2^{-1}\mu_2))^T(\Sigma_1^{-1}+\Sigma_2^{-1})(x-(\Sigma_1^{-1}+\Sigma_2^{-1})^{-1}(\Sigma_1^{-1}\mu_1+\Sigma_2^{-1}\mu_2)) \right) \right) \\ &\propto p_\mathcal{N}\left(x\middle|(\Sigma_1^{-1}+\Sigma_2^{-1})^{-1}(\Sigma_1^{-1}\mu_1+\Sigma_2^{-1}\mu_2),(\Sigma_1^{-1}+\Sigma_2^{-1})^{-1}\right) \\ &\propto p_\mathcal{N}\left(x\middle|(\Sigma_1+\Sigma_2)^{-1}(\Sigma_2\mu_1+\Sigma_1\mu_2),\Sigma_1\Sigma_2(\Sigma_1+\Sigma_2)^{-1}\right) \\ \end{align}

この性質を利用すると、

\begin{align} q(x_{t-1}|x_t,x_0) &\propto p_\mathcal{N}\left(x_{t-1}\middle|\frac{x_t}{\sqrt{\alpha_t}},\frac{1-\alpha_t}{\alpha_t}I\right)p_\mathcal{N}(x_{t-1}|\sqrt{\bar{\alpha}_{t-1}}x_0,(1-\bar{\alpha}_{t-1})I) \\ &\propto p_\mathcal{N}\left(x_{t-1}\middle| \frac{1}{\frac{1-\alpha_t}{\alpha_t}+(1-\bar{\alpha}_{t-1})}\left((1-\bar{\alpha}_{t-1})\frac{x_t}{\sqrt{\alpha_t}} + \frac{1-\alpha_t}{\alpha_t}\sqrt{\bar{\alpha}_{t-1}}x_0\right) , \frac{\frac{1-\alpha_t}{\alpha_t}(1-\bar{\alpha}_{t-1})}{\frac{1-\alpha_t}{\alpha_t}+(1-\bar{\alpha}_{t-1})}I \right) \\ &= p_\mathcal{N}\left(x_{t-1}\middle| \frac{\alpha_t}{1-\bar{\alpha}_{t}}\left((1-\bar{\alpha}_{t-1})\frac{x_t}{\sqrt{\alpha_t}} + \frac{1-\alpha_t}{\alpha_t}\sqrt{\bar{\alpha}_{t-1}}x_0\right) , \frac{(1-\alpha_t)(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_{t}}I \right) \\ &= p_\mathcal{N}\left(x_{t-1}\middle| \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_{t}}x_t + \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_{t}}x_0 , \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_{t}}\beta_tI \right) \\ &=: p_\mathcal{N}\left(x_{t-1}\middle| \tilde{\mu}_t(x_t,x_0) , \tilde{\beta}_tI \right) \\ \end{align}

となる。ここで、

\begin{align} \begin{cases} \tilde{\mu}_t(x_t,x_0) := \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_{t}}x_t + \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_{t}}x_0, \\ \tilde{\beta}_t := \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_{t}}\beta_t. \\ \end{cases} \end{align}

\(q(x_{t-1}|x_t,x_0)\)は\(x_{t-1}\)と\(p_\mathcal{N}(x_{t-1}|\tilde{\mu}_t(x_t,x_0), \tilde{\beta}_t I)\)はどちらも\(x_{t-1}\)で積分すると1になるので、比例定数は1で\(q(x_{t-1}|x_t,x_0) = p_\mathcal{N}(x_{t-1}|\tilde{\mu}_t(x_t,x_0), \tilde{\beta}_t I)\)となる。

\(L_{t-1}\)の式に戻す。
ここで、ニューラルネットワークは正規分布の平均\(\mu_\theta(x_t,t)\)のみ出力するものとし、分散共分散行列は\(t\)のみに依存する対角行列\(\Sigma_\theta(x_t,t) = \sigma_t^2 I\)とする。\(\sigma_t^2\)の具体的な値は、\(\sigma_t^2=\beta_t\)あるいは上述の\(\sigma_t^2=\tilde{\beta}_t\)などとすることが提案されている。
正規分布同士のKL情報量の結果を用いて次のようになる。

\begin{align} L_{t-1} &= \mathbb{E}_{q(x_t|x_0)}\left( D_{KL}\left(q(x_{t-1}|x_t,x_0)\|p_\theta(x_{t-1}|x_t)\right) \right) \\ &= \mathbb{E}_{q(x_t|x_0)}\left( D_{KL}\left(p_\mathcal{N}\left(x_{t-1}|\tilde{\mu}_t(x_t,x_0), \tilde{\beta}_t I\right)\|p_\mathcal{N}\left(x_{t-1}|\mu_\theta(x_t,t),\Sigma_\theta(x_t,t)\right) \right) \right) \\ &= \mathbb{E}_{q(x_t|x_0)}\left( D_{KL}\left(p_\mathcal{N}\left(x_{t-1}|\tilde{\mu}_t(x_t,x_0), \tilde{\beta}_t I\right)\|p_\mathcal{N}\left(x_{t-1}|\mu_\theta(x_t,t),\sigma_t^2 I\right) \right) \right) \\ &= \mathbb{E}_{q(x_t|x_0)}\left( \frac{1}{2}\left( D\log \frac{\sigma_t^2}{\tilde{\beta}_t} - D + \frac{\tilde{\beta}_t}{\sigma_t^2}D + \frac{1}{\sigma_t^2}\|\tilde{\mu}_t(x_t,x_0) - \mu_\theta(x_t,t)\|^2 \right) \right) \\ &= \mathbb{E}_{q(x_t|x_0)}\left( \frac{1}{2}\left( D\left(\frac{\tilde{\beta}_t}{\sigma_t^2}-\log \frac{\tilde{\beta}_t}{\sigma_t^2}-1\right) + \frac{1}{\sigma_t^2}\|\tilde{\mu}_t(x_t,x_0) - \mu_\theta(x_t,t)\|^2 \right) \right) \\ &= \frac{1}{2\sigma_t^2} \mathbb{E}_{q(x_t|x_0)}\left( \|\tilde{\mu}_t(x_t,x_0) - \mu_\theta(x_t,t)\|^2 \right) + C \\ \end{align}


\(x_{t-1}\)の平均\(\mu_\theta(x_t,t)\)を予測するモデル。\(\mu_\theta\)と\(\tilde{\mu}\)の差で損失を取る。

この損失関数は次のように考えることもできる。
\(q(x_t|x_0) = p_\mathcal{N}(x_t|\sqrt{\bar{\alpha}_t}x_0,(1-\bar{\alpha}_t)I)\)なので、\(\varepsilon \sim \mathcal{N}(0,I)\)に対して\(x_t\)を次のように表すこともできる。

\begin{align} x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\varepsilon \end{align}

したがって、

\begin{align} x_0 = \frac{1}{\sqrt{\bar{\alpha}_t}}\left(x_t - \sqrt{1-\bar{\alpha}_t}\varepsilon\right) \end{align}

これを\(\tilde{\mu}(x_t,x_0)\)に代入すると、\(\tilde{\mu}(x_t,x_0)\)を\(x_t,x_0\)ではなく\(x_t,\varepsilon\)で表すことができるようになる。

\begin{align} \tilde{\mu}(x_t,x_0) &= \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}x_0 \\ &= \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}\frac{1}{\sqrt{\bar{\alpha}_t}}\left(x_t - \sqrt{1-\bar{\alpha}_t}\varepsilon\right) \\ &= \frac{\alpha_t(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t\sqrt{\alpha_t}}x_t + \frac{\beta_t}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\left(x_t - \sqrt{1-\bar{\alpha}_t}\varepsilon\right) \\ &= \frac{1}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\left(\left(\alpha_t(1-\bar{\alpha}_{t-1}) + \beta_t\right)x_t - \beta_t\sqrt{1-\bar{\alpha}_t}\varepsilon\right) \\ &= \frac{1}{(1-\bar{\alpha}_t)\sqrt{\alpha_t}}\left(\left(1 - \bar{\alpha}_t\right)x_t - \beta_t\sqrt{1-\bar{\alpha}_t}\varepsilon\right) \\ &= \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\varepsilon\right) \\ \end{align}

ニューラルネットワークの出力\(\mu_\theta(x_t,t)\)も同様の形に変形させる。

\begin{align} \varepsilon_\theta(x_t,t) := \frac{\sqrt{1-\bar{\alpha}_t}}{\beta_t}(x_t - \sqrt{\alpha_t}\mu_\theta(x_t,t)) \end{align}

とすると、

\begin{align} \mu_\theta(x_t,t) = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\varepsilon_\theta(x_t,t)\right) \\ \end{align}

これらの結果を用いると、\(L_{t-1}\)は以下のようになる。

\begin{align} L_{t-1} &= \frac{1}{2\sigma_t^2} \mathbb{E}_{q(x_t|x_0)}\left( \|\tilde{\mu}_t(x_t,x_0) - \mu_\theta(x_t,t)\|^2 \right) + C \\ &= \frac{1}{2\sigma_t^2} \mathbb{E}_{p_\mathcal{N}(\varepsilon|0,I)}\left( \left\|\frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\varepsilon\right) - \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\varepsilon_\theta(x_t,t)\right)\right\|^2 \right) + C \\ &= \frac{\beta_t^2}{2\sigma_t^2\alpha_t(1-\bar{\alpha}_t)} \mathbb{E}_{p_\mathcal{N}(\varepsilon|0,I)}\left( \|\varepsilon - \varepsilon_\theta(x_t,t)\|^2 \right) + C \\ \end{align}


ノイズ\(\varepsilon_\theta(x_t,t)\)を予測するモデル。\(\varepsilon_\theta\)と\(\tilde{\varepsilon}\)の差で損失を取る。

\(\varepsilon_\theta(x_t,t)\)は上の式を用いることで\(\mu_\theta(x_t,t)\)から求めることもできるが、ニューラルネットワークが直接\(\varepsilon_\theta(x_t,t)\)を出力しても良い。実験によると、ニューラルネットワークが\(x_{t-1}\)の平均\(\mu_\theta(x_t,t)\)を予測するよりも、ノイズ\(\varepsilon_\theta(x_t,t)\)を予測して損失に利用した方が性能が高いことが報告されている。

\(L_0\)

\begin{align} L_0 &= - \mathbb{E}_{q(x_{1:T}|x_0)}\left(\log p_\theta(x_0|x_1)\right) \\ &= - \mathbb{E}_{q(x_1|x_0)}\left(\log p_\theta(x_0|x_1)\right) \\ &= - \mathbb{E}_{q(x_1|x_0)}\left(\log p_\mathcal{N}(x_0|\mu_\theta(x_1,1),\sigma_1^2 I)\right) \\ &= - \mathbb{E}_{q(x_1|x_0)}\left(\log \prod_d^D p_\mathcal{N}(x_{0,d}|\mu_d,\sigma_1^2)\right) \\ &= - \mathbb{E}_{q(x_1|x_0)}\left(\sum_d^D \log p_\mathcal{N}(x_{0,d}|\mu_d,\sigma_1^2)\right) \\ \end{align}

学習に使われる画像は各成分(=ピクセル・チャンネル)が\(\{0, 1, \cdots , 255\}\)の値を取るが、ニューラルネットワークの入出力は、各成分が\([-1, 1]\)の値を取るようにスケーリングされる。つまり、\(x_0 \in \{-1, -\frac{253}{255}, \cdots , -\frac{1}{255}, \frac{1}{255}, \cdots , \frac{253}{255}\}\)となる。

\(x_0\)は離散な値を取るので、\(p_\theta(x_0|x_1)\)を離散化する。
具体的には、\(p_\theta(x|x_1)\)からサンプルした値を丸めた値が\(x_0\)と一致する確率に置き換える。

\begin{align} p_\mathcal{N}(x_{0,d}|\mu_d,\sigma_1^2) &= \int_{I(x_{0,d})} p_\mathcal{N}(x|\mu_d,\sigma_1^2) dx \\ &\sim \frac{2}{255} p_\mathcal{N}(x_{0,d}|\mu_d,\sigma_1^2) \end{align}

ここで、\(I(x_{0,d})\)は以下で定義される区間。

\begin{align} \begin{cases} I(-1) = [-\infty, -\frac{254}{255}], \\ I(x) = [x-\frac{1}{255}, x+\frac{1}{255}], \quad (-1 < x < 1) \\ I(1) = [\frac{254}{255}, \infty]. \\ \end{cases} \end{align}


積分区間

積分を更に近似する。

\begin{align} p_\mathcal{N}(x_{0,d}|\mu_d,\sigma_1^2) &\sim \int_{x_{0,d}-1/255}^{x_{0,d}+1/255} p_\mathcal{N}(x|\mu_d,\sigma_1^2) dx \\ &\sim \frac{2}{255} p_\mathcal{N}(x_{0,d}|\mu_d,\sigma_1^2) \end{align}

したがって、\(L_0\)は最終的に以下のようになる。定義から\(\tilde{\mu}_1(x_1,x_0) = x_0\)となることを利用している。

\begin{align} L_0 &= - \mathbb{E}_{q(x_1|x_0)}\left(\log p_\theta(x_0|x_1)\right) \\ &= - \mathbb{E}_{q(x_1|x_0)}\left(\sum_d^D \log p_\mathcal{N}(x_{0,d}|\mu_d,\sigma_1^2)\right) \\ &\sim - \mathbb{E}_{q(x_1|x_0)}\left(\sum_d^D \log\left( \frac{2}{255} p_\mathcal{N}(x_{0,d}|\mu_d,\sigma_1^2) \right)\right) \\ &\sim - \mathbb{E}_{q(x_1|x_0)}\left(\sum_d^D \left( -\frac{1}{2\sigma_1^2}|x_{0,d}-\mu_d|^2 \right) \right) + C \\ &= \frac{1}{2\sigma_1^2} \mathbb{E}_{q(x_1|x_0)}\left( \|x_{0}-\mu_\theta(x_1,1)\|^2 \right) + C \\ &= \frac{1}{2\sigma_1^2} \mathbb{E}_{q(x_1|x_0)}\left( \|\tilde{\mu}_1(x_1,x_0)-\mu_\theta(x_1,1)\|^2 \right) + C \\ \end{align}

損失関数の簡略化

以上の結果より、損失関数は定数を除いて次のようにまとめられる。

\begin{align} Loss &= L_T + \sum_{t=2}^T L_{t-1} +L_0 \\ &= \sum_{t=2}^T \frac{1}{2\sigma_t^2} \mathbb{E}_{q(x_t|x_0)}\left( \|\tilde{\mu}_t(x_t,x_0) - \mu_\theta(x_t,t)\|^2 \right) + \frac{1}{2\sigma_1^2} \mathbb{E}_{q(x_1|x_0)}\left( \|\tilde{\mu}_1(x_1,x_0)-\mu_\theta(x_1,1)\|^2 \right) \\ &= \sum_{t=1}^T \frac{1}{2\sigma_t^2} \mathbb{E}_{q(x_t|x_0)}\left( \|\tilde{\mu}_t(x_t,x_0) - \mu_\theta(x_t,t)\|^2 \right) \\ &= \sum_{t=1}^T \frac{\beta_t^2}{2\sigma_t^2\alpha_t(1-\bar{\alpha}_t)} \mathbb{E}_{p_\mathcal{N}(\varepsilon|0,I)}\left( \|\varepsilon - \varepsilon_\theta(x_t,t)\|^2 \right) \\ &= \frac{T}{2}\mathbb{E}_{t,\varepsilon}\left( \frac{\beta_t^2}{\sigma_t^2\alpha_t(1-\bar{\alpha}_t)} \|\varepsilon - \varepsilon_\theta(x_t,t)\|^2 \right) \\ \end{align}

また、学習に用いる教師画像は1枚だけではないので、複数枚分の平均を取る。

\begin{align} Loss &= \frac{T}{2}\mathbb{E}_{t,x_0,\varepsilon}\left( \frac{\beta_t^2}{\sigma_t^2\alpha_t(1-\bar{\alpha}_t)} \|\varepsilon - \varepsilon_\theta(x_t,t)\|^2 \right) \\ \end{align}

ここで、\(x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\varepsilon\)。
また、\(\mathbb{E}_{\varepsilon,t,x_0}\)は、\(\varepsilon\)を\(\mathcal{N}(0,I)\)から、\(t\)を\(\{0, 1, \cdots, T\}\)から一様に、\(x_0\)をデータセット\(X^D\)から一様にサンプルしたときの期待値を意味する。

各時刻の係数\(\frac{T\beta_t^2}{2\sigma_t^2\alpha_t(1-\bar{\alpha}_t)}\)は各時刻で以下のように変動する。


各時刻の損失関数の係数 (\(T=1000\)、\(\beta_t\)は\(\beta_1=10^{-4}\)から\(\beta_T=0.02\)に線形に増加、\(\sigma_t^2=\beta_t\)の場合)

DDPMではこれらの係数を全て無視し、損失関数を更に簡略化した損失についても検証している。

\begin{align} Loss_{simple} &= \mathbb{E}_{t,x_0,\varepsilon}\left( \|\varepsilon - \varepsilon(x_t,t)\|^2 \right) \\ \end{align}

上図のように、損失関数は浅い(\(t\)が小さい)タイムステップほど大きな値を取っていたので、係数の均一化によって浅いタイムステップの損失の重み付けを下げることになる。
論文によると、係数均一化を行うことで、より深いタイムステップで困難になるノイズ除去問題をより強く学習するような効果が得られるという。実験でも係数均一化によって生成画像の品質が向上することが確認されている。

DDPMの損失関数

\begin{align} \begin{cases} Loss &= \mathbb{E}_{t,x_0,\varepsilon}\left( \frac{\beta_t^2}{\sigma_t^2\alpha_t(1-\bar{\alpha}_t)} \|\varepsilon - \varepsilon_\theta(x_t,t)\|^2 \right) \\ Loss_{simple} &= \mathbb{E}_{t,x_0,\varepsilon}\left( \|\varepsilon - \varepsilon(x_t,t)\|^2 \right) \end{cases} \end{align}

学習・生成アルゴリズム

学習

DDPMの学習アルゴリズム
  1. \(x_0\)をデータセットから一様にサンプル。
  2. \(t\)を\(\{1, 2, \cdots , T\}\)から一様にサンプル。
  3. \(\varepsilon\)を\(\mathcal{N}(0,I)\)からサンプル。
  4. \(x_t := \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\varepsilon\)を生成。
  5. \(x_t,t\)をニューラルネットワークに入力し、出力\(\varepsilon_\theta\)を得る。
  6. \(\varepsilon,\varepsilon_\theta\)を用いて損失関数を計算し、ニューラルネットワークを学習。

生成

DDPMの生成アルゴリズム
  1. \(x_T\)を\(\mathcal{N}(0,I)\)からサンプル。
  2. 以下の手順を\(t = T, \cdots, 1\)に対して繰り返し行う。
    1. \(x_t,t\)をニューラルネットワークに入力し、出力\(\varepsilon_\theta\)を得る。
    2. \(\mu_\theta := \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\varepsilon_\theta\right)\)を計算。
    3. \(p_\mathcal{N}\left(x_{t-1}|\mu_\theta,\sigma_t^2I\right)\)に従って\(x_{t-1}\)をサンプル。ただし、\(t=1\)のときだけは\(x_0=\mu_\theta\)とする。
  3. \(x_0\)を画像として出力。

実験

論文には以下の設定で様々な実験を行った結果が掲載されている。

  • \(T=1000\)
  • \(\beta_t\)は\(10^{-4}\)から\(0.02\)に線形に増加。つまり、\(\beta_t = 10^{-4} + (0.02 - 10^{-4})\frac{t-1}{1000-1} \)。
  • ニューラルネットワークはU-Net型の独自モデル。内部でSelf-Attentionを使用。
  • 時刻\(t\)はTransformerと同様の位置符号化によってU-Netに与えられる。
  • 重みは全タイムステップで共有。

CelebA-HQデータセットで学習したモデルの生成例

LSUN Churchデータセットで学習したモデルの生成例

LSUN Bedroomデータセットで学習したモデルの生成例

CIFAR10データセットで学習したモデルの生成例
(完全なノイズから画像を生成するまでの中間生成画像を可視化)

CelebA-HQの画像に1000, 750, 500, 250, 0回のノイズを加え、それぞれのノイズ付与画像から\(x_0\)を復元した結果。
右下がノイズ付与画像で、それ以外は右下の画像から復元した3つの生成例。

CelebAの画像(左右Source)に各ステップ(縦軸)のノイズを加え、ノイズ付与画像同士の加重平均(横軸)から\(x_0\)を復元した結果。生成された平均画像も顔画像として自然な見た目になっていることがわかる。

参考

解説