- vừa được xem lúc

Sinh dữ liệu với mô hình diffusion và mô hình dạng SDE tổng quát.

0 0 42

Người đăng: Quang

Theo Viblo Asia

Trong bài viết này, mình sẽ giới thiệu về mô hình diffusion, một mô hình sinh với sự đột phá gần đây, cùng với mô hình score matching đã vượt qua GAN trong việc sinh dữ liệu. Hai mô hình này có thể xem như trường hợp đặc biệt của phương trình vi phân ngẫu nhiên, và được tổng quát thành mô hình dạng SDE, đưa ra một góc nhìn mới cũng như việc kết hợp hai loại mô hình này. Mô hình diffusion cũng như mô hình dạng SDE khi sinh dữ liệu không điều kiện thậm chí còn cho kết quả tốt hơn GAN khi sinh dữ liệu với nhãn cho trước.

Do nội dung khá dài nên phần cài đặt mình sẽ để sang bài khác nếu có thời gian, các bạn có thể xem trước notebook tutorial của tác giả tại đây. Một số chứng minh chi tiết mình sẽ để ở cuối để tránh đi xa khỏi nội dung chính, các bạn quan tâm có thể đọc thêm.

Mô hình diffusion

Ý tưởng của phương pháp này là biến đổi phân bố dữ liệu thành một phân bố có thể lấy mẫu được. Việc sinh dữ liệu sẽ bắt đầu từ phân bố này, sau đó biến đổi ngược về phân bố ban đầu. Mô hình cần học ở đây sẽ là phép biến đổi ngược đó. Quá trình biến đổi này được mô tả bằng một chuỗi các phân bố, cụ thể hơn chúng ta sẽ sử dụng quá trình ngẫu nhiên để mô tả chuỗi này.

Định nghĩa: Quá trình ngẫu nhiên là một họ các biến ngẫu nhiên {Xt}tT\{X_t\}_{t\in T} từ cùng một không gian xác suất sang cùng một không gian trạng thái. Ở đây tập chỉ số TT có thứ tự, ví dụ T=R+T=\mathbb{R}^+ hoặc T=Z+T=\mathbb{Z}^+.

Quá trình ngẫu nhiên được gọi là quá trình Markov nếu nó thỏa mãn tính chất Markov. Một cách trực quan, xác suất của trạng thái tại tương lai khi biết trạng thái hiện tại không phụ thuộc vào quá khứ. Đối với chuỗi Markov, tính chất này có thể được viết thành

P(Xn+m=iX1,,Xn)=P(Xn+m=iXn)\mathbb{P}(X_{n+m}=i|X_1,\dots,X_n)=\mathbb{P}(X_{n+m}=i|X_n)

Quá trình thuận

Để cho đơn giản, xác suất chuyển từ thời điểm tt sang thời điểm ss sẽ được kí hiệu là q(xsxt)q(x_s|x_t).Từ tính chất Markov, xác suất liên hợp được phân tích thành

q(x0xt)=q(x0)i=1Tq(xixi1)q(x_0\dots x_t)=q(x_0)\prod_{i=1}^T q(x_i|x_{i-1})

Xác suất chuyển q(xtxt1)q(x_t|x_{t-1}) được mô hình bởi N(xt;1βtxt1,βtI)\mathcal{N}(x_t;\sqrt{1-\beta_t} x_{t-1}, \beta_tI). Xác suất khi biết trạng thái x0x_0 cũng là phân bố Gaussian, đạt được nhờ tính chất Markov. Đặt αt=1βt,αtˉ=i=1tαi\alpha_t = 1-\beta_t,\,\bar{\alpha_t}=\prod_{i=1}^t \alpha_i, ta có q(xtx0)=N(xt;αtˉx0,(1αtˉ)I)q(x_t|x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha_t}}x_0, (1-\bar{\alpha_t})I).

Phân bố tại trạng thái xTx_T được xem như prior, sao cho có thể lấy mẫu được. Nhờ vào tính chất của xác suất điều kiện trên, với βt\beta_t phù hợp, q(xT)N(0,I)q(x_T)\approx \mathcal{N}(0, I).

Quá trình nghịch

Lúc này, quá trình sẽ bắt đầu từ phân bố p(xT)p(x_T) tại xTx_T, biến đổi ngược lại để quay về phân bố gốc của dữ liệu p(x0)p(x_0). Quá trình này có thể xem như một chuỗi Markov với chiều ngược lại, do đó xác suất liên hợp được phân tích thành

p(x0xT)=p(xT)i=1Tp(xi1xi)p(x_0\dots x_T) = p(x_T)\prod_{i=1}^{T}p(x_{i-1}|x_i)

Mục tiêu lúc này là tìm xác suất chuyển p(xt1xt)p(x_{t-1}|x_t) của chuỗi Markov này. Ta sẽ mô hình xác suất này bởi phân bố Gaussian, có dạng N(xt1;μθ(xt,t),Σθ(xt,t))\mathcal{N}(x_{t-1};\mu_{\theta}(x_t,t), \Sigma_{\theta}(x_t, t)).

Huấn luyện

Mục tiêu của quá trình huấn luyện là cực đại likelihood của phân bố dữ liệu của mô hình sinh

p(x0)=p(x0xT)dx1xT=p(x0xT)q(x1xTx0)q(x1xTx0)dx1xT=p(xT)i=1Tp(xi1xi)q(xixi1)dQ(x1xTx0)\begin{aligned} p(x_0)&=\int p(x_0\dots x_T)dx_1\dots x_T\\ &=\int \frac{p(x_0\dots x_T)}{q(x_1\dots x_T|x_0)}q(x_1\dots x_T|x_0)dx_1\dots x_T\\ &= \int p(x_T)\prod_{i=1}^T \frac{p(x_{i-1}|x_i)}{q(x_i|x_{i-1})} dQ(x_1\dots x_T|x_0) \end{aligned}

