Xây Dựng Mô Hình Ngôn Ngữ Lớn (Phần 6): Tinh chỉnh cho nhiệm vụ phân loại văn bản

0 0 0

Người đăng: Le Thanh Cong

Theo Viblo Asia

Ở các bài viết trước, ta đã triển khai gần như hoàn chỉnh việc xây dựng 1 mô hình ngôn ngữ lớn, từ việc xử lý dữ liệu đầu vào, cơ chế attention, logic trong khối Transformer cho đến tiền huấn luyện. Bây giờ, chúng ta sẽ cùng đi đến bước cuối trong quá trình xây dựng một mô hình ngôn ngữ lớn: Tinh chỉnh mô hình.

Mô hình sau khi trải qua quá trình tiền huấn luyện đã có thể sinh văn bản khá mượt mà.

Tuy nhiên, nó có thể vẫn đang còn hạn chế ở một số nhiệm vụ chuyên biệt như phân loại văn bản, dịch thuật ...

=> Do đó, tinh chỉnh là bước để xử lý và cải thiện các vấn đề trên.

File Juputer NoteBook của bài viết này nằm tại đây

1. Các phương pháp tinh chỉnh mô hình

Hai phương pháp phổ biến nhất là Classification fine-tuningInstruction fine-tuning.

Điểm chung của 2 phương pháp là tập dữ liệu tinh chỉnh có nhãn dán (câu hỏi và đáp án).

Sự khác nhau về mục đích:

  • Classification fine-tuning giúp mô hình phân loại dữ liệu tốt hơn (ví dụ cho đọc 1 email và xác định là spam hay không spam)
  • Instruction fine-tuning giúp mô hình trả lời tốt hơn khi gặp các câu hỏi phức tạp

Ví dụ về Classification fine-tuning



Ví dụ về Instruction fine-tuning

=> Ở bài viết này chúng ta sẽ tìm hiểu Classification fine-tuning trước.

2. Chuẩn bị tập dữ liệu

Ở bước đầu tiên này chúng ta tiến hành tải về, xem thử bên trong tệp dữ liệu.

import urllib.request
import zipfile
import os
from pathlib import Path
import pandas as pd url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path = "sms_spam_collection.zip"
extracted_path = "sms_spam_collection"
data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv" def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path): if data_file_path.exists(): print(f"{data_file_path} already exists. Skipping download and extraction.") return # Downloading file with urllib.request.urlopen(url) as response: with open(zip_path, "wb") as out_file: out_file.write(response.read()) # Giải nén file zip with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(extracted_path) # Add .tsv file extension original_file_path = Path(extracted_path) / "SMSSpamCollection" os.rename(original_file_path, data_file_path) print(f"File downloaded and saved as {data_file_path}") try: download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
except (urllib.error.HTTPError, urllib.error.URLError, TimeoutError) as e: print(f"Primary URL failed: {e}. Trying backup URL...") url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/sms%2Bspam%2Bcollection.zip" download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path) # Đọc dữ liệu trong file và in ra
df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
print(df) # In ra số lượng bản ghi chia theo nhãn dán
print(df["Label"].value_counts())
"""
Label
ham 4825
spam 747
Name: count, dtype: int64
"""
	Label	Text
0	ham Go until jurong point, crazy.. Available only ...
1	ham Ok lar... Joking wif u oni...
2	spam Free entry in 2 a wkly comp to win FA Cup fina...
3	ham U dun say so early hor... U c already then say...
4	ham Nah I don't think he goes to usf, he lives aro...
...	... ...
  • Dễ thấy có sự chênh lệch giữa số email spam và ham
  • Để đơn giản hóa, ta sẽ lấy tập con sao cho chứa 747 mẫu từ mỗi lớp.
  • Còn có nhiều cách khác để xử lý sự mất cân bằng lớp, nhưng chúng nằm ngoài phạm vi của chhương. Có thể tìm thấy ví dụ và thêm thông tin trong hướng dẫn sử dụng imbalanced-learn.
