- vừa được xem lúc

Skorch: Cách để Pytorch trở nên đơn giản

0 0 11

Người đăng: Trinh Quang

Theo Viblo Asia

Giới thiệu Skorch

PyTorch luôn là một lựa chọn của hầu hết các anh em AI engineer để xây dựng bất kỳ mô hình học sâu nào. Tuy nhiên, có một điều đặc biệt mà anh em nào từng code PyTorch cũng thường xuyên gặp phải như hình bên dưới:

Trong đoạn code trên, với mỗi epoch thì chính ta cần lặp qua hết các batch dữ liệu. Mỗi batch dữ liệu chúng ta cần forward qua mô hình, tính loss và backward để cập nhập trọng số cho các layer trong mô hình, đôi khi còn phải tính performance để kiểm tra xem mô hình có đang học đúng nữa không chứ 😂 . Việc này đối với các anh em beginer thì rất tốt để học và nắm được từng bước thì code sẽ hoạt động như thế nào. Nhưng đối với các anh em đã vững thì nhiều khi việc code hẳn hoi ra như vậy thì tốn cũng kha khá thời gian, vậy thì tại sao không thử với một số thư viện mạnh mẽ giúp chúng ta chỉ cần vài dòng code đã có thể có được kết quả mình cần. Mình thì thường hay cần chạy một số mô hình benchmark hoặc đơn giản là thử nghiệm một kiến trúc hoặc bộ dữ liệu nào đó để biết kết quả ban đầu. Để tối ưu về thời gian và nắm bắt các kết quả ban đầu thì mình thường sử dụng Skorch. Vậy thì Skorch là gì và hoạt động như thế nào, mình và anh em sẽ tìm hiểu qua nhé.

Skorch (Sklearn + PyTorch) là một thư viện mã nguồn mở cung cấp tính tương thích đầy đủ với Scikit-learn cho PyTorch qua đó giúp đơn giản hóa rất nhiều quá trình huấn luyện mạng neural với PyTorch

Điều này có nghĩa là chúng ta có thể huấn luyện các mô hình PyTorch một cách tương tự như Scikit-learn, sử dụng các hàm như fit(), predict(), score(), v.v.

Với Skorch, bạn có thể sử dụng những lợi ích của PyTorch như tính linh hoạt cao và hiệu suất tốt, cùng với khả năng triển khai mô hình nhanh chóng và dễ dàng.

Để cài đặt Skorch, bạn chỉ cần lệnh pip đơn giản

pip install skorch

