- vừa được xem lúc

Bacteria classification bằng thư viện fastai

0 0 37

Người đăng: Hieu Bui

Theo Viblo Asia

Giới thiệu

fastai là 1 thư viện deep learning hiện đại, cung cấp API bậc cao để giúp các lập trình viên AI cài đặt các mô hình deep learning cho các bài toán như classification, segmentation... và nhanh chóng đạt được kết quả tốt chỉ bằng vài dòng code. Bên cạnh đó, nhờ được phát triển trên nền tảng thư viện Pytorch, nên fastai còn cung cấp các cấu phần bậc thấp cho các nhà nghiên cứu phát triển mô hình mới, cũng như hoàn toàn tương thích với các thành phần của pytorch.

Trong bài viết này, mình sẽ giới thiệu về 1 số tính năng của fastai và áp dụng chúng để xây dựng 1 mô hình phân lớp. Let's get started !!!

Cài đặt fastai

Các bạn có thể cài đặt fastai trên máy mình bằng câu lệnh sau:

pip install fastai --upgrade -q

Sau khi cài đặt thì chạy đoạn code sau để import fastai và các thư viện cần thiết:

import os
import requests
import urllib.request
import zipfile
import matplotlib.pyplot as plt
from torchsummary import summary
from fastai.vision.all import *
from bs4 import BeautifulSoup

Khi import fastai thì một số thư viện phổ biến như numpy, pandas, matplotlib cũng được import cùng nên không cần import lại nữa

Dữ liệu

Mình sẽ sử dụng bộ dữ liệu ảnh chụp vi khuẩn lấy từ trang web này. Các bạn có thể tải dữ liệu thẳng từ trên website về máy sau đó giải nén hoặc nếu ai dùng google colab thì có thể dùng đoạn code này:

os.makedirs('dibas_zip')
os.makedirs('dibas_images') url = 'http://misztal.edu.pl/software/databases/dibas/'
response = requests.get(url)
soup = BeautifulSoup(response.text, "html.parser") links = [tag['href'] for tag in soup.findAll('a')]
for link in links: if ".zip" in link: file_name = link.partition("/dibas/")[2] urllib.request.urlretrieve(link, 'dibas_zip/' + file_name) zip_ref = zipfile.ZipFile('dibas_zip/' + file_name, 'r') zip_ref.extractall('dibas_images/') zip_ref.close() print("Downloaded and extracted: " + file_name)

Dữ liệu của chúng ta gồm 692 ảnh:

fns = []
for root,dirs,files in os.walk(path, topdown=true): for f in files: fns.append(root/Path(f))
len(fns), fns[0]

Tạo Dataloader

fastai cung cấp API cho việc tạo Dataloader của pytorch 1 cách đơn giản và nhanh chóng

dblock = DataBlock(blocks=(ImageBlock, CategoryBlock), get_y=RegexLabeller(r'/(.+)_\d+.tif$'), splitter=RandomSplitter(valid_pct=0.1), item_tfms=[Resize(512)])

Đoạn lệnh trên sẽ trả về object DataBlock. Cùng tìm hiểu xem từng tham số dùng để làm gì nhé

  • block: Định nghĩa xem Dataloader sẽ trả về gì. Do bài toán của chúng ta là bài toán phân lớp nên Dataloader sẽ trả 2 thứ: ảnh và label tương ứng của nó.
  • get_y: lấy label từ tên file như thế nào. Label của mỗi bức ảnh là 1 phần trong tên file của nó. fastai cung cấp class RegexLabeller sử dụng regular expression để tách label từ tên file. VD:
RegexLabeller(r'/(.+)_\d+.tif$')('dibas_images/Lactobacillus.delbrueckii_0019.tif')

  • splitter: chia dataset thành 2 tập train/validation
  • item_tfms: do ảnh ở các kích thước khác nhau nên cần resize lại thành cùng kích thước mới có thể đóng gói thành từng batch.

Sau khi có Datablock thì chỉ cần chạy:

dls = dblock.dataloaders(source=fns, bs=16)

cùng với các tham số: nguồn dữ liệu (list các file ảnh) và batch size. Phương thức trên sẽ trả về object Dataloaders. Như tên gọi của nó, Dataloaders bao gồm nhiều Dataloader (1 train và 1 validation). Mọi người có thể index vào dls để truy cập các Dataloader: dls[0], dls[1].

Ta có thể kiểm tra xem bộ dữ liệu có bao nhiêu class:

print(dls.vocab)

Training

Việc luyện mô hình được xử lý bằng class Learner. Với bài toán phân lớp các bạn có thể tạo Learner bằng hàm cnn_learner:

learn = cnn_learner(dls, resnet50, metrics=accuracy)

Các tham số bao gồm:

  • Dataloaders
  • KIến trúc CNN. Ở đây mình dùng Resnet50 nhưng mọi người có thể dùng các mạng CNN pretrain có sẵn trên torchvision
  • List các metrics