def create_balanced_dataset(df): # Đếm số lượng "spam" num_spam = df[df["Label"] == "spam"].shape[0] # Random lấy số lượng "ham" bằng với số lượng "spam" ở trên ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123) # Gộp lại balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]]) # bộ dữ liệu gồm 747 spam và ham return balanced_df balanced_df = create_balanced_dataset(df)
print(balanced_df["Label"].value_counts())
  • Chuyển đổi nhãn dữ liệu sang dạng số
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1}) print(balanced_df)
4307 0 Awww dat is sweet! We can think of something t...
4138 0 Just got to <#>
4831 0 The word "Checkmate" in chess comes from the P...
4461 0 This is wishing you a great day. Moji told me ...
5440 0 Thank you. do you generally date the brothas?
... ... ...
  • Chia dữ liệu thành các phần training, validationtest lần lượt theo tỷ lệ 70%, 10% và 20%
def random_split(df, train_frac, validation_frac): df = df.sample(frac=1, random_state=123).reset_index(drop=True) train_end = int(len(df) * train_frac) validation_end = train_end + int(len(df) * validation_frac) train_df = df[:train_end] validation_df = df[train_end:validation_end] test_df = df[validation_end:] return train_df, validation_df, test_df train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1) train_df.to_csv("train.csv", index=None)
validation_df.to_csv("validation.csv", index=None)
test_df.to_csv("test.csv", index=None)

3. Xử lý tập dữ liệu

Ta đang làm việc với một tập dữ liệu nội dung email với các độ dài khác nhau. Để nhóm các email này thành batch để xử lý, chúng ta có hai lựa chọn chính:

  • Cắt bớt (truncate) tất cả các email xuống độ dài của email ngắn nhất .
  • Đệm thêm (pad) vào tất cả sao cho đều độ dài bằng độ dài của email dài nhất.

Lựa chọn đầu tiên tiết kiệm chi phí tính toán và lưu trữ hơn, nhưng gây mất mát thông tin đáng kể và có thể làm giảm hiệu suất của mô hình. Do đó, chúng ta chọn cách thứ hai. Ký tự đệm thêm là <|endoftext|>

Như đã biết, mô hình không thể hiểu được dữ liệu dạng văn bản thô mà cần phải "số hóa". Tập dữ liệu dùng để tinh chỉnh cũng không ngoại lệ. Tiến hành chuyển dữ liệu sang dạng tokenID.

Phần gạch chân là phần đệm thêm vào, sao cho các chuỗi có độ dài bằng nhau

import tiktoken
import torch
from torch.utils.data import Dataset
import pandas as pd # Dùng `DataLoader` để chia các tập dữ liệu theo các batch
from torch.utils.data import DataLoader tokenizer = tiktoken.get_encoding("gpt2")
print(tokenizer.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})) class SpamDataset(Dataset): def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256): self.data = pd.read_csv(csv_file) # Tiền xử lý: chuyển mỗi văn bản trong cột "Text" thành danh sách các token. self.encoded_texts = [ tokenizer.encode(text) for text in self.data["Text"] ] # Nếu không chỉ định max_length, tự động tìm độ dài của chuỗi dài nhất if max_length is None: self.max_length = self._longest_encoded_length() else: self.max_length = max_length # Nếu có chỉ định max_length, cắt bớt chuỗi nếu chúng có độ dài lớn hơn max_length self.encoded_texts = [ encoded_text[:self.max_length] for encoded_text in self.encoded_texts ] # Thêm phần tử cho tất cả chuỗi để đảm bảo chúng có cùng độ dài. self.encoded_texts = [ encoded_text + [pad_token_id] * (self.max_length - len(encoded_text)) for encoded_text in self.encoded_texts ] def __getitem__(self, index): encoded = self.encoded_texts[index] label = self.data.iloc[index]["Label"] return ( torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long) ) def __len__(self): return len(self.data) # Hàm tìm độ dài lớn nhất của 1 chuỗi trong tập dữ liệu def _longest_encoded_length(self): max_length = 0 for encoded_text in self.encoded_texts: encoded_length = len(encoded_text) if encoded_length > max_length: max_length = encoded_length return max_length train_dataset = SpamDataset( csv_file="train.csv", max_length=None, tokenizer=tokenizer
) val_dataset = SpamDataset( csv_file="validation.csv", max_length=train_dataset.max_length, tokenizer=tokenizer
) test_dataset = SpamDataset( csv_file="test.csv", max_length=train_dataset.max_length, tokenizer=tokenizer
) # Thiết lập số worker (0 nghĩa là không sử dụng đa luồng)
# Kích thước 1 lô là 8
num_workers = 0
batch_size = 8 torch.manual_seed(123) train_loader = DataLoader( dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True,
) val_loader = DataLoader( dataset=val_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=False,
) test_loader = DataLoader( dataset=test_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=False,
) # Xem số lượng batch ở mỗi tập
print(f"{len(train_loader)} training batches")
print(f"{len(val_loader)} validation batches")
print(f"{len(test_loader)} test batches") """
130 training batches
19 validation batches
38 test batches
"""

