- vừa được xem lúc

Callback trong fastai (P3)

0 0 9

Người đăng: Hieu Bui

Theo Viblo Asia

Intro

Tiếp tục chuỗi bài viết về thư viện fastai, trong bài viết hôm nay, chúng ta sẽ cùng nhau tìm hiểu về hệ thống callback - nguyên liệu chính của training loop trong class Learner.

Một chút về Callback

Callback là gì? Callback về cơ bản chỉ là một function được gọi khi một sự kiện nào đó xảy ra. Ví dụ khi các bạn code 1 trang web bằng HTML với một nút trên đó. Nếu bạn muốn có 1 tác vụ nào đó được thực hiện khi người dùng bấm nút, khi đó bạn sẽ viết một function làm những việc cần thiết và truyền vào thuộc tính onClick của thẻ HTML. Hàm này được gọi là callback.

Callback trong fastai

Basic

Callback trong fastai được sử dụng để customize training loop của mô hình. Thông thường, training loop viết bằng pytorch sẽ giống như đoạn code ở dưới

def train(train_dl, model, epochs, optimizer, loss_func): for _ in range(epochs): model.train() for xb, yb in train_dl: out = model(xb) loss = loss_func(out, yb) loss.backward() optimizer.step() optimizer.zero_grad() for xb, yb in val_dl: validate(xb, yb) model.eval()

Trong nhiều trường hợp ta sẽ cần phải thêm tính năng cho training loop, ví dụ như:

  • Thêm regularization
  • Hyperparameter scheduling (learning rate, momentum, ...)
  • Log metrics

Cho mối trường hợp ta sẽ phải viết lại training loop để thực hiện những chức năng trên. Fastai giải quyết vấn đề này bằng 1 hệ thống callback. Sau mỗi bước của training loop trong fastai (hàm fit của Learner) sẽ có 1 đoạn code gọi tới hàm callback.

Fastai training loop với callback

Dưới đây là một ví dụ cực đơn giản về cách tạo và sử dụng callback:

from fastai.test_utils import synth_learner
from fastai.callback.core import Callback # 
class CountParamCallback(Callback): def before_fit(self): print("Num param:", self.count_parameters(self.learn.model)) def count_parameters(self, model): return sum(p.numel() for p in model.parameters() if p.requires_grad) # khởi tạo learner
learn = synth_learner(cbs=[CountParamCallback()])
with learn.no_mbar(): learn.fit(2)

Trong ví dụ trên mình đã tạo một callback đơn giản để đếm số parameter của mô hình trước khi train. Để định nghĩa một callback, ta kế thừa class Callback và định nghĩa một số method đặc biệt :

class CountParamCallback(Callback): def before_fit(self): print("Num param:", self.count_parameters(self.learn.model))

Tên method cũng khá là dễ hiểu: before_fit nghĩa là trước khi sự kiện fit (train mô hình) thì hãy làm các việc trong method này. Để sử dụng callback thì khi khởi tạo Learner, chúng ta chỉ cần set tham số cbs là một list các callback cần thiết:

learn = synth_learner(cbs=[CountParamCallback(), ...])
#hoặc 
learn = Learner(dls, model, cbs=[...])

Về cơ bản thì hệ thống callback trong fastai cho phép ta truy cập và sửa đổi tất cả mọi thứ trong quá trình huấn luyện mô hình (dữ liệu , optimizer, learning rate, ...), một trong những tác giả của thư viện đã gọi đây là "infinitely customizable training loop".

Các sự kiện trong fastai training loop

