Giới thiệu
Nếu trong tay có dữ liệu lớn và muốn tăng hiệu suất mô hình, hầu như chúng ta sẽ nghĩ ngay tới việc scale kích thước mô hình lên. Điều này được chứng minh là cải thiện hiệu suất của mô hình và hiệu quả của việc sử dụng mẫu (sample efficiency). Tuy nhiên, nếu chỉ scale model size thì vẫn chưa đủ để giải quyết các task khó liên quan đến suy luận như: Toán học, commonsense reasoning và symbolic reasoning.
Các nghiên cứu trước đây đề xuất phương pháp tăng khả năng suy luận của LLM bằng 2 ý tưởng đơn giản sau:
- Thứ nhất, để cải thiện "arithmetic reasoning" (suy luận toán học) của LLM, ta có thể tạo ra các lý do hoặc giải thích bằng ngôn ngữ tự nhiên dẫn đến kết quả cuối cùng. Nói một cách dễ hiểu hơn là đưa ra những gợi ý để model suy luận từ từ cho tới khi cho ra kết quả cuối cùng. Một số nghiên cứu trước đây đã làm cho model có khả năng tạo ra các bước trung gian trong quá trình giải toán bằng cách sử dụng ngôn ngữ tự nhiên để giải thích các bước tính toán hoặc suy luận. Điều này được thực hiện bằng cách train lại model từ đầu hoặc finetuning pretrained model.
- Cách thứ hai, thay vì train lại model làm tốn tài nguyên và thời gian thì ta có thể thực hiện prompting. Cụ thể, ta chỉ cần cung cấp một số ví dụ minh họa cho model về cách thức mà task được thực hiện. Ví dụ, cho model biết một vài cặp câu hỏi và câu trả lời tương ứng có liên quan đến task, và từ đó model có thể học được cách thức giải quyết các câu hỏi tương tự mà nó chưa từng được huấn luyện trước đó.
Tuy nhiên, hai ý tưởng trên đều có những hạn chế:
- Với cách train với finetuning model thì rất là khó để xây dựng một dataset suy luận chất lượng. Tức là bình thường khi training một model, ta sẽ chỉ cần xây dựng bộ dataset với các cặp input-output ngắn gọn, nếu phải làm thêm cả cách suy luận từ input ra được output thì cực kì tốn nhiều công sức và thời gian.
- Với cách few-shot prompting ở trên thì kết quả cho được còn khá tệ trên các task yêu cầu suy luận và thường không cho hiệu suất ổn định nếu như tăng kích thước model lên.
Bài báo đề xuất phương pháp giải quyết các vấn đề trên bằng cách sử dụng prompting gồm 3 thành phần <input, chain of thought, output>
. Chain of thought là một chuỗi các bước suy luận bằng ngôn ngữ tự nhiên để dẫn tới output cuối cùng. Đây chính là phương pháp chain-of-though prompting.
Chain-of-Thought Prompting
Giống như các bài toán có từ 2 lời giải trở lên mà ta đã học từ hồi tiểu học Ý tưởng ở đây là tách một bài toán thành những bài toán nhỏ và giải chúng, trang bị cho model khả năng tạo các "chuỗi suy luận" để đưa ra kết quả.
Hình trên là ví dụ của một model thực hiện chain-of-thought để giải quyết một bài toán đố mà trước đó nếu không áp dụng chain-of-thought thì kết quả bị không chính xác.
Chain-of-thought có các đặc điểm làm cho LLM có thể giải quyết tốt hơn các bài toán cần suy luận:
- Chain of thought cho phép model phân tách một bài toán thành các bài toán nhỏ trung gian và giải chúng. Với bài toán cần nhiều bước suy luận sẽ yêu cầu lượng tính toán nhiều hơn.
- Chain of thought cung cấp khả năng giải thích hành vi của mô hình (tức là cách mô hình đưa ra kết quả), từ đó ta có thể debug các bước suy luận của mô hình, kiểm tra xem bước nào thì mô hình làm sai,...
- Chain of thought có thể ứng dụng (về mặt lý thuyết) cho tất cả các task mà con người có thể giải quyết bằng ngôn ngữ.
- Chain of thought dễ dàng được sử dụng trong các mô hình ngôn ngữ lớn bằng cách đưa vào few-shot prompting.
Arithmetic Reasoning
Với task này, tác giả sử dụng 5 benchmark:
- GSM8K https://huggingface.co/datasets/gsm8k
- SVAMP https://opendatalab.com/OpenDataLab/SVAMP
- ASDiv https://paperswithcode.com/dataset/asdiv
- AQuA https://paperswithcode.com/dataset/aqua
- MAWPS https://paperswithcode.com/dataset/mawps
Language model sử dụng là:
- GPT-3: Sử dụng các phiên bản text-ada-001, text-babbage-001, text-curie-001, và text-davinci-002. Các phiên bản này tương ứng với các mô hình InstructGPT với số lượng tham số tương ứng là 350M, 1.3B, 6.7B, và 175B.
- LaMDA: Có các mô hình với số lượng tham số lần lượt là 422M, 2B, 8B, 68B, và 137B.
- PaLM: Bao gồm các mô hình có số lượng tham số là 8B, 62B, và 540B.
- UL2 20B: Được xây dựng với 20 tỷ tham số.
- Codex: Sử dụng phiên bản code-davinci-002 trong OpenAI API.
Kết quả được thể hiện trong hình sau:
Tổng quan, việc sử dụng chain-of-thought cho kết quả tốt ở nhiều benchmark, model và cả các size khác nhau. Ta có một số nhận xét sau:
- Thứ nhất, sử dụng chain-of-thought chỉ tốt với các model có số lượng tham số từ 100B trở lên. Về mặt định tính, với các model nhỏ, sử dụng chain-of-thought đưa ra những suy luận mượt mà nhưng lại không được logic, từ đó dẫn đến hiệu suất thấp hơn so với cách sử dụng prompt thông thường.
- Thứ hai, chain-of-thought cho hiệu suất tốt ở các bài toán phức tạp. Cụ thể, tại benchmark GSM8K, việc sử dụng chain-of-though cho kết quả tốt gấp đôi tại model GPT và PaLM.
- Thứ ba, sử dụng chain-of-thought với model GPT-3 175B và PaLM 540B đạt tới kết quả SOTA, so sánh với việc finetuning model trên một task cụ thể.
Commonsense Reasoning
Benchmark được sử dụng là:
- CSQA: https://paperswithcode.com/dataset/csqa
- StrategyQA: https://paperswithcode.com/dataset/strategyqa
- Date
- Sports
- SayCan: https://say-can.github.io/
Kết quả được thể hiện trong hình dưới, với tất cả các task, sử dụng chain-of-though đạt kết quả tốt vượt qua standard prompting và thậm chí cả con người.
Symbolic Reasoning
Nhóm tác giả sử dụng 2 task để thử nghiệm khả năng symbolic reasoning của chain-of-thought như sau:
- Concat 2 kí tự cuối của tên. Như tên gọi, task yêu cầu concat 2 kí tự cuối của 1 tên. Ví dụ: Viblo GPT => oT. Task này khó hơn so với phiên bản concat 2 kí tự đầu tiên của tên (task này đã được hoàn thành tốt mà không cần đến chain of thought).
- Lật đồng xu. Task này yêu cầu model suy luận xem đồng xu đang sấp hay ngửa qua một loạt các thao tác. Ví dụ: Một đồng xu đang mặt ngửa. Minh lật đồng xu 1 lần, sau đó Huy tiếp tục lật đồng xu 1 lần tiếp. Hỏi đồng xu lúc này đang xấp hay ngửa => Câu trả lời là ngửa.
Ở đây, ta có 2 loại là in-domain test và out-of-domain test. In-domain test nghĩa là các ví dụ được đưa vào trong prompt có số step giống với tập test. Với out-of-domain test, các ví dụ được đưa vào trong prompt có số step khác với tập test. Chẳng hạn trong task concat 2 kí tự cuối, ví dụ trong prompt là với tên có 2 từ, nhưng khi test sẽ test trên tên có 3-4 từ. Điều này tương tự với task Lật đồng xu.
Kết quả được thể hiện trong hình dưới.
Việc sử dụng Chain-of-thought prompting tiếp túc outperform so với việc sử dụng prompting tiêu chuẩn, kể cả in-domain test hay out-of-domain test.
Kết luận
Giống như con người, việc hướng dẫn model đưa ra một chuỗi các suy luận giúp làm tăng hiệu suất của model mà không cần phải finetuning trên một task cụ thể nào đó. Tuy nhiên, mặc dù giống con người ở khoản đưa ra các suy luận nhưng chưa chắc model thực sự đang suy luận để đưa ra câu trả lời. Ngoài ra, như trình bày trong phần thực nghiệm, việc sử dụng chain-of-thought thường chỉ ngon khi sử dụng với model kích thước lớn (khoảng 100B tham số), việc sử dụng với model kích thước nhỏ hơn không mang lại kết quả tốt hơn so với cách prompting tiêu chuẩn. Bên cạnh đó, không có gì đảm bảo về hướng suy luận đúng khi áp dụng phương pháp này. Khi mô hình được kích hoạt để thực hiện một tác vụ nào đó dựa trên chuỗi suy nghĩ, nó tự sinh ra một loạt các bước, tương tự như cách con người suy nghĩ về vấn đề. Tuy nhiên, điều quan trọng là không có đảm bảo rằng mỗi bước trong chuỗi suy nghĩ của mô hình là hoàn toàn chính xác. Điều này có thể dẫn đến việc mô hình tạo ra cả các câu trả lời đúng và không đúng cho cùng một vấn đề. Trong thực tế, điều này có thể diễn ra khi mô hình tạo ra các chuỗi suy luận mà dẫn đến kết quả không chính xác hoặc không phản ánh đúng với thực tế. Vấn đề này là một hướng nghiên cứu tiềm năng trong tương lai, để cải thiện khả năng của mô hình tạo ra thông tin chính xác hơn, giúp nâng cao độ tin cậy của các mô hình ngôn ngữ trong việc suy luận và giải quyết vấn đề.
Tham khảo
[1] Chain-of-Thought Prompting Elicits Reasoning in Large Language Models