4. Thử nghiệm khả năng của mô hình trước khi tinh chỉnh

Cùng thử xem mô hình sẽ xử lý thế nào khi được yêu cầu phần loại tin nhắn spam.

# ... File đầy đủ: https://sal.vn/6SEtHM def main(): CHOOSE_MODEL = "gpt2-small (124M)" INPUT_PROMPT = "Every effort moves" BASE_CONFIG = { "vocab_size": 50257, # Vocabulary size "context_length": 1024, # Context length "drop_rate": 0.0, # Dropout rate "qkv_bias": True # Query-key-value bias } model_configs = { "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12}, "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16}, "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20}, "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25}, } BASE_CONFIG.update(model_configs[CHOOSE_MODEL]) model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")") settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2") model = GPTModel(BASE_CONFIG) load_weights_into_gpt(model, params) model.eval() text = ( "Is the following text 'spam'? Answer with 'yes' or 'no':" " 'You are a winner you have been specially" " selected to receive $1000 cash or a $2000 award.'" ) token_ids = generate_text_simple( model=model, idx=text_to_token_ids(text, tokenizer), max_new_tokens=23, context_size=BASE_CONFIG["context_length"] ) print(token_ids_to_text(token_ids, tokenizer)) if __name__ == "__main__": main()
Is the following text 'spam'? Answer with 'yes' or 'no': 'You are a winner you have been specially selected to receive $1000 cash or a $2000 award.' The following text 'spam'? Answer with 'yes' or 'no': 'You are a winner

Với việc yêu cầu trả lời yes hoặc no, mô hình đang chưa hiểu và trả ra kết quả không liên quan. Giờ lại lúc bắt tay vào bước tinh chỉnh.

5. Điều chỉnh định dạng đầu ra của mô hình

Với nhiệm phụ phân loại spam, đầu ra của mô hình chỉ cần 2 giá trị là 01.

=> Do đó, chúng ta cần sửa lại sao cho số giá trị đầu ra giảm từ hơn 50k xuống còn 2.

Hình minh họa việc thay đổi lớp Linear output trong mô hình để phù hợp với nhiệm vụ phân loại spam.

# ... def main(): # ... # đóng băng mô hình, nghĩa là dừng quá trình cập nhật các tham số for param in model.parameters(): param.requires_grad = False torch.manual_seed(123) num_classes = 2 # Thay thế lớp Linear output để đầu ra chứa 2 giá trị model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes) # Cho phép khối Transformer cuối cùng (thứ 12) có thể cập nhật tham số khi huấn luyện for param in model.trf_blocks[-1].parameters(): param.requires_grad = True # Cho phép lớp chuẩn hóa cuối cùng có thể cập nhật tham số khi huấn luyện for param in model.final_norm.parameters(): param.requires_grad = True model.eval() inputs = tokenizer.encode("Do you have time") inputs = torch.tensor(inputs).unsqueeze(0) with torch.no_grad(): outputs = model(inputs) print("Outputs:\n", outputs) """ Outputs: tensor([[[-1.5854, 0.9904], [-3.7235, 7.4548], [-2.2661, 6.6049], [-3.5983, 3.9902]]]) Outputs dimensions: torch.Size([1, 4, 2]) """ print("Last output token:", outputs[:, -1, :]) # Last output token: tensor([[-3.5983, 3.9902]]) # Với TH 2 giá trị đầu ra thì có thể lược bỏ bước softmax đi # chỉ cần giá trị cái nào lớn hơn thì chọn cái đó là được probas = torch.softmax(outputs[:, -1, :], dim=-1) label = torch.argmax(probas) print("Class label:", label.item()) # Class label: 1 => Yes
  • Về mặt kỹ thuật, chỉ cần sửa out_head là đủ.
  • Tuy nhiên, thực tế cho thấy rằng việc tinh chỉnh thêm các khối khác có thể cải thiện hiệu suất đáng kể.
  • Vì vậy, a sẽ sửa thêm khối Transformer cuối cùng và khối Final LayerNorm.