Áp dụng bất đẳng thức Jensen ta có

logp(x0)log(p(xT)i=1Tp(xi1xi)q(xixi1))dQ(x1xTx0)\begin{aligned} \log p(x_0) &\geq\int \log(p(x_T)\prod_{i=1}^T \frac{p(x_{i-1}|x_i)}{q(x_i|x_{i-1})}) dQ(x_1\dots x_T|x_0) \end{aligned}

với t>1t>1, ta có thể tính posterior như sau

q(xtxt1)=q(xtxt1,x0)tıˊnh chaˆˊt Markov=q(xt1xt,x0)q(xtx0)q(xt1x0)\begin{aligned} q(x_t|x_{t-1}) &= q(x_t|x_{t-1}, x_0)\quad \text{tính chất Markov}\\ &=\frac{q(x_{t-1}|x_t, x_0)q(x_t|x_0)}{q(x_{t-1}|x_0)} \end{aligned}

Chặn dưới của log likelihood trở thành

L(x0)=Eq(x1xTx0)[logp(XT)+t=2Tlogp(Xt1Xt)q(XtXt1)+logp(x0X1)+logq(X1x0)]=Eq(x1xTx0)[logp(XT)q(XTx0)+t=2Tlogp(Xt1XT)q(Xt1Xt,x0)+logp(x0X1)+logq(X1x0)]\begin{aligned} L(x_0)&=\mathbb{E}_{q(x_1\dots x_{T}|x_0)}[\log p(X_T)+\sum_{t=2}^T\log\frac{p(X_{t-1}|X_t)}{q(X_t|X_{t-1})} + \log p(x_0|X_1) + \log q(X_1|x_0)]\\ &=\mathbb{E}_{q(x_1\dots x_{T}|x_0)}[\log\frac{p(X_T)}{q(X_T|x_0)}+\sum_{t=2}^T\log \frac{p(X_{t-1}|X_T)}{q(X_{t-1}|X_t, x_0)} + \log p(x_0|X_1) + \log q(X_1|x_0)]\\ \end{aligned}

Thành phần logq(x1x0)\log q(x_1|x_0) là xác suất chuyển của quá trình thuận, do đó không có tham số và có thể loại bỏ trong quá trình huấn luyện. Mục tiêu của chúng ta là cực đại chặn dưới của log likelihood, tương đương với việc cực tiểu hàm mục tiêu sau

L=Eq[KL(q(xTX0)p(xT))+t=2TKL(q(xt1Xt,X0)p(xt1Xt))logp(X0X1)]L=\mathbb{E}_q[KL(q(x_T|X_0)||p(x_T)) +\sum_{t=2}^TKL(q(x_{t-1}|X_t, X_0)||p(x_{t-1}|X_t))-\log p(X_0|X_1)]

với kì vọng được lấy theo q(x0xT)q(x_0\dots x_T).

Các xác suất ở trên đều là phân bố Gaussian, do đó khoảng cách KL có thể tính từ kì vọng và phương sai. Đối với posterior, q(xt1xt,x0)q(x_{t-1}|x_t, x_0) sẽ là phân bố Gaussian N(xt1;αˉt1βt1αˉtx0+αt(1αˉt1)1αˉtxt,β~tI)\mathcal{N}(x_{t-1}; \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}x_0 +\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_{t}}x_t, \tilde\beta_tI), với β~t=1αˉt11αˉtβt\tilde \beta_t=\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t.

Mô hình denoise diffusion

Để cho đơn giản, Σθ(xt,t)\Sigma_{\theta}(x_t, t) sẽ được đặt là σt2I\sigma_t^2I, với σt\sigma_t được chọn trước, do đó không tham gia vào quá trình huấn luyện. Tác giả đưa ra hai lựa chọn σt2=βt\sigma_t^2=\beta_tσt2=β~t\sigma_t^2=\tilde \beta_t, tương đương với việc entropy H(q(xt1xt))H(q(x_{t-1}|x_t)) lớn nhất và nhỏ nhất, qua thực nghiệm hai cách chọn này cho kết quả tương đương.

Kí hiệu μ~t(xt,x0)\tilde\mu_t(x_t, x_0) là kì vọng của q(xt1xt,x0)q(x_{t-1}|x_t, x_0), với khoảng cách KL giữa hai phân bố Gaussian ta có

Lt1=Eq[KL(q(xt1Xt,X0)p(xt1Xt))]=Ex0,xt[12σt2μ~t(Xt,X0)μθ(Xt,t)2]+C\begin{aligned} L_{t-1} &= \mathbb{E}_q[KL(q(x_{t-1}|X_t, X_0)||p(x_{t-1}|X_t))]\\ &=\mathbb{E}_{x_0, x_t}[\frac{1}{2\sigma_t^2}||\tilde\mu_t(X_t,X_0)-\mu_{\theta}(X_t,t)||^2] + C \end{aligned}

Hàm μθ(xt,t)\mu_{\theta}(x_t,t) dự đoán kì vọng μ~(xt,x0)=αˉt1βt1αˉtx0+αt(1αˉt1)1αˉtxt\tilde\mu(x_t,x_0)=\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}x_0 +\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_{t}}x_t của q(xt1xt,x0)q(x_{t-1}|x_t, x_0) khi biết xtx_ttt. Điều này tương đương với việc dự đoán x0x_0 khi biết xtx_t. Tuy nhiên, từ thực nghiệm, tác giả thấy việc tham số như vậy không đưa ra kết quả tốt. Từ xác suất chuyển của quá trình thuận, chúng ta có xt(x0,ϵ)=αtˉx0+1αtˉϵx_t(x_0,\epsilon) = \sqrt{\bar{\alpha_t}}x_0+ \sqrt{1-\bar{\alpha_t}}\epsilon, trong đó ϵN(0,I)\epsilon\sim\mathcal{N}(0,I). Nói cách khác, trong quá trình thuận, x0x_0 có thể được tham số bởi xt(x0,ϵ)x_t(x_0,\epsilon) và một biến ngẫu nhiên độc lập ϵ\epsilon thông qua x0=1αˉt(xt1αˉtϵ)x_0=\frac{1}{\sqrt{\bar \alpha_t}}(x_t-\sqrt{1-\bar\alpha_t}\epsilon). Như vậy, thay vì đoán x0x_0 khi biết xtx_t, chúng ta có thể xây dựng mô hình ϵθ(xt,t)\epsilon_{\theta}(x_t,t) đoán nhiễu ϵ\epsilon khi biết xtx_t (đây là lí do cho từ denoise trong tên gọi).

