拡散モデルのサンプラー (3) - UniPC

この記事では拡散モデルの生成過程で使われるサンプラーの一種であるUniPCについて解説する。
これまでと同様に、SMLDの前提・表記を用いる。

関連記事

目次

UniPC

UniPCはDPM-Solverの考え方を拡張したアルゴリズム。

前提

以下の定義・表記を用いる。

  • s=ti+1<t=ti<ti1<<tip+1
  • rj:=tij(j{1,2,,p1})
  • rp:=ti+1
    • したがって、s=rp<t<r1<<rp1
  • h:=λsλt=λti+1λti
  • kj:=λrjλth(j{1,2,,p})
    • したがって、kp=1

本投稿における表記

UniP

DPM-Solverによると、拡散モデルのODEは

dyt=ε(yt)dςt=eλtε^(λt)dλt

であり、

ys=ytλtλseλε^(λ)dλ=ytλtλseλn=0p1(λλt)nn!ε^(n)(λt)dλ+O(hp+1)=ytςsn=0p1hn+1ϕn+1(h)ε^(n)(λt)+O(hp+1)=ytςs(eh1)ε^(yt)ςsn=1p1hn+1ϕn+1(h)ε^(n)(λt)+O(hp+1)=yt+(ςsςt)ε^(yt)ςsn=1p1hn+1ϕn+1(h)ε^(n)(λt)+O(hp+1)

が成り立つことを既に解説した。ただし、ϕn(h)=m=0hm(n+m)!

既に計算済みの{ε(yrj)}j=1p1ε(yt)を用いて、

n=1p1hn+1ϕn+1(h)ε^(n)(λt)=j=1p1uj(h)(ε(yrj)ε(yt))+O(hp+1)

と表すことを試みる。この条件を満たす関数の集合{uj}j=1p1を探す必要がある。

式を変形すると、

n=1p1hn+1ϕn+1(h)ε^(n)(λt)=j=1p1uj(h)(ε(yrj)ε(yt))+O(hp+1)=j=1p1uj(h)(ε^(λt+kjh)ε^(λt))+O(hp+1)=j=1p1uj(h)n=1p1(kjh)nn!ϵ^(n)(λt)+O(hp+1)=n=1p1hnn!(j=1p1kjnuj(h))ϵ^(n)(λt)+O(hp+1)

なので、各n{1,2,,p1}に対して、

hn+1ϕn+1(h)=hnn!j=1p1kjnuj(h)+O(hp+1)

n!hϕn+1(h)=j=1p1kjnuj(h)+O(hp+1n)

となれば良い。これはつまり、

h(1!ϕ2(h)2!ϕ3(h)(p1)!ϕp(h))=(111k1k2kp1k1p2k2p2kp1p2)(k1u1(h)k2u2(h)kp1up1(h))+(O(hp)O(hp1)O(h2))

となることを意味する。ここで、

Vp1:=(111k1k2kp1k1p2k2p2kp1p2)

Vandermonde行列と呼ばれる行列である。

式を解くと、

(k1u1(h)k2u2(h)kp1up1(h))=hVp11((1!ϕ2(h)2!ϕ3(h)(p1)!ϕp(h))+(O(hp1)O(hp2)O(h)))

となる。
この式から、各uj(h)uj(h)=O(h)となることがわかる。

各成分のO(hpn)の部分はそれぞれそのオーダー以上の任意の関数に置き換えることができるが、特に全て0とした場合を考え、

(k1u1(h)k2u2(h)kp1up1(h))=hVp11(1!ϕ2(h)2!ϕ3(h)(p1)!ϕp(h))

とすれば良い。

このようにして求められたアルゴリズムをUniP(Unified Predictor)と言う。収束精度はpとなる。


UniP
UniP
  • s=ti+1<t=ti<ti1<<tip+1
  • rj:=tij(j{1,2,,p1})
  • h:=λsλt
  • kj:=λrjλth(j{1,2,,p1})
  • {y¯rj}j=1p1が既に生成されていて、真の解yrjとの差がy¯rjyrj=O(hp)
  • y¯tが既に生成されていて、真の解ytとの差がy¯tyt=O(hp+1)
  • 既に計算済みの{ε(y¯rj)}j=1p1ε(y¯t)を保持している
  • εθ(yt,t)ytに対してLipschitz連続

とする。
このとき、y¯sを次のように生成する。

y¯s=y¯t+(ςsςt)ε(y¯t)ςsj=1p1vj(h)kj(ε(y¯rj)ε(y¯t))

ただし、vj(h)は以下から求められる値。

(v1(h)v2(h)vp1(h)):=hVp11(1!ϕ2(h)2!ϕ3(h)(p1)!ϕp(h))

Vp1:=(111k1k2kp1k1p2k2p2kp1p2)

ϕn(h):=m=0hm(n+m)!=1hn(ehm=0n1hmm!)

例 (p=1)

p=1の場合、総和の項がなくなりEuler法と完全に一致する。

y¯ts=y¯t+(ςsςt)ε(y¯t)

例 (p=2)

p=2の場合、行列やベクトルは全て1次元の値になる。

v1(h):=h11!ϕ2(h)=ehh1h

となるので、UniPのアルゴリズムは次のようになる。

y¯ts=y¯t+(ςsςt)ε(y¯t)ςtsehh1k1h(ε(y¯r1)ε(y¯t))=y¯t+(ςsςt)(ε(y¯t)+12k12(ehh1)h(eh1)(ε(y¯r1)ε(y¯t)))