Tại sao lại chỉ dùng thêm Final LayerNorm và khối Transformer cuối cùng ?

  • Tiết kiệm tài nguyên tính toán
  • Tinh chỉnh có nghĩa là thay đổi nhỏ chứ không phải là huấn luyện lại.
  • Bộ dữ liệu đặc thù nhỏ hơn nhiều so với dữ liệu pretrained, việc tinh chỉnh toàn bộ mô hình có thể dẫn đến overfitting

6. Tính toán hàm mất mát và độ chính xác

Tính hàm mất mát

  • Hàm mất mát vẫn được tính theo phương pháp Cross Entropy tương tự như giai đoạn tiền huấn luyện.
  • Hàm calc_loss_batch ở đây giống với trong chương 5, ngoại trừ việc chúng ta chỉ quan tâm đến việc tối ưu token cuối cùng model(input_batch)[:, -1, :] thay vì tất cả các token model(input_batch)
def calc_loss_batch(input_batch, target_batch, model, device): input_batch, target_batch = input_batch.to(device), target_batch.to(device) logits = model(input_batch)[:, -1, :] loss = torch.nn.functional.cross_entropy(logits, target_batch) return loss def calc_loss_loader(data_loader, model, device, num_batches=None): total_loss = 0. if len(data_loader) == 0: return float("nan") elif num_batches is None: num_batches = len(data_loader) else: # Reduce the number of batches to match the total number of batches in the data loader # if num_batches exceeds the number of batches in the data loader num_batches = min(num_batches, len(data_loader)) for i, (input_batch, target_batch) in enumerate(data_loader): if i < num_batches: loss = calc_loss_batch(input_batch, target_batch, model, device) total_loss += loss.item() else: break return total_loss / num_batches
Training loss: 2.453
Validation loss: 2.583
Test loss: 2.322

Tính độ chính xác

def calc_accuracy_loader(data_loader, model, device, num_batches=None): model.eval() correct_predictions, num_examples = 0, 0 if num_batches is None: num_batches = len(data_loader) else: num_batches = min(num_batches, len(data_loader)) # Lặp qua từng batch trong data_loader và kiểm tra nếu chưa vượt quá số batch tối đa. for i, (input_batch, target_batch) in enumerate(data_loader): if i < num_batches: input_batch, target_batch = input_batch.to(device), target_batch.to(device) # Tính toán đầu ra with torch.no_grad(): logits = model(input_batch)[:, -1, :] # Logits of last output token predicted_labels = torch.argmax(logits, dim=-1) # Giá trị đầu ra # Số mẫu đã xử lý num_examples += predicted_labels.shape[0] # Số dự đoán chính xác correct_predictions += (predicted_labels == target_batch).sum().item() else: break # Tỷ lệ chính xác return correct_predictions / num_examples def main(): # ... device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes torch.manual_seed(123) # For reproducibility due to the shuffling in the training data loader train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=10) val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=10) test_accuracy = calc_accuracy_loader(test_loader, model, device, num_batches=10) print(f"Training accuracy: {train_accuracy*100:.2f}%") print(f"Validation accuracy: {val_accuracy*100:.2f}%") print(f"Test accuracy: {test_accuracy*100:.2f}%")
Training accuracy: 46.25%
Validation accuracy: 45.00%
Test accuracy: 48.75%

Độ chính xác còn chưa được 50%, do chúng ta vẫn chưa hề thực hiện huấn luyện.

7. Tinh chỉnh mô hình với dữ liệu gán nhãn

  • Trong phần này, chúng ta huấn luyện để cải thiện độ chính xác trong việc phân loại email của mô hình

  • Hàm train_classifier_simple dưới đây gần như giống với hàm train_model_simple mà chúng ta đã sử dụng ở chương 5

  • Chỉ có hai điểm khác biệt là:

    1. Theo dõi số lượng mẫu huấn luyện đã xử lý (examples_seen) thay vì số lượng token đã xử lý
    2. Tính toán độ chính xác sau mỗi chu kỳ huẩn luyện thay vì in ra một đoạn văn bản mẫu