Skorch có rất nhiều ví dụ và code mẫu để chạy với nhiều bài toán khác nhau từ transferlearning tới LLM. Trong bài viết này mình chỉ đi qua một triển khai cơ bản của bài toán classify, các bạn có thể đọc thêm trên documents (https://skorch.readthedocs.io/en/stable/) của skorch để tìm hiểu việc chạy cho các bài toán khác như thế nào nhé.

Triển khai với Skorch

Loading Data

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
# Load dataset
mnist = fetch_openml('mnist_784', as_frame=False, cache=False) # Preproces data
X = mnist.data.astype('float32')
y = mnist.target.astype('int64')

Build Neural Network with PyTorch

import torch
from torch import nn
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
mnist_dim = X.shape[1]
hidden_dim = int(mnist_dim/8)
output_dim = len(np.unique(mnist.target)) class ClassifierModule(nn.Module): def __init__( self, input_dim=mnist_dim, hidden_dim=hidden_dim, output_dim=output_dim, dropout=0.5, ): super(ClassifierModule, self).__init__() self.dropout = nn.Dropout(dropout) self.hidden = nn.Linear(input_dim, hidden_dim) self.output = nn.Linear(hidden_dim, output_dim) def forward(self, X, **kwargs): X = F.relu(self.hidden(X)) X = self.dropout(X) X = F.softmax(self.output(X), dim=-1) return X

Define a classifier by skorch

Có rất nhiều tham số được truyền vào, tùy vào độ phức tạp của mô hình và các kỹ thuật training bạn muốn sử dụng thì bạn có thể tham khảo documents để biết thêm những tham số có thể truyền vào nha. Ở đây mình chỉ define một mô hình đơn giản nên các tham số đưa vào class cũng không quá phức tạp.

from skorch import NeuralNetClassifier
torch.manual_seed(0) net = NeuralNetClassifier( ClassifierModule, max_epochs=20, lr=0.1, device=device,
)

Training

Chỉ đơn giản băng 1 câu lệnh

net.fit(X_train, y_train)
 epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------ 1 0.8387 0.8800 0.4174 3.8169 2 0.4332 0.9103 0.3133 0.8510 3 0.3612 0.9233 0.2684 0.8208 4 0.3233 0.9309 0.2317 0.8079 5 0.2938 0.9353 0.2173 0.8074 6 0.2738 0.9390 0.2039 0.8277 7 0.2600 0.9454 0.1868 0.8224 8 0.2427 0.9484 0.1757 0.8623 9 0.2362 0.9503 0.1683 0.8312 10 0.2226 0.9512 0.1621 0.8221 11 0.2184 0.9529 0.1565 0.8158 12 0.2090 0.9541 0.1508 0.7974 13 0.2067 0.9570 0.1446 0.8123 14 0.1978 0.9570 0.1412 0.8304 15 0.1923 0.9582 0.1392 0.8421 16 0.1889 0.9582 0.1342 0.8153 17 0.1855 0.9612 0.1297 0.8458 18 0.1786 0.9613 0.1266 0.8827 19 0.1728 0.9615 0.1250 0.8335 20 0.1698 0.9613 0.1248 0.8112

Evaluate

from sklearn.metrics import accuracy_score
y_pred = net.predict(X_test)
accuracy_score(y_test, y_pred)
0.9631428571428572

Kết luận

Trong bài viết này, chúng ta có thể triển khai các mô hình ML & DL qua việc sử dụng Skorch. Chỉ vài bước xử lý dữ liệu và định nghĩa mô hình chúng ta đã có thể huấn luyện và đưa ra một số các kết quả ban đầu. Skorch còn hỗ trợ tìm kiếm siêu tham số của mô hình giúp bạn tuning và tìm ra mô hình tốt nhất. Một số điểm hạn chế mà mình thấy đó là skorch không có các module visualize một cách trực quan hóa, và sẽ khá khó debug khi quá trình forward gặp vấn đề.

Tài liệu tham khảo

[1] https://skorch.readthedocs.io/en/stable/

Bình luận

Bài viết tương tự

- vừa được xem lúc

TorchServe, công cụ hỗ trợ triển khai mô hình PyTorch

Lời mở đầu. Hôm nay tôi sẽ giới thiệu sơ qua cho các bạn công cụ triển khai mô hình dành riêng cho mô hình PyTorch.

0 0 37

- vừa được xem lúc

Bacteria classification bằng thư viện fastai

Giới thiệu. fastai là 1 thư viện deep learning hiện đại, cung cấp API bậc cao để giúp các lập trình viên AI cài đặt các mô hình deep learning cho các bài toán như classification, segmentation... và nhanh chóng đạt được kết quả tốt chỉ bằng vài dòng code. Bên cạnh đó, nhờ được phát triển trên nền tản

0 0 37

- vừa được xem lúc

Pytorch - Một số tips hay, tối ưu cho quá trình huấn luyện model của bạn

Xin chào các bạn, cũng lâu rồi mình mới quay trở lại ngồi viết mấy bài chia sẻ trên viblo. Chẹp, dạo này làm remote nên lười vận động, lười cả viết bài hẳn.

0 0 276

- vừa được xem lúc

Hướng dẫn tất tần tật về Pytorch để làm các bài toán về AI

Giới thiệu về pytorch. Pytorch là framework được phát triển bởi Facebook.

0 0 181

- vừa được xem lúc

Nhận diện khuôn mặt với mạng MTCNN và FaceNet (Phần 2)

Chào mừng các bạn đã quay lại với series "Nhận diện khuôn mặt với mạng MTCNN và FaceNet" của mình. Ở phần 1, mình đã giải thích qua về lý thuyết và nền tảng của 2 mạng là MTCNN và FaceNet.

0 0 733

- vừa được xem lúc

Video Understanding: Tổng quan

"Thợ lặn" hơi lâu, sau sự kiện MayFest thì đến bây giờ cũng là 3 tháng rồi mình không viết thêm bài mới. Thế nên là, hôm nay mình lại ngoi lên, đầu tiên là để luyện lại văn viết một chút, tiếp theo cũ

0 0 97