Ở bài viết trước chúng ta đã cùng tìm hiểu qua các kỹ thuật tiền xử lý dữ liệu văn bản. Với đầu ra input embeddings thu được, chúng ta sẽ tiếp tục dành toàn bộ bài viết để hiểu cơ chế Attention, thứ được coi như trái tim của một mô hình ngôn ngữ.
File Jupyter NoteBook của bài viết này được thể hiển ở đây
1. Vấn đề khi dịch thuật và mạng nơ-ron RNN
Ta đang phát triển một mô hình ngôn ngữ lớn để dịch thuật, nếu dịch từng từ một theo thứ tự như ảnh dưới đây thì chắc chắn câu văn dịch ra sẽ rất rời rạc, khó hiểu.
Do ngữ pháp của các ngôn ngữ khác nhau, thứ tự các từ sau khi dịch có thể khác nhiều so với câu gốc.
Để giải quyết vấn đề trên, cách tiếp cận phổ biến là là mạng nơ-ron với hai mô-đun: encoder và decoder. Nhiệm vụ của encoder là đọc và xử lý toàn bộ văn bản đầu vào trước, sau đó decoder tạo ra văn bản đã được dịch
Trước khi kiến trúc Transformer ra đời, RNN (Recurrent Neural Networks) là phương pháp phổ biến nhất áp dụng cho các tác vụ dịch thuật.
RNN là một mạng nơ-ron trong đó đầu ra từ các bước trước được đưa vào làm đầu vào cho bước hiện tại, khiến chúng rất phù hợp với dữ liệu tuần tự như văn bản. Trạng thái tại mỗi bước xử lý của RNN được gọi là hidden state.
Tuy nhiên, RNN thường gặp khó khăn với các chuỗi dài, khi chúng có thể bị "quên" đi các trạng thái cách xa trước đó. Các xử lý tuần tự cũng khiến cho RNN chậm hơn so với các phương pháp xử lý song song như Transformer.
2. Attention
Năm 2014, một nghiên cứu có tên Bahdanau attention được công bố, với nội dung chính là đề xuất một phương pháp chỉnh sửa encoder-decoder của RNN sao cho decoder có thể truy cập được vào tất cả các token của văn bản đầu vào. Dựa vào trọng số (attention weight) được tính toán, mô hình sẽ chọn ra đâu là từ tiếp theo phù hợp để dịch và đưa vào kết quả.
Attention (tạm dịch là chú ý) là một cơ chế giúp mô hình tập trung vào những phần quan trọng nhất của dữ liệu đầu vào, thay vì xử lý tất cả thông tin một cách đồng đều.
Hình trên mô tả bước thứ 2 trong tác vụ dịch, trọng số lớn nhất mà mô hình tính toán được ở bước này nằm ở từ du => Từ tiếp theo được trả về sẽ lẽ you (từ tương ứng của du trong tiếng Anh).
Chúng ta sẽ không đi chi tiết về cách cơ chế Bahdanau attention hoạt động như thế nào ở bài viết này, sẽ khiến bài viết trở nên rất dài.
Ba năm sau, vào năm 2017. Nhóm nghiên cứu của Google đã đề xuất kiến trúc Transformer sử dụng cơ chế Self-Attention lấy cảm hứng từ cơ chế Bahdanau attention. Từ đây, RNN đã dần dần đi vào dĩ vãng trong các tác vụ dịch thuật nói riêng và xử lý ngôn ngữ tự nhiên nói chung.
Self-attention là cơ chế giúp mô hình hiểu mối quan hệ giữa các từ trong cùng một câu, bất kể khoảng cách giữa chúng. Đây là thành phần cốt lõi của Transformer, LLMs giúp nó vượt trội so với RNN.
Chúng ta sẽ lần lượt nghiên cứu các phương pháp khác nhau của cơ chế Self-Attention từ đơn giản tới phức tạp
- Self-attention đơn giản: Loại đơn giản nhất với các trọng số cố định => không huấn luyện được
- Self-attention phiên bản có trọng số có thể huấn luyện được (trainable weights)
- Causal attention (hay masked attention): Đảm bảo mô hình chỉ "nhìn thấy" các token trước đó trong văn bản.
- Multi-head attention: Phương pháp được ứng dựng thực tiễn trong các mô hình ngôn ngữ lớn, gồm nhiều bước xử lý song song.
Vậy tại sao gọi là self-attention, khác gì với attention ?
- Trong Self-Attention, từ
self
thể hiện rằng việc tính toán trọng số là giữa các phần tử trong cùng một chuỗi đầu vào (input sequence). - Attention truyền thống như phương pháp Bahdanau attention thì tính toán trọng số giữa hai chuỗi với nhau (đầu vào và đầu ra)
3. Self-Attention đơn giản (simple self-attention)
Trọng số attention trong phương pháp này được tính trực tiếp từ dữ liệu đầu vào. Cùng xem ví dụ minh họa dưới đây.
Với văn bản “Your journey starts with one step.”
đã được embeddings hóa, ta cần tính các vector các z
thể hiện mối quan hệ giữa các từ trong câu với nhau.
Trên hình là ví dụ trong việc tính mối quan hệ giữa từ journey
với các từ còn lại trong câu. Để tính được vector z
từ các token embeddings sẽ cần kết hợp với các trọng số attention α
.
Để tính được trọng số attention α
, ta cần qua một bước trung gian tính toán giá trị attention score (ω
).
Attention score
Để tính được các giá trị attention score với từ journey
, ta lấy vector nhân vô hướng với các vector còn lại. vector trong trường hợp này sẽ được gọi với cái tên embedded query token.
Biểu diễn với python
import torch inputs = torch.tensor( [[0.43, 0.15, 0.89], # Your (x^1) [0.55, 0.87, 0.66], # journey (x^2) [0.57, 0.85, 0.64], # starts (x^3) [0.22, 0.58, 0.33], # with (x^4) [0.77, 0.25, 0.10], # one (x^5) [0.05, 0.80, 0.55]] # step (x^6)
) query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs): attn_scores_2[i] = torch.dot(x_i, query) # dot là phép vô hướng giữa 2 vector
print(attn_scores_2) # Kết quả: tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
Chuẩn hóa
Việc chuẩn hóa Attention Scores là một bước quan trọng trong cơ chế Self-Attention.
Tại sao lại cần chuẩn hóa attention scores ?
- Không có ý nghĩa xác suất: Các giá trị Attention Scores không được chuẩn hóa không có ý nghĩa xác suất, trong khi các trọng số attention cần biểu diễn mức độ "tập trung" của mỗi phần tử vào các phần tử khác.
- Không ổn định: Các giá trị attention scores có thể là những số rất lớn hoặc rất bé. Nếu dùng các giá trị này làm trọng số có thể xảy ra tình trạng thiên lệch khi các số lớn chi phối và làm cho các giá trị nhỏ trở nên không đáng kể.
Phương pháp phổ biến để chuẩn hóa được sử dụng là hàm softmax, các kết quả sẽ được điều chỉnh lại sao cho tổng của tất cả chúng bằng 1
.
import torch inputs = torch.tensor( [[0.43, 0.15, 0.89], # Your (x^1) [0.55, 0.87, 0.66], # journey (x^2) [0.57, 0.85, 0.64], # starts (x^3) [0.22, 0.58, 0.33], # with (x^4) [0.77, 0.25, 0.10], # one (x^5) [0.05, 0.80, 0.55]] # step (x^6)
) query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs): attn_scores_2[i] = torch.dot(x_i, query) attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2) # Kết quả: Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
# Sum: tensor(1.)
Context vector z
Tính context vector bằng cách lấy tổng của các tích trọng số attention và token embeddings.
import torch inputs = torch.tensor( [[0.43, 0.15, 0.89], # Your (x^1) [0.55, 0.87, 0.66], # journey (x^2) [0.57, 0.85, 0.64], # starts (x^3) [0.22, 0.58, 0.33], # with (x^4) [0.77, 0.25, 0.10], # one (x^5) [0.05, 0.80, 0.55]] # step (x^6)
) query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs): attn_scores_2[i] = torch.dot(x_i, query) attn_weights_2 = torch.softmax(attn_scores_2, dim=0) context_vec_2 = torch.zeros(query.shape) # context_vec_2 có số chiều bằng vector x(2)
for i,x_i in enumerate(inputs): context_vec_2 += attn_weights_2[i]*x_i
print(context_vec_2)
# Kết quả: tensor([0.4419, 0.6515, 0.5683])
Công thức tổng quát
Ví dụ ở trên chúng ta chỉ đang tính context vector z cho trường hợp token thứ 2 là journey. Để tính toàn bộ giá trị của các vector z, ta thực hiện các bước như sau:
-
- Tính attention scores bằng cách lấy ma trận token embeddings nhân cho ma trận chuyển vị (transpose) của nó.
-
- Chuẩn hóa attention scores thành trọng số attention
-
- Tính context vector z bằng cách lấy trọng số attention nhân cho token embeddings
Lưu ý: Việc tính attention score mà chúng ta thực hiện ở các mục trên bản chất là nhân dòng thứ 2 của ma trận inputs
lần lượt với các cột của ma trận inputs^T
(ma trận chuyển vị của inputs)
Dưới đấy là minh họa bằng Python:
import torch
import torch.nn.functional as F inputs = torch.tensor( [[0.43, 0.15, 0.89], # Your (x^1) [0.55, 0.87, 0.66], # journey (x^2) [0.57, 0.85, 0.64], # starts (x^3) [0.22, 0.58, 0.33], # with (x^4) [0.77, 0.25, 0.10], # one (x^5) [0.05, 0.80, 0.55]] # step (x^6)
) # Bước 1: Tính toán Attention Scores
# Attention Scores = inputs * inputs^T
attention_scores = torch.matmul(inputs, inputs.T) # (6, 6) # Bước 2: Chuẩn hóa Attention Scores bằng hàm softmax
attention_weights = F.softmax(attention_scores, dim=-1) # (6, 6) # Bước 3: Tính toán Context Vector z
# z = attention_weights * inputs
context_vector = torch.matmul(attention_weights, inputs) # (6, 3) # In kết quả
print("\nContext Vector z:\n", context_vector) """Kết quả Context Vector z: tensor([[0.4421, 0.5931, 0.5790], [0.4419, 0.6515, 0.5683], [0.4431, 0.6496, 0.5671], [0.4304, 0.6298, 0.5510], [0.4671, 0.5910, 0.5266], [0.4177, 0.6503, 0.5645]])
"""
4. Self-Attention với trọng số có thể huấn luyện được (self-attention with trainable weights)
Hay có tên gọi khác là scaled dot-product attention.
- Với trường hợp self-attention đơn giản, các trọng số được tính toán theo một công thức toán học nhất định => không thể điều chỉnh hay thay đổi.
- Trọng số có thể huấn luyện được là trọng số có thể thay đổi, tối ưu hoá.
Tính toán trọng số attention
Mục này chúng ta tiếp tục minh họa việc tính toán context vector z
của từ đứng thứ 2 trong câu là journey
.
- Ta có , , lần lượt là ba ma trận.
- Sử dụng 3 ma trận đã cho trong quá trình tính toán trọng số attention.
- Thông số trong các ma trận có thể được điều chỉnh qua quá trình huấn luyện.
Bước 1: Tính các vector query, key, value (q, k, v)
Ta có các công thức là các phép nhân ma trận sau đây:
với i = 2 (chỉ tính ở trường hợp này)
Tổng quát hơn, ta có công thức tính vector k
, v
import torch
import torch.nn.functional as F # Đầu vào (inputs)
inputs = torch.tensor( [[0.43, 0.15, 0.89], # Your (x^1) [0.55, 0.87, 0.66], # journey (x^2) [0.57, 0.85, 0.64], # starts (x^3) [0.22, 0.58, 0.33], # with (x^4) [0.77, 0.25, 0.10], # one (x^5) [0.05, 0.80, 0.55]] # step (x^6)
) x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2 torch.manual_seed(123) """Để requires_grad=False
có nghĩa rằng các ma trận này khởi tạo ngẫu nhiên
mà không qua quá trình huấn luyện
ví mục đích của chúng ta chỉ là tính toán""" W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) # W_q
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) # W_k
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) # W_v query_2 = x_2 @ W_query
keys = inputs @ W_key
values = inputs @ W_value print(query_2)
print(keys)
print(values)
# Kết quả
"""
tensor([0.4306, 1.4551])
tensor([[0.3669, 0.7646], [0.4433, 1.1419], [0.4361, 1.1156], [0.2408, 0.6706], [0.1827, 0.3292], [0.3275, 0.9642]])
tensor([[0.1855, 0.8812], [0.3951, 1.0037], [0.3879, 0.9831], [0.2393, 0.5493], [0.1492, 0.3346], [0.3221, 0.7863]])
"""
Bước 2: Tính attention score
Sau khi đã thu được các vector q, k, v
, chúng ta đến bước tiếp theo là tính attention score
Tính cả vector attention score: Với là chuyển vị của k
# ... attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)
# Kết quả: tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])
Bước 3: Tính trọng số attention
Chuẩn hóa attention score với hàm softmax
, ta thu được trọng số attention.
# ...
d_k = keys.shape[1] # số chiều của vector k
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)
# Kết quả: tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])
Bước 4: Tính context vector z
Cuối cùng chúng ta thực hiện phép nhân ma trận trọng số attention
với ma trận v
# ...
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)
# Kết quả: tensor([0.3061, 0.8210])
Tổng quát
Sau khi đã tính được context vector .Chúng ta cùng xem các bước tổng quát để tính cả context vector z
Bước 1. Với X là vector input embeddings và là các hệ số tự do (bias)
,
,
Bước 2.
Bước 3: với là số chiều của k (ví dụ trên thì k = 2)
Bước 4:
Bias thường được bỏ qua để đơn giản hóa công thức, nhưng trong triển khai thực tế của các mô hình, bias thường được sử dụng để tăng khả năng biểu diễn và linh hoạt của mô hình.
Tại sao lại là Query, Key và Value ?
Các thuật ngữ key, query và value trong cơ chế Attention được mượn từ những cái tên tương ứng trong bộ môn cơ sở dữ liệu.
- Query hàm ý như một truy vấn tìm kiếm trong cơ sở dữ liệu. Với một từ hoặc token trong câu mà mô hình đang tập trung phân tích, cần phải "truy vấn" dữ liệu về mức độ liện quan của nó với các từ còn lại.
- Key giống như một khóa trong cơ sở dữ liệu được sử dụng để lập chỉ mục và tìm kiếm. Trong cơ chế attention, mỗi mục trong chuỗi đầu vào (ví dụ: mỗi từ trong câu) cần có 1 key để có thể truy vấn. Do đó, tính atention scores dùng Q và T.
- Value như là dữ liệu bản ghi trong cơ sở dữ liệu. Thực hiện câu truy vấn với một giá trị key, nó sẽ trả về các giá trị tương ứng.
5. Causal attention
Causal Attention (hay còn gọi là Masked Attention) là một kỹ thuật được nhằm che dấu các thông tin phía sau khiến mô hình chỉ nắm được các dữ liệu ở phía trước.
Nói đơn giản hơn, điều này đảm bảo rằng việc dự đoán từ tiếp theo chỉ nên phụ thuộc vào các từ đứng trước nó.
Triển khai Causal attention
Cách 1
- Bước 1: Tính trọng số attention
- Bước 2: Tạo một ma trận tam giác dưới có số chiều bằng số chiều của ma trận trọng số
- Bước 3: Nhân ma trận trọng số với ma trận tam giác dưới
- Bước 4: Chuẩn hóa thêm một lần nữa
import torch
import torch.nn.functional as F inputs = torch.tensor( [[0.43, 0.15, 0.89], # Your (x^1) [0.55, 0.87, 0.66], # journey (x^2) [0.57, 0.85, 0.64], # starts (x^3) [0.22, 0.58, 0.33], # with (x^4) [0.77, 0.25, 0.10], # one (x^5) [0.05, 0.80, 0.55]] # step (x^6)
) x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2 torch.manual_seed(123) W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) # W_q
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) # W_k
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) # W_v queries = inputs @ W_query
keys = inputs @ W_key
values = inputs @ W_value attn_scores = queries @ keys.T # Chuẩn hóa attention scores
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) # Số chiều của ma trận trọng số (context_length = 6)
context_length = attn_scores.shape[0] # Tạo ma trận tam giác dưới kích thước 6x6
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple) # Nhân 2 ma trận `mask_simple` và `attention_weights` với nhau
masked_simple = attn_weights * mask_simple
print(masked_simple) # Chuẩn hoá 1 lần nữa
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)
Kết quả in ra:
Cách 2
- Bước 1: Thay thay các giá trị nằm trên đường chéo chính của ma trận attention scores thành
-inf
- Bước 2: Chuẩn hóa với softmax cho ra trọng số attention
So với cách 1, cách này sử dụng ít phép toán hơn, không cần thêm bước chuẩn hóa không cần thiết. Khác biệt cơ bản giữa 2 cách là việc thay thế các giá trị trên đường chéo chính thành -inf
.
import torch
import torch.nn.functional as F inputs = torch.tensor( [[0.43, 0.15, 0.89], # Your (x^1) [0.55, 0.87, 0.66], # journey (x^2) [0.57, 0.85, 0.64], # starts (x^3) [0.22, 0.58, 0.33], # with (x^4) [0.77, 0.25, 0.10], # one (x^5) [0.05, 0.80, 0.55]] # step (x^6)
) x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2 torch.manual_seed(123) W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) # W_q
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) # W_k
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) # W_v queries = inputs @ W_query
keys = inputs @ W_key
values = inputs @ W_value attn_scores = queries @ keys.T
print(attn_scores)
print("\n")
context_length = attn_scores.shape[0]
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)
print("\n") attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)
Dropout
Ngoài việc che giấu thông tin bằng casual attention. Trong thực tế, người ta còn áp dụng thêm kỹ thuật Dropout nhằm loại bỏ ngẫu nhiên một phần tham số của mô hình.
Cần lưu ý rằng dropout chỉ được sử dụng trong quá trình huấn luyện.
Hình dưới đây minh họa kỹ thuật dropout với tỷ lệ 50%
Lưu ý: -Khi áp dụng dropout, các giá trị không bị loại bỏ sẽ được bù lại cách nhân với một hệ số có giá trị 1 / (1 - dropout_rate)
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) # dropout 50%
example = torch.ones(6, 6) # Ma trận đơn vị 6x6 với các phần tử là 1
print(example)
print(dropout(example))
Kết quả
tensor([[1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.]]) tensor([[2., 2., 0., 2., 2., 0.], [0., 0., 0., 2., 0., 2.], [2., 2., 2., 2., 0., 2.], [0., 2., 2., 0., 0., 2.], [0., 2., 0., 2., 0., 2.], [0., 2., 2., 2., 2., 0.]])
6. Multi-head Attention
Sự khác biệt cơ bản giữa single-head attention và multi-head attention là ở số lượng ma trận , , .
Dưới đây là hình minh họa cho mô-đun multi-head attention với 2 single-head attention
import torch
import torch.nn as nn inputs = torch.tensor( [[0.43, 0.15, 0.89], # Your (x^1) [0.55, 0.87, 0.66], # journey (x^2) [0.57, 0.85, 0.64], # starts (x^3) [0.22, 0.58, 0.33], # with (x^4) [0.77, 0.25, 0.10], # one (x^5) [0.05, 0.80, 0.55]] # step (x^6)
) """
- d_in: Kích thước đầu vào của mỗi token (ở đây bằng 3).
- d_out: Kích thước đầu ra sau khi biến đổi (ở đây bằng 2).
- context_length: Độ dài tối đa của chuỗi (ở đây bằng 6).
- dropout: Tỷ lệ dropout
"""
class CausalAttention(nn.Module): def __init__(self, d_in, d_out, context_length, dropout): super().__init__() self.d_out = d_out self.W_query = nn.Linear(d_in, d_out, bias=False) self.W_key = nn.Linear(d_in, d_out, bias=False) self.W_value = nn.Linear(d_in, d_out, bias=False) self.dropout = nn.Dropout(dropout) self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # Tính trọng số attention def forward(self, x): b, num_tokens, d_in = x.shape keys = self.W_key(x) queries = self.W_query(x) values = self.W_value(x) attn_scores = queries @ keys.transpose(1, 2) attn_scores.masked_fill_( self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) attn_weights = torch.softmax( attn_scores / keys.shape[-1]**0.5, dim=-1 ) attn_weights = self.dropout(attn_weights) context_vec = attn_weights @ values return context_vec # num_heads: số lượng head
class MultiHeadAttentionWrapper(nn.Module): def __init__(self, d_in, d_out, context_length, dropout, num_heads): super().__init__() # Lặp qua từng head và tính trọng số self.heads = nn.ModuleList( [CausalAttention(d_in, d_out, context_length, dropout) for _ in range(num_heads)] ) # Ghép các context vectors từ các head lại với nhau def forward(self, x): return torch.cat([head(x) for head in self.heads], dim=-1) torch.manual_seed(123) batch = torch.stack((inputs, inputs), dim=0)
context_length = batch.shape[1]
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper( d_in, d_out, context_length, 0.0, num_heads=2
) context_vecs = mha(batch) print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)
Kết quả in ra:
tensor([[[-0.4519, 0.2216, 0.4772, 0.1063], [-0.5874, 0.0058, 0.5891, 0.3257], [-0.6300, -0.0632, 0.6202, 0.3860], [-0.5675, -0.0843, 0.5478, 0.3589], [-0.5526, -0.0981, 0.5321, 0.3428], [-0.5299, -0.1081, 0.5077, 0.3493]], [[-0.4519, 0.2216, 0.4772, 0.1063], [-0.5874, 0.0058, 0.5891, 0.3257], [-0.6300, -0.0632, 0.6202, 0.3860], [-0.5675, -0.0843, 0.5478, 0.3589], [-0.5526, -0.0981, 0.5321, 0.3428], [-0.5299, -0.1081, 0.5077, 0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])
Multi-head attention with weight splits
Thay vì nhiều như ở trên, chúng ta tạo ra các 1 bộ ma trận duy nhất rồi sau đó chia chúng thành các ma trận riêng lẻ cho từng head attention.
Mô tả các tính ma trận Q, ma trận K, V cũng sẽ tương tự
class MultiHeadAttention(nn.Module): def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False): super().__init__() # Đảm bảo d_out chia hết cho số lượng heads assert (d_out % num_heads == 0), \ "d_out must be divisible by num_heads" self.d_out = d_out self.num_heads = num_heads self.head_dim = d_out // num_heads # Kích thước cho mỗi head # Các ma trận W self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) # Lớp để gộp kết quả từ các head lại self.out_proj = nn.Linear(d_out, d_out) self.dropout = nn.Dropout(dropout) # causal mask self.register_buffer( "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1) ) def forward(self, x): b, num_tokens, d_in = x.shape # batch size, số token, kích thước đầu vào # Biến đổi input thành các tensor query, key, value keys = self.W_key(x) # Kích thước (shape): (b, num_tokens, d_out) queries = self.W_query(x) values = self.W_value(x) # Chia nhỏ thành các head: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) print("Split Keys: ", keys) values = values.view(b, num_tokens, self.num_heads, self.head_dim) queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) # Đưa chiều num_heads lên trước: -> (b, num_heads, num_tokens, head_dim) keys = keys.transpose(1, 2) print("Transpose Keys: ", keys) queries = queries.transpose(1, 2) values = values.transpose(1, 2) # Tính attention scores bằng cách nhân Q với K^T attn_scores = queries @ keys.transpose(2, 3) # (b, num_heads, num_tokens, num_tokens) # Áp dụng causal attention mask_bool = self.mask.bool()[:num_tokens, :num_tokens] attn_scores.masked_fill_(mask_bool, -torch.inf) # Gán -inf vào các vị trí bị che # Chuẩn hóa attention scores bằng softmax attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) # Áp dụng dropout attn_weights = self.dropout(attn_weights) # Tính toán context vector: (b, num_heads, num_tokens, head_dim) context_vec = (attn_weights @ values).transpose(1, 2) # Đổi lại về (b, num_tokens, num_heads, head_dim) # Gộp các head lại: (b, num_tokens, num_heads * head_dim = d_out) context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) # Dự đoán đầu ra cuối cùng (có thể học được) context_vec = self.out_proj(context_vec) return context_vec
Ví dụ cụ thể về cách chia và biến đổi các head
x = [ [[1,2,3,4,5,6], [7,8,9,10,11,12], [13,14,15,16,17,18]], [[19,20,21,22,23,24], [25,26,27,28,29,30], [31,32,33,34,35,36]]
] # Chuyển sang Tensor
x = torch.tensor(x, dtype=torch.float32) # Batch 0 có 3 token
# Batch 1 cũng có 3 token
# Mỗi token là 1 vector 6 chiều. b, num_tokens, d_in = x.shape # 2, 3, 6
num_heads = 2
d_out = 6
head_dim = d_out // num_heads # 3
dropout = 0.1
multihead_attention = MultiHeadAttention(d_in, d_out, context_length, dropout, num_heads)
output = multihead_attention(x)
Kết quả in ra keys, ta cùng xem ý nghĩa
Split Keys:
tensor([ [ # batch 0 [[ -1.4421, 0.5373, 0.5513],[ -1.6091, -0.8023, 0.4414]], # token 0 -> head 0: [-1.4421, 0.5373, 0.5513], head 1: [1.6091, -0.8023, 0.4414] [[ -3.8594, 3.2438, 1.5168], [ -5.4921, -0.9958, 3.2902]],# token1 [[ -6.2767, 5.9502, 2.4822],[ -9.3751, -1.1893, 6.1390]] # token2 ], [ # batch 1 [[ -8.6940, 8.6567, 3.4477], [-13.2580, -1.3828, 8.9877]], # token3 [[-11.1113, 11.3632, 4.4131], [-17.1410, -1.5763, 11.8365]], # token4 [[-13.5286, 14.0696, 5.3785], [-21.0240, -1.7698, 14.6853]] # token5 ] ], grad_fn=<ViewBackward0>)
Transpose Keys: Biến đổi lại để mỗi head xử lý các tokens của riêng nó.
tensor([[ # Các token do head0 xử lý [ [ -1.4421, 0.5373, 0.5513], [ -3.8594, 3.2438, 1.5168], [ -6.2767, 5.9502, 2.4822] ], # Các token do head1 xử lý [ [ -1.6091, -0.8023, 0.4414], [ -5.4921, -0.9958, 3.2902], [ -9.3751, -1.1893, 6.1390] ]],