Stable Diffusion 3

この記事では2024年2月に発表され、6月にモデルが公開されたStable Diffusion 3(SD3)の理論やモデル構造について解説する。

目次

Stable Diffusion 3の理論

Stable Diffusion 3(SD3)は、Stability AI社が開発した画像生成AI「Stable Diffusion (SD)」の後継モデル。同社が2024年2月22日に発表し、6月12日にソースコードと学習済みモデルの一部が公開された。

Stable Diffusion 3の理論は主に「Diffusion Transformer」・「Rectified Flow」の2つのアイデアによって構成される。

Diffusion Transformer

Diffusion Transformer(DiT)は、2022年12月に発表された拡散画像生成モデルの実装例。
拡散画像生成の実装には主にCNNで構成されるU-Net構造が採用されることが多かったが、DiTではU-NetやCNNではなくVision Transformer(ViT)をベースとした構造が採用されている。

元々Transformerは言語入力を処理するために考案されたモデルだが、ViTでは画像処理にTransformerを適用するために、入力画像をタイル状に分割してそれぞれの領域(=パッチ)を言語モデルの1トークンとして扱う。


ViTの仕組み (ViTの論文から引用)

SD3におけるDiTの構造については後述。

Rectified Flow

Rectified Flow(RF)は、2022年9月に発表された異なる2つの確率分布を結ぶ経路を学習する理論。最適輸送問題に関連のある「Flow」と呼ばれる分類に属する理論。
拡散モデルもまたデータ分布と標準正規分布の2つの分布を結ぶ経路を学習していると捉えることができるので、Rectified Flowの理論を当てはめて生成モデルを学習することができる。

Rectified Flow

2つの分布\(\pi_0(x), \pi_1(x)\)から\(x_0\sim\pi_0, x_1\sim\pi_1\)を取り出し、\(x_0, x_1\)の内分点を\(x_t=(1-t)x_0+tx_1\)とする。
Rectified Flowでは、\(\frac{dx_t}{dt}=x_1-x_0\)の\(\pi_0\times\pi_1\)に対する期待値をニューラルネットワークで学習する。そのため、損失関数は次のように設計されている。

\begin{align} Loss = \mathbb{E}_{t,x_0\sim\pi_0,x_1\sim\pi_1}\left( \left\| (x_1-x_0)-v_\theta(x_t,t) \right\|^2 \right) \end{align}

\(\frac{dx_t}{dt}\)や\(v_\theta\)のことを速度(velocity)と呼ぶ。

これは\(\alpha_t=1-t, \sigma_t=t\)の拡散モデルに対応していて、生成過程のODEは次のような簡潔な形になる。
(本投稿では拡散モデルの\(t\)に合わせて\(x_t\)を定義しているが、RFの論文では\(x_0\)から\(x_1\)を生成するという前提になっていて添字の1と0が逆なので注意)

\begin{align} \frac{dx_t}{dt} = v_\theta(x_t,t) \end{align}

Reflow

RFの論文では、RFによって学習したモデルを使って\(x_1\)(ノイズに相当)から\(\bar{x}_0\)(データに相当)を生成し、\((x_0,\bar{x}_1)\)の集合を新たなデータセットとして2段階目の学習をすることが提案されている。このような手法をReflowと言う。

RFでは\(x_0,x_1\)は取り得る全ての組み合わせで学習され、\(v_\theta(x_t,t)\)は\(x_1-x_0\)の\(\pi(x_0|x_t)\)に対する期待値を近似していた。
一方で、Reflowでは\(x_0,x_1\)がProbability Flowに沿って一対一に対応していて、PFが交叉しないという性質により、その内分点である\(x_t\)に対して\(x_0(x_t),x_1(x_t)\)は一意に定まる。したがって、Reflowの\(v_\theta(x_t,t)\)は期待値ではなく\(x_1(x_t)-x_0(x_t)\)そのものを近似するように学習される。
このことから、Reflowを実行することで、Consistency Modelsと同様にProbability Flowが直線状になる効果が得られることがわかる。

