- vừa được xem lúc

Đã đến lúc học về Diffusion Models

0 0 26

Người đăng: Nguyen Tung Thanh

Theo Viblo Asia

Sau một thời gian tìm hiểu và làm việc với mô hình Diffusion, mình viết bài này với hy vọng bài viết này sẽ có ích phần nào với các bạn muốn tìm hiểu về Diffusion Models.

Một số từ tạm dịch

  • Diffusion: khuếch tán
  • Quasi-static process: quá trình chuẩn tĩnh
  • Thermodynamic Equilibrium: cân bằng nhiệt động học (NĐH)
  • Forward Diffusion Process: quá trình khuếch tán thuận
  • Reverse Diffusion Process: quá trình đảo ngược
  • Transition kernel: nhân biến đổi.

Lời mở đầu

Diffusion Models đứng sau một số sản phẩm đình đám trong lĩnh vực sinh ảnh. Ứng dụng của loại mô hình này đã được mở rộng sang Object Detection, Image Segmentation,... Danh sách này nhiều khả năng sẽ còn mở rộng hơn nữa. Bài viết này sẽ giới thiệu từ ý tưởng cơ bản, những khái niệm (cực kỳ đơn giản) trong nhiệt động học đã truyền cảm hứng cho Diffusion Models trong học sâu. Sau đó, ta sẽ đi đến lý thuyết và cách huấn luyện Diffusion Models trong học sâu.

I. Từ Nhiệt động học

"Ý tưởng mới thường có tính cách liên ngành"

GS Nguyễn Văn Tuấn
Artificial Neural Network (Mạng nơ-ron nhân tạo), Genetic Algorithm (Giải thuật di truyền), Attention Mechanism (Cơ chế tập trung) là những ví dụ cho câu nói trên. Những phương pháp kể trên được lấy cảm hứng từ Khoa học thần kinh, Sinh học, Hệ thống thị giác của con người. Từ 2015, danh sách kể trên có thêm Diffusion Models. Ý tưởng cơ bản của phương pháp này được truyền cảm hứng từ Non-equilibrium Thermodynamics (Nhiệt động học không cân bằng). Cho đến hiện tại, những phương pháp SOTA đã có nhiều thay đổi làm cho cách hoạt động ngày càng khác xa so với Duffision trong Nhiệt động học. Dù vậy, việc hiểu Diffusion trong Nhiệt động học sẽ giúp ta hình dung được cách hoạt động của Diffusion Models trong Deep Learning. Trong phần này, chúng ta sẽ tìm hiểu về Diffusion trong nhiệt động học.

1. Diffusion là gì?

Khuếch tán là hiện tượng chuyển động của các phân tử (hoặc ion, năng lượng...) từ vùng có mật độ cao hơn sang vùng có mật độ thấp hơn.

Hình 1: Ví dụ của hiện tượng khuếch tán.
Hình 1. minh hoạ hiện tượng khuếch tán khi ta nhỏ một giọt thuốc nhuộm vào một cốc nước. Ban đầu giọt thuốc nhuộm tập trung ở 1 vùng nhỏ trong cốc nước với mật độ cao. Trong quá trình khuếch tán, thuốc nhuộm dần lan sang nhiều vùng trong cốc nước và mật độ của nó cũng loãng hơn. Sau một thời gian đủ lâu, thuốc nhuộm gần như sẽ phân bố đều trong cốc nước.

2. Thermodynamic Equilibrium là gì?

Bản thân từ Equilibrium mang nghĩa là cân bằng. Một hệ (system) để được gọi là Cân bằng nhiệt động học cần đồng thời đạt được các điều kiện:

  1. Cân bằng về nhiệt
  2. Cân bằng cơ học
  3. Cân bằng hoá học
  4. Cân bằng pha

Hình 2: Hai hệ gồm khí bên trong bình. Bên trái: hệ cân bằng nhiệt. Bên phải: hệ không cân bằng nhiệt.

Chúng ta không cần thiết phải hiểu về 4 yếu tố kể trên. Nhưng sẽ tốt hơn nếu chúng ta hình dung được những yếu tố này. Ví dụ, xét đến yếu tố đầu tiên là Cân bằng nhiệt. Một hệ được gọi là Cân bằng nhiệt nếu nhiệt độ tại mọi điểm của hệ là giống nhau. Hình 2 minh hoạ một hệ cân bằng nhiệt(bên trái) và một hệ không cân bằng nhiệt (bên phải). Ta có thể thấy ở hình bên trái nhiệt độ tại các điểm của hệ khá "giống" nhau, còn ở hình bên phải nhiệt độ ở các điểm là rất khác nhau.