Từ cách tham số này, chúng ta có thể thay vào μ~(xt,x0)\tilde\mu(x_t,x_0) để được

μ~(xt,x0)=αt(1αˉt1)1αˉtxt+αˉt1βt1αˉt(1αˉt(xt1αˉtϵ))=1αt(xtβt1αˉtϵ)\begin{aligned} \tilde\mu(x_t,x_0)&=\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_{t}}x_t+\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}(\frac{1}{\sqrt{\bar \alpha_t}}(x_t-\sqrt{1-\bar\alpha_t}\epsilon))\\ &=\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon) \end{aligned}

Tương tự như vậy, μθ(xt,t)\mu_{\theta}(x_t,t) lúc này sẽ được tham số như sau

μθ(xt,t)=1αt(xtβt1αˉtϵθ(xt,t))\mu_{\theta}(x_t,t)=\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_{\theta}(x_t,t))

Nhắc lại, trong quá trình huấn luyện (quá trình thuận), xtx_t có thể tính từ x0x_0 thông qua xt(x0,ϵ)=αtˉx0+1αtˉϵx_t(x_0,\epsilon) = \sqrt{\bar{\alpha_t}}x_0+ \sqrt{1-\bar{\alpha_t}}\epsilon. Lúc này, hàm mục tiêu sẽ trở thành

Lt1=Ex0,ϵ[βt22σt2αt(1αtˉ)ϵϵθ(αtˉx0+1αtˉϵ,t)2]L_{t-1}=\mathbb{E}_{x_0,\epsilon}[\frac{\beta_t^2}{2\sigma_t^2\alpha_t(1-\bar{\alpha_t})}||\epsilon-\epsilon_{\theta}(\sqrt{\bar{\alpha_t}}x_0+ \sqrt{1-\bar{\alpha_t}}\epsilon,t)||^2]

và hàm mục tiêu cho tại toàn bộ vị trí sẽ là L=Et[Lt1]L=\mathbb{E}_t[L_{t-1}] với tt tuân theo phân bố đều U{1,T}\mathcal{U}\{1,T\}.

Để cho đơn giản, chúng ta có thể tối ưu với phiên bản không trọng số của hàm mục tiêu bên trên

L=Ex0,ϵ,t[ϵϵθ(αtˉx0+1αtˉϵ,t)2]L=\mathbb{E}_{x_0,\epsilon,t}[||\epsilon-\epsilon_{\theta}(\sqrt{\bar{\alpha_t}}x_0+ \sqrt{1-\bar{\alpha_t}}\epsilon,t)||^2]

Lấy mẫu

Thay vì mô hình trực tiếp kì vọng của p(xt1xt)p(x_{t-1}|x_t), chúng ta đã mô hình nhiễu ϵθ(xt,t)\epsilon_{\theta}(x_t,t). Do đó, ở bước lấy mẫu, giả sử đã biết xtx_t, chúng ta sẽ tính lại kì vọng này qua công thức

μ~(xt,t)=1αt(xtβt1αˉtϵθ(xt,t))\tilde \mu(x_t,t)=\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_{\theta}(x_t,t))

Lúc này xt1x_{t-1} sẽ được tính bởi

xt1=μ~(xt,t)+σtz,zN(0,I)x_{t-1}=\tilde\mu(x_t,t)+\sigma_t z,\,z\sim\mathcal{N}(0,I)

Bắt đầu từ xTN(0,I)x_T\sim \mathcal{N}(0,I), chúng ta thực hiện tuần tự TT bước đến khi tìm được x0x_0.

Mô hình SDE tổng quát

Liên hệ giữa mô hình diffusion và score matching

Hàm mục tiêu của mô hình denoise diffusion có thể xem như denoise score matching. Với phân bố q(xtx0)=N(xt;αtˉx0,(1αtˉ)I)q(x_t|x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha_t}}x_0, (1-\bar{\alpha_t})I), score logq(xtx0)\nabla\log q(x_t|x_0) của phân bố này sẽ là αˉtx0xt1αˉt\frac{\sqrt{\bar\alpha_t}x_0-x_t}{1-\bar\alpha_t}. Chú ý αˉtx0xt1αˉtN(0,I)\frac{\sqrt{\bar\alpha_t}x_0-x_t}{\sqrt{1-\bar\alpha_t}}\sim\mathcal{N}(0,I), nếu ta thay biến ngẫu nhiên này cho ϵ\epsilon trong thành phần Lt1L_{t-1} của hàm mục tiêu không trọng số trong mô hình denoise diffusion, ta có

Lt1=Ex0,xt[1αˉtlogq(xtx0)ϵθ(xt,t)2]=(1αˉt)Ex0,xt[logq(xtx0)sθ(xt,t)2]\begin{aligned} L_{t-1}&=\mathbb{E}_{x_0,x_t}[||\sqrt{1-\bar\alpha_t}\nabla\log q(x_t|x_0)-\epsilon_{\theta}(x_t,t)||^2]\\ &=(1-\bar\alpha_t)\mathbb{E}_{x_0,x_t}[||\nabla\log q(x_t|x_0)-s_{\theta}(x_t,t)||^2] \end{aligned}