下図にRectified FlowとReflowを適当なモデルで学習した結果を示す。
紫の点は\(x_1\)、薄い赤の点は学習に利用した教師データ\(x_0\)、濃い赤の点は学習後モデルから生成した\(x_0\)、青い線はProbability Flow。
Rectified Flowだけでは対角線上の経路も学習されてしまうが、更にReflowを行うことで、対応する点同士の最短距離の経路が学習されることがわかる。


Rectified Flowで学習した例

Rectified Flowで学習後にReflowを実行した例

Reflowによって経路は直線になるので、Reflowを実行した後に蒸留を行うとより効率的に1ステップ生成モデルを学習できると論じられている。

なお、SD3ではRectified Flowのみを取り入れていてReflowを実行していないようなので、SD3のProbability Flowは直線状にはなっていない。

RFのSD3への適用

広範な設定の実験を行うため、\(\alpha_t=1-t,\sigma_t=t\)だけでなく、より一般の\(\alpha_t,\sigma_t\)についてもRectified Flowと同様に\(\frac{dx_t}{dt}\)の\(\pi_0\times\pi_1\)に対する期待値を考える。
以前の考察(RF)より、拡散過程

\begin{align} x_t = \alpha_tx_0 + \sigma_t\varepsilon \end{align}

に対して、\(\frac{dx_t}{dt}\)の期待値は次のようになる。(\(\lambda_t:=\log\frac{\alpha_t^2}{\sigma_t^2}\))

\begin{align} v_\theta(x_t,t) &\sim \mathbb{E}_{x_0|x_t}\left(\frac{dx_t}{dt}\right) \\ &= \frac{\alpha'_t}{\alpha_t}x_t - \frac{\sigma_t\lambda'_t}{2}\mathbb{E}_{x_0|x_t}(\varepsilon) \end{align}

拡散モデルのノイズ学習モデルでは、\(\varepsilon_\theta(x_t,t)\sim\mathbb{E}_{x_0|x_t}(\varepsilon)\)となるように学習するので、

\begin{align} v_\theta(x_t,t) = \frac{\alpha'_t}{\alpha_t}x_t - \frac{\sigma_t\lambda'_t}{2}\varepsilon_\theta(x_t,t) \end{align}

という関係が成り立つ。
したがって、RFの損失関数はノイズ予測モデルを用いると次のように表すこともできる。

\begin{align} Loss &= \mathbb{E}_{t,x_0,\varepsilon}\left( \left\| \frac{dx_t}{dt}-v_\theta(x_t,t) \right\|^2 \right) \\ &= \mathbb{E}_{t,x_0,\varepsilon}\left( \left\| \left(\frac{\alpha'_t}{\alpha_t}x_t - \frac{\sigma_t\lambda'_t}{2}\varepsilon\right) - \left(\frac{\alpha'_t}{\alpha_t}x_t - \frac{\sigma_t\lambda'_t}{2}\varepsilon_\theta(x_t,t)\right) \right\|^2 \right) \\ &= \mathbb{E}_{t,x_0,\varepsilon}\left( \left\| \frac{\sigma_t\lambda'_t}{2}\left(\varepsilon - \varepsilon_\theta(x_t,t)\right) \right\|^2 \right) \\ &= \mathbb{E}_{t,x_0,\varepsilon}\left( \left(\frac{\sigma_t\lambda'_t}{2}\right)^2\left\| \varepsilon - \varepsilon_\theta(x_t,t) \right\|^2 \right) \\ \end{align}

