- vừa được xem lúc

Training thêm data cho GPT-2 model, version thử nghiệm thực tế :smile

0 0 14

Người đăng: Phan Ngoc

Theo Viblo Asia

Mở

Vừa qua cũng thấy nhiều bạn tìm hiểu về chatGPT, các kiến thức chủ đạo như mô hình Tranformer https://viblo.asia/p/tu-transformer-den-language-model-bai-1-bat-dau-voi-kien-truc-mo-hinh-transformer-38X4EN1gJN2 . Mình dù không hiểu nhiều lắm, nhưng dựa trên các thông tin trên internet, cũng vọc vạch làm một số demo nhỏ cho việc train model GPT-2, ở mức newbie , hy vọng a.e chạy phát cho vui để lấy cảm hứng học được nhiều hơn 😊

Load the dataframe

import pandas as pd df = pd.read_csv('../data/Restaurant_Reviews.tsv', sep='\t')
df = df.rename(columns={'Review': 'text'})
df 

Clean text một chút.

import nltk
nltk.download('stopwords')
import re
from nltk.corpus import stopwords def clean_text(text): # Make text lowercase text = text.lower() # Remove text in square brackets text = re.sub('\[.*?\]', '', text) # Remove links text = re.sub('https?://\S+|www\.\S+', '', text) # Remove punctuation text = re.sub('[^a-zA-Z0-9\s]+', '', text) # Remove words containing numbers text = re.sub('\w*\d\w*', '', text) # Remove stop words stop_words = set(stopwords.words('english')) words = text.split() filtered_words = [word for word in words if word not in stop_words] text = ' '.join(filtered_words) # Remove extra whitespace text = re.sub('\s+', ' ', text).strip() return text # Apply the clean_text function to all text in the 'text' column
df['text'] = df['text'].apply(clean_text) # Show the updated dataframe
df.head()

Phần tách thành những token nhỏ hơn

import re contents = df['text'].values.tolist()
def santilize(x): t = x.split(' ') new_list = [item for item in t if item is not None] return new_list
content_tokens = list(map(santilize, contents))
print('content_tokens', content_tokens[0]) 

Load model pretrained của GPT2

from transformers import AutoTokenizer, AutoModelWithLMHead
modelMaskedLM = AutoModelWithLMHead.from_pretrained('gpt2')
tokenizerVI = AutoTokenizer.from_pretrained('gpt2')

Add thêm vocabulary mới cho model.

def flatten_list(lst): flattened = [] for item in lst: if isinstance(item, list): flattened.extend(flatten_list(item)) else: flattened.append(item) return flattened sentence_tokens = flatten_list(content_tokens)
print('sentence_tokens', set(sentence_tokens[0:30]))
new_tokens = set(sentence_tokens) - set(tokenizerVI.get_vocab().keys())
print('length before add:', len(tokenizerVI.vocab))
tokenizerVI.add_tokens(list(new_tokens))
print('length after add:', len(tokenizerVI.vocab))
modelMaskedLM.resize_token_embeddings(len(tokenizerVI)+1)

Viết file những câu + token sẽ train để review.

# Open the file in write mode
with open("../data/news_output.txt", "w") as file: # Truncate the file to the current position of the file pointer file.truncate() print(content_tokens[:20])
with open("../data/news_output.txt", "w") as file: for t in content_tokens[:2]: file.write(' '.join(t) + "\n") tokenizerVI.add_special_tokens({'pad_token': '[PAD]'})

Encode thành list, seq lấy max=512

max_seq_length = 512
encoded_texts = [tokenizerVI.encode(text, truncation=True, max_length=max_seq_length) for text in contents]
print(encoded_texts[:2])

Train thôi ✌️