DPM-Solver-2の導出過程でehh1=h(eh1)2+O(h3)と近似する場面があるが、その近似を実行しなかった場合に上の式と一致する。

例 (p=3)

p=3の場合、

(v1(h)v2(h)):=h(11k1k2)1(1!ϕ2(h)2!ϕ3(h))=hk2k1(k21k11)(ϕ2(h)2ϕ3(h))=hk2k1(k2ϕ2(h)2ϕ3(h)k1ϕ2(h)+2ϕ3(h))

となるので、UniPのアルゴリズムは次のようになる。

y¯ts=y¯t+(ςsςt)ε(y¯t)ςtshk2ϕ2(h)2ϕ3(h)k1(k2k1)(ϵ(y¯ti1)ε(y¯t))ςtshk1ϕ2(h)+2ϕ3(h)k2(k2k1)(ϵ(y¯ti2)ε(y¯t))

UniC

UniPCではUniPに加えて、UniC(Unified Corrector)というアルゴリズムを組み合わせて利用する。
UniCでは、UniPで一旦予測したysを利用して再びysを予測(補正)する。UniPの出力と区別するために、UniCの出力をyscと表すことにする。

UniCのアルゴリズムの考え方はUniPと全く同じである。UniPではp個の値

{ε(yt),ε(yr1),ε(yr2),,ε(yrp1)}={ε(yt),ε(yti1),ε(yti2),,ε(ytip+1)}

を利用していたが、UniCではp+1個の値

{ε(yt),ε(yr1),ε(yr2),,ε(yrp1),ε(yrp)}={ε(yt),ε(yti1),ε(yti2),,ε(ytip+1),ε(yti+1)}

を利用してp個の関数vj(h)を求める。

総和の個数が1つ増えたことで、最終的な収束次数はp+1になる。


UniC
UniC
  • s=ti+1<t=ti<ti1<<tip+1
  • rj:=tij(j{1,2,,p1})
  • rp:=ti+1
  • h:=λsλt
  • kj:=λrjλth(j{1,2,,p})
  • {y¯rj}j=1pが既に生成されていて、真の解yrjとの差がy¯rjyrj=O(hp+1)
  • y¯tが既に生成されていて、真の解ytとの差がy¯tyt=O(hp+1)
  • y¯tcが既に生成されていて、真の解ytとの差がy¯tcyt=O(hp+2)
  • 既に計算済みの{ε(y¯rj)}j=1pε(y¯t)を保持している
  • εθ(yt,t)ytに対してLipschitz連続

とする。
このとき、y¯scを次のように生成する。

y¯sc=y¯tc+(ςsςt)ε(y¯t)ςsj=1pvj(h)kj(ε(y¯rj)ε(y¯t))

ただし、vj(h)は以下から求められる値。

(v1(h)v2(h)vp(h)):=hVp1(1!ϕ2(h)2!ϕ3(h)p!ϕp+1(h))

Vp:=(111k1k2kpk1p1k2p1kpp1)

ϕn(h):=m=0hm(n+m)!=1hn(ehm=0n1hmm!)

UniPとUniCを交互に実行することでytcを生成していくアルゴリズムがUniPCである。

一方、ytの予測には必ずしもUniPを使う必要はなく、DPM-Solverなど他のアルゴリズムとUniCを組み合わせることもできる。論文ではUniPCだけでなく、DDIM + UniCDPM-Solver++(2M) + UniCなどについても実験が行われ、性能が比較されている。

データ予測モデルのUniPC

データ予測モデルのUniPも同様に導出できる。
DPM-Solver++によると、

ys=ςsςtyt+ςsλsλteλx^θ(λ)dλ=ςsςtyt+n=0p1hn+1ψn+1(h)x^θ(n)(λt)+O(hp+1)=ςsςtyt+(1ςsςt)x^θ(yt)+n=1p1hn+1ψn+1(h)x^θ(n)(λt)+O(hp+1)

が成り立つ。ただし、ψn(h)は次のように定義される関数。

ψn(h)=ϕn(h)=m=0(h)m(n+m)!

既に計算済みの{xθ(yrj)}j=1p1xθ(yt)を用いて、

n=1p1hn+1ψn+1(h)xθ^(n)(λt)=j=1p1uj(h)(xθ(yrj)xθ(yt))+O(hp+1)

と表すことを試みる。

ノイズ予測モデルの場合と記号が置き換わっただけなので、以降は同様に求めることができる。UniCの導出も同様。

UniP (for data prediction model)
  • s=ti+1<t=ti<ti1<<tip+1
  • rj:=tij(j{1,2,,p1})
  • h:=λsλt
  • kj:=λrjλth(j{1,2,,p1})
  • {y¯rj}j=1p1が既に生成されていて、真の解yrjとの差がy¯rjyrj=O(hp)
  • y¯tが既に生成されていて、真の解ytとの差がy¯tyt=O(hp+1)
  • 既に計算済みの{xθ(y¯rj)}j=1p1xθ(y¯t)を保持している
  • εθ(yt,t)ytに対してLipschitz連続

とする。
このとき、y¯sを次のように生成する。

y¯s=ςsςtyt+(1ςsςt)x^θ(yt)+j=1p1vj(h)kj(xθ(yrj)xθ(yt))

ただし、vj(h)は以下から求められる値。

(v1(h)v2(h)vp1(h)):=hVp11(1!ψ2(h)2!ψ3(h)(p1)!ψp(h))

Vp1:=(111k1k2kp1k1p2k2p2kp1p2)

ψn(h):=m=0(h)m(n+m)!=1(h)n(ehm=0n1(h)mm!)

参考