Giới thiệu về LSTM
Hi mọi người, mình là Hiếu, ở bài viết trước mình có giới thiệu về RNN. Tuy RNN là một mô hình khá hay nhưng nó vẫn sẽ rất có thể bị Exploding Gradient hoặc Vanishing Gradient, từ đó gây ra hiện tượng đầu ra sai lệch cho mô hình. Hôm nay mình sẽ giới thiệu cho các bạn về LSTM – một mô hình có thể coi là một sự nâng cấp của RNN, giúp hạn chế hiện tượng Exploding Gradient và Vanishing Gradient.
1. Định nghĩa về LSTM
LSTM (Long Short-Term Memory) là một loại mạng nơ-ron hồi tiếp (RNN) được thiết kế để giải quyết vấn đề về việc ghi nhớ thông tin trong chuỗi dữ liệu dài hạn, vốn là điểm yếu của các RNN truyền thống do hiện tượng mất dần/lan truyền gradient. LSTM nổi bật bởi khả năng lưu trữ và truy xuất thông tin trong thời gian dài thông qua cấu trúc đặc biệt của các "cổng" (gates).
Sơ đồ hoạt động của LSTM như sau:
Nhìn rất quen đúng chứ, chúng ta có thể nói vui với nhau rằng LSTM là một RNN với những trang bị xịn hơn. Điểm đặc trưng của LSTM là các cổng, giúp kiểm soát việc truy xuất thông tin và việc lưu trữ.
LSTM bao gồm các đơn vị bộ nhớ (hay còn gọi là các memory cell) và 3 cổng chính để điều chỉnh luồng thông tin:
- Cổng quên (Forget Gate): Quyết định thông tin nào từ trạng thái ô (cell state) sẽ bị loại bỏ. Cổng quên sử dụng hàm sigmoid để tạo ra giá trị từ 0 (quên hoàn toàn) đến 1 (giữ hoàn toàn).
- Cổng nhập (Input Gate): Quyết định xem thông tin mới nào sẽ được thêm vào trạng thái ô. Cổng này cũng sử dụng hàm sigmoid để quyết định và hàm tanh để tạo ra giá trị mới.
- Cổng xuất (Output Gate): Quyết định thông tin nào từ trạng thái ô sẽ được sử dụng để tạo đầu ra. Kết hợp sigmoid và tanh để lọc và truyền thông tin.
Trạng thái ô (cell state) là "bộ nhớ dài hạn" của LSTM, cho phép lưu trữ thông tin qua nhiều bước thời gian. Trạng thái ẩn (hidden state) là đầu ra ngắn hạn, truyền thông tin đến bước tiếp theo.
Cơ chế hoạt động của LSTM
- Bước 1: Quên thông tin không cần thiết: Cổng quên sử dụng đầu vào hiện tại và trạng thái ẩn trước đó để quyết định thông tin nào trong trạng thái ô nên được bỏ đi.
- Bước 2: Cập nhật trạng thái ô: Cổng nhập chọn thông tin mới để thêm vào trạng thái ô, kết hợp với giá trị được tạo từ hàm tanh.
- Bước 3: Tạo đầu ra: Cổng xuất lọc trạng thái ô qua hàm tanh và sigmoid để tạo trạng thái ẩn mới, đồng thời là đầu ra của bước hiện tại.
Ưu điểm của LSTM
- Xử lý chuỗi dài: Nhờ cơ chế trạng thái ô, LSTM có thể lưu giữ thông tin qua nhiều bước thời gian, phù hợp với các bài toán như dịch máy, phân tích cảm xúc, hay dự đoán chuỗi thời gian.
- Khắc phục vấn đề gradient: Cổng quên và cổng nhập giúp điều chỉnh luồng thông tin, giảm thiểu vấn đề gradient biến mất hoặc bùng nổ.
- Linh hoạt: Có thể áp dụng cho nhiều loại dữ liệu chuỗi, từ văn bản, âm thanh đến dữ liệu tài chính.
Hạn chế
- Phức tạp tính toán: LSTM yêu cầu nhiều tài nguyên tính toán hơn RNN thông thường do cấu trúc cổng phức tạp.
- Khó tối ưu hóa: Việc huấn luyện LSTM có thể phức tạp, đòi hỏi kỹ thuật điều chỉnh siêu tham số cẩn thận.
- Bị thay thế bởi Transformer: Trong nhiều ứng dụng hiện đại (như NLP), mô hình Transformer (như BERT, GPT) tỏ ra hiệu quả hơn nhờ cơ chế attention.
2. Chi tiết hoạt động của LSTM
Ok, sơ lược về cách hoạt động của LSTM là như vậy, nhưng mà để hiểu rõ ta cần sử dụng thêm một chút toán nữa, sau đây là cách các thành phần trong LSTM được tính toán.
Trước khi đi vào chi tiết từng công thức, mình sẽ chú thích các thành phần trong LSTM, cũng như RNN, đa số có các ký hiệu mà chúng ta rất quen:
-
: Đầu vào tại thời điểm (vector đầu vào). Đây là dữ liệu được nhận ở bước hiện tại (ví dụ như là 1 từ trong câu, giá trị đo tại thời điểm ,...).
-
: Trạng thái ẩn của bước trước đó (). Chứa thông tin của toàn bộ chuỗi dữ liệu trước thời điểm , như một bộ nhớ ngắn hạn.
-
: Ma trận trọng số cho từng thành phần của mạng:
- : Trọng số của cổng quên (forget gate).
- : Trọng số cổng nhập (input gate).
- : Trọng số tạo giá trị ứng viên (candidate state).
- : Trọng số cổng xuất (output gate).
Các ma trận là các tham số học được trong quá trình huấn luyện, quyết định độ quan trọng của đầu vào và các trạng thái ẩn trước đó để tính tín hiệu cho từng cổng.
-
: Các vector bias cho từng cổng:
- : Bias cho cổng quên.
- : Bias cho cổng nhập.
- : Bias cho giá trị ứng viên.
- : Bias cho cổng xuất.
Bias giúp tăng tính linh hoạt cho mô hình khi kết hợp với các phép nhân ma trận.
-
: Trạng thái ô (Cell state) tại thời điểm , là bộ nhớ dài hạn của LSTM có chức năng lưu trữ thông tin quan trọng xuyên suốt chuỗi dữ liệu.
-
: Hàm kích hoạt sigmoid, cho đầu ra từ 0 đến 1 giúp điều khiển mức độ thông tin đi qua các cổng.
-
: Hàm kích hoạt tanh, cho đầu ra từ -1 đến 1 để chuẩn hóa giá trị cập nhật trạng thái ô và trạng thái ẩn.
Cổng quên (Forget Gate)
Hàm sigmoid cho đầu ra từ 0 đến 1 để cho cổng quên quyết định nên giữ lại hay loại bỏ thông tin nào từ trạng thái của ô trước :
- Nếu gần 1: Giữ lại thông tin cũ.
- Nếu gần 0: Quên thông tin đó.
Cổng nhập (Input Gate)
Cổng nhập cũng sử dụng hàm sigmoid để xác định mức độ để cập nhật trạng thái mới, cổng này sẽ kiểm soát lượng thông tin mới sẽ được thêm vào trạng thái ô hiện tại.
Giá trị ứng viên trạng thái ô (Candidate Cell State)
Hàm tanh cho giá trị từ -1 đến 1 nhằm tạo ra một giá trị ứng viên để bổ sung thông tin mới vào trạng thái ô.
Cập nhật trạng thái ô (Cell State Update)
Trạng thái ô tại thời điểm là sự kết hợp giữa thông tin cũ (sau khi đã được chọn lọc bởi cổng quên) và thông tin mới (được xác định qua cổng nhập và giá trị ứng viên). Phép toán sẽ nhân từng thành phần để chọn lọc thông tin giữ lại và thông tin mới.
Cổng xuất (Output Gate)
Cổng xuất quyết định phần nào của trạng thái ô sẽ được đưa ra làm trạng thái ẩn , là đầu ra của LSTM tại thời điểm hiện tại.
Trạng thái ẩn (Hidden State)
Trạng thái ẩn là đầu ra cuối cùng của ô, sẽ làm đầu vào cho bước tiếp theo hoặc lớp phía sau (Chẳng hạn như lớp dự đoán).
3. Quy trình huấn luyện của LSTM
Ok, nếu mọi người đã đọc đến đây, thì hoàn toàn các bạn đã hiểu về cơ chế của LSTM rồi đó, bởi vì quy trình huấn luyện LSTM bao gồm các bước cơ bản giống với RNN: Lan truyền thuận -> Tính hàm mất mát -> Lan truyền ngược -> Cập nhật trọng số. Cái khác là LSTM có cấu trúc phức tạp nên các bước này có những đặc điểm riêng.
3.1 Lan truyền thuận (Forward Propagation)
Giống với RNN, tại mỗi bước thời gian , mạng nhận đầu vào , kết hợp với trạng thái ẩn trước , tính trạng thái ẩn mới và tạo đầu ra thông qua một phép biến đổi (Ở đây mình dùng hàm như đã trình bày ở trên).
Tuy nhiên với cấu trúc đặc biệt, LSTM sử dụng trạng thái ô và trạng thái ẩn cùng với 3 cổng để xử lý thông tin, quy trình của chúng mình đã trình bày ở cơ chế hoạt động ở mục 1.
3.2 Tính hàm mất mát (Loss)
Giống như RNN, hàm mất mát được tính dựa trên sự khác biệt giữa đầu ra dự đoán (Dựa trên ) và giá trị thực tế. Các hàm mất mát phổ biến gồm MSE cho bài toán hồi quy hoặc Cross-Entropy Loss cho bài toán phân loại.
Bước này ở LSTM không có khác biệt lớn, chỉ là LSTM thường được dùng trong các bài toán phức tạp (Dịch máy, dự đoán chuỗi), hàm mất mát có thể được tính trên toàn bộ chuỗi hoặc trên các đầu ra tại mỗi bước thời gian.
3.3 Lan truyền ngược qua thời gian (Backpropagation Through Time – BPTT)
Điểm giống với RNN là Gradient của hàm mất mát được tính ngược từ đầu ra về các tham số (Các trọng số và bias ) qua các bước thời gian. Gradient sẽ được truyền ngược qua các trạng thái ẩn để cập nhật trọng số.
Tuy nhiên, điểm “ăn tiền” của LSTM là:
- Tránh “biến mất” gradient: Nhờ cơ chế trạng thái ô và các cổng, LSTM duy trì thông tin qua các bước thời gian dài hiệu quả hơn. Trạng thái ô () truyền thông tin gần như tuyến tính, giảm thiểu việc gradient bị thu hẹp quá mức.
- Lan truyền ngược qua trạng thái ô: Trạng thái ô () được cập nhật liên tục qua các bước thời gian, và gradient được truyền ngược qua và , giúp duy trì thông tin dài hạn.
Nhưng mà, với một đầu ra xịn hơn so với RNN, thì nó cũng phải bắt buộc thực hiện nhiều phép tính toán phức tạp hơn và nhiều hơn mà cụ thể là trong việc tính gradient:
- Gradient của hàm mất mát được tính qua , sau đó lan truyền ngược qua , , , và .
- Các cổng sử dụng hàm sigmoid và tanh, nên gradient được tính bằng các áp dụng chain rule.
Nếu như bạn đã quên, thì công thức chain rule là:
Nếu:
Thì:
3.4 Cập nhật trọng số
LSTM cũng sử dụng các thuật toán tối ưu (Như SGD, Adam,…) Để cập nhật các tham số dựa trên Gradient:
Với là learning rate, là tham số.
Trong LSTM có nhiều tham số hơn do có các cổng và các ô, nên việc cập nhật cũng sẽ nhiều hơn:
- Trọng số và bias của cổng quên ()
- Trọng số và bias của cổng nhập ()
- Trọng số và bias của cổng xuất ()
Do số lượng tham số nhiều như vậy nên LSTM thường cần nhiều tài nguyên tính toán hơn RNN.
3.5 Gradient cụ thể cho từng tham số của LSTM
Ta sử dụng quy tắc chuỗi (Chain rule) để lan truyền ngược gradient của hàm mất mát qua các bước thời gian. Gradient của ảnh hưởng đối với từng tham số cho đến các cổng và trạng thái ô. Vì yêu cầu của từng bài toán khác nhau nên sẽ được tính tùy theo đề bài.
3.5.1 Gradient của trạng thái ẩn và trạng thái ô
Gradient của
Gradient của
Gradient của
3.5.2 Gradient của cổng xuất ()
Từ lúc này, nhiều chỗ mình sẽ để để cho gọn nhé
Gradient của
Lý do là vì ta có hàm sigmoid có công thức: mà
Mà
Gradient đối với
Mà:
Nên:
Kết quả sẽ là ma trận ( là số chiều của hidden state, là số chiều của input vector).
Gradient đối với :
Kết quả là vector
3.5.3 Gradient của trạng thái ô trước () và các cổng khác
Gradient của cổng quên ()
Gradient đối với
Gradient của cổng nhập () và giá trị ứng viên ()
Gradient đối với
Gradient đối với
####3.5.4 Gradient truyền ngược qua thời gian
Gradient của và
Mỗi cổng () đều phụ thuộc vào , nên gradient đối với là tổng đóng góp từ tất cả các cổng:
Trong đó là phần của ma trận tương ứng với . Gradient này sẽ được truyền ngược về các bước trước ().
3.6 Các điểm cần lưu ý khi huấn luyện LSTM
- Khởi tạo tham số: Trọng số và bias thường được khởi tạo ngẫu nhiên hoặc sử dụng kỹ thuật như Xavier initialization. Bias của cổng quên đôi khi được khởi tạo lớn hơn (ví dụ: 1.0) để khuyến khích giữ thông tin ban đầu.
- Gradient Clipping: Để tránh bùng nổ gradient, gradient thường được cắt nếu vượt quá một ngưỡng).
- Tính toán phức tạp: Do LSTM có nhiều tham số, việc tính gradient thủ công rất phức tạp. Trong thực tế, các framework như TensorFlow hoặc PyTorch tự động tính gradient bằng autograd.
- Chuẩn hóa gradient (Gradient Clipping): Để tránh bùng nổ gradient, gradient thường được cắt (clip) nếu vượt quá ngưỡng nhất định.
- Chuẩn bị dữ liệu:
- Dữ liệu chuỗi cần được xử lý thành các batch với độ dài cố định hoặc sử dụng padding cho các chuỗi có độ dài khác nhau.
- Với bài toán như NLP, đầu vào thường là các vector embedding (như Word2Vec, GloVe).
- Tối ưu hóa siêu tham số: Learning rate, số đơn vị LSTM, số lớp (layers), và kích thước batch cần được điều chỉnh cẩn thận để đạt hiệu quả tốt.
Đoạn kết
Vậy là chúng mình đã đi hết LSTM, một mô hình RNN cải tiến, tuy giải quyết khá tốt vấn đề bùng nổ gradient hoặc vanishing gradient, tuy nhiên do phải tính toán khá nhiều, nên là bài viết sau mình sẽ giới thiệu cho mọi người về Transformer.