from transformers import LineByLineTextDataset, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments # training_data = LineByLineTextDataset(
# tokenizer=tokenizerVI,
# file_path='../data/news_output.txt',
# block_size=1024,
# ) data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizerVI, mlm=False, mlm_probability=0.15
) training_args = TrainingArguments( output_dir="./results-text", overwrite_output_dir=True, num_train_epochs=12, per_device_train_batch_size=16, per_device_eval_batch_size=8, warmup_steps=1000, logging_steps=500,
) trainer = Trainer( model=modelMaskedLM, args=training_args, train_dataset=encoded_texts, data_collator=data_collator
) trainer.train()
trainer.save_model()

Và rồi xem thử model hoạt động thế nào 🌸

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel text = """
write comment about "Loved this", for restaurant
"""
input_ids = tokenizerVI.encode(text, return_tensors='pt')
print('input_ids',input_ids)
max_length = 100 sample_outputs = modelMaskedLM.generate(input_ids,pad_token_id=tokenizerVI.eos_token_id, bos_token_id=tokenizerVI.bos_token_id, eos_token_id=tokenizerVI.eos_token_id, do_sample=True, max_length=max_length, min_length=max_length, top_k=40, num_beams=5, early_stopping=True, no_repeat_ngram_size=2, num_return_sequences=3) for i, sample_output in enumerate(sample_outputs): print(">> Generated text {}\n\n{}".format(i+1, tokenizerVI.decode(sample_output.tolist()))) print('\n---') 
  • Kết quả cùng thầy zui zui, dù không chắc có hiệu quả không ✌️

Thank for reading 😃

Bình luận

Bài viết tương tự

- vừa được xem lúc

Hành trình AI của một sinh viên tồi

Mình ngồi gõ những dòng này vào lúc 2h sáng (chính xác là 2h 2 phút), quả là một đêm khó ngủ. Có lẽ vì lúc chiều đã uống cốc nâu đá mà giờ mắt mình tỉnh như sáo, cũng có thể là vì những trăn trở về lý thuyết chồng chất ánh xạ mình đọc ban sáng khiến không tài nào chợp mắt được hoặc cũng có thể do mì

0 0 131

- vừa được xem lúc

[Deep Learning] Key Information Extraction from document using Graph Convolution Network - Bài toán trích rút thông tin từ hóa đơn với Graph Convolution Network

Các nội dung sẽ được đề cập trong bài blog lần này. . Tổng quan về GNN, GCN. Bài toán Key Information Extraction, trích rút thông tin trong văn bản từ ảnh.

0 0 201

- vừa được xem lúc

Tìm hiểu về YOLO trong bài toán real-time object detection

1.Yolo là gì. . Họ các mô hình RCNN ( Region-Based Convolutional Neural Networks) để giải quyết các bài toán về định vị và nhận diện vật thể.

0 0 272

- vừa được xem lúc

Encoding categorical features in Machine learning

Khi tiếp cận với một bài toán machine learning, khả năng cao là chúng ta sẽ phải đối mặt với dữ liệu dạng phân loại (categorical data). Khác với các dữ liệu dạng số, máy tính sẽ không thể hiểu và làm việc trực tiếp với categorical variable.

0 0 244

- vừa được xem lúc

TF Lite with Android Mobile

Như các bạn đã biết việc đưa ứng dụng đến với người sử dụng thực tế là một thành công lớn trong Machine Learning.Việc làm AI nó không chỉ dừng lại ở mức nghiên cứu, tìm ra giải pháp, chứng minh một giải pháp mới,... mà quan trọng là đưa được những nghiên cứu đó vào ứng dụng thực tế, được sử dụng để

0 0 54

- vừa được xem lúc

Xây dựng hệ thống Real-time Multi-person Tracking với YOLOv3 và DeepSORT

Trong bài này chúng ta sẽ xây dựng một hệ thống sử dụng YOLOv3 kết hợp với DeepSORT để tracking được các đối tượng trên camera, YOLO là một thuật toán deep learning ra đời vào tháng 5 năm 2016 và nó nhanh chóng trở nên phổ biến vì nó quá nhanh so với thuật toán deep learning trước đó, sử dụng YOLO t

0 0 303