3. Quasi-static process

Hình 3: Một hệ nén bằng vật nặng

Hình 4: Quá trình chuẩn tĩnh: lấy dần các hạt nặng

Hình 5: Quá trình đảo ngược: bỏ dần vào lại các hạt nặng

Như vậy ta đã biết một hệ cân bằng NĐH phải cân bằng về nhiều mặt. Xét một hệ cân bằng NĐH sử dụng các hạt đặt trên đỉnh pít tông để nén khí như trong hình 3. Giả sử ban đầu hệ có thể tích là V0V_0 là áp suất là P0P_0. Mỗi cặp (PtP_t, VtV_t) sẽ xác định một trạng thái.

Giả sử các hạt này là rất nhẹ và có rất nhiều hạt trên đỉnh pít tông. Nếu ta lấy ra một hạt trên đỉnh pít tông. Thể tích khí ở bên dưới sẽ nới rộng ra một chút là đạt được trạng thái cân bằng mới (P1P_1, V1V_1). Lấy ta tiếp tục việc lấy ra dần từng hạt một cách chậm rãi ta sẽ thu được (P2P_2, V2V_2), (P3P_3, V3V_3), ...(PTP_T, VTV_T). Với T là số lần lấy hạt ra, và các trạng thái (PtP_t, VtV_t) đều cân bằng NĐH (1tT1 \leq t \leq T). Quá trình chuyển từ trạng thái (P0P_0, V0V_0) sang (PTP_T, VTV_T) cực kỳ chậm như vậy là quasi-static.

Lý do người ta quan tâm đến quá trình quasi-static là tất cả Quá trình có thể đảo ngược đều là quá trình quasi-static. Nếu ta thêm lại dần dần các hạt đã lấy ra (từng chút từng chút một như lúc lấy ra) ta sẽ thu được các trạng thái trung gian. Trong điều kiện lý tưởng (các hạt vô cùng nhẹ, không có ma sát, thả không vận tốc ban đầu,...), các trạng thái trung gian sẽ chính là đảo ngược quá trình lấy ra (PTP_T, VTV_T), (PT1P_{T-1}, VT1V_{T-1}),...(P1P_1, V1V_1), (P0P_0, V0V_0). Quá trình đảo ngược này được minh hoạ trong hình 4.

II. Diffusion trong Deep Learning

Trong phần này chúng ta sẽ tìm hiểu cách ý tưởng Diffusion trong Deep Learning được triển khai.

1. Ý tưởng tổng quan

Ý tưởng cơ bản của Diffusion trong Deep Learning là phá hủy cấu trúc của dữ liệu một cách có hệ thống và cực kỳ chậm thông qua Quá trình khuếch tán thuận. Quá trình này có tính lặp lại và được minh họa ở hình 6. Sau đó chúng ta sẽ học cách để đảo ngược quá trình này. Quá trình đảo ngược được minh họa ở hình 7.

Cụ thể hơn, chúng ta sẽ định nghĩa quá trình khuếch tán thuận. Quá trình này chuyển phân phối phân phối dữ liệu (vốn phức tạp) sang một phân phối đơn giản và có thể dễ dàng làm việc (như lấy mẫu). Sau đó ta sẽ học cách để đảo ngược quá trình này. Nếu làm được điều này, Quá trình đảo ngược sẽ được sử dụng để sinh dữ liệu. Trong hai quá trình này, chúng ta chỉ cần sử dụng mạng neural để học cách thực hiện quá trình đảo ngược. Quá trình thuận hoàn toàn được cố định (fully defined, fixed) trước hoặc chỉ cần học một số biến phụ. Trong bài viết này, chúng ta sẽ cố định hoàn toàn Quá trình thuận.

Hình 6: Quá trình khuếch tán thuận

Hình 7: Quá trình đảo ngược

Hy vọng đến đây chúng ta đã hình dung được phần nào ý tưởng diffusion trong Deep Learning. Sau đây chúng ta sẽ đi vào tìm hiểu chi tiết của hai quá trình trên.

2. Quá trình khuếch tán thuận

Hình 8: Mô hình đồ thị của quá trình khuếch tán và quá trình đảo ngược

