Mở đầu.
Halu, Xin chào mọi người, lâu quá rồi chưa quay trở lại viết 1 bài viết nào cả, nên là nhân cơ hội được nghỉ thư giãn 1 vài ngày cuối tuần ngồi đọc đọc xem có gì vui không 😁 thì có 1 paper khá hay vừa được public vào đầu năm, tên thì mọi người biết luôn rồi nhỉ LLaDA (Large Language Diffusion Model), paper này được đánh giá là một trong những hướng đi mới thú vị và đổi mới cách tư duy. Thế thì thử đi sâu vào xem các pháp sư đã làm gì model nào 😚
Paper : Large Language Diffusion Models
Hold up 🖐🏻 :
- Trước tiên thì bài viết này chỉ là một bài viết tham khảo mình tự ngồi đọc paper, tra chatgpt 😀, bla bla để viết nên sẽ có nhiều sai xót, mong mọi người đóng góp thêm.
- Toán mình khá dốt, và mình là engineer, không phải dân chuyên research, chỉ đơn giản sự tò mò đưa mình đi vọc vách mọi thứ for fun =)). Cho nên hãy đọc for fun thôi, muốn hiểu kỹ hơn thì mọi người phải trực tiếp ngồi vọc vạch paper.
- Đây chỉ đơn giản là thú vui ngoài giờ của mình để luyện khả năng đọc paper thôi.
Yep, bài viết dự là bài viết sẽ hơi dài một chút nên là upvote lấy động lực cho mình trước đi đã nào :v :v
Okee, Xong rồi, bắt đầu đọc hiểu xíu nào.
Lý do và động lực cải tiến.
We contend that the intelligence of LLMs—manifested in scalability, instruction-following, in-context learning, conversational ability, and compression—stems not from the autoregressive mechanism per se, but rather from the core principle of generative modeling: approximating the true language distribution through maximum likelihood estimation.
Như trên thì tác giả có đang nói đến trí tuệ hay khả năng
quan trọng của LLM nằm ở 5 ý chính :
- Scalability : Khả năng mở rộng của LLMs khi mô hình hoạt động hiệu quả hơn với số lượng tham chiếu lớn hơn, với lượng dữ liệu huấn luyện nhiều hơn hoặc tài nguyên tính toán nhiều hơn. Tuy nhiên càng gần trở lại đây thì pre-training đang dần
Hit the wall
, việc tăng kích thước mô hình lớn hơn, tài nguyên hơn không làm cho model trở nên tốt hơn hay outperform hẳn so với các model cùng kích thước khác. - Instruction-following : Là khả năng thực hiện nhiệm vụ thông qua fine tuning hoặc instruction tuning.
- In-context learning : Là khả năng học trong ngữ cảnh, tức là llm học được từ các ví dụ được cung cấp trong prompt mà không cần cập nhật tham số của mô hình (few shot learning). Vì llm được huấn luyện để
predict next token
nên chúng có thể tổng quát hoá và tái tạo mẫu đã thấy trong ngữ cảnh (prompt), Điều này không phải do tính autoregressive mà do trong quá trình post-training theo nguyên tắc mô hình sinh đã giúp LLM tối ưu hoá phân phối của mô hình theo dữ liệu đầu vào. - Conversational ability : Là khả năng tham gia vào các cuộc hội thoại một cách tự nhiên. hiểu ngữ cảnh và tiếp tục cuộc nói chuyện một cách mạch lạc. Điều này là do mô hình đã học được quy luật ngôn ngữ trong quá trình post training mô hình.
- Compression : Là khả năng nén kiến thức từ một lượng lớn dữ liệu vào các tham số của mô hình (Q, K, V, ....)
The key insight overlooked previously is: it is the generative modeling principles, rather than the autoregressive formulation
Tác giả nhấn mạnh rằng việc LLm có được các khả năng trên không chỉ đơn giản là hệ quả của việc sử dụng cơ chế autoregressive (hiểu đơn giản là cơ chế sinh text tuần tự từ trái qua phải, ... - hay next-tokens prediction paradigm
) mà quan trọng hơn là nguyên tắc cốt lõi của mô hình sinh : Việc tối ưu hoá xác suất của mô hình sinh sao cho nắm bắt dược phân phối ẩn thực sự của dữ liệu được train.
Generative Modeling Principles.
hừm, thế thì làm sao để tối ưu hoá xác suất của mô hình sao cho nó nắm bắt được phân phối ẩn của dữ liệu ?
Có hai phương pháp phổ biến giúp làm điều này : MLE (hay maximum likelikhood estimation) và KL divergence (Kullback-Leibler divergence). Rồi rồi đến toán rồi, làm nó lẹ tẹo nào 🥲. Dưới đây là công thức của generative modeling priciples
mà trong paper có nhắc đến. Nói một cách đơn giản hơn và dân dã thì nó là hàm mục tiêu để tối ưu hoá xác suất của model với xác suất ẩn của dữ liệu đầu vào.
Oke, Thì vì ở phía sau cũng sẽ phải động một chút đến toán khi diễn giải công thức tính xác suất của LLaDA, cho nên chúng ta đi diễn giải một chút công thức ở trên nhé.
MLE - Maximum Likelihood Estimation.
Thường thì mình không viết chi tiết như này 🫠, nhưng mà sau đọc lại đỡ mất công load não thì đi từ cơ bản, ôn lại bãi cũ một tẹo. Nếu bạn muốn skip đoạn toán này thì kéo xuống phần Note ngay sau đoạn toán nhé.
Với một tập huấn luyện gồm N mẫu, x1, x2, .. xn, mỗi xi là một đoạn văn, một câu hay một chuỗi tokens - độ dài max là L được lấy từ phân phối thực hàm likelihood cho mô hình cơ bản sẽ như sau :
- là hàm likelihood, là xác suất của toàn bộ tập dữ liệu theo mô hình
- là xác suất của mô hình với điểm dữ liệu xi.
Mục tiêu của MLE là tìm sao cho xác suất tổng của dữ liệu thực tế là cao nhất:
Vì xác suất nhỏ dễ dẫn đến tràn số nên người ta nên người ta thường dùng log likelihood
thay vì likelihood:
Cuối cùng công thức tối ưu thực tế của MLE là :
nhưng thay vì tìm cực đại hoá hàm log likelihood thì thường tối thiểu hoá negative log likelihood, từ đây có thể sử dụng các phương tối ưu như gradient descent hay adam, ...
Nốt nào, trong LLM, ta coi mỗi văn bản x là một chuỗi token (x1, x2, ... xL - L là độ dài tối đa như đã nói ở trên) với mô hình autoregressive thì xác suất cả chuỗi được tính như sau :
Thay vô thì công thức cuối cùng sẽ trông như này :
KL Divergence - Kullback-Leibler divergence.
Công thức đo lường sự khác biệt giữa hai phân phối xác suất P(x) và Q(x). Trong LLMs, thì nó giúp đánh giá mức độ khớp giữa phân phối thực tế của dữ liệu và phân phối của mô hình .
Thì công thức của nó như sau :
Thay P(x) = và Q(x) = thì công thức thành :
Để ý một chút thì ở đây là hằng số do phân phối data không đổi nên :
- là hằng số
- là hằng số.
Mục tiêu tối thiểu hoá KL divergence theo chỉ còn lại là : tương đương việc hay chính là maximum likelihood estimation mà đã trình bày sơ bộ ở bên trên.
Tóm gọm.
Việc tối ưu hoá của MLE, hay KL divergence chỉ cần hiểu đơn giản là tối ưu hoá xác suất của mô hình () sao cho gần với xác suất của dữ liệu nhất.
Trong autoregressive modeling, việc áp dụng MLE tương đương với huấn luyện mô hình dự đoán token tiếp theo
một cách chính xác nhất có thể
dựa trên các tokens trước đó
Kết quả là một mô hình LLM có phân phối xấp xỉ với giúp mô hình sinh văn bản giống với dữ liệu thực nhất.
Thì cuối cùng tác giả lập luận rằng, việc mô hình LLm đạt được những khả năng mà đã trình bày ở đầu, một phần không nhỏ là dựa vào Generative Modeling Principles chứ không hoàn toàn là cơ chế
autoregressive
. Và từ đó sinh ra LLaDA.
FACT : Mọi mô hình chỉ đơn giản là mô hình xác suất, cố gắng học dữ liệu đầu vào, nhưng mà đứng tin nó quá vì đơn giản nó chỉ là xác suất thôi 🫠
Large Language Diffusion Models
LLaDA, a diffusion model trained from scratch under the pre-training and supervised finetuning (SFT) paradigm. LLaDA models distributions through a forward data masking process and a reverse process, parameterized by a vanilla Transformer to predict masked tokens. The core of LLaDA is a mask predictor.
LLaDA như tên, mô hình được xây dựng theo ý tưởng diffusion models. Hãy quan sát hình dưới nhé.
LLaDA được hoạt động theo 2 cơ chế chính :
- Forward process : hiểu đơn giản thì là quá trình mask tokens
- Reverse process : ngược lại với forward process, quá trình này sẽ dự đoán các toàn bộ các mask tokens.
LLaDA được train trên 2 giai đoạn và quá trình Sampling:
- Pre-training : LLaDA được train với text mà mỗi tokens sẽ được
che đi
một cách ngẫu nhiên và độc lập với nhau. Tỷ lệ này được gọi là t từ 0 đến 1, với mỗi tokens sẽ có xác suất t bị che đi và độc lập với các từ khác. Nhiệm vụ là dự đoán tất cả các mask token dựa trên các tokens còn lại - SFT : Tương tự như pre-training, tuy nhiên sẽ chỉ mask phần
response
giúp mô hình hiểu được instruction following và các tính chất khác như in context learning trong prompt, ... Mục tiêu cũng như vậy, dự đoán tất cả các mask dựa trên các tokens còn lại. - Samling (hay cũng có thể là quá trình Inference) : Có thể hiểu đây là quá tình
denoise
từ t = 1 (toàn bộ các tokens đều bị che) cho đến t = 0 (Khi tất cả các tokens đều được dự đoán ra). Đầu tiên thì LLaDA sẽ dự đoán tất cả các từ bị che cùng một lúc sau đó áp dụng chiến lược remask để tiếp tục quá trình. Chiến lược remask có thể dựa trênlow confidence remasking
- những token nào có conf thấp thì cho remask lại,Random Remasking
- che ngẫu nhiên những token được dự đoán trước đó một các ngẫu nhiên, ngoài ra còn cóSemi-autoregressive Remasking
chia chuỗi response thành các khốiblocks
nhỏ và thực hiện samling trên từng blocks một và thực hiện từ trái qua phải như autoregressive.
Thì ở trên chúng ta cũng hình dung ra được việc LLaDA chạy như nào rồi nhỉ, đơn giản như quá trình samling thôi, forward làm nhiệm vụ mask tokens, sau đó reverse sẽ dự đoán toàn bộ mask tokens, lặp đi lặp lại cho đến khi nào hoàn thiện chiến lược remask thôi 😗.
Tiếp theo chúng ta sẽ đi sâu hơn về việc training và inference của mô hình LLaDA, cách mô hình học được biểu diễn phân phối của dữ liệu thông qua mục A. Formulation of Masked Diffusion Models
mà tác giả đã trình bày.
Training.
Như đã biết thì model LLaDA thông qua 2 cơ chế chính : forward và reverse process. Vậy thì quá trình đó được diễn ra như thế nào và tại sao mô hình lại học được phân phối xác suất của dữ liệu.
Forward Process.
Ở trên chúng ta đã đề cập đến việc forward là quá trình mask các tokens một cách độc lập với tỷ lệ t. Vậy thì quá trình này thực chất được diễn ra như thế nào ?
These models introduce a forward process {xt} indexed by a time . This process gradually and independently masks all tokens in the sequence x0. At time t = 0, the data point x0 is fully observed with no masks, while for t ∈ (0, 1], xt represents latent variables with varying mask ratios in expectation.
Ở đây, thì tác giả đã nói rằng forward process được thực hiện với t từ 0 đến 1. Với t = 0 thì x0 không bị mask, còn t = 1 thì x0 sẽ bị mask hoàn toàn. Để cho clear, thì ở đây tác giả xử lý t với N samling steps . Ví dụ với N = 5, thì t sẽ được tăng dần [0, 0.2, 0. 4, 0.6, 0.8, 1] , với mỗi t sẽ sinh ra một , tương ứng với đó thì mỗi tokens trong sẽ bị mask với tỷ lệ t và tỷ lệ này độc lập với từng tokens.
Từ đó chúng ta xây dựng ra được phân phối có điều kiện để sinh ra được từ :
Với xác suất có điều kiện cho mỗi token như sau :
M đại diện cho việc token đó bị mask. Còn t và 1-t thì đã giải thích ở trên, tỷ lệ token bị mask hay không.
Ở đây tác giả có nhắc đến một 2 khái niệm khá thú vị :
Notably, the linear masking probability is analogous to but distinct from, the noise schedule in continuous diffusion models (Sohl-Dickstein et al., 2015; Ho et al., 2020; Song et al., 2020). This linearity is motivated by the assumption that the information in the text is proportional to the number of tokens on average, making it reasonable to lose information linearly during the forward process.
Linear Masking Probability : Xác suất mask tuyến tính, là việc làm mờ dần câu gốc, cụ thể mỗi token có xác suất bị mask tăng theo thời gian t, với t chạy từ 0 đến 1.
Ví dụ :
- Với t = 0 : không có tokens nào bị mask.
- Với t = 0. 5 : có khoảng 50% các tokens bị mask
- Với t = 1 : toàn bộ tokens đều bị mask.
hold up ... tỷ lệ t là độc lập với mỗi token trong câu, sao lại viết 50% tokens bị mask ?
Đúng thật, với t là một tỷ lệ xác suất, đương nhiên sẽ xảy ra các trường hợp không tokens nào bị mask hay thậm chí toàn bộ tokens bị mask. Điều này chỉ đơn giản là xuất phát từ Kỳ vọng trung bình (Expectation). Với t = 0.5 Xác suất để toàn bộ câu bị mask (giả sử độ dài L = 10) = = 1/1024 = 0.1% . Tỷ lệ mask toàn bộ rất nhỏ, với 1-t, hay tỷ lệ không mask cũng tương tự.
Với mỗi token độc lập với nhau thì trung bình sẽ có t x L token sẽ bị mask, hay tương đương với t% bị mask.
Noise Schedule : Đây là quá trình thêm noise Gaussain vào dữ liệu gốc theo một lịch trình, lịch trình này xác định mức độ noise được thêm vào ở mỗi bước thời gian t, mức độ noise được thêm vào có thể tuyến tính, bậc 2 hay nhiều dạng khác tùy thiết kế.
Tóm gọn
Cách tiếp cận của forward process dựa trên một giả định rằng information in the text is proportional to the number of tokens on average
.
Hay là thông tin trong văn bản tỷ lệ thuận với số lượng token trung bình. Mỗi token trong câu được xem là mang 1 lượng thông tin tương đương (các token là độc lập với nhau), việc mask đi một token, ta sẽ mất đi một phần thông tin tương ứng.
Việc thực hiện linear masking, tương đương với việc thêm noise trong diffusion model, giúp cho mô hình học cách khôi phực dữ liệu gốc từ trạng thái làm mờ
. 🫠
Reverse Process.
Thuật toán của Reverse Process như trên.
Thì Quá trình Reverse Process thì nghược lại, với t từ 1 về 0, model sẽ sinh ra hay dự đoán các token bị mask từ câu masked trước đó.
Nếu để ý trong thuật toán có một biến s = t - 1/N hay chính là bước tiếp theo dự đoán (timestep). Tại mỗi step t, mask predictor hay model sẽ lấy đầu vào là làm đầu vào và dự đoán toàn bộ masked tokens
. Sau đó áp dụng các chiến lược remask khác nhau để mask lại những token
theo chiến lược, ví dụ ở trong ảnh thì là Random Remask
, với tỷ lệ s / t, vị trí i của câu dự đoán sẽ bị mask lại.
Tại sao lại là s/t ? Mình hiểu đơn giản thì đây chỉ là 1 hệ số điều chỉnh khi dự đoán từ trang thái t sang s thôi, hoặc có thể đặt đơn giản hơn là s trực tiếp cũng được, vì s là step tiếp theo trực tiếp của t, vì s < t thì s / t luôn nhỏ hơn 1 và giá trị nằm trong khoảng từ s đến t luôn. Phù hợp để đạt làm xác suất remask lại.
Oke, thì từ đó dẫn đến phân phối xác suất có điều kiện cho quá trình reverse để dự đoán từ cũng gần gần giống với forward process như sau :
Với xác suất cho mỗi tokens là :
ở đây chúng ta có 4 trường hợp :
- : nếu vị trí i không bị mask và vị trí trong s và t giống nhau thì xác suất luôn bằng = 1. Luôn đúng, vì tại i trong xt mà không bị mask thì luôn gán lại i trong xs bằng luôn với i trong xt.
- : chính là trường hợp else trên thuật toán, vị trí i trong xs bị mask lại với tỷ lệ s/t
- : đây là nghược lại so với trường hợp trên chính là (1 - s/t), nhưng mà có thêm là xác suất có điều kiện để dự đoán vị trí i trong xs trên xt.
- otherwise : Nhưng trường hợp còn lại ví dụ như i trong t khác M, nhưng trong s lại dự đoán là M, ... thì xác suất = 0
Ngoài chiến thuật Random Remask
như đã trình bày ở trên thì tác giả còn trình bày một thuật toán khác là Low-confidence remask như ảnh ở dưới.
Nhìn chung thì cũng tương tự như Random Remask, tuy nhiên thay vì sử dụng một xác suất để quyết định xem có mask lại token được đự đoán hay không, thì thuật toán này sẽ mask lại các tokens mà bị confidence thấp.
Cross-Entropy Loss.
Cuối cùng thì model được train bằng cross entropy loss cho việc nhiệm vụ dự đoán mask, được mô tả bởi công thức sau :
- là chuỗi gốc, lấy mẫu từ dữ liệu huấn luyên.
- t là thời gian, hay xác suất được lấy mẫu đồng đều từ [0, 1]
- là chuỗi sau khi masking từ với xác suất t
- : là hàm chỉ thị, bằng 1 nếu token thứ i bị masked, 0 nếu nghược lại. Việc này giúp chỉ tính toán loss trên những token bị mask.
- : Xác suất dự đoán cho token gốc tại vị trí i, dựa trên
- 1/t : là việc chuẩn hóa loss dựa trên tỷ lệ token bị masked trung bình tại t.
Cuối cùng thì người ta chứng minh được hàm Loss function là một cận trên (Upper bound) của negative log likelihood (trong paper có đề link chứng minh - nhưng mà nhiều toán quá nên mình lấy kết quả ra thôi 🫠) :
Từ đó có thể thấy rằng việc Minimum hàm loss chính là maximum đối với hàm negative log likelihood, phù hợp với việc giúp mô hình học được phân phối xác suất của dữ liệu mà ta đã trình bày ở Generative Modeling Principles.
Ngoài ra thì chúng ta cũng có thể nhìn thuật toán pre-training của mô hình LLaDA được mô tả bởi hình dưới đây
Tuy nhiên việc sử dụng Loss với xác suất t của mỗi token, và các tokens độc lập với nhau, dẫn đến số lượng token bị mask có thể dao động ngẫu nhiên, đặc biệt là các chuỗi ngắn, gây ra phương sai cao.
Tác giả đề xuất công thức mới tương đương :
- l : là số lượng token được mask, được lấy mẫu đồng đều từ tập {1,2,3,..., L} với L là độ dài của chuỗi.
- : là chuỗi dữ liệu gốc.
- : là chuỗi bị mask ngẫu nhiên l token từ mà không thay thế.
- Còn lại thì tương tự với hàm loss cũ.
Công thức mới cố định tỷ lệ token bị mask là l/L giúp giảm phương sai. Và kết quả thực nghiệm với công thức cũ cần hơn 1000 monte carlo estimates để đạt được kết quả ổn định, còn công thức mới chỉ cần 128 mẫu để đạt ổn định.
Tóm gọn.
Thông qua việc che dần các token các câu bằng quá trình forward, và bắt mô hình học phân phối xác suất thông qua quá trình Reverse, tất cả đều nhằm thỏa mãn lý thuyết về Generative Modeling Priniples.
Ngoài ra có một điểm đặc biệt trong LLaDA,về việc sử dụng EOS :
- Padding : Token EOS được thêm vào cuối các cặp prompt - response ngắn hơn trong mini-batch để đảm bảo các chuỗi có cùng độ dài.
- Masking : Trong training, token EOS được coi là một phần của response và được mask để tính vào hàm Loss.
- Loại bỏ tỏng sampling : Khi inference, các token EOS được loại bỏ ra khỏi đầu ra để đảm bảo phản hồi cuối cùng không chứa các token EOS không cần thiết. Điều này giúp mô hình học được cách kiểm soát độ dài của response bằng việc sinh ra token EOS, nâng cao chất lượng, tính tự nhiên và phù hợp với prompt đầu vào.
Kết luận.
Bài viết cũng đã đủ dài và chúng ta đã đi qua một số lý thuyết điển hình của paper, cách tiếp cận của phương pháp mới, về việc sử dụng diffusion model để sinh text, nâng cao khả năng suy luận nghược, thay vì sử dụng autoregressive như các mô hình LLM trước đây. Về cơ bản thì đây là một tư tưởng khá hay, tuy nhiên vẫn còn một số hạn chế :
- Quá trình suy luận khá tốn kém về mặt tính toán, ở tại mỗi bước sampling, LLaDA sẽ dự đoán toàn bộ các tokens bị mask theo chiến lược. Điều này dẫn đến việc LLaDA chậm hơn và tốn nhiều tài nguyên hơn so với ARM, đặc biệt với số lượng sampling step lớn.
- Phải Fix cứng length đầu vào, không thể Arbitrary-length được.
- Không sử dụng được các kỹ thuật tối ưu như KV caching.
- Đồng thời chất lượng không đồng đều ở các nhiệm vụ khác nhau.
Để cải thiện việc này, ngay sau đó đã có luôn một paper cải thiện các nhược điểm ở trên đó là Block Diffusion . Một paper khá hay chúng ta sẽ để một bài viết nào đó trong tương lai phân tích nó sau nhé 😁
Cuối cùng, ở bài viết này mình lược bỏ khá nhiều thứ đã được tác giả đề cập đến trong paper do mình không thấy nó đủ quan trọng và chưa phải trọng tâm để hiểu quá trình hình thành của mô hình, nếu mà mọi người muốn đọc sâu hơn nữa, và muốn update thêm phần nào thì xin hãy bình luận để mình cập nhật. Đến đây thì đã khá dài rồi, hẹn gặp mọi người ở bài viết sau 😚😚.
Nếu thấy nó hữu ích, hãy upvote giúp mình với nào 😉 để mình có động lực viết bài tiếp theo nào.