với sθ(xt,t)=ϵθ(xt,t)1αˉts_{\theta}(x_t,t)=-\frac{\epsilon_{\theta}(x_t,t)}{\sqrt{1-\bar\alpha_t}}. Lúc này, hàm mục tiêu sẽ là

L=t=1T(1αˉt)Ex0,xt[logq(xtx0)sθ(xt,t)2]L=\sum_{t=1}^T(1-\bar\alpha_t)\mathbb{E}_{x_0,x_t}[||\nabla\log q(x_t|x_0)-s_{\theta}(x_t,t)||^2]

Đây chính là hàm mục tiêu của NCSN với trọng số (1αˉt)(1-\bar\alpha_t) khi sử dụng denoise score matching. Tương tự như NCSN, trọng số (1αˉt)(1-\bar\alpha_t) có tính chất (1αˉt)1/E[logq(xtx0)2](1-\bar\alpha_t)\propto1/\mathbb{E}[||\nabla\log q(x_t|x_0)||^2]. Cách nhìn này cho thấy sự liên hệ giữa phương pháp score matching và mô hình diffusion, đó là thay đổi phân bố dữ liệu bằng một họ các nhiễu, và học mô hình khử nhiễu lần lượt. Từ đây, ta có thể tổng quát cả hai phương pháp này, bằng cách mô hình họ các nhiễu bởi quá trình ngẫu nhiên liên tục, biểu diễn bởi một phương trình vi phân ngẫu nhiên (SDE).

Mô hình với SDE

Cụ thể hơn, với phân bố dữ liệu p0p_0 ban đầu, ta mong muốn biến đổi nó thành một phân bố đơn giản pTp_T, theo nghĩa có thể lấy mẫu một cách dễ dàng, ví dụ như N(0,I)\mathcal{N}(0,I) trong mô hình diffusion. Nói cách khác, ta cần một quá trình ngẫu nhiên Xt{X_t} với t[0,T]t\in[0,T] sao cho p(x0)=p0,p(xT)=pTp(x_0)=p_0, p(x_T)=p_T. Quá trình ngẫu nhiên này có thể mô tả bởi phương trình vi phân ngẫu nhiên Itô (từ bây giờ khi nhắc đến SDE, chúng ta sẽ hiểu đó là Itô SDE)

dxt=f(x,t)dt+g(t)dwdx_t=f(x,t)dt + g(t)dw

trong đó f(x,t):Rd×R+Rdf(x,t):\mathbb{R}^d\times\mathbb{R}^+\mapsto \mathbb{R}^d, g(t):R+Rg(t):\mathbb{R}^+\mapsto\mathbb{R}, dwdw kí hiệu một cách hình thức vi phân của chuyển động Brown. Một cách trực quan, dw=N(0,Δt)dw=\mathcal{N}(0,\Delta t) với Δt0\Delta t\to 0. Để cho đơn giản, chúng ta chỉ xét g(t)g(t) có dạng trên, tuy nhiên tất cả kết quả bên dưới đều có thể mở rộng cho hàm g(t)g(t) trả về ma trận.

SDE của mô hình diffusion

Nhắc lại quá trình thuận của mô hình diffusion có thể được mô tả bởi quá trình ngẫu nhiên {xt}t=0T\{x_t\}_{t=0}^T. Giả sử chúng ta dùng σt2=βt\sigma_t^2=\beta_t, chuỗi Markov có dạng

xt=1βtxt1+βtzt1,zN(0,1)x_t=\sqrt{1-\beta_t}x_{t-1}+\sqrt{\beta_t}z_{t-1},\, z\sim\mathcal{N}(0,1)

Quá trình ngẫu nhiên này có thể xem như rời rạc của một quá trình ngẫu nhiên liên tục, chúng ta sẽ tìm quá trình này bằng cách cho TT\to\infty. Đặt βˉt=Tβt\bar\beta_t=T\beta_t, chuỗi này sẽ tiến về một hàm β(t):[0,1]R\beta(t):[0,1]\mapsto\mathbb{R}, β(tT)=βˉt\beta(\frac{t}{T})=\bar\beta_t. Tương tự quá trình ngẫu nhiên của xix_iziz_i cũng tiến tới quá trình ngẫu nhiên liên tục x(tT)=xt,z(tT)=ztx(\frac{t}{T})=x_t, z(\frac{t}{T})=z_t. Đặt Δt=t/T\Delta t=t/T, dùng khai triển Taylor bậc 1, phương trình trên có thể viết lại thành

x(t+Δt)=1β(t+Δt)Δtx(t)+β(t+Δt)Δtz(t)x(t)12β(t)Δtx(t)+β(t)Δtz(t)\begin{aligned} x(t+\Delta t)&=\sqrt{1-\beta(t+\Delta t)\Delta t}x(t)+\sqrt{\beta(t+\Delta t)\Delta t}z(t)\\ &\approx x(t)-\frac{1}{2}\beta(t)\Delta tx(t)+\sqrt{\beta(t)\Delta t}z(t) \end{aligned}

Khi Δt0\Delta t\to 0, phương trình này hội tụ tới SDE

dxt=12β(t)xtdt+β(t)dwdx_t=-\frac{1}{2}\beta(t)x_tdt+\sqrt{\beta(t)}dw

SDE của mô hình NCSN

Nhắc lại, mô hình NCSN thêm lần lượt nhiễu với phương sai {σt}t=1N\{\sigma_t\}_{t=1}^N vào phân bố dữ liệu. Quá trình này có thể viết lại thành

xt=xt1σt2σt12z,zN(0,I)x_{t}=x_{t-1}-\sqrt{\sigma_{t}^2-\sigma_{t-1}^2}z,\qquad z\sim\mathcal{N}(0,I)

với σ0=0\sigma_0=0. Lập luận tương tự như trên, ta có thể tính giới hạn khi NN\to\infty