# Giống ở phần pretraining
def evaluate_model(model, train_loader, val_loader, device, eval_iter): model.eval() with torch.no_grad(): train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter) val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter) model.train() return train_loss, val_loss def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs, eval_freq, eval_iter): # Khởi tạo các biến lưu trữ giá trị hàm mất mát và độ chính xác của tập train và validation train_losses, val_losses, train_accs, val_accs = [], [], [], [] examples_seen, global_step = 0, -1 # Vòng lặp huấn luyện chính for epoch in range(num_epochs): model.train() # Đặt mô hình ở chế độ huấn luyện for input_batch, target_batch in train_loader: optimizer.zero_grad() # Đặt lại gradient lỗi từ lần lặp batch trước đó loss = calc_loss_batch(input_batch, target_batch, model, device) loss.backward() # Tính toán gradient lỗi optimizer.step() # Cập nhật trọng số mô hình bằng gradient lỗi examples_seen += input_batch.shape[0] # theo dõi số lượng mẫu thay vì token global_step += 1 # Làm việc với hàm mất mát if global_step % eval_freq == 0: train_loss, val_loss = evaluate_model( model, train_loader, val_loader, device, eval_iter) train_losses.append(train_loss) val_losses.append(val_loss) print(f"Ep {epoch+1} (Step {global_step:06d}): " f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}") # Tính độ chính xác sau mỗi chu kỳ train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter) val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter) print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="") print(f"Validation accuracy: {val_accuracy*100:.2f}%") train_accs.append(train_accuracy) val_accs.append(val_accuracy) return train_losses, val_losses, train_accs, val_accs, examples_seen

Thực thi quá trình huấn luyện

import time
# ... File đầy đủ: https://sal.vn/ZijlcZ def main(): # ... start_time = time.time() torch.manual_seed(123) optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1) num_epochs = 5 train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple( model, train_loader, val_loader, optimizer, device, num_epochs=num_epochs, eval_freq=50, eval_iter=5, ) end_time = time.time() execution_time_minutes = (end_time - start_time) / 60 print(f"Training completed in {execution_time_minutes:.2f} minutes.") if __name__ == "__main__": main()

Kết quả in ra sau 5 chu kỳ huấn luyện:

Ep 1 (Step 000000): Train loss 2.153, Val loss 2.392
Ep 1 (Step 000050): Train loss 0.617, Val loss 0.637
Ep 1 (Step 000100): Train loss 0.523, Val loss 0.557
Training accuracy: 70.00% | Validation accuracy: 72.50% Ep 2 (Step 000150): Train loss 0.561, Val loss 0.489
Ep 2 (Step 000200): Train loss 0.419, Val loss 0.397
Ep 2 (Step 000250): Train loss 0.409, Val loss 0.353
Training accuracy: 82.50% | Validation accuracy: 85.00% Ep 3 (Step 000300): Train loss 0.333, Val loss 0.320
Ep 3 (Step 000350): Train loss 0.340, Val loss 0.306
Training accuracy: 90.00% | Validation accuracy: 90.00% Ep 4 (Step 000400): Train loss 0.136, Val loss 0.200
Ep 4 (Step 000450): Train loss 0.153, Val loss 0.132
Ep 4 (Step 000500): Train loss 0.222, Val loss 0.137
Training accuracy: 100.00% | Validation accuracy: 97.50% Ep 5 (Step 000550): Train loss 0.207, Val loss 0.143
Ep 5 (Step 000600): Train loss 0.083, Val loss 0.074
Training accuracy: 100.00% | Validation accuracy: 97.50%
Training completed in 5.31 minutes.

  • Dựa vào độ dốc đi xuống của 2 giá trị mất mát, chúng ta thấy rằng mô hình học tốt
  • Hơn nữa, 2 đường màu xanh và cam giảm cùng nhau trong suốt 5 chu kỳ cho thấy rằng mô hình không có xu hướng overfit

8. Thử nghiệm thực tế để phân loại email

Thử nghiệm

Chúng ta cùng lại các câu hỏi cũ mà trước đó nó chưa hiểu xem mô hình trả lời thế nào ?