Khi sử dùng mô hình pretrain, learner sẽ tự thêm các 1 số lớp Linear ở cuối phần CNN

learn.model
(1): Sequential( (0): AdaptiveConcatPool2d( (ap): AdaptiveAvgPool2d(output_size=1) (mp): AdaptiveMaxPool2d(output_size=1) ) (1): Flatten(full=False) (2): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (3): Dropout(p=0.25, inplace=False) (4): Linear(in_features=4096, out_features=512, bias=False) (5): ReLU(inplace=True) (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (7): Dropout(p=0.5, inplace=False) (8): Linear(in_features=512, out_features=33, bias=False) )

Mặc định phần trọng số của CNN sẽ đóng băng và không update trong quá trình train.

Việc train model rất đơn giản:

learn.fit_one_cycle(8, 1e-3)

fit_one_cycle(8, 1e-3) sẽ train model trong 8 epoch sử dụng 1-cycle policy. Nếu không muốn sử dụng learning rate scheduler, các bạn có thể dùng phương thức fit

Chỉ sau 8 epoch, accuracy đã đạt 98.5%. Giờ ta sẽ phá băng để train phần CNN:

learn.unfreeze()
learn.fit_one_cycle(5, 1e-7)

Để xem kết quả của mô hình, mọi người có thể chạy learn.show_results()

Toàn bộ quá trình từ load dữ liệu đến train model chỉ mất chưa đến 10 dòng code

dblock = DataBlock(blocks=(ImageBlock, CategoryBlock), get_y=RegexLabeller(r'/(.+)_\d+.tif$'), splitter=RandomSplitter(valid_pct=0.1), item_tfms=[Resize(512)]) dls = dblock.dataloaders(source=fns, bs=16) learn = cnn_learner(dls, resnet50, metrics=accuracy) learn.fit_one_cycle(8, 1e-3) learn.unfreeze()
learn.fit_one_cycle(5, 1e-7)

Lời kết

Trên đây, mình đã hướng dẫn các bạn cài đặt cài đặt mô hình phân loại vi khuẩn với độ chính xác 98.5% bằng thư viện fastai. Chỉ với chưa tới 10 dòng code, ta đã vượt qua kết quả SOTA 97% trên bộ dữ liệu này (các bạn có thể kiểm tra tại đây). Nếu mọi người thấy bài viết có ích, xin hãy để lại cho mình 1 upvode nhé. Cảm ơn mọi người đã quan tâm và hẹn gặp lại trong những bài tiếp theo.

Bình luận

Bài viết tương tự

- vừa được xem lúc

Tìm hiểu về YOLO trong bài toán real-time object detection

1.Yolo là gì. . Họ các mô hình RCNN ( Region-Based Convolutional Neural Networks) để giải quyết các bài toán về định vị và nhận diện vật thể.

0 0 285

- vừa được xem lúc

Xây dựng hệ thống Real-time Multi-person Tracking với YOLOv3 và DeepSORT

Trong bài này chúng ta sẽ xây dựng một hệ thống sử dụng YOLOv3 kết hợp với DeepSORT để tracking được các đối tượng trên camera, YOLO là một thuật toán deep learning ra đời vào tháng 5 năm 2016 và nó nhanh chóng trở nên phổ biến vì nó quá nhanh so với thuật toán deep learning trước đó, sử dụng YOLO t

0 0 317

- vừa được xem lúc

Sơ lược về bài toán Person Re-identification

Với những công nghệ hiện đại của thế kỷ 21 chúng ta đã có những phần cứng cũng như phần mềm mạnh mẽ để giải quyết những vấn đề và bài toán nan giải như face recognition, object detection, NLP,... Một trong những vấn đề nan giải cũng được mọi người chú ý ko kém những chủ đề trên là Object Tracking, v

0 0 63

- vừa được xem lúc

Xác định ý định câu hỏi trong hệ thống hỏi đáp

Mục tiêu bài viết. Phân tích câu hỏi là pha đầu tiên trong kiến trúc chung của một hệ thống hỏi đáp, có nhiệm vụ tìm ra các thông tin cần thiết làm đầu vào cho quá trình xử lý của các pha sau (trích c

0 0 94

- vừa được xem lúc

Epoch, Batch size và Iterations

Khi mới học Machine Learning và sau này là Deep Learning chúng ta gặp phải các khái niệm như Epoch, Batch size và Iterations. Để khỏi nhầm lẫn mình xin chia sẻ với các bạn sự khác nhau giữa các khái n

0 0 46

- vừa được xem lúc

Một số cải tiến của cross-entropy loss cho Face Recognition

Introduction. Bài toán face recognition trong vài năm trở lại đây đã đạt dược nhiều bước tiến lớn nhờ vào sự phát triển của học sâu (Deep learning), mà cụ thể hơn là mạng neural tích chập (Convolution

0 0 117