x(t+Δt)=x(t)+σ2(t+Δt)σ2(t)z(t)dσ2(t)dtΔtz(t)x(t+\Delta t)=x(t)+\sqrt{\sigma^2(t+\Delta t)-\sigma^2(t)}z(t)\approx\sqrt{\frac{d \sigma^2(t)}{dt}\Delta t}z(t)

sử dụng khai triển Taylor bậc 1 của σ2(t)\sigma^2(t). Khi Δt0\Delta t\to 0, chuỗi xtx_t hội tụ tới quá trình ngẫu nhiên mô tả bởi

dxt=dσ2(t)dtdwdx_t=\sqrt{\frac{d \sigma^2(t)}{dt}}dw

Lấy mẫu

Việc lấy mẫu tương đương với đảo chiều thời gian của quá trình ngẫu nhiên. Quá trình nghịch này được mô tả bởi SDE sau

dxt=(f(x,t)g(t)2xtlogpt(xt))dt+g(t)dwˉdx_t=(f(x,t)-g(t)^2\nabla_{x_t}\log p_t(x_t))dt + g(t)d\bar w

ở đây wˉ\bar w là chuyển động Brown theo chiều ngược lại, từ TT về 00.

Nếu biết được score của pt(x)p_t(x), chúng ta có thể mô phỏng lại quá trình ngược này. Bắt đầu từ xTpTx_T\sim p_T, từ phương trình trên, chúng ta sẽ biến đổi xTx_T thành x0x_0 tuân theo phân bố p0p_0 của dữ liệu. Như vậy, mục tiêu của chúng ta là xây dựng mô hình sθ(x(t),t)s_{\theta}(x(t),t) xấp xỉ xtlogpt(xt))\nabla_{x_t}\log p_t(x_t)).

Giải SDE

Quá trình lấy mẫu được thực hiện bằng cách giải phương trình SDE nghịch. Tương tự như khi rời rạc hóa quá trình thuận, chúng ta có thể giải bằng cách rời rạc hóa quá trình nghịch

xt=xt+1ft+1(xt+1)+gt+12sθ(xi+1,i+1)+gt+1z,zN(0,I)x_t=x_{t+1}-f_{t+1}(x_{t+1})+g_{t+1}^2s_{\theta}(x_{i+1},i+1)+g_{t+1}z,\,z\sim\mathcal{N}(0,I)

Quay lại với cách cập nhật của mô hình denoise diffusion, giả sử ta dùng σt2=βt\sigma_t^2=\beta_t

xt1=1αt(xtβt1αˉtϵθ(xt,t))+βtzx_{t-1} = \frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_{\theta}(x_t,t)) + \sqrt{\beta_t}z

Đặt sθ(xt,t)=ϵθ(xt,t)1αˉts_{\theta}(x_t,t)=-\frac{\epsilon_{\theta}(x_t,t)}{\sqrt{1-\bar\alpha_t}}, ta có thể biến đổi như sau

xt1=11βt(xt+βts(xt,t))+βtz(1+12βt)(xt+βts(xt,t))+βtzkhai triển Taylor=(1+12βt)xt+βts(xt,t)+12βt2s(xt,t)+βtzxt+12βtxt+βts(xt,t)+βtz\begin{aligned} x_{t-1}&=\frac{1}{\sqrt{1-\beta_t}}(x_t+\beta_ts(x_t,t))+\sqrt{\beta_t}z\\ &\approx (1+\frac{1}{2}\beta_t)(x_t+\beta_ts(x_t,t)) +\sqrt{\beta_t}z\qquad \text{khai triển Taylor}\\ &=(1+\frac{1}{2}\beta_t)x_t+ \beta_ts(x_t,t)+\frac{1}{2}\beta_t^2s(x_t,t)+\sqrt{\beta_t}z\\ &\approx x_t+\frac{1}{2}\beta_tx_t+ \beta_ts(x_t,t)+\sqrt{\beta_t}z \end{aligned}

Quá trình nghịch của SDE ứng với mô hình diffusion là

dxt=(12β(t)xtβtxlogpt(xt))dt+β(t)dwdx_t=(-\frac{1}{2}\beta(t)x_t-\beta_t\nabla_x\log p_t(x_t))dt+\sqrt{\beta(t)}dw

Ta có thể thấy thuật toán lấy mẫu của mô hình denoise diffusion gần giống với việc giải quá trình nghịch thông qua rời rạc hóa.

Lấy mẫu với Predictor-Corrector

Ở phần trước, ta đã biết quá trình lấy mẫu có thể thực hiện bằng việc giải phương trình SDE nghịch, và thuật toán lấy mẫu của mô hình diffusion thuộc loại này. Mặt khác, ta đang mô hình score của pt(xt)p_t(x_t), do đó ta cũng có thể lấy mẫu với (annealed) Langevin dynamics.

Để có thể sinh dữ liệu tốt hơn, chúng ta có thể kết hợp hai phương pháp này. Lấy mẫu thông qua giải SDE sẽ được xem như thuật toán chính, gọi là Predictor. Ở bước thứ ii trong Predictor, sau khi cập nhật xTix_{T-i} qua xTi+1x_{T-i+1}, chúng ta sẽ thực hiện Langevin dynamics MM lần với s(xTi,Ti)s(x_{T-i},T-i)

xTi=xTi+ϵis(xTi,Ti)+2ϵiz,zN(0,I)x_{T-i}=x_{T-i}+\epsilon_i s(x_{T-i},T-i) + \sqrt{2\epsilon_i}z,\,z\sim\mathcal{N}(0,I)

Từ góc nhìn này, cách sinh dữ liệu của NCSN có thể xem như Predictor là hàm đồng nhất, Corrector là Langevin dynamics, cách sinh dữ liệu của mô hình denoise diffusion có thể xem như Predictor là giải quá trình nghịch, Corrector là hàm đồng nhất.

Huấn luyện

