Giới thiệu
LLEMMA là một LLM cho một miền cụ thể (domain specific) là toán học. Llemma gồm 2 phiên bản là phiên bản 7 tỷ tham số và phiên bản 34 tỷ tham số. Điểm hay của LLEMMA là có khả năng sử dụng các công cụ tính toán để giải quyết các vấn đề toán học ví dụ như Python interpreter hoặc các định lý, định luật. Llemma cũng đạt hiệu suất SOTA so với các model public trong task về toán học.
Chi tiết phương pháp
Phương pháp để train model LLEMMA là tiếp tục pretraining Code Llama trên tập dữ liệu Proof-Pile-2.
Data: Proof-Pile-2
Nhóm tác giả xây dựng Proof-Pile-2, một tập dữ liệu gồm 55 tỷ token từ các bài báo khoa học, dữ liệu web có chứa thông tin về toán học, các code liên quan đến toán học.
Code. Nhóm tác giả xây dựng AlgebraicStack, một tập dữ liệu source code gồm 11 tỷ token từ 17 ngôn ngữ, bao gồm số học, ký hiệu và toán học hình thức (formal math). Tập dữ liệu bao gồm code được lọc từ Stack, GitHub và dữ liệu chứng minh chính thức (kiểu chứng minh định lý,... ).
Web data. Sử dụng OpenWebMath, là một tập dữ liệu có 15 tỷ token từ các web liên quan tới toán học chất lượng.
Bài báo khoa học. Sử dụng ArXiv là tập dữ liệu con của RedPajama với 29 tỷ token.
Dữ liệu ngôn ngữ và code tổng quát. Trong các tập dữ liệu được sử dụng để train LLEMMA thì có một lượng nhỏ dữ liệu tổng quát, lượng dữ liệu này đóng vai trò như một regularization. Vì tập dữ liệu pretrained của LLaMA 2 không được public nên nhóm tác giả sử dụng tập dữ liệu Pile thay thế. Dữ liệu training cuối cùng gồm:
- 95% Proof-Pile-2
- 2% từ Pile (với dữ liệu ArXiv bị loại bỏ, vì nó đã nằm trong Proof-Pile-2 rồi)
- 3% từ tập dữ liệu GitHub thuộc RedPajama
Thông tin về model và cách training
Model LLEMMA được khởi tạo từ model Code Llama, một model ngôn ngữ chỉ sử dụng phần decoder. Model Code Llama lại xuất phát từ Llama 2 và được tiếp tục training trên 500 tỷ token về code. Nhóm tác giả tiếp tục training model Code Llama trên tập dữ liệu Proof-Pile-2 với mục tiêu huấn luyện ngôn ngữ theo phong cách autoregressive. Trong đó, model 7 tỷ tham số được training với 200 tỷ token, trong khi model 34 tỷ tham số được training cho 50 tỷ token.
Các thông tin khởi tạo cho việc training model như sau:
- Cả 2 model 7 tỷ tham số và 34 tỷ tham số đều được train với bfloat16 mixed precision sử dụng thư viện GPT-NeoX
- Sử dụng 256 card GPU A100 với 40GB memory cho việc training
- Sử dụng Tensor Parallelism với world size 2 cho LLEMMA-7B và 8 cho LLEMMA-34B
- Sử dụng Flash Attention 2 để cải thiện băng thông và giảm lượng memory sử dụng
Thông tin training LLEMMA 7B:
- Được train 42,000 step với global batch size là 4 triệu token và độ dài context là 4096 token
- Mất khoảng 23,000 giờ trên trên GPU A100
- Learning rate được warm up đến giá trị sau 500 step
- Cosin decay được sử dụng để giảm maximum learning rate xuống 30 lần sau 48,000 step
Nhận thấy là số step ở scheduler (48,000 step) khác với số step cho việc training (42,000 step). Lý do là nhóm tác giả dự định train với 48,000 step nhưng bị loss NaN tại step 42,000. Nguyên nhân có thể là do quá trình tối ưu không ổn định hoặc do lỗi từ phần cứng.
Thông tin training LLEMMA 34B:
- Được train 12,000 step với global batch size là 4 triệu token và độ dài context là 4096 token
- Mất khoảng 47,000 giờ trên trên GPU A100
- Learning rate được warm up đến giá trị sau 500 step, sau đó giảm maximum learning rate xuống 30 lần
Đánh giá model
Chain-of-thought trong giải quyết vấn đề toán học
Đây là các task liên quan đến việc độc lập đưa ra lời giải mà không sử dụng thêm bất kì công cụ bên ngoài nào. Dữ liệu được sử dụng để đánh giá là các bộ: MATH, GSM8k, OCWCourse, MMLU-STEM và SAT.
Kết quả khi so sánh với các model SOTA được thể hiện trong bảng dưới.
Bảng dưới là kết quả voting cho 2 model LLEMMA và Minerva.
Giải quyết vấn đề toán học bằng cách sử dụng công cụ
Bộ dữ liệu được sử dụng để đánh giá là MATH và GSM8k kết hợp với sử dụng Python.
Formal math
Phần này gồm 2 task Informal-to-formal proving và Formal-to-formal proving. Một số kết quả của model trong các bài toán chứng minh được thể hiện dưới đây.
Tác động của Data Mixture
Khi training một mô hình ngôn ngữ, ta thường upsample các tập con chất lượng cao của dữ liệu training theo trọng số mixture. Ban đầu, nhóm tác giả thực hiện training trong thời gian ngắn với các tỷ lệ trọng số khác nhau, sau đó chọn một tỷ lệ giúp giảm thiểu sự phức tạp trên một tập hợp văn bản được tổ chức chất lượng cao (sử dụng tập huấn luyện MATH). Kết quả được thể hiện trong hình dưới:
Kết luận
Qua bài viết, ta đã tìm hiểu về LLM LLEMMA được sử dụng cho các task liên quan tới toán học. LLEMMA đạt hiệu suất hàng đầu trong đa dạng các task liên quan tới toán, không chỉ tính toán đơn thuần mà còn là các bài toán chứng minh step by step. Bên cạnh đó, ý tưởng xây dựng một tập dữ liệu training Proof-Pile-2 cũng rất hay, được kết hợp và trích chọn từ nhiều nguồn data khác nhau. LLEMMA và Proof-Pile-2 giúp mở ra nhiều ứng dụng hay về ứng dụng LLM cho toán học, các nhà toán học có thể sử dụng LLM như một công cụ hữu ích cho việc nghiên cứu và những người dùng như chúng ta có thể tận dụng LLM này cho việc học tập