- vừa được xem lúc

RNN và những người bạn

0 0 1

Người đăng: Huey Anthony Disward

Theo Viblo Asia

Hi mọi người, mình là Hiếu, hiện đang nghiên cứu trí tuệ nhân tạo. Hiện tại mình đang đi theo hướng NLP và để khởi đầu cho series Deeplearning thì mình sẽ chia sẻ đôi chút về RNN nhé:

I. Khái niệm RNN

RNN (Recurrent Neural Network) là một loại mạng nơ-ron nhân tạo chuyên xử lý dữ liệu tuần tự (tức là dạng chuỗi) như văn bản, âm thanh hoặc dữ liệu theo thời gian. RNN có khả năng nhớ thông tin từ các bước trước đó thông qua trạng thái ẩn, giúp mô hình hiểu được ngữ cảnh và mối quan hệ giữa các phần tử trong chuỗi. Sơ đồ hoạt động của RNN được biểu thị như dưới đây:

  • Với mỗi ô vuông xanh, ta có hh là các trạng thái ẩn (hidden state).
  • Dữ liệu đầu vào (xx) tác động lên trạng thái ẩn thông qua trọng số WhxW_hx.
  • Trạng thái ẩn cũng nhận thông tin từ chính nó từ bước trước thông qua trọng số WhhW_hh (vòng lặp).
  • Trạng thái ẩn sinh ra đầu ra thông qua trọng số WyhW_yh.
  • Unfold là quá trình mở rộng cấu trúc hồi quy theo thời gian.

Ta có thể thấy rằng mỗi bước thời gian tt có đầu vào xtx_t, trạng thái ẩn hth_t, đầu ra yty_t. Trạng thái ẩn hth_t nhận thông tin từ đầu vào xtx_t và trạng thái ẩn trước đó là ht1h_{t-1} rồi sinh ra đầu ra yty_t.


Đối với mỗi bước thời gian t, ta có các công thức:

Cập nhật trạng thái ẩn

ht=f(Whxxt+Whhht1+bh)h_t = f(W_{hx} x_t + W_{hh} h_{t-1} + b_h)

Trong đó:

  • xtx_t: Đầu vào tại thời điểm tt.
  • ht1h_{t-1}: Là trạng thái ẩn trước đó.
  • WhxW_{hx}: Là ma trận trọng số để nhân với vector đầu vào xtx_t. Tham số này quyết định mức độ ảnh hưởng thông tin đầu vào hiện tại lên trạng thái ẩn mới.
  • WhhW_{hh}: Ma trận trọng số nhân với vector trạng thái ẩn trước đó ht1h_{t-1}. Quyết định mức độ “ghi nhớ” thông tin từ quá khứ (từ các bước ngay trước) vào trạng thái hiện tại.
  • bhb_h: Là vector bias được cộng vào sau khi thực hiện các phép nhân trọng số để tăng linh hoạt cho mô hình, bias giúp dịch chuyển hàm kích hoạt, tăng khả năng học.

ff là hàm kích hoạt (thường là tanh hoặc ReLU). Ở đây mình sử dụng hàm tanh, vì hàm tanh chấp nhận giá trị âm, trong khi hàm sigmoid chỉ lấy giá trị trong khoảng (0,1). Nên trạng thái ẩn sẽ là:

ht=tanh(Whxxt+Whhht1)h_t = \tanh(W_{hx} x_t + W_{hh} h_{t-1})

Hàm tanh có công thức:

tanh(z)=ezezez+ez\tanh(z) = \frac{e^z - e^{-z}}{e^z + e^{-z}}

  • Đầu vào z(,+)z \in (-\infty, +\infty) và đầu ra tanh(z)[1,1]\tanh(z) \in [-1, 1].

Lý do dùng hàm tanh:

  • Để ép giá trị đầu ra về khoảng [1,1][-1, 1], giúp các giá trị trạng thái ẩn hth_t không bị tăng quá lớn (ổn định quá trình lan truyền tín hiệu qua nhiều bước).
  • Nếu không có hàm kích hoạt phi tuyến thì mô hình chỉ giống như phép biến đổi tuyến tính (không học các quan hệ phức tạp). Hàm tanh giúp RNN học được các quan hệ phức tạp.
  • Giảm nguy cơ bị Exploding/Vanishing Gradient khi lan truyền qua nhiều bước.

Tính đầu ra

yt=g(Wyhht)y_t = g(W_{yh} h_t)

  • gg là hàm kích hoạt đầu ra (softmax cho bài toán phân loại, identity cho hồi quy).
  • WyhW_{yh}: Là ma trận trọng số để biến đổi trạng thái ẩn với đầu ra.