Mọi điều chỉnh đối với training loop đều được thực hiện thông qua Callback với các method có tên tương ứng với các sự kiện trong training loop. Ta cũng có thể dễ dàng kết hợp các kỹ thuật khác nhau được định nghĩa trong các callback khác nhau. Một callback có thể implement các sự kiện sau:

  • after_create: gọi sau khi khởi tạo Learner
  • before_fit: gọi trước khi bắt đầu training hoặc inference
  • before_epoch: gọi ở đầu mỗi epoch, hữu ích khi cần reset trạng thái nào đó sau mỗi epoch
  • before_train: gọi trước khi bắt đầu quá trình train của mỗi epoch
  • before_batch: gọi ở đầu mỗi batch, sau khi lấy batch ra từ data loader. Có thể dùng để thay đổi input trước khi đi qua mô hình (data augmentation chẳng hạn).
  • after_pred: gọi sau khi gọi phương thức forward của mô hình. Có thể dùng để thay đổi output trước khi cho qua hàm loss (reshape, ...)
  • after_loss: gọi sau khi tính loss nhưng trước khi gọi backward. Có thể dùng để thêm regularization cho loss (L2, L1, ...)
  • before_backward: gọi sau khi tính loss
  • after_backward: gọi sau khi gọi backward của hàm loss, nhưng trước khi update tham số mô hình.
  • before_step: tương tự after_backward nhưng trong docs khuyến khích dùng cái này thay vì after_backward. Có thể dùng để cập nhật lại gradient (gradient clipping, ...)
  • after_step: gọi sau khi cập nhật tham số mô hình (opimizer.step()) và trước khi gọi optimizer.zero_grad()
  • after_batch: gọi ở cuối mỗi batch
  • after_train: gọi ở cuối mỗi epoch
  • before_validate: gọi ở đầu quá trình validation của mỗi epoch
  • after_validate: gọi ở cuối quá trình validation của mỗi epoch
  • after_epoch: gọi ở cuối mỗi epoch
  • after_fit: gọi ở cuối quá trình training

Các attribute có thể truy cập trong callback

