- vừa được xem lúc

SAM: Giải thuật tối ưu đang dần được ứng dụng rộng rãi

0 0 16

Người đăng: Nguyen Tung Thanh

Theo Viblo Asia

Trong quá trình training, chúng ta thường chỉ quan tâm đến giá trị training loss mà không quan tâm đến độ dốc (sharpness) của đồ thị loss quanh điểm đó. Mối quan hệ giữa hình dạng của đồ thị loss và tính tổng quát hoá (generalization) của mô hình đã được nghiên cứu trong các nghiên cứu trước. Một nghiên cứu thực nghiệm với hơn 40 phương pháp đo phức tạp chỉ ra rằng: phép đo sharpness-based có mối tương quan cao nhất (so với những phép đo còn lại) với tính tổng quát hoá. Từ đó, nhiều nghiên cứu tối ưu mô hình có tính đến dộ dốc của loss được công bố. Tuy nhiên, những phương pháp này hoặc chưa hiệu quả về mặt tính toán (efficient) hoặc chưa đem lại hiệu quả cao hoặc cả hai. SAM có lẽ là phương pháp sharpness-based đầu tiên đề xuất được cách áp dụng đem lại hiệu quả rõ rệt với chi phí tính toán tăng thêm chấp nhận được (so với độ chính xác tăng thêm thì mình sẵn sàng đánh đổi). Kể từ khi SAM ra mắt năm 2020, nhiều nghiên cứu cải thiện SAM về cả mặt performance và effciency.

Hình 1: bên trái là đồ thị loss mà các minima của nó có độ dốc cao. Bên phải là độ thị loss mà minima phẳng hơn

Tóm tắt một số đặc điểm của SAM:

  • SAM là một objective function.
  • SAM hướng tới tìm trọng số vừa thoả mãn 2 điều kiện
    • có loss trên tập train nhỏ (như objective thông thường)
    • loss của tất cả các trọng số hàng xóm gần đó đều phải nhỏ (W+ϵW + \epsilon, với WW là giá trị trọng số của mô hình và ϵ2ρ||\epsilon||_2 \leq \rho).
  • Ưu điểm của SAM:
    • Cải thiện tính tổng quát hoá của mô hình từ đó cho kết quả tốt hơn trên tập test.
    • Robust với label noise
    • Tái hiện kết quả training tốt hơn (reproducible)
    • Attention map có tính interpretable cao
    • Dễ implement, và chi phí tính toán hiệu quả (so với những phương pháp sharpness-based khác)
  • Nhược điểm:
    • Thời gian training lâu gấp đôi ( so với phương pháp không dùng SAM)

Cơ sở của SAM

Bảng 1: Kết quả train mô hình CNN trên tập CIFAR10

Không phải tất cả các minima có giá trị loss trên tập train bằng nhau đều lại kết quả trên tập test tương đương nhau. Bảng 1 cho thấy kết quả huẩn luyện mô hình CNN trên tập CIFAR10 với các batch size khác nhau. Trong cả 4 trường hợp train loss đều xấp xỉ bằng 0 và train accuracy đều là 100%. Tuy nhiên, test accuracy ở các trường hợp lại có sự khác nhau rõ rệt. Như vậy, chỉ dựa vào loss để đánh giá một mô hình được huấn luyện tốt hay chưa là chưa đủ. Một mô hình được cho kết quả rất tốt trên tập train có thể cho kết quả rất tệ trên tập test. Trong trường hợp đó, mô hình có tính tổng quát hoá không tốt.

Sự kết nối giữa hình dạng của đồ thị loss và tính tổng quát hoá của mô hình đã được nghiên cứu rộng rãi cả về mặt lý thuyết và thực nghiệm. Cụ thể, những minima có hình dạng loss phẳng hơn (flatness) sẽ những tổng quát hoá hơn minima có độ dốc lớn (sharpness). Như ở hình 1, ta có thể dự đoán rằng minima ở hình bên phải sẽ có tính tổng quát hoá tốt hơn so với mô hình ở bên trái. Tận dụng trên sự liên kết giữa hình dạng đồ thị loss và tính tổng quát hoá của mô hình, SAM cải thiện tính tổng quát hoá của mô hình bằng cách tối ưu đồng thời giá trị loss và độ dốc của loss.