Quá trình thuận xuất phát từ phân phối của dữ liệu q(x0)q(x_0) và chuyển đổi dần dần thành phân phối có thể dễ dàng làm việc hơn q(xT)N(xt;0,I)q(x_T) \approx \mathcal{N}(x_t; 0,\,\bold{I}). Phân bố của q(xT)q(x_T) được chọn trước là một prior.

Trong hình 8, ở quá trình thuận, x0x_0 là dữ liệu; x1,x2,...,xTx_1, x_2,...,x_T là các biến ẩn (latent) có cùng chiều với x0x_0, T là số bước biến đổi. Quá trình khuếch tán được mô tả bằng một chuỗi Markov, nghĩa là trạng thái xtx_{t} chỉ phụ thuộc vào xt1x_{t-1}. Nhân biến đổi (transition kernel) q(xtxx1)q(x_{t}|x_{x-1}) được chọn sao cho:

q(xtxt1)=N(xt;xt11βt,βtI)q(x_{t}|x_{t-1}) = \mathcal{N}(x_t; x_{t-1}\sqrt{1 - \beta_t},\,\beta_t\bold{I})

Với βt\beta_t là tốc độ khuếch tán ở bước t. Đặt αt:=1βt\alpha_t := 1- \beta_t, αtˉ:=s=1tαs\bar{\alpha_t} := \prod_{s=1}^{t} \alpha_s và qua các phép biến đổi, ta có:

q(xtx0)=N(xt;αtˉx0,(1αtˉ)I)q(x_t|x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha_t}}x_{0},\,(1 - \bar{\alpha_t})\bold{I})

Đây là một tính chất quan trọng của quá trình thuận. Tính chất này trên cho phép ta lấy mẫu được xtx_t ở một bước t bất kỳ một cách trực tiếp (mà không phải đi từ bước 0, 1, 2,... rồi mới đến t). Phần biến đổi để thu được công thức trên các bạn có thể xem ở đây.

Phân phối của quá trình thuận thu được bằng cách bắt đầu từ q(x0)q(x_0) và áp dụng nhân biến đổi bên trên qua T bước là:

q(x0...T)=q(x0)t=1Tq(xtxx1)q(x_{0...T}) = q(x_0) \prod_{t=1}^{T} q(x_{t}|x_{x-1})

Như đã nói ở trên, quá trình khuếch tán thuận mà chúng ta lựa chọn đã hoàn toàn được cố định. Như vậy chúng ta cần chọn trước các giá trị β1,β2,...βT\beta_1, \beta_2, ...\beta_T, còn được gọi là lịch trình phương sai. Các giá trị này cần thoả mãn hai điều kiện:

  1. Tổng lượng nhiễu β1,β2,...βT\beta_1, \beta_2, ...\beta_T phải đủ lớn giúp chuyển phân phối dữ liệu trở thành nhiễu đẳng hướng Gaussian.
  2. Lượng nhiễu ở mỗi bước βt\beta_t phải đủ nhỏ để có thể đảo ngược được. Điều này tương tự như điều kiện để một quá trình là quasi-static ở phần I.

Giá trị của T phải được chọn trước. Để thoả mãn hai điều kiện trên thì T bắt buộc phải đủ lớn. T càng lớn thì ta có thể làm cho βt\beta_t càng nhỏ. Để tóm tắt lại phần này mình xin trích lại slide của tác giả paper ở hình 9.

Hình 9: Tóm tắt quá trình thuận

3. Quá trình đảo ngược

Quá trình đảo ngược cũng là một chuỗi Markov có những trạng thái như quá trình thuận nhưng theo chiều ngược lại, như được thể hiện ở hình 8. Quá trình đảo ngược còn được gọi quá trình sinh. Phân phối của quá trình sinh có được qua T bước áp dụng nhân biến đổi (tương tự quá trình thuận):

p(x0...T)=p(xT)t=1Tp(xt1xx)p(x_{0...T}) = p(x_T) \prod_{t=1}^{T} p(x_{t-1}|x_{x})

Với p(xT)=N(xT;0,I)p(x_T) = \mathcal{N}(x_T; 0,\,\bold{I}). Để đảo ngược được chúng ta chỉ cần xác định nhân biến đổi ngược p(xt1xx)p(x_{t-1}|x_{x}) và biến đổi T bước để thu được x0x_0. Vì βt\beta_t nhỏ nên ta biết rằng nhân biến đổi này cũng là một phân phối Gaussian. Còn mean và covariance của phân phối này thì chúng ta có thể sử dụng mạng neural để ước lượng. Ta có thể viết nhân biến đổi ngược dưới dạng tổng quát như sau:

p(xt1xx)=N(xt1;μθ(xt,t),Σθ(xt,t))p(x_{t-1}|x_{x}) = \mathcal{N}(x_{t-1}; \bold{\mu}_\theta (x_t, t),\,\Sigma_\theta(x_t, t))

Chi phí tính toán của Diffusion Model chủ yếu đến từ chi phí tính toán của hai mô hình μθ(xt,t)\bold{\mu}_\theta (x_t, t)Σθ(xt,t)\Sigma_\theta(x_t, t). Để tóm tắt lại phần này, mình lại xin trích slide của tác giả ở hình 10.

Hình 10: Tóm tắt quá trình đảo ngược.

Tips: hãy chú ý rằng các phân phối liên quan đến quá trình thuận là qq, với quá trình đảo ngược là pp

4. Huấn luyện

4.1 Hàm mục tiêu

Như đã nói ở trên với βt\beta_t đủ nhỏ, q(xt1,xt)q(x_{t-1}, x_t) là phân bố Gaussian. Chúng ta không thể dễ dàng ước lượng được phân phối này vì nó yêu cầu sử dụng cả tập dữ liệu. Tuy nhiên chúng ta có thể xác định được phân phối này khi đặt điều kiện trên x0x_0 bằng cách áp dụng quy tắc Bayes. Ta có q(xt1xt,x0)=q(xtxt1,x0)q(xtx0)q(xt1x0q(x_{t-1}| x_t, x_0) = \frac{q(x_t|x_{t-1}, x_0) q(x_t|x_0)}{q(x_{t-1}|x_0}. Qua các phép biến đổi ta có thể viết lại như sau:

q(xt1xt,x0)=N(xt1;μ~t(xt,x0),β~tI)q(x_{t-1}|x_t, x_0) = \mathcal{N}(x_{t-1}; \bold{\tilde{\mu}}_t(x_t, x_0),\,\tilde{\beta}_t \bold{I})

với β~t=1αˉt11αˉtβt\tilde{\beta}_t = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \beta_tμ~t(xt,x0)=1αt(xt1αt1αtˉϵt)\bold{\tilde{\mu}}_t(x_t, x_0) = \frac{1}{\sqrt{\alpha_t}} (x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha_t}}} \epsilon_t) với ϵtN(0,I)\epsilon_t \sim \mathcal{N}(0, \bold{I})

Chi tiết phần biến đổi để ra được công thức trên các bạn có thể xem thêm ở đây.

Việc huấn luyện được thực hiện bằng cách tối ưu chặn của Negative Log Likelihood:

LVLB=Eq[logq(x1:Tx0)pθ(x0:T)]E[logpθ(x0)]L_{VLB} = \mathbb{E}_q[log \frac{q(x_{1:T}|x_0)}{p_\theta (x_{0:T})}] \ge \mathbb{E} [-logp_\theta(x_0)]

Hàm mục tiêu này có thể được viết lại thành tổng của các thành phần KL-divergence và entropy để có thể tính toán dễ dàng.

Eq[DKL(q(xTx0)p(xT))LT+t>1DKL(q(xt1xt,x0)pθ(xt1xt))Lt1logpθ(x0x1)L0]\mathbb{E}_{q}[\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{T} \mid \mathbf{x}_{0}\right) \| p\left(\mathbf{x}_{T}\right)\right)}_{L_{T}}+\sum_{t>1} \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_{0}\right) \| p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)\right)}_{L_{t-1}} \underbrace{-\log p_{\theta}\left(\mathbf{x}_{0} \mid \mathbf{x}_{1}\right)}_{L_{0}}]

Trong đó, LTL_T là hằng số vì quá trình thuận không có gì để học cả nên ta có thể bỏ qua thành phần này.

Để tối ưu L0L_0, tác giả sử dụng một bộ decoder riêng, nhân biến đổi lúc này được tính theo cách riêng, bài viết này sẽ không nói kỹ về phần này.

Để tối ưu Lt1L_{t-1}, chúng ta cần một mạng neural học pθ(xt1xt)p_\theta(x_{t-1}|x_t) xấp xỉ q(xt1xt,x0)q(x_{t-1}|x_t, x_0). Vì βt~\tilde{\beta_t} là hằng số nên chúng ta chỉ cần huấn luyện μθ\mu_\theta để dự doán μt~\tilde{\mu_t}. Nhưng thay vì dự đoán trực tiếp μt~\tilde{\mu_t}, để đơn giản hơn, chúng ta có thể chỉ cần dự đoán ϵt\epsilon_t vì các thành phần khác của μt~\tilde{\mu_t} đã biết:

μθ(xt,t)=1αt(xt1αt1αtˉϵθ(xt,t))\bold{\mu}_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} (x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha_t}}} \epsilon_\theta(x_t, t))

Như vậy ta đã quy việc giải bài toán sinh dữ liệu về bài toán hồi quy. Đối với các thành phần KL-divergence, việc tính toán có thể thực hiện khá đơn giản vì các phân phối đều là Gaussian.

4.2 Kiến trúc mô hình

Nhìn chung Diffusion Models không cần kiến trúc chuyện dụng như Flow Models. Do yêu cầu về chiều của đầu ra phải giống với chiều của đầu vào, các mạng tương tự U-Net thường được sử dụng. Những mô hình này có một vài thay đổi so với thông thường để xét đến cả yếu tố bước thời gian (t). Một cách làm đơn giản là mã hoá thông tin thời gian tt dưới dạng position embedding có dạng hình sin rồi thêm vào mỗi khối residual.

Tra cứu (sẽ cập nhật sau)

Phần cuối cũng bài viết cung cấp mục tra cứu và giải thích các ký hiệu phổ biến được sử dụng.

Lời kết

Trong bài này mình đã trình bày cơ bản về Diffusion Models qua góc nhìn của mình. Phần kết nối giữa Khuếch tán trong Nhiệt động học và trong Học sâu mình không trình bày nhiều vì muốn các bạn có hình dung riêng. Hy vọng các bạn thấy bài viết này hữu ích. Nếu có vấn đề hay thắc mình gì đừng ngại cho mình biết ở phần comment nhé. Cảm ơn các bạn đã đọc bài. Hẹn gặp lại ở bài viết sau về chủ đề Diffusion Models.

Tài liệu tham khảo

Bình luận

Bài viết tương tự

- vừa được xem lúc

Tấn công và phòng thủ bậc nhất cực mạnh cho các mô hình học máy

tấn công bậc nhất cực mạnh = universal first-order adversary. Update: Bleeding edge của CleverHans đã lên từ 3.1.0 đến 4.

0 0 42

- vừa được xem lúc

[Deep Learning] Key Information Extraction from document using Graph Convolution Network - Bài toán trích rút thông tin từ hóa đơn với Graph Convolution Network

Các nội dung sẽ được đề cập trong bài blog lần này. . Tổng quan về GNN, GCN. Bài toán Key Information Extraction, trích rút thông tin trong văn bản từ ảnh.

0 0 219

- vừa được xem lúc

Trích xuất thông tin bảng biểu cực đơn giản với OpenCV

Trong thời điểm nhà nước đang thúc đẩy mạnh mẽ quá trình chuyển đổi số như hiện nay, Document Understanding nói chung cũng như Table Extraction nói riêng đang trở thành một trong những lĩnh vực được quan tâm phát triển và chú trọng hàng đầu. Vậy Table Extraction là gì? Document Understanding là cái

0 0 230

- vừa được xem lúc

Con đường AI của tôi

Gần đây, khá nhiều bạn nhắn tin hỏi mình những câu hỏi đại loại như: có nên học AI, bắt đầu học AI như nào, làm sao tự học cho đúng, cho nhanh, học không bị nản, lộ trình học AI như nào... Sau nhiều lần trả lời, mình nghĩ rằng nên viết hẳn một bài để trả lời chi tiết hơn, cũng như để các bạn sau này

0 0 157

- vừa được xem lúc

[B5'] Smooth Adversarial Training

Đây là một bài trong series Báo khoa học trong vòng 5 phút. Được viết bởi Xie et. al, John Hopkins University, trong khi đang intern tại Google. Hiện vẫn là preprint do bị reject tại ICLR 2021.

0 0 45

- vừa được xem lúc

Deep Learning với Java - Tại sao không?

Muốn tìm hiểu về Machine Learning / Deep Learning nhưng với background là Java thì sẽ như thế nào và bắt đầu từ đâu? Để tìm được câu trả lời, hãy đọc bài viết này - có thể kỹ năng Java vốn có sẽ giúp bạn có những chuyến phiêu lưu thú vị. DJL là tên viết tắt của Deep Java Library - một thư viện mã ng

0 0 139