つまり、一般RFとDDPMには学習時の\(t\)に対する損失の重み付けの違いしかないことがわかる。
また、特にRFのノイズスケジューラーでは\(\sigma_t=t, \lambda'_t=\frac{2}{t(1-t)}\)なので、損失関数は更に具体的に次のようになる。

\begin{align} Loss = \mathbb{E}_{t,x_0,\varepsilon}\left( \frac{1}{(1-t)^2}\left\| \varepsilon - \varepsilon_\theta(x_t,t) \right\|^2 \right) \\ \end{align}

損失関数の重み付け

SD3の論文では更に様々な設定で損失関数を重み付けして学習し、性能を比較している。

中でも rf/lognorm(0.00, 1.00) という設定が特に性能が高かったと述べられている。
rf/lognorm(0.00, 1.00)では\(\alpha_t=1-t,\sigma_t=t\)のノイズスケジュールを採用し、損失関数を\(\pi_{ln}\)によって次のように重み付けている。

\begin{align} Loss &= \mathbb{E}_{t,x_0,\varepsilon}\left( \pi_{ln}(t;0.00,1.00)\left(\frac{\sigma_t\lambda'_t}{2}\right)^2\left\| \varepsilon - \varepsilon_\theta(x_t,t) \right\|^2 \right) \\ &= \mathbb{E}_{t,x_0,\varepsilon}\left( \frac{1}{\sqrt{2\pi}}\frac{1}{t(1-t)}\exp\left( -\frac{\mathrm{logit}(t)^2}{2} \right)\frac{1}{(1-t)^2} \left\| \varepsilon - \varepsilon_\theta(x_t,t) \right\|^2 \right) \\ &= \mathbb{E}_{t,x_0,\varepsilon}\left( \frac{1}{\sqrt{2\pi}}\frac{1}{t(1-t)^3}\exp\left( -\frac{\mathrm{logit}(t)^2}{2} \right) \left\| \varepsilon - \varepsilon_\theta(x_t,t) \right\|^2 \right) \\ \end{align}

\(\pi_{ln}(t;m,s)\)はLogit-normal distributionで、\(\mathrm{logit}(t)=\log\frac{t}{1-t}\)が正規分布に従うような分布である。

\begin{align} \pi_{ln}(t;m,s) = \frac{1}{s\sqrt{2\pi}}\frac{1}{t(1-t)}\exp\left( -\frac{\left(\mathrm{logit}(t)-m\right)^2}{2s^2} \right) \end{align}

損失関数の重みは次の図のようになっている。元の重み付けに沿ってノイズの多い領域が重点的に学習されているが、\(t=1\)では重みが0となり全体が有限の値を取ることがわかる。
なお、SD3の実際の実装ではニューラルネットワークの直接の出力は\(\varepsilon_\theta\)ではなく\(v_\theta\)であることに注意。


rf/lognorm(0.00, 1.00)の損失の重み付け (横軸は\(t\)、縦軸は\(\pi_{ln}(t;0.00,1.00)\left(\frac{\sigma_t\lambda'_t}{2}\right)^2\))

Schedule Shift

画像サイズによるノイズレベルの違い

HxWの画像\(x_0\)に対して拡散過程でノイズを付与した画像を\(x_t\)として、\(x_t\)を\(1/s\)に縮小した画像を\(\tilde{y}_t\)とする。
また、逆に\(x_0\)を\(1/s\)に縮小した画像を\(y_0\)とし、\(y_0\)から拡散過程でノイズを付与した画像を\(y_t\)とする。
ただし、\(s\)は自然数であり、縮小は平均画素法によって行われるものとする。

このとき、\(\tilde{y}_t\)の1つの画素値は、その画素に対応する\(x_t\)の\(s^2\)個の画素値の平均になっている。
正規分布の再生成より、\(s^2\)個の独立した標準正規分布について\(\sum_j^{s^2} \varepsilon^{(j)} \sim \mathcal{N}(0,s^2)\)となるので、\(\tilde{y}_t\)の各画素値は次のように表される。

\begin{align} \tilde{y}_t^{(i)} &= \frac{1}{s^2}\sum_j^{s^2} x_t^{(j)} \\ &= \frac{1}{s^2}\sum_j^{s^2} (\alpha_t x_0^{(j)} + \sigma_t \varepsilon^{(j)}) \\ &= \alpha_t \frac{1}{s^2} \sum_j^{s^2} x_0^{(j)} + \frac{\sigma_t}{s^2}\sum_j^{s^2} \varepsilon^{(j)} \\ &= \alpha_t \frac{1}{s^2} \sum_j^{s^2} x_0^{(j)} + \frac{\sigma_t}{s^2}s\varepsilon \\ &= \alpha_t y_0^{(i)} + \frac{\sigma_t}{s} \varepsilon \\ \end{align}

一方\(y_t\)については、

\begin{align} y_t^{(i)} &= \alpha_t y_0^{(i)} + \sigma_t \varepsilon \\ \end{align}

なので、\(\tilde{y}_t\)と\(y_t\)には\(s\)倍の標準偏差比があることがわかる。

一般に、拡散過程では\(t=0\)に近い初期時刻で高周波成分(=細部の表現)が破損され、\(t\)が大きくなるにつれて低周波成分(=全体像)が破損され始めることが知られている。
しかし、上の式で確認したように画像サイズが大きい\(x_t\)ほど低周波成分にかかるノイズの係数が小さくなるので、拡散過程のほとんどの時間で低周波成分は破損されず、\(t=1\)に近い僅かな領域で急速に破損されるという挙動になる。生成過程においては、全タイムステップ数が少ない場合に低周波成分を構成するための十分なタイムステップを確保することができず、生成結果の品質に悪影響があると分析されている。

Schedule Shift

その問題を緩和するため、高解像度の画像\(x_0\)に対して拡散過程を適用する際、低解像度の場合とSNRが一致するようにノイズスケジュールの調整を行う方法が考案された。このテクニックは参考論文では「Resolution-dependent shifting of timestep schedules」・「shift schedules」などと呼ばれているが、本投稿ではSchedule Shiftと呼ぶことにする。

Schedule Shiftを上記の式に沿って表現すると、\(x_t\)側の\(\alpha_t,\sigma_t\)を調整して\(\tilde{\alpha}_t,\tilde{\sigma}_t\)に変更することになる。
このときのSNRが\(\tilde{y}_t\)と\(y_t\)で一致するように、\(\tilde{y}_t\)のノイズスケジュールはより早い時間で強いノイズを加えなければならない。

\begin{align} \frac{\alpha_t^2}{\sigma_t^2} &= SNR_y(t) \\ &= SNR_\tilde{y}(t) \\ &= \frac{\tilde{\alpha}_t^2}{\left(\frac{\tilde{\sigma}_t}{s}\right)^2} \\ &= \frac{\tilde{\alpha}_t^2}{\tilde{\sigma}_t^2}s^2 \end{align}

したがって、

\begin{align} \mathrm{SNR}_{shift}(t) &:= \frac{\tilde{\alpha}_t^2}{\tilde{\sigma}_t^2} \\ &= \frac{\alpha_t^2}{\sigma_t^2}\frac{1}{s^2} \\ &= \mathrm{SNR}(t)\frac{1}{s^2} \end{align}

となるように\(\tilde{\alpha}_t,\tilde{\sigma}_t\)を選べば良い。
\(s>1\)のとき、調整されたSNRは元のSNRより小さな値になる(=ノイズの比率が大きくなる)ことがわかる。

\(\tilde{\alpha}_t,\tilde{\sigma}_t\)が元の\(\alpha_t,\sigma_t\)に対して時間をずらした関数

\begin{align} \tilde{\alpha}_t = \alpha_{\tilde{t}} \\ \tilde{\sigma}_t = \sigma_{\tilde{t}} \end{align}

であるとすれば、\(\frac{\tilde{\alpha}_t^2}{\tilde{\sigma}_t^2}=\mathrm{SNR}(\tilde{t})\)なので\(\tilde{t}\)は

\begin{align} \tilde{t} = \mathrm{SNR}^{-1}\left(\mathrm{SNR}(t)\frac{1}{s^2}\right) \end{align}

のように表される。
画像の生成時には\(t\)を等間隔に取り出した上で、上記の式によって\(t\)を\(\tilde{t}\)に変換し、以後は\(t\)の代わりに\(\tilde{t}\)を使って生成を行うことで、サイズの大きい画像に対して品質向上の効果が得られるということになる。

特にRFのノイズスケジュール(\(\alpha_t=1-t,\sigma_t=t\))の場合、

\begin{align} \left(\frac{1}{\tilde{t}}-1\right)^2 &= \left(\frac{1-\tilde{t}}{\tilde{t}}\right)^2 \\ &= \mathrm{SNR}(\tilde{t}) \\ &= \mathrm{SNR}(t)\frac{1}{s^2} \\ &= \left(\frac{1-t}{t}\right)^2\frac{1}{s^2} \\ &= \left(\frac{1-t}{st}\right)^2 \end{align}

なので、

\begin{align} \tilde{t}(t) &= \frac{1}{\frac{1-t}{st} + 1} \\ &= \frac{st}{1+(s-1)t} \\ \end{align}

\begin{align} t(\tilde{t}) &= \frac{1}{s\frac{1-\tilde{t}}{\tilde{t}}+1} \\ &= \frac{\tilde{t}}{s+(1-s)\tilde{t}} \\ \end{align}

と表すことができる。


RFのSchedule Shift (timestep schedule)

SD3のShifted Scheduleとその他のSchedulerの比較 (横軸は\(t\)、縦軸は\(\mathrm{SNR}(t)\))

SD3の論文では\(s\)の変化に対する人間による評価を比較し、1024x1024の画像生成を行う際には(異なる画像サイズでノイズ比を合わせるという元の意味から離れて)固定値\(s=3.0\)を選択したと述べられている。

論文には\(s=3.0\)のtimestep scheduleを「学習とサンプリングの両方」に適用すると記されているので、前述の損失の重み付けに更にスケジューラー分の偏りが反映されるものと思われる。Schedule Shiftが学習にも使われている場合、最終的な損失関数は次のようになり、元の重み付けよりもノイズが多い方に偏っていることがわかる。

\begin{align} Loss &= \mathbb{E}_{t\sim\mathcal{U}(0,1),x_0,\varepsilon}\left( \pi_{ln}(\tilde{t}(t);0.00,1.00)\left(\frac{\sigma_{\tilde{t}(t)}\lambda'\left(\tilde{t}(t)\right)}{2}\right)^2\left\| \varepsilon - \varepsilon_\theta(x_{\tilde{t}(t)},\tilde{t}(t)) \right\|^2 \right) \\ &= \mathbb{E}_{x_0,\varepsilon}\left( \int_0^1 \pi_{ln}(\tilde{t}(t);0.00,1.00)\left(\frac{\sigma_{\tilde{t}(t)}\lambda'\left(\tilde{t}(t)\right)}{2}\right)^2\left\| \varepsilon - \varepsilon_\theta(x_{\tilde{t}(t)},\tilde{t}(t)) \right\|^2 dt \right) \\ &= \mathbb{E}_{x_0,\varepsilon}\left( \int_0^1 \pi_{ln}(\tilde{t};0.00,1.00)\left(\frac{\sigma_\tilde{t}\lambda'_\tilde{t}}{2}\right)^2\left\| \varepsilon - \varepsilon_\theta(x_\tilde{t},\tilde{t}) \right\|^2 \frac{dt}{d\tilde{t}}d\tilde{t} \right) \\ &= \mathbb{E}_{x_0,\varepsilon}\left( \int_0^1 \pi_{ln}(\tilde{t};0.00,1.00)\left(\frac{\sigma_\tilde{t}\lambda'_\tilde{t}}{2}\right)^2\left\| \varepsilon - \varepsilon_\theta(x_\tilde{t},\tilde{t}) \right\|^2 \frac{s}{(s+(1-s)\tilde{t})^2} d\tilde{t} \right) \\ &= \mathbb{E}_{\tilde{t}\sim\mathcal{U}(0,1),x_0,\varepsilon}\left( \frac{s}{(s+(1-s)\tilde{t})^2} \pi_{ln}(\tilde{t};0.00,1.00)\left(\frac{\sigma_\tilde{t}\lambda'_\tilde{t}}{2}\right)^2\left\| \varepsilon - \varepsilon_\theta(x_\tilde{t},\tilde{t}) \right\|^2 \right) \\ &= \mathbb{E}_{\tilde{t}\sim\mathcal{U}(0,1),x_0,\varepsilon}\left( \frac{s}{(s+(1-s)\tilde{t})^2} \frac{1}{\sqrt{2\pi}}\frac{1}{\tilde{t}(1-\tilde{t})}\exp\left( -\frac{\mathrm{logit}(\tilde{t})^2}{2} \right) \frac{1}{(1-\tilde{t})^2}\left\| \varepsilon - \varepsilon_\theta(x_\tilde{t},\tilde{t}) \right\|^2 \right) \\ &= \mathbb{E}_{\tilde{t}\sim\mathcal{U}(0,1),x_0,\varepsilon}\left( \frac{1}{\sqrt{2\pi}} \frac{s}{(s+(1-s)\tilde{t})^2\tilde{t}(1-\tilde{t})^3} \exp\left( -\frac{\mathrm{logit}(\tilde{t})^2}{2} \right) \left\| \varepsilon - \varepsilon_\theta(x_\tilde{t},\tilde{t}) \right\|^2 \right) \\ \end{align}


Schedule Shiftを適用した場合と適用しない場合のrf/lognorm(0.00, 1.00)の損失関数の重み付け
(それぞれ積分して1になるよう正規化)

SD3のモデル構造

Stability AIが実験した800Mから8BパラメーターのSD3の内、2BパラメーターにあたるStable Diffusion 3 Mediumのパラメーターファイルが2024年6月12日に公開された。
2024年6月12日時点では公式のソースコードが公開されていない代わりにComfyUIのworkflowが公式から配布されていて、ComfyUIでの実装が実質的な公式ソースコードとなっている模様。また、diffusersでもSD3が実装されている。

以下ではMediumの設定に沿ってモデル構造を図解する。図中のクラス名などはdiffusersの実装に準拠する。

予測モデル

SDのノイズ予測モデルに相当する部分は、SD3では前述の「RFの速度」を予測する。
SD3 Mediumでは、24個のMM-DiT(図中では「JointTransformerBlock」と表記)を主とした次の図のような構造を取る。


SD3の予測モデル (1024x1024の画像を生成する場合) (クリックで拡大)

各構成要素について以下で解説する。

MM-DiT

MM-DiT(multimodal transformer-based diffusion backbone)と呼ばれる構造のDiTを使用。
MM-DiTには文章情報と画像情報を処理する2つの経路が存在し、途中でTransformerを通して互いの情報を共有する仕組みになっている。


MM-DiTの構造 (クリックで拡大)

Positional Embedding

192x192のグリッドに対して、左上を\((0,0)\)として縦方向を添字を\(i\)、横方向の添字を\(j\)としする。
このとき、各座標\(i,j\)に対して次のように1536のベクトルを割り当てることで、192x192x1536のテンソルを構成する。

\begin{align} pos\_embed_{i,j} = \begin{pmatrix} \sin\left(\frac{j-64}{4\cdot10000^{0/384}}\right) \\ \sin\left(\frac{j-64}{4\cdot10000^{1/384}}\right) \\ \vdots \\ \sin\left(\frac{j-64}{4\cdot10000^{383/384}}\right) \\ \cos\left(\frac{j-64}{4\cdot10000^{0/384}}\right) \\ \cos\left(\frac{j-64}{4\cdot10000^{1/384}}\right) \\ \vdots \\ \cos\left(\frac{j-64}{4\cdot10000^{383/384}}\right) \\ \sin\left(\frac{i-64}{4\cdot10000^{0/384}}\right) \\ \sin\left(\frac{i-64}{4\cdot10000^{1/384}}\right) \\ \vdots \\ \sin\left(\frac{i-64}{4\cdot10000^{383/384}}\right) \\ \cos\left(\frac{i-64}{4\cdot10000^{0/384}}\right) \\ \cos\left(\frac{i-64}{4\cdot10000^{1/384}}\right) \\ \vdots \\ \cos\left(\frac{i-64}{4\cdot10000^{383/384}}\right) \\ \end{pmatrix} \end{align}

SD3でHxWx3の画像を生成するとき、初期ノイズは(H/8)x(W/8)x16のサイズで作成され、最初のパッチ化で(H/16)x(W/16)x1536のサイズになる。その(H/16)x(W/16)のテンソル(=latent)を上で定義した192x192x1536のグリッドの中央に配置して、latentのそれぞれの座標に対応するグリッドの1536次元値をPositional Embeddingとして足し合わせることで利用する。

Timestep Embedding

時刻\(t\)の埋め込みはSDSDXLと全く同じ方法で作られる。次元は256。

Unpatching

1024x1024の生成の場合、ノイズ予測モデルが出力すべき値のshapeは128x128x16だが、DiTを利用するために2x2でパッチ化を行ったのでニューラルネットワークの直接の出力は64x64=4096個のベクトルになる。
そこで、4096個の64次元のベクトルを4分割して次の図のように配置し直すことで、空間次元を128x128に戻すことができる。この処理はUnpatchingと名付けられている。


Unpatching

Text Encoder

SD3では、SDXLで使用したものと同じ2種類のCLIPと、T5(Text-to-Text Transfer Transformer)というモデルの合計3つのText Encoderが同時に使われる。T5はGoogleが2019年10月に発表した言語モデルで、翻訳や要約などの問いを入力に取り、その答えを出力するよう学習されたものである。


SD3のText Encoder全体 (クリックで拡大)
  • CLIP-L/14 - SDやSDXLに使われたものと同じ。SDXL同様にCLIP skip=2が適用される。
  • CLIP-G/14 - SDXLに使われたものと同じ。SDXL同様にCLIP skip=2が適用される。
  • T5 XXL - T5のバージョン1.1で追加された11Bパラメーターのモデル。Encoderのみを使う。

3つのText Encoderの出力は、次の図のように1つのベクトル列(154x4096)に統合されてDiTに入力される。


最終的な文章埋め込み

3つのText Encoderにはそれぞれ学習中に46.3%という高いドロップアウト率が設定されているので、どれか一つを取り除いて代わりに0で埋めても品質をそれほど下げることなく画像を生成することが可能。

CLIP

CLIPの構造は次の図の通りで、SDXLのときと同様。
SDXL同様にCLIP skip=2を選択。


SD3のCLIP-Gの構造 (クリックで拡大)
CLIP-Lについても次元(1280→768)とレイヤー数(32→12)が変わるだけで構造は同じ。

SDXLではCLIP-Gのみから最終層出力のpoolを行っていたが、SD3ではCLIP-Lでも同じようにpoolを行う。
poolされた文章情報はSDXLと同様にtimestep埋め込みに足し合わされ、DiTに入力される。

T5

T5ではCLIP skipのような方法は使わず、Text Encoderの出力をそのまま利用する。


SD3のCLIP-Gの構造 (クリックで拡大)

VAE

SD3のVAEの構造は下図の通り。ResやAttentionなどの構成要素はSDのVAEで使われたものと同じ。
SDとSDXLのVAEは出力チャンネルが4だったが、SD3では16チャンネルで学習される。
それ以外の部分はSDやSDXLのものとほとんど違いはない。


SD3のVAEの構造 (クリックで拡大)

参考

timestep schedules

実装