Giải thuật SAM

Trong phần này mình sẽ trình bày giải thuật tối ưu SAM.

Hình 2: Minh hoạ giả thuật tối ưu SAM.

Hình 2 mình hoạ giải thuật gradient descent thông thường (từ WtW_t sang Wt+1W_{t+1}) và SAM (từ WtW_t sang Wt+1SAM)W^{SAM}_{t+1}). Giải thuật gradient descent thông thường được sẽ update trọng số WtW_t theo ngược chiều gradient bằng cách trừ tích của gradient với giá trị learning rate η\eta. Giải thuật SAM trước tiên tính WadvW_{adv} (adversarial) bằng cách cộng WtW_t với ρL(Wt)2L(Wt)\frac{\rho}{||\nabla L(W_t)||_2} \nabla L(W_t) (là gradient được scale norm theo ρ\rho). Mục đích tính WadvW_{adv} là vì SAM kỳ vọng giá trị loss tại giá trị này sẽ có giá trị gần với giá trị loss lớn nhất xung quanh WtW_t. Sau đó, ta tính gradient tại WadvW_{adv} sau đó apply gradient này tại W_t. Với các bước thực hiện như vậy, SAM hướng tới tìm W vừa có loss tại đó nhỏ và loss tại những giá trị W xung quanh cũng nhỏ.

Một câu hỏi mà một số bạn hỏi là tại sao SAM update gradient tại WtW_t mà không phải tại WadvW_{adv}. Để trả lời, ta hãy nhìn vào công thức đạo hàm của hàm loss SAM được ghi trong paper:

Gradient của hàm loss SAM tại W được tính xấp xỉ thông qua gradient của hàm loss thông thường tại W+ϵ^(W)W + \hat{\epsilon}(W) (được estimate bằng WadvW_{adv}). Vậy nên gradient được apply vào WtW_t chứ không phải WadvW_{adv}

SAM Pytorch

Tác giả implement SAM trên Jax. Mọi người cũng có thể sử dụng SAM (và cả ASAM, một phiên bản cải tiến về độ hiệu quả của SAM) được implement trên Pytorch (không official) ở link sau: https://github.com/davda54/sam. Trong phần README của repo đã hướng dẫn chi tiết cách sử dụng SAM vào trong project hiện tại của bạn. Dễ thấy, thời gian training với SAM sẽ lâu hơn gấp đôi so với baseline so phải forward và backward hai lần. Đây cũng chính là nhược điểm lớn nhất của SAM.

Hiệu quả

Hình 3: Kết quả evaluate trên tập train và tập test. Màu cam: baseline train với SGD không dùng SAM. Màu xanh: SAM+SGD. Màu tím ASAM+SGD

Để xác thực tính hiệu quả của giải thuật SAM. Mình đã áp dụng SAM và ASAM vào bài toán polyp segmentation. Hình 3 thể hiện kết quả đánh giá mô hình huấn luyện trên cả tập train và 5 tập test. Cả mô hình huấn luyện với SAM và ASAM đều cho độ chính xác thấp hơn trên tập train so với baseline chỉ dùng SGD. Tuy nhiên, cả hai mô hình này đều tất hơn baseline trên 4/5 tập. Cho thấy hiệu quả vượt trội của SAM so với baseline. Một điều nữa là mỗi thí nghiệm đều được thực hiện 5 lần, và mô hình train với SAM và ASAM có variance thấp hơn hẳn so với baseline. Đều này giống với tính chất mà tác giả nêu ra là tính reproducibility của mô hình train với SAM sẽ được cải thiện.

Các nghiên cứu sau SAM

Kể từ sau khi SAM ra mắt năm 2020. Nhiều nghiên cứu theo hướng sharpness-based đã xuất hiện với mục cải thiện về cả độ chính xác và nhược điểm về thời gian training. Một số trong những nghiên cứu đó có thể kể đến:

  • ASAM: một phiên bản adaptive của SAM
  • ESAM: phiên bản hiệu quả về mặt tính toán hơn của SAM
  • LookSAM: SAM nhanh hơn cho huấn luyện mô hình Vision Transformer
  • SAF: một phiên bản cải thiện về thời gian training của SAM được cho là gần như không lâu so với không dùng SAM.

Kết luận

SAM có lẽ là một trong số ít những paper gần đây mình biết mang liệu hiệu quả rõ rệt và có thể được áp dụng rộng rãi trong nhiều bài toán. Trong bài viết này mình đã trình bày về cách hoạt động cũng như cách áp dụng và kết quả của SAM. Một số phần quan trọng không được nhắc đến trong bài này lý thuyết, chứng minh và biến đổi các bạn có thể xem thêm trong paper. Team mình đã áp dụng SAM trong một số dự án và đều cải thiện kết quả đáng kể.

Hy vọng việc áp dụng SAM cũng sẽ đem lại kết quả tương tự với các bạn. Mình rất mong được biết kết quả của việc áp dụng SAM vào trong bài toán của các bạn. Cảm ơn các bạn đã đọc bài, nếu thấy hữu ích hãy cho mình 1 upvote nhé.

Tham khảo:

https://arxiv.org/abs/2010.01412 https://arxiv.org/abs/2106.01548 https://www.youtube.com/watch?v=QBiLph-r5Hw&t=2808s https://github.com/davda54/sam

Bình luận

Bài viết tương tự

- vừa được xem lúc

Tấn công và phòng thủ bậc nhất cực mạnh cho các mô hình học máy

tấn công bậc nhất cực mạnh = universal first-order adversary. Update: Bleeding edge của CleverHans đã lên từ 3.1.0 đến 4.

0 0 42

- vừa được xem lúc

[Deep Learning] Key Information Extraction from document using Graph Convolution Network - Bài toán trích rút thông tin từ hóa đơn với Graph Convolution Network

Các nội dung sẽ được đề cập trong bài blog lần này. . Tổng quan về GNN, GCN. Bài toán Key Information Extraction, trích rút thông tin trong văn bản từ ảnh.

0 0 219

- vừa được xem lúc

Trích xuất thông tin bảng biểu cực đơn giản với OpenCV

Trong thời điểm nhà nước đang thúc đẩy mạnh mẽ quá trình chuyển đổi số như hiện nay, Document Understanding nói chung cũng như Table Extraction nói riêng đang trở thành một trong những lĩnh vực được quan tâm phát triển và chú trọng hàng đầu. Vậy Table Extraction là gì? Document Understanding là cái

0 0 230

- vừa được xem lúc

Con đường AI của tôi

Gần đây, khá nhiều bạn nhắn tin hỏi mình những câu hỏi đại loại như: có nên học AI, bắt đầu học AI như nào, làm sao tự học cho đúng, cho nhanh, học không bị nản, lộ trình học AI như nào... Sau nhiều lần trả lời, mình nghĩ rằng nên viết hẳn một bài để trả lời chi tiết hơn, cũng như để các bạn sau này

0 0 157

- vừa được xem lúc

[B5'] Smooth Adversarial Training

Đây là một bài trong series Báo khoa học trong vòng 5 phút. Được viết bởi Xie et. al, John Hopkins University, trong khi đang intern tại Google. Hiện vẫn là preprint do bị reject tại ICLR 2021.

0 0 45

- vừa được xem lúc

Deep Learning với Java - Tại sao không?

Muốn tìm hiểu về Machine Learning / Deep Learning nhưng với background là Java thì sẽ như thế nào và bắt đầu từ đâu? Để tìm được câu trả lời, hãy đọc bài viết này - có thể kỹ năng Java vốn có sẽ giúp bạn có những chuyến phiêu lưu thú vị. DJL là tên viết tắt của Deep Java Library - một thư viện mã ng

0 0 139