def classify_review(text, model, tokenizer, device, max_length=None, pad_token_id=50256): model.eval() # Đặt mô hình ở chế độ đánh giá # Chuẩn bị đầu vào cho mô hình input_ids = tokenizer.encode(text) # tokenizer hóa văn bản supported_context_length = model.pos_emb.weight.shape[0] # Độ dài tối đa mà mô hình hỗ trợ # Cắt ngắn chuỗi nếu quá dài input_ids = input_ids[:min(max_length, supported_context_length)] # Đệm chuỗi để đạt đến độ dài tối đa input_ids += [pad_token_id] * (max_length - len(input_ids)) input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0) # Thêm chiều batch (kích thước lô) # Suy luận mô hình with torch.no_grad(): # Không tính toán gradient trong quá trình suy luận để tiết kiệm bộ nhớ logits = model(input_tensor)[:, -1, :] # Logits của token đầu ra cuối cùng predicted_label = torch.argmax(logits, dim=-1).item() # Lấy nhãn có xác suất cao nhất # Trả về kết quả phân loại return "spam" if predicted_label == 1 else "not spam" # Chuyển đổi đầu ra thành dạng văn bản

Mô hình đã hiểu và phân biệt đúng 2 nội dung thuộc spam hay not spam

text_1 = ( "You are a winner you have been specially" " selected to receive $1000 cash or a $2000 award."
) print(classify_review( text_1, model, tokenizer, device, max_length=train_dataset.max_length
))
# spam text_2 = ( "Hey, just wanted to check if we're still on" " for dinner tonight? Let me know!"
) print(classify_review( text_2, model, tokenizer, device, max_length=train_dataset.max_length
))
# not spam

Lưu lại mô hình

Mô hình hoạt động khá tốt, ta tiến hành lưu lại thông số mô hình để có thể tái sử dụng với code sau:

torch.save(model.state_dict(), "review_classifier.pth")

Nạp lại mô hình từ file đã lưu

model_state_dict = torch.load("review_classifier.pth, map_location=device")
model.load_state_dict(model_state_dict)

Tài liệu tham khảo

https://github.com/rasbt/LLMs-from-scratch/tree/main/ch06

Bình luận

Bài viết tương tự

- vừa được xem lúc

Tấn công và phòng thủ bậc nhất cực mạnh cho các mô hình học máy

tấn công bậc nhất cực mạnh = universal first-order adversary. Update: Bleeding edge của CleverHans đã lên từ 3.1.0 đến 4.

0 0 42

- vừa được xem lúc

[Deep Learning] Key Information Extraction from document using Graph Convolution Network - Bài toán trích rút thông tin từ hóa đơn với Graph Convolution Network

Các nội dung sẽ được đề cập trong bài blog lần này. . Tổng quan về GNN, GCN. Bài toán Key Information Extraction, trích rút thông tin trong văn bản từ ảnh.

0 0 228

- vừa được xem lúc

Trích xuất thông tin bảng biểu cực đơn giản với OpenCV

Trong thời điểm nhà nước đang thúc đẩy mạnh mẽ quá trình chuyển đổi số như hiện nay, Document Understanding nói chung cũng như Table Extraction nói riêng đang trở thành một trong những lĩnh vực được quan tâm phát triển và chú trọng hàng đầu. Vậy Table Extraction là gì? Document Understanding là cái

0 0 234

- vừa được xem lúc

Con đường AI của tôi

Gần đây, khá nhiều bạn nhắn tin hỏi mình những câu hỏi đại loại như: có nên học AI, bắt đầu học AI như nào, làm sao tự học cho đúng, cho nhanh, học không bị nản, lộ trình học AI như nào... Sau nhiều lần trả lời, mình nghĩ rằng nên viết hẳn một bài để trả lời chi tiết hơn, cũng như để các bạn sau này

0 0 162

- vừa được xem lúc

[B5'] Smooth Adversarial Training

Đây là một bài trong series Báo khoa học trong vòng 5 phút. Được viết bởi Xie et. al, John Hopkins University, trong khi đang intern tại Google. Hiện vẫn là preprint do bị reject tại ICLR 2021.

0 0 47

- vừa được xem lúc

Deep Learning với Java - Tại sao không?

Muốn tìm hiểu về Machine Learning / Deep Learning nhưng với background là Java thì sẽ như thế nào và bắt đầu từ đâu? Để tìm được câu trả lời, hãy đọc bài viết này - có thể kỹ năng Java vốn có sẽ giúp bạn có những chuyến phiêu lưu thú vị. DJL là tên viết tắt của Deep Java Library - một thư viện mã ng

0 0 151