Tương tự như hàm mục tiêu của NCSN cũng như mô hình denoise diffusion, hàm mục tiêu của mô hình SDE sẽ có dạng score matching trên tất cả mức độ nhiễu. Điểm khác biệt là biến thời gian tt lúc này là biến ngẫu nhiên liên tục tuân theo phân bố đều U[0,1]\mathcal{U}[0,1]

L=Et[λ(t)Ex0,xt[logp(xtx0)sθ(xt,t)2]]L=\mathbb{E}_t[\lambda(t)\mathbb{E}_{x_0,x_t}[||\nabla\log p(x_t|x_0)-s_{\theta}(x_t,t)||^2]]

Ở đây λ(t)\lambda(t) là hàm trọng số, có thể chọn giống như NCSN và mô hình denoise diffusion là λ(t)1/E[logq(xtx0)2]\lambda(t)\propto1/\mathbb{E}[||\nabla\log q(x_t|x_0)||^2].

Việc tính hàm mất mát yêu cầu score của phân bố chuyển trong quá trình thuận. Đối với trường hợp SDE tổng quát, ta cần giải phương trình Kolmogorov tiến để tìm phân bố này. Khi f(x,t)=a(t)x+b(t)f(x,t)=a(t)x+b(t), phân bố chuyển là phân bố Gaussian, do đó chỉ cần biết kì vọng và phương sai để tính score. Kì vọng mtm_t và ma trận hiệp phương sai PtP_t sẽ thỏa mãn phương trình vi phân sau

dmtdt=a(t)mt+b(t)\frac{dm_t}{dt}=a(t)m_t+b(t)

dPtdt=2a(t)Pt+g(t)2\frac{dP_t}{dt}=2a(t)P_t+g(t)^2

Để tránh việc phải tính phân bố chuyển, chúng ta có thể dùng phương pháp score matching khác, ví dụ như sliced score matching, với hàm mục tiêu

L=Et[λ(t)Ex0ExtEv[12sθ(xt,t)2+vJs(.,t)(xt)v]]L=\mathbb{E}_t[\lambda(t)\mathbb{E}_{x_0}\mathbb{E}_{x_t}\mathbb{E}_v[\frac{1}{2}||s_{\theta}(x_t,t)||^2+v^\intercal J_{s(.,t)}(x_t)v]]

với Js(.,t)(xt)J_{s(.,t)}(x_t) là ma trận Jacobian của s(xt,t)s(x_t,t), vJs(.,t)(xt)vv^\intercal J_{s(.,t)}(x_t)v tính bởi v(vs(xt,t))v^\intercal\nabla(v^\intercal s(x_t,t)).

Kết luận

Trong bài này, mình đã giới thiệu mô hình diffusion và mô hình dạng SDE tổng quát mà trong đó score matching và mô hình diffusion là trường hợp đặc biệt. Cách tiếp cận này hiện đã cho kết quả tốt nhất hiện tại cho mô hình sinh.

Tuy nhiên cách tiếp cận này có các nhược điểm sau: Các trạng thái có cùng số chiều, do đó việc mô hình quá trình nghịch cần đảm bảo điều đó chứ không thể thay đổi số chiều dữ liệu. Việc lấy mẫu tốn khá nhiều thời gian, do cần phải đi từng bước để giải phương trình SDE nghịch, chưa tính đến việc kết hợp với Corrector trong quá trình lấy mẫu.

Tham khảo

Một số định nghĩa và chứng minh chi tiết

Công thức các phân bố trong quá trình thuận của mô hình diffusion

Tính chất: Với q(xtxt1)=N(xt;αtxt1,(1αt)I)q(x_t|x_{t-1})=\mathcal{N}(x_t;\sqrt{\alpha_t}x_{t-1}, (1-\alpha_t)I), ta có q(xtx0)=N(xt;αtˉx0,(1αtˉ)I)q(x_t|x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha_t}}x_0, (1-\bar{\alpha_t})I),

trong đó αˉt=i=1tαi\bar\alpha_t=\prod_{i=1}^t\alpha_i.

Chứng minh:

Quá trình Markov thỏa mãn tính chất sau

Mệnh đề: Với t1>t2>t3t_1>t_2>t_3, xác suất chuyển thỏa mãn phương trình Chapman-Kolmogorov

pt3t1(xt1xt3)=pt3t2(xt2xt3)pt2t1(xt1xt2)dxt2p_{t_3t_1}(x_{t_1}|x_{t_3})=\int p_{t_3t_2}(x_{t_2}|x_{t_3})p_{t_2t_1}(x_{t_1}|x_{t_2})dx_{t_2}

Tính chất trên có thể chứng minh dễ dàng bằng tính chất Markov.

Chúng ta chỉ cần chứng minh cho t=2t=2, các trường hợp còn lại có thể suy ra theo quy nạp. Hơn nữa, ma trận hiệp phương sai có dạng βtI\beta_tI, do đó ta chỉ cần chứng minh cho trường hợp xRx\in\mathbb{R}.

Từ phương trình Chapman-Kolmogorov, ta có

q(x2x0)=q(x2x1)q(x1x0)dx1=12(1α1)(1α2)πexp((x2α2x1)22(1α2))exp((x1α1x0)22(1α1))dx1=12(1α1)(1α2)πexp(12((x2α1α2x0)21α1α2+(1α1α2)(x1α2(1α1)x2+α1(1α2)x01α1α2)2(1α1)(1α2)))dx1=12π(1α1α2)exp(12(x2α1α2x0)21α1α2).\begin{aligned} q(x_2|x_0) &= \int q(x_2|x_1)q(x_1|x_0)dx_1\\ &=\frac{1}{2\sqrt{(1-\alpha_1)(1-\alpha_2)}\pi}\int \exp(-\frac{(x_2-\sqrt\alpha_2x_1)^2}{2(1-\alpha_2)})\exp(-\frac{(x_1-\sqrt\alpha_1x_0)^2}{2(1-\alpha_1)})dx_1\\ &=\frac{1}{2\sqrt{(1-\alpha_1)(1-\alpha_2)}\pi}\int\exp(-\frac{1}{2}(\frac{(x_2-\sqrt{\alpha_1\alpha_2}x_0)^2}{1-\alpha_1\alpha_2} +\frac{(1-\alpha_1\alpha_2)(x_1-\frac{\sqrt\alpha_2(1-\alpha_1)x_2+\sqrt\alpha_1(1-\alpha_2)x_0}{1-\alpha_1\alpha_2})^2}{(1-\alpha_1)(1-\alpha_2)}))dx_1\\ &=\frac{1}{\sqrt{2\pi(1-\alpha_1\alpha_2)}}\exp(-\frac{1}{2}\frac{(x_2-\sqrt{\alpha_1\alpha_2}x_0)^2}{1-\alpha_1\alpha_2}). \end{aligned}

