Hi mọi người, mình là Hiếu, hiện đang nghiên cứu trí tuệ nhân tạo. Hiện tại mình đang đi theo hướng NLP và để khởi đầu cho series Deeplearning thì mình sẽ chia sẻ đôi chút về RNN nhé:
I. Khái niệm RNN
RNN (Recurrent Neural Network) là một loại mạng nơ-ron nhân tạo chuyên xử lý dữ liệu tuần tự (tức là dạng chuỗi) như văn bản, âm thanh hoặc dữ liệu theo thời gian. RNN có khả năng nhớ thông tin từ các bước trước đó thông qua trạng thái ẩn, giúp mô hình hiểu được ngữ cảnh và mối quan hệ giữa các phần tử trong chuỗi. Sơ đồ hoạt động của RNN được biểu thị như dưới đây:
- Với mỗi ô vuông xanh, ta có là các trạng thái ẩn (hidden state).
- Dữ liệu đầu vào () tác động lên trạng thái ẩn thông qua trọng số .
- Trạng thái ẩn cũng nhận thông tin từ chính nó từ bước trước thông qua trọng số (vòng lặp).
- Trạng thái ẩn sinh ra đầu ra thông qua trọng số .
- Unfold là quá trình mở rộng cấu trúc hồi quy theo thời gian.
Ta có thể thấy rằng mỗi bước thời gian có đầu vào , trạng thái ẩn , đầu ra . Trạng thái ẩn nhận thông tin từ đầu vào và trạng thái ẩn trước đó là rồi sinh ra đầu ra .
Đối với mỗi bước thời gian t, ta có các công thức:
Cập nhật trạng thái ẩn
Trong đó:
- : Đầu vào tại thời điểm .
- : Là trạng thái ẩn trước đó.
- : Là ma trận trọng số để nhân với vector đầu vào . Tham số này quyết định mức độ ảnh hưởng thông tin đầu vào hiện tại lên trạng thái ẩn mới.
- : Ma trận trọng số nhân với vector trạng thái ẩn trước đó . Quyết định mức độ “ghi nhớ” thông tin từ quá khứ (từ các bước ngay trước) vào trạng thái hiện tại.
- : Là vector bias được cộng vào sau khi thực hiện các phép nhân trọng số để tăng linh hoạt cho mô hình, bias giúp dịch chuyển hàm kích hoạt, tăng khả năng học.
là hàm kích hoạt (thường là tanh hoặc ReLU). Ở đây mình sử dụng hàm tanh, vì hàm tanh chấp nhận giá trị âm, trong khi hàm sigmoid chỉ lấy giá trị trong khoảng (0,1). Nên trạng thái ẩn sẽ là:
Hàm tanh có công thức:
- Đầu vào và đầu ra .
Lý do dùng hàm tanh:
- Để ép giá trị đầu ra về khoảng , giúp các giá trị trạng thái ẩn không bị tăng quá lớn (ổn định quá trình lan truyền tín hiệu qua nhiều bước).
- Nếu không có hàm kích hoạt phi tuyến thì mô hình chỉ giống như phép biến đổi tuyến tính (không học các quan hệ phức tạp). Hàm tanh giúp RNN học được các quan hệ phức tạp.
- Giảm nguy cơ bị Exploding/Vanishing Gradient khi lan truyền qua nhiều bước.
Tính đầu ra
- là hàm kích hoạt đầu ra (softmax cho bài toán phân loại, identity cho hồi quy).
- : Là ma trận trọng số để biến đổi trạng thái ẩn với đầu ra.
Ưu điểm của RNN
- Xử lý dữ liệu tuần tự (chuỗi) tốt: Có khả năng ghi nhớ thông tin từ các bước trước, phù hợp cho các bài toán liên quan đến chuỗi.
- Giúp giảm số lượng tham số đáng kể so với nơ-ron truyền thống: Mô hình tổng quát tốt hơn.
- RNN có thể xử lý các chuỗi có độ dài khác nhau: Không cố định như CNN hoặc MLP.
Nhược điểm của RNN
- Khó học các quan hệ dài hạn: Khi chuỗi quá dài, sẽ xảy ra hiện tượng vanishing gradient hoặc exploding gradient khiến các mô hình khó nhớ thông tin từ các bước xa về trước.
- Các bước tính toán đều là nối tiếp: Quá trình huấn luyện và dự đoán chậm hơn so với các mô hình khác (CNN, Transformer,…).
- Khó song song hóa trên GPU như các mạng khác, yêu cầu lượng dữ liệu đủ lớn.
Các hiện tượng có thể xảy ra trong RNN
- Exploding Gradient: Là hiện tượng xảy ra khi các giá trị Gradient trong quá trình huấn luyện mạng nơ-ron trở lên rất lớn (có thể tiến dần đến ) do trong lúc huấn luyện RNN, quá trình lan truyền ngược sẽ nhân nhiều lần các ma trận trọng số qua các bước thời gian. Hậu quả là mô hình không hội tụ được, cho kết quả rác.
- Vanishing Gradient: Ngược lại với Exploding Gradient, đây là hiện tượng xảy ra khi giá trị gradient trở lên rất nhỏ (gần bằng 0) trong quá trình huấn luyện. Điều này khiến cho gradient gần như bằng 0 ở đầu các lớp khiến cho mô hình không học được gì thêm.
2. Gradient, Loss và tối ưu cho mô hình RNN
Ở mục trên, chúng ta có nhắc đến lan truyền ngược, vậy lan truyền ngược là gì?
Trước khi có ngược thì ta phải có xuôi, hay gọi cách thông thường là lan truyền tiến. Cơ chế của lan truyền ngược và tiến là gì?
Lan truyền tiến (forward pass)
- Dữ liệu đầu vào được truyền qua từng lớp của mạng để đưa ra dự đoán.
- So sánh đầu ra dự đoán với giá trị thực để tính toán ra hàm mất mát (Loss Function).
Lan truyền ngược (Backpropagation Through Time)
- Tính Gradient của hàm mất mát với từng trọng số trong mạng (Dùng Chain Rule).
- Gradient cho biết cần thay đổi trọng số bao nhiêu để giảm lỗi.
- Trọng số được cập nhật theo hướng giảm Loss bởi các thuật toán tối ưu như SGD, Adam,…
Chain-Rule (Quy tắc chuỗi)
Nếu:
Thì:
Gradient trong RNN là công cụ giúp đánh giá và tối ưu hóa loss, bằng cách lan truyền ngược thông tin lỗi qua thời gian để cập nhật tham số của mạng.
Mất mát được tính trên dự đoán và giá trị thật tại mỗi bước thời gian :
- Loss là hàm tính độ mất mát ở từng bước, nhưng khi cập nhật tham số thì sẽ lấy tổng hoặc trung bình loss của cả chuỗi.
- Gradient tính tổng hợp của tất cả các bước, dùng để cập nhật tham số một lần sau mỗi chuỗi.
- Hàm loss sẽ tùy vào loại bài toán (Phân loại nhị phân thì có thể dùng Cross-Entropy Loss, hồi quy có thể sử dụng MSE).
- : là các tham số khác nhau, trong RNN nó có thể là: .
Để điều chỉnh các tham số, ta có thể sử dụng SGD (Stochastic Gradient Descent):
SGD là biến thể của Gradient Descent truyền thống, là thuật toán tối ưu để tìm giá trị của tham số tốt nhất (trọng số mô hình) bằng cách giảm dần hàm mất mát.
Công thức Gradient Descent:
Trong đó:
- : Vector tham số của mô hình (trọng số, bias).
- : Tốc độ học.
- : Gradient của loss theo tham số .
Đặt vào trong bài toán RNN, ta có các tham số sau:
Ta có thể dùng quy tắc chuỗi (Chain-Rule) để tính cho từng tham số ở trên:
- Gradient theo :
Với
- Gradient theo :
Với
(do )
- Gradient theo :
- Gradient theo :
Tóm lại, quy trình của RNN sẽ là:
- Forward pass: Cho dữ liệu đầu vào chạy qua RNN để sinh ra các dự đoán ở từng bước thời gian.
- Tính Loss: So sánh 2 giá trị dự đoán và thực tế tại từng bước, rồi cộng lại thành giá trị Loss.
- Backpropagation Through Time: Tính Gradient của Loss theo từng tham số của mạng (trọng số, bias) lan truyền ngược qua nhiều bước thời gian.
- Dùng các thuật toán tối ưu (Như SGD, Adam,...) để cập nhật tham số.
- Lặp lại đến khi nào Loss giảm về mức tối ưu chúng ta mong muốn.!
Kết luận
Vậy là chúng ta đã tìm hiểu xong cơ bản một chút toán từ RNN, bài viết tiếp theo mình sẽ viết về LSTM, một kỹ thuật có thể giải quyết nhược điểm của RNN. Cảm ơn mọi nguowfid dã dành thời gian đọc.