Khi viết callback, ta có thể truy cập một số attribute của class Learner. Sử dụng bằng cách viết: self.learn.attr ( thay attr bằng attribute tương ứng.

  • model: mô hình hiện dùng để train hoặc validate
  • dls: object DataLoaders
  • loss_func: hàm loss truyền vào khi khởi tạo Learner
  • opt: object optimizer
  • cbs: danh sách tất cả callback
  • dl: dataloader hiện đang sử dụng (train hoặc val dataloader)
  • x/xb: input của mô hình lấy từ dl. Chỉ có thể assign giá trị cho attribute xb
  • y/yb: output của mô hình lấy từ dl. Chỉ có thể assign giá trị cho attribute yb
  • pred: prediction của model
  • loss_grad: giá trị hàm loss
  • loss: bản copy của loss_grad. Dùng cho logging
  • n_epoch: số epoch
  • n_iter: độ dài của dl
  • epoch: epoch hiện tại (từ 0 - n_epoch-1)
  • iter: index hiện tại của dl (từ 0 - n_iter - 1)

Một số callback có sẵn trong fastai

Gradient clipping

from fastai.test_utils import synth_learner
from fastai.callback.training import GradientClip learn = synth_learner()
learn.fit(3, cbs=[GradientClip])

Mix Precision training

Chắc sẽ hữu ích cho bạn nào máy ít VRAM

from fastai.test_utils import synth_learner
from fastai.callback.fp16 import MixedPrecision fp16 = MixedPrecision()
learn = synth_learner(lr=1.1)
learn.fit(3, cbs=[fp16])

Một cách khác để dùng mix precision trong fastai

learn = synth_learner()
learn.to_fp16()

Weights & Biases

W&B là một công cụ dùng để visualize và theo dõi các thí nghiệm học máy. Chỉ cần thêm lệnh khởi tạo callback bạn có thể log tất tần tật về mô hình của bạn lên W&B

from fastai.callback.wandb import * wb = WandbCallback( log_preds=True, log_model=True, log_dataset=True
)
# To log only during one training phase
learn.fit(..., cbs=[wb]) # To log continuously for all training phases
learn = learner(..., cbs=[wb])

Ngoài ra, còn có rất nhiều callback hữu ích khác. Các bạn có thể tham khảo trong documentation: https://docs.fast.ai/

Bonus: GAN training loop

Phần này không có sẵn trong fastai nhưng mình thấy khá hay nên giới thiệu ở đây. Hệ thống callback trong fastai khá là linh hoạt nên ta có thể tận dụng nó để implement các training loop phức tạp hơn một chút, điển hình là GAN. Trong repo https://github.com/tmabraham/UPIT, tác giả đã viết một callback chỉ để train phần discriminator của mạng GAN (cụ thể là CycleGAN) và để training loop bình thường lo việc train phần generator.

class CycleGANTrainer(Callback): """`Learner` Callback for training a CycleGAN model.""" run_before = Recorder def before_train(self, **kwargs): self.crit = self.learn.loss_func.crit if not getattr(self,'opt_G',None): self.opt_G = self.learn.opt_func(self.learn.splitter(nn.Sequential(*flatten_model(self.G_A), *flatten_model(self.G_B))), self.learn.lr) else: self.opt_G.hypers = self.learn.opt.hypers if not getattr(self, 'opt_D',None): self.opt_D = self.learn.opt_func(self.learn.splitter(nn.Sequential(*flatten_model(self.D_A), *flatten_model(self.D_B))), self.learn.lr) else: self.opt_D.hypers = self.learn.opt.hypers self.learn.opt = self.opt_G def before_batch(self, **kwargs): self._set_trainable() self._training = self.learn.model.training self.learn.xb = (self.learn.xb[0],self.learn.yb[0]), self.learn.loss_func.set_input(*self.learn.xb) def after_step(self): self.opt_D.hypers = self.learn.opt.hypers def after_batch(self, **kwargs): "Discriminator training loop" if self._training: # Obtain images fake_A, fake_B = TensorBase(self.learn.pred[0].detach()), TensorBase(self.learn.pred[1].detach()) (real_A, real_B), =self.learn.xb real_A, real_B = TensorBase(real_A), TensorBase(real_B) self._set_trainable(disc=True) # D_A loss calc. and backpropagation loss_D_A = 0.5 * (self.crit(self.D_A(real_A), 1) + self.crit(self.D_A(fake_A), 0)) loss_D_A.backward() self.learn.loss_func.D_A_loss = loss_D_A.detach().cpu() # D_B loss calc. and backpropagation loss_D_B = 0.5 * (self.crit(self.D_B(real_B), 1) + self.crit(self.D_B(fake_B), 0)) loss_D_B.backward() self.learn.loss_func.D_B_loss = loss_D_B.detach().cpu() # Optimizer stepping (update D_A and D_B) self.opt_D.step() self.opt_D.zero_grad() self._set_trainable()

Reference

Bình luận

Bài viết tương tự

- vừa được xem lúc

Tìm hiểu về YOLO trong bài toán real-time object detection

1.Yolo là gì. . Họ các mô hình RCNN ( Region-Based Convolutional Neural Networks) để giải quyết các bài toán về định vị và nhận diện vật thể.

0 0 284

- vừa được xem lúc

Xây dựng hệ thống Real-time Multi-person Tracking với YOLOv3 và DeepSORT

Trong bài này chúng ta sẽ xây dựng một hệ thống sử dụng YOLOv3 kết hợp với DeepSORT để tracking được các đối tượng trên camera, YOLO là một thuật toán deep learning ra đời vào tháng 5 năm 2016 và nó nhanh chóng trở nên phổ biến vì nó quá nhanh so với thuật toán deep learning trước đó, sử dụng YOLO t

0 0 316

- vừa được xem lúc

Sơ lược về bài toán Person Re-identification

Với những công nghệ hiện đại của thế kỷ 21 chúng ta đã có những phần cứng cũng như phần mềm mạnh mẽ để giải quyết những vấn đề và bài toán nan giải như face recognition, object detection, NLP,... Một trong những vấn đề nan giải cũng được mọi người chú ý ko kém những chủ đề trên là Object Tracking, v

0 0 63

- vừa được xem lúc

Bacteria classification bằng thư viện fastai

Giới thiệu. fastai là 1 thư viện deep learning hiện đại, cung cấp API bậc cao để giúp các lập trình viên AI cài đặt các mô hình deep learning cho các bài toán như classification, segmentation... và nhanh chóng đạt được kết quả tốt chỉ bằng vài dòng code. Bên cạnh đó, nhờ được phát triển trên nền tản

0 0 37

- vừa được xem lúc

Xác định ý định câu hỏi trong hệ thống hỏi đáp

Mục tiêu bài viết. Phân tích câu hỏi là pha đầu tiên trong kiến trúc chung của một hệ thống hỏi đáp, có nhiệm vụ tìm ra các thông tin cần thiết làm đầu vào cho quá trình xử lý của các pha sau (trích c

0 0 94

- vừa được xem lúc

Epoch, Batch size và Iterations

Khi mới học Machine Learning và sau này là Deep Learning chúng ta gặp phải các khái niệm như Epoch, Batch size và Iterations. Để khỏi nhầm lẫn mình xin chia sẻ với các bạn sự khác nhau giữa các khái n

0 0 45