\square

Tính chất: q(xt1xt,x0)=N(xt1;αˉt1βt1αˉtx0+αt(1αˉt1)1αˉtxt,β~tI)q(x_{t-1}|x_t, x_0) =\mathcal{N}(x_{t-1}; \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}x_0 +\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_{t}}x_t, \tilde\beta_tI), với β~t=1αˉt11αˉtβt\tilde \beta_t=\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t

Chứng minh: Tương tự như trên, chúng ta cũng chỉ cần chứng minh cho trường hợp xRx\in\mathbb{R}.

q(xt1xt,x0)=q(xtxt1)q(xtx0)q(xt1x0)=(2πβt)1/2(2π(1αˉt1))1/2(2π(1αˉt))1/2exp(xtαtxt122βtxt1αˉt1x022(1αˉt1)+xtαˉtx022(1αˉt))=(2πβ~t)1/2exp(1β~txt1αˉt1βt1αˉtx0+αt(1αˉt1)1αˉtxt2).\begin{aligned} q(x_{t-1}|x_t, x_0)&=\frac{q(x_t|x_{t-1})q(x_t|x_0)}{q(x_{t-1}|x_0)}\\ &=(2\pi\beta_t)^{-1/2}(2\pi(1-\bar\alpha_{t-1}))^{-1/2}(2\pi(1-\bar\alpha_t))^{1/2}\\ &\quad\exp\left(-\frac{||x_t-\sqrt{\alpha_t}x_{t-1}||^2}{2\beta_t}-\frac{||x_{t-1}-\sqrt{\bar\alpha_{t-1}}x_0||^2}{2(1-\bar\alpha_{t-1})}+\frac{||x_t-\sqrt{\bar\alpha_t}x_0||^2}{2(1-\bar\alpha_t)}\right)\\ &=(2\pi\tilde\beta_t)^{-1/2}\exp\left(-\frac{1}{\tilde\beta_t}||x_{t-1}-\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}x_0 +\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_{t}}x_t||^2\right). \end{aligned}

\square

Một số tính chất của phương trình vi phân ngẫu nhiên

Phương trình vi phân ngẫu nhiên Itô với điều kiện đầu x0=xx_0=x

dxt=f(xt,t)dt+g(t)dwdx_t=f(x_t,t)dt+g(t)dw

là biểu diễn hình thức của phương trình tích phân sau

xt=x+0tf(xt,t)dt+0tg(t)dwx_t=x+\int_0^tf(x_t,t)dt+\int_0^tg(t)dw

Tích phân đầu tiên là tích phân Riemann-Stieltjes thông thường. Tuy nhiên ta không thể tính tích phân thứ hai như vậy, do chuyển động Brown không thỏa mãn tính chất bounded variation. Thay vào đó, ta sẽ sử dụng tích phân Itô để tính đại lượng này

0tg(t)dw=limKk=0K1g(tk)(wtk+1wtk)\int_0^tg(t)dw=\lim_{K\to \infty}\sum_{k=0}^{K-1}g(t_k)(w_{t_{k+1}}-w_{t_k})

với tk=kΔt,t=KΔtt_k=k\Delta t, t=K\Delta t.

Từ đây chúng ta có tính chất sau

k=0K1E[g(tk)(wtk+1wtk)]=k=0K1E[g(tk)]E[wtk+1wtk]=0\sum_{k=0}^{K-1}\mathbb{E}[g(t_k)(w_{t_{k+1}}-w_{t_k})]=\sum_{k=0}^{K-1}\mathbb{E}[g(t_k)]\mathbb{E}[w_{t_{k+1}}-w_{t_k}]=0

theo định nghĩa của chuyển động Brown, do đó

E[0tg(t)dw]=0\mathbb{E}[\int_0^tg(t)dw]=0

Với quá trình ngẫu nhiên xtx_t và một hàm tất định u(x,t):Rd×R+Ru(x,t):\mathbb{R}^d\times \mathbb{R}^+\mapsto\mathbb{R}, chúng ta cũng không thể tính đạo hàm toàn phân du(xt,t)dt\frac{du(x_t,t)}{dt} bằng chain rule như thông thường, thay vào đó chúng ta sẽ dùng công thức Itô

du(xt,t)=u(xt,t)tdt+u(xt,t)dxt+12(dxt)Hxu(xt,t)dxtdu(x_t,t)=\frac{\partial u(x_t,t)}{\partial t}dt+\nabla u(x_t,t)^\intercal dx_t+\frac{1}{2}(dx_t)^\intercal H_xu(x_t,t)dx_t

trong đó Hxu(xt,t)H_xu(x_t,t) là ma trận Hessian của uu.

Chứng minh SDE nghịch

Từ công thức Itô, chúng ta có hai phương trình quan trọng. Đó là phương trình Kolmogorov tiến

p(xt,t)t=ifi(xt,t)p(xt,t)xi+12i,j2g(t)2p(xt,t)xixj\frac{\partial p(x_t,t)}{\partial t}=-\sum_i\frac{\partial f^i(x_t,t)p(x_t,t)}{\partial x^i}+\frac{1}{2}\sum_{i,j}\frac{\partial^2 g(t)^2p(x_t,t)}{\partial x^ix^j}

và phương trình Kolmogorov lùi

pts(xsxt)t=ifi(xt,t)pts(xsxt)x+i,jg(t)222pts(xsxt)x2-\frac{\partial p_{ts}(x_s|x_t)}{\partial t}=\sum_if^i(x_t,t)\frac{\partial p_{ts}(x_s|x_t)}{\partial x}+\sum_{i,j}\frac{g(t)^2}{2}\frac{\partial^2 p_{ts}(x_s|x_t)}{\partial x^2}

với t<st<s, pts(xsxt)p_{ts}(x_s|x_t) là xác suất chuyển từ trạng thái xtx_t tại tt sang trạng thái xsx_s tại ss, fi,xif_i,x_i là chỉ số thứ ii của f,xf,x.

Phương trình SDE nghịch có thể suy ngược từ phương trình Kolmogorov như sau: Với xác suất liên hợp xs,xtx_s, x_t, ta có

p(xs,xt)=pts(xsxt)p(xt)p(x_s,x_t)=p_{ts}(x_s|x_t)p(x_t)

p(xs,xt)t=pts(xsxt)p(xt,t)t+p(xt)pts(xsxt)t\frac{\partial p(x_s,x_t)}{\partial t}= p_{ts}(x_s|x_t)\frac{\partial p(x_t,t)}{\partial t}+ p(x_t)\frac{\partial p_{ts}(x_s|x_t)}{\partial t}

Thay phương trình Kolmogorov vào phương trình trên, ta được

p(xs,xt)t=ifˉi(xt,t)p(xs,xt)xi+12i,j2g(t)2p(xs,xt)xixj-\frac{\partial p(x_s,x_t)}{\partial t}=\sum_i\frac{\partial \bar f^i(x_t,t)p(x_s,x_t)}{\partial x^i}+\frac{1}{2}\sum_{i,j}\frac{\partial^2 g(t)^2p(x_s,x_t)}{\partial x^ix^j}

trong đó

fˉ(xt,t)=f(x,t)g(t)21p(xt)pt(xt)=f(x,t)g(t)2logpt(xt)\bar f(x_t,t)= f(x,t)-g(t)^2\frac{1}{p(x_t)}\nabla p_t(x_t)=f(x,t)-g(t)^2\nabla\log p_t(x_t)

Tích phân cả hai vế cho xsx_s, ta được phương trình Kolmogorov tiến của quá trình nghịch, ứng với SDE

dxt=fˉ(xt,t)dt+g(t)dwˉdx_t=\bar f(x_t,t)dt+g(t)d\bar w

Kì vọng và phương sai của SDE tuyến tính

Với hàm uu bất kì, từ công thức Itô

du(xt,t)=u(xt,t)tdt+u(xt,t)dxt+12(dxt)Hxu(xt,t)dxt=(u(xt,t)t+u(xt,t)f(xt,t)+12g(t)2i,j2uxixj)dt+ug(t)dw\begin{aligned} du(x_t,t)&=\frac{\partial u(x_t,t)}{\partial t}dt+\nabla u(x_t,t)^\intercal dx_t+\frac{1}{2}(dx_t)^\intercal H_xu(x_t,t)dx_t\\ &=\left(\frac{\partial u(x_t,t)}{\partial t}+\nabla u(x_t,t)^\intercal f(x_t,t)+\frac{1}{2}g(t)^2\sum_{i,j}\frac{\partial^2u}{\partial x^i\partial x^j}\right)dt + \nabla u \,g(t)dw \end{aligned}

Lấy kì vọng hai vế và dùng tính chất kì vọng của tích phân Itô bằng 0, ta có

dE[u]dt=E[ut]+E[uf(xt,t)]+12g(t)2E[i,j2uxixj]\frac{d\mathbb{E}[u]}{dt}=\mathbb{E}[\frac{\partial u}{\partial t}]+\mathbb{E}[\nabla u^\intercal f(x_t,t)]+\frac{1}{2}g(t)^2\mathbb{E}[\sum_{i,j}\frac{\partial^2u}{\partial x^i\partial x^j}]

Thay u=xiu=x^i, ta tính được kì vọng của xix^i, từ đó suy ra kì vọng của xx. Với ma trận hiệp phương sai, ta thay u=xixjmi(t)mj(t)u=x^ix^j-m^i(t)m^j(t), với m(t)m(t) là kì vọng của xtx_t.

Bình luận

Bài viết tương tự

- vừa được xem lúc

Sinh dữ liệu với mô hình dựa trên score

Chúng ta đã tìm hiểu về cách huấn luyện mô hình score và cách lấy mẫu với Langevin dynamics. Tuy nhiên cách làm trực tiếp đó chưa đủ để sinh ra dữ liệu tốt.

0 0 15

- vừa được xem lúc

Giới thiệu về Diffussion model

1 . Giới thiệu về sơ lược về diffussion model.

0 0 20

- vừa được xem lúc

Giới thiệu về Diffussion model (series 2)

1. Variable Diffussion model (VDM). 1.1 Lịch sử hình thành.

0 0 25

- vừa được xem lúc

Diffussion model (Series 3)

1. Tổng quan bài viết này.

0 0 24

- vừa được xem lúc

Hình ảnh độ phân giải cao với Latent Diffusion Models

1. Giới thiệu vấn đề.

0 0 25

- vừa được xem lúc

Imagen - Mô hình SOTA giải quyết bài toán Text-to-Image

Imagen - mô hình mới được công bố gần đây bởi Google với khả năng generate hình ảnh với đoạn text mô tả bất kỳ, cho dù ảnh đó không có thật hoặc phi vật lý. Phía trên là một ví dụ của ảnh được sinh ra

0 0 33