Ưu điểm của RNN

  • Xử lý dữ liệu tuần tự (chuỗi) tốt: Có khả năng ghi nhớ thông tin từ các bước trước, phù hợp cho các bài toán liên quan đến chuỗi.
  • Giúp giảm số lượng tham số đáng kể so với nơ-ron truyền thống: Mô hình tổng quát tốt hơn.
  • RNN có thể xử lý các chuỗi có độ dài khác nhau: Không cố định như CNN hoặc MLP.

Nhược điểm của RNN

  • Khó học các quan hệ dài hạn: Khi chuỗi quá dài, sẽ xảy ra hiện tượng vanishing gradient hoặc exploding gradient khiến các mô hình khó nhớ thông tin từ các bước xa về trước.
  • Các bước tính toán đều là nối tiếp: Quá trình huấn luyện và dự đoán chậm hơn so với các mô hình khác (CNN, Transformer,…).
  • Khó song song hóa trên GPU như các mạng khác, yêu cầu lượng dữ liệu đủ lớn.

Các hiện tượng có thể xảy ra trong RNN

  • Exploding Gradient: Là hiện tượng xảy ra khi các giá trị Gradient trong quá trình huấn luyện mạng nơ-ron trở lên rất lớn (có thể tiến dần đến \infty) do trong lúc huấn luyện RNN, quá trình lan truyền ngược sẽ nhân nhiều lần các ma trận trọng số qua các bước thời gian. Hậu quả là mô hình không hội tụ được, cho kết quả rác.
  • Vanishing Gradient: Ngược lại với Exploding Gradient, đây là hiện tượng xảy ra khi giá trị gradient trở lên rất nhỏ (gần bằng 0) trong quá trình huấn luyện. Điều này khiến cho gradient gần như bằng 0 ở đầu các lớp khiến cho mô hình không học được gì thêm.

2. Gradient, Loss và tối ưu cho mô hình RNN

Ở mục trên, chúng ta có nhắc đến lan truyền ngược, vậy lan truyền ngược là gì?
Trước khi có ngược thì ta phải có xuôi, hay gọi cách thông thường là lan truyền tiến. Cơ chế của lan truyền ngược và tiến là gì?

Lan truyền tiến (forward pass)

  • Dữ liệu đầu vào được truyền qua từng lớp của mạng để đưa ra dự đoán.
  • So sánh đầu ra dự đoán với giá trị thực để tính toán ra hàm mất mát (Loss Function).

Lan truyền ngược (Backpropagation Through Time)

  • Tính Gradient của hàm mất mát với từng trọng số trong mạng (Dùng Chain Rule).
  • Gradient cho biết cần thay đổi trọng số bao nhiêu để giảm lỗi.
  • Trọng số được cập nhật theo hướng giảm Loss bởi các thuật toán tối ưu như SGD, Adam,…

Chain-Rule (Quy tắc chuỗi)

Nếu:

  • y=f(u)y = f(u)
  • u=g(x)u = g(x)

Thì:

dydx=dydududx\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx}

Gradient trong RNN là công cụ giúp đánh giá và tối ưu hóa loss, bằng cách lan truyền ngược thông tin lỗi qua thời gian để cập nhật tham số của mạng.

Mất mát được tính trên dự đoán y^t\hat{y}_t và giá trị thật yty_t tại mỗi bước thời gian tt:

L(θ)=t=1Tloss(y^t,yt)L(\theta) = \sum_{t=1}^T loss(\hat{y}_t, y_t)

  • Loss là hàm tính độ mất mát ở từng bước, nhưng khi cập nhật tham số thì sẽ lấy tổng hoặc trung bình loss của cả chuỗi.
  • Gradient tính tổng hợp của tất cả các bước, dùng để cập nhật tham số một lần sau mỗi chuỗi.
  • Hàm loss sẽ tùy vào loại bài toán (Phân loại nhị phân thì có thể dùng Cross-Entropy Loss, hồi quy có thể sử dụng MSE).
  • θ\theta: là các tham số khác nhau, trong RNN nó có thể là: Whh,Whx,Why,bhW_{hh}, W_{hx}, W_{hy}, b_h.

Để điều chỉnh các tham số, ta có thể sử dụng SGD (Stochastic Gradient Descent):

SGD là biến thể của Gradient Descent truyền thống, là thuật toán tối ưu để tìm giá trị của tham số tốt nhất (trọng số mô hình) bằng cách giảm dần hàm mất mát.

Công thức Gradient Descent:

θθηθL(θ)\theta \leftarrow \theta - \eta \nabla_\theta L(\theta)

Trong đó:

  • θ\theta: Vector tham số của mô hình (trọng số, bias).
  • η\eta: Tốc độ học.
  • θL(θ)\nabla_\theta L(\theta): Gradient của loss theo tham số θ\theta.

Đặt vào trong bài toán RNN, ta có các tham số sau: Whh,Whx,Why,bhW_{hh}, W_{hx}, W_{hy}, b_h

Ta có thể dùng quy tắc chuỗi (Chain-Rule) để tính θL(θ)\nabla_\theta L(\theta) cho từng tham số ở trên:

  • Gradient theo WhyW_{hy}:

LWhy=t=1TLy^ty^tWhy=t=1TδtyhtT\frac{\partial L}{\partial W_{hy}} = \sum_{t=1}^T \frac{\partial L}{\partial \hat{y}_t} \cdot \frac{\partial \hat{y}_t}{\partial W_{hy}} = \sum_{t=1}^T \delta_t^y h_t^T

Với δty=l(y^t,yt)y^t\delta_t^y = \frac{\partial l(\hat{y}_t, y_t)}{\partial \hat{y}_t}

  • Gradient theo WhxW_{hx}:

LWhx=t=1TδthxtT\frac{\partial L}{\partial W_{hx}} = \sum_{t=1}^T \delta_t^h x_t^T

Với δth=Lht(1ht2)\delta_t^h = \frac{\partial L}{\partial h_t} \odot (1 - h_t^2)
(do tanh(z)=1tanh2(z)\tanh'(z) = 1 - \tanh^2(z))

  • Gradient theo WhhW_{hh}:

LWhh=t=1Tδthht1T\frac{\partial L}{\partial W_{hh}} = \sum_{t=1}^T \delta_t^h h_{t-1}^T

  • Gradient theo byb_y:

Lby=t=1Tδty\frac{\partial L}{\partial b_y} = \sum_{t=1}^T \delta_t^y


Tóm lại, quy trình của RNN sẽ là:

  1. Forward pass: Cho dữ liệu đầu vào chạy qua RNN để sinh ra các dự đoán ở từng bước thời gian.
  2. Tính Loss: So sánh 2 giá trị dự đoán và thực tế tại từng bước, rồi cộng lại thành giá trị Loss.
  3. Backpropagation Through Time: Tính Gradient của Loss theo từng tham số của mạng (trọng số, bias) lan truyền ngược qua nhiều bước thời gian.
  4. Dùng các thuật toán tối ưu (Như SGD, Adam,...) để cập nhật tham số.
  5. Lặp lại đến khi nào Loss giảm về mức tối ưu chúng ta mong muốn.!

Kết luận

Vậy là chúng ta đã tìm hiểu xong cơ bản một chút toán từ RNN, bài viết tiếp theo mình sẽ viết về LSTM, một kỹ thuật có thể giải quyết nhược điểm của RNN. Cảm ơn mọi nguowfid dã dành thời gian đọc.

Bình luận

Bài viết tương tự

- vừa được xem lúc

Tìm hiểu về YOLO trong bài toán real-time object detection

1.Yolo là gì. . Họ các mô hình RCNN ( Region-Based Convolutional Neural Networks) để giải quyết các bài toán về định vị và nhận diện vật thể.

0 0 298

- vừa được xem lúc

Xây dựng hệ thống Real-time Multi-person Tracking với YOLOv3 và DeepSORT

Trong bài này chúng ta sẽ xây dựng một hệ thống sử dụng YOLOv3 kết hợp với DeepSORT để tracking được các đối tượng trên camera, YOLO là một thuật toán deep learning ra đời vào tháng 5 năm 2016 và nó nhanh chóng trở nên phổ biến vì nó quá nhanh so với thuật toán deep learning trước đó, sử dụng YOLO t

0 0 323

- vừa được xem lúc

Sơ lược về bài toán Person Re-identification

Với những công nghệ hiện đại của thế kỷ 21 chúng ta đã có những phần cứng cũng như phần mềm mạnh mẽ để giải quyết những vấn đề và bài toán nan giải như face recognition, object detection, NLP,... Một trong những vấn đề nan giải cũng được mọi người chú ý ko kém những chủ đề trên là Object Tracking, v

0 0 75

- vừa được xem lúc

Bacteria classification bằng thư viện fastai

Giới thiệu. fastai là 1 thư viện deep learning hiện đại, cung cấp API bậc cao để giúp các lập trình viên AI cài đặt các mô hình deep learning cho các bài toán như classification, segmentation... và nhanh chóng đạt được kết quả tốt chỉ bằng vài dòng code. Bên cạnh đó, nhờ được phát triển trên nền tản

0 0 43

- vừa được xem lúc

Xác định ý định câu hỏi trong hệ thống hỏi đáp

Mục tiêu bài viết. Phân tích câu hỏi là pha đầu tiên trong kiến trúc chung của một hệ thống hỏi đáp, có nhiệm vụ tìm ra các thông tin cần thiết làm đầu vào cho quá trình xử lý của các pha sau (trích c

0 0 100

- vừa được xem lúc

Epoch, Batch size và Iterations

Khi mới học Machine Learning và sau này là Deep Learning chúng ta gặp phải các khái niệm như Epoch, Batch size và Iterations. Để khỏi nhầm lẫn mình xin chia sẻ với các bạn sự khác nhau giữa các khái n

0 0 55