- vừa được xem lúc

Finetuning Stable Diffusion hiệu quả với LoRA

0 0 13

Người đăng: Bách Lê

Theo Viblo Asia

Giới thiệu

Có thể nói image generation (gen ảnh/ sinh ảnh) là vua của mọi nghề!

Hiện nay, image generation ở Việt Nam mới chỉ ở giai đoạn khởi đầu. Tại sao lại nói đây là vua của mọi nghề? Vì vừa có tiền, vừa có quyền, vừa kiếm được nhiều $ lại vừa được xã hội trọng vọng.

Và đi cùng với sự phát triển của generative AI nói chung và image generation nói riêng, Stable Diffusion đóng vai trò không hề nhỏ trong việc đưa công nghệ đến gần hơn với cộng đồng. Gọi khả năng ứng dụng của Stable Diffusion là cái ghế, vì nó không phải bàn. Bạn có thể dùng Stable Diffusion để tạo app alime hóa ảnh như Loopsie, làm phim hoạt hình như Artisto, gen ảnh để bỏ vào slides mai lên lớp thuyết trình, tạo ảnh quảng cáo cho sản phẩm của công ty hoặc ướm thử một bộ quần áo trên mạng nào đó lên ảnh bản thân xem có hợp không vì giá nó bằng 30 bát phở lận.

Nếu là người quan tâm đến Stable Diffusion hay các model gen ảnh, chắc hẳn các bạn đã không còn cảm thấy xa lạ với LoRA. Cho dù sinh sau đẻ muộn, LoRA đã bỏ xa cả Dreambooth và Textual Inversion trong cuộc đua chiếm lấy tình yêu của cộng đồng gen ảnh, thống trị các nền tảng train model gen ảnh với sự vượt trội trong thời gian train và kích thước checkpoint. Biểu đồ dưới từ Google Trends là minh chứng rõ ràng nhất cho sự áp đảo của LoRA (đường màu xanh) so với Dreambooth (đường màu đỏ) và Textual Inversion (đường màu vàng).

Vậy điều gì đã giúp cho LoRA vượt xa hai đồng nghiệp và trở nên được yêu thích đến vậy? Hãy cùng mình khám phá, tìm hiểu về ý tưởng và những gì LoRA có thể đem lại nhé.

Bối cảnh ra đời

Khác với nhiều người tưởng rằng LoRA là phương pháp sinh ra để tối ưu quá trình training Stable Diffusion, LoRa hay Low-Rank Adaptation vốn bắt nguồn từ những cố gắng train các mô hình ngôn ngữ lớn (Large Language Model - LLM).

Ngày nay, fine-tuning đang dần trở nên phổ biến như một phương pháp hiệu quả để train các mô hình deep learning với ít data nhưng vẫn cho ra kết quả tốt. Do chi phí và tài nguyên bỏ ra tỷ lệ thuận với độ lớn của mô hình, có thể nói các công ty ngày càng phụ thuộc vào pretrained weight. Tuy nhiên, khác với một số kiến trúc đặc trưng như Convolution networks, Attention mechanism (linh hồn của LLMs) không cho phép các mô hình ngôn ngữ freeze toàn bộ layer trước và chỉ train một vài layer sau cùng. Chính vì thế, kể cả khi có thể tận dụng được các checkpoint đã có, việc training các mô hình lớn vẫn là điều gì đó xa xỉ khi mà chúng ta vẫn phải train từng ấy parameters cho các downstream task. Thậm chí đối với các công ty lớn, việc retrain lại mô hình sau một khoảng thời gian để cập nhật kiến thức cho model cũng gặp phải rào cản chi phí và vấn đề lưu trữ.

Trước khi LoRA ra đời, AdapterPrefix tuning cũng đã cố gắng cản thiện các vấn đề trên. Tuy nhiên, Adapter làm tăng độ trễ (latency) khi infer model, trong khi Prefix tuning lại khiến cho model khó tối ưu hơn, hai giải pháp này không phải những phương án tốt trong việc training LLM. Để đọc thêm về các phương pháp finetuning LLM hiệu quả, các bạn có thể tìm đọc thêm về Parameter Effcient Fine-Tuning (PEFT).

Vậy LoRA hoạt động thế nào?

Cho bạn nào chưa biết, LLM - đối tượng chính của LoRA - hoạt động chủ yếu dựa trên Attention mechanism. Đây là kỹ thuật để các mô hình ngôn ngữ có thể tập trung vào các phần khác nhau của câu để hiểu được đúng ý nghĩa của cả câu. Chúng ta sẽ đi sâu hơn vào Cross attention:

Trong Cross attention mechanism, có ba ma trận WVW_V,WKW_KWQW_Q được dùng để biến các input thành Value, Key và Query. Tensor màu xám đi qua ma trận WVW_VWKW_K trở thành Value (V) và Key (K) mang thông tin của câu, trong khi ma trận màu xanh đi qua WQW_Q trở thành Query (Q), nói cho model biết rằng nên tập trung vào phần nào của câu và liên hệ giữa các từ. Ba ma trận trên là các weight matrix có size lớn nhất (tensor trong transformer thường có size 768x64, vậy nên mỗi ma trận sẽ có size 768x768 - xấp xỉ 600,000 tham số). Việc của LoRA là làm thế nào đó để có thể lưu ít tham số hơn nhưng vẫn giữ được performance của model.

Đúng như cái tên - Low-rank Adaptation, phương pháp LoRA ra đời dựa trên phép phân rã ma trận (matrix decomposition). Về bản chất, matrix decomposition là kỹ thuật phân tích một ma trận thành tích của hai (hoặc nhiều) ma trận. Xét ma trận WW có shape mnm*n, ta có thể tách WW thành hai ma trận AABB với shape của AAmkm*k và shape của BBknk*n (Tất nhiên kk phải nhỏ hơn mmnn rồi =)) nếu không thì lượng thông tin phải lưu của AABB còn lớn hơn WW ban đầu mất). VD một ma trận size 4x4 có thể được phân tích thành tích của ma trận 4x1 và ma trận 1x4. Điều này có ý nghĩa gì ư? Thay vì phải lưu thông tin của 16 số bên trong ma trận 4x4, giờ đây chúng ta chỉ cần lưu 8 số trong hai ma trận 4x1 và 1x4. Tương tự với 768x768, nếu chúng ta tách thành hai ma trận 768x4 và 4x768, chúng ta chỉ cần lưu 6144 tham số - chỉ bằng gần 1% lượng tham số cần lưu ban đầu.

KHOAN ĐÃ! Lừa trẻ con à??=D Xét phép phân rã W=ABW = A*B thì rõ ràng là cần rank(W)<=min(rank(A),rank(B))rank(W) <= min(rank(A), rank(B)). Với việc k<min(m,n)k < min(m, n), t cần rank(W)<krank(W) < k. Vậy thì nếu WW có rank cao hơn kk, việc phân rã WW sẽ làm mất mát thông tin của ma trận, làm sao có thể đảm bảo được performance cơ chứ?

Thực ra, theo quan sát của nghiên cứu từ Aghajanyan et al. (2020), weight của các model thường có rank thấp (low intrinsic rank). Từ đó, các nhà nghiên cứu cho rằng phần được update thêm sau training (gọi là ΔW\Delta W) cũng sẽ có rank thấp. Khi rank(W)rank(W) thấp thì chúng ta có thể có rr nhỏ, giờ thì không có trò lừa nào ở đây nữa nhé :”)

Tận dụng matrix decomposition, các tác giả “chèn” ΔW\Delta W ( =AB=A*B) vào các khối transformer trước khi finetune. Đúng thế, thay vì để WW update bình thường như bao mô hình khác, chúng ta biểu diễn W=W0+W=W0+BAW = W_0 + ∆W = W_0 + BA . Khi chạy input x0x_0 qua WW, thay vì tính x1=Wx0x_1 = Wx_0, ta tính x1=(W0+BA)x0.x_1 = (W_0 + BA)x_0.

Sao phải tính toán cho mất công nhỉ? Bởi vì trước khi finetune, chúng ta freeze toàn bộ các W0W_0, chỉ có các ma trận được chèn vào AABB (thường gọi là các adapter) là được update mà thôi. Như vậy, chúng ta không update - và các GPU cũng không cần lưu - gradients và activations của W0W_0. Nói cách khác, chúng ta chỉ học khoảng 6 nghìn params thay vì 600 nghìn params như bình thường.

LoRA và Stable Diffusion

Ừm,... Nghe thì cũng hay đấy, nhưng mà lấy đâu ra Attention trong Diffusion? Chẳng lẽ chỉ áp LoRA cho đoạn Text encoder thôi ư? Thực ra là không. Attention mechanism - cụ thể là Cross attention - còn được gắn vào các model Unet bên trong Stable Diffusion.

Trong quá trình denoise của Stable Diffusion, Cross attention được sử dụng khéo léo để chọc các thông tin từ prompt vào, hướng cho model gen ra các ảnh như theo yêu cầu. Kéo bài viết lên trên một chút, prompt ở đây sẽ đươc encode và đóng vai trò như Query, trong khi representation của ảnh trong latent space sẽ giống như Value và Key. Ứng dụng của LoRA trong Stable Diffusion chính là ở đây.

Mặc dù không được đề cập trong paper nhưng LoRA lần đầu được ứng dụng trong Stable Diffusion bởi Simo Ryu. Từ đó đến nay, các model LoRA được share liên tục trên CivitAI và Reddit, tạo nên một cộng đồng yêu thích gen ảnh lớn mạnh với vô số style ảnh độc đáo và mới lạ.

Sức mạnh của LoRA

Hiểu được ý tưởng rồi, giờ hãy cùng mình điểm qua những lợi ích mà LoRA có thể đem lại, đặc biệt là trong việc training model Stable Diffusion nhé!

  • Giảm thiểu tài nguyên tính toán (computing resource): Chắc chắn rồi. Sinh ra với sứ mệnh khả thi hóa quá trình training ở điều kiện vừa và thấp, người dùng giờ đây chỉ cần bật google colab thay vì các card đồ họa siêu cấp vip pro như A100 hay RTX4090 (Google colab còn không cần thuê Pro cơ =D).
  • Lưu trữ và chuyển đổi task dễ dàng: Nếu như ngày trước mỗi model Stable Diffusion cần cả chục GB để lưu thì nay chỉ vài MB là đủ để bạn thay đổi hoàn toàn style của bức ảnh. Việc có kích thước file nhỏ không chỉ giúp chúng ta lưu trữ được nhiều model mà còn giúp việc download/upload trở nên dễ dàng hơn đáng kể.
  • Sử dụng được chung với nhiều phương pháp khác: Mặc dù không được để ý nhiều nhưng LoRA có thể được dùng chung với các cách finetuning khác như Prefix tuning hay Dreambooth. Việc sử dụng chung với các phương pháp khác có thể khắc phục một số hạn chế của LoRA.

Kết quả là bây giờ với base model của Stable Diffusion có sẵn trong máy, bạn chỉ cần dạo quanh một vòng Reddit/CiviAI, nheo mắt ngắm xem style nào ưng mắt và tải thêm vài chục MB cho model LoRA. Giờ bạn có thể gen ảnh waifu của mình với đủ loại tư thế, gen ảnh Superman theo phong cách anime, gen ảnh của bản thân theo style Gigachad,… Nghe bánh cuốn đấy nhỉ? Nhưng mà người yêu mình không thích. Người yêu mình thích cún, nên mình sẽ demo thử vài ảnh cún theo nhiều style khác nhau nhé :”)

Tò mò về Hà Nội vào năm 3069? Hãy dùng LoRA.

Ao ước nhìn thấy Emma Watson dưới nét vẽ của Gosho Aoyama? Hãy dùng LoRA.

Mong muốn nhìn thấy thêm các tuyệt tác theo phong cách của cố họa sĩ Picasso? Hãy dùng LoRA.

Với những người làm computer vision, LoRA nói riêng và Stable Diffusion nói chung cũng có thể là công cụ hữu hiệu trong việc generate data, giúp tăng đáng kể accuracy và performance của model. Tất cả những gì chúng ta cần làm là nhặt một vài bức ảnh rồi bật YouTube xem các pháp sư dạy train trên Colab hoặc Local.

Thế LoRA là phương pháp finetune toàn năng rồi phải không?

Không =)))

Tuyệt vời là thế nhưng LoRA cũng có những hạn chế riêng của mình, bao gồm:

  • Không thể tự do chọn các adapter khác nhau cho các sample khác nhau cho cùng một batch. VD như 1 batch của bạn có 10 prompt gen ảnh idol: 2 prompt đầu style anime, 4 prompt sau sinh theo phong cách vẽ của Picasso, 4 prompt cuối cùng sinh theo kiểu truyện tranh Mỹ, thì rất tiếc là không được.
  • Cơ chế hoạt động khoa học của LoRA finetuning vẫn còn chưa rõ ràng.
  • Không phù hợp cho large-scale training: LoRA chỉ mạnh để dạy cho model gen ảnh với style khác nhau. Với các concept phức tạp (một nhóm người, nhiều nhân vật mới, …), sử dụng các phương pháp finetune khác như Dreambooth hoặc finetune thông thường sẽ có kết quả tốt hơn.

Kết

Trí tưởng tượng của chúng ta đôi khi bị ngăn cách với hiện thực bởi khả năng vẽ của bản thân (như mình :"( vẽ cái máy bay không khác gì con cá). Giờ đây với LoRA, vách ngăn ấy đã mẻ đi nhiều phần. Việc còn lại của chúng ta là nhón chân bước qua cái vách ấy và tận dụng sức mạnh của image generation trong các công việc sáng tạo liên quan đến hình ảnh. Mình hy vọng bài viết đã giải thích được cho các bạn cơ chế hoạt động và vai trò của LoRA đối với Stable Diffusion. Nếu có góp ý hay câu hỏi gì cho mình, đừng ngần ngại comment ở phía dưới. Mình sẽ cố gắng phản hồi trong thời gian sớm nhất. Cảm ơn bạn vì đã đọc đến những dòng cuối cùng này :”)

Reference

https://arxiv.org/abs/2106.09685

https://github.com/cloneofsimo/lora

https://huggingface.co/blog/lora

https://stable-diffusion-art.com/lora/

Tìm hiểu về Pixta Vietnam

Cập nhật tin tức mới nhất của Pixta Vietnam 👉 http://bit.ly/3kdkzvW

Bình luận

Bài viết tương tự

- vừa được xem lúc

Các chỉ số đánh giá được sử dụng cho bài toán Image Generation: IS, FID, PSNR, SSIM,...

1. Giới thiệu về bài toán Image Generation.

0 1 98

- vừa được xem lúc

Tìm hiểu về kiến trúc Transformer

Giới thiệu. Với sự ra đời của cơ chế attention thì vào năm 2017 paper Attention is all you need đã giới thiệu một kiến trúc mới dành cho các bài toán NLP mà không có sự xuất hiện của các mạng nơ-ron h

0 0 387

- vừa được xem lúc

Sinh tín hiệu hình sine với mô hình GAN

Giới thiệu. Các ứng dụng về GAN ở domain về ảnh thì vô cùng nhiều nhưng trong domain tín hiệu time-series thì chưa có nhiều.

0 0 41

- vừa được xem lúc

Imbalance Problem in Object Detection

1. Giới thiệu.

0 0 37

- vừa được xem lúc

Từ lý thuyết lượng tử đến Quantum Neural Network

Một số kiến thức cần nắm. Mình khuyến khích mọi người trước khi đọc bài này thì nên tìm hiểu Quantum Computing hoặc đọc bài giới thiệu cơ bản về tính toán lượng tử mà mình đã viết để có thể hiểu rõ hơ

0 0 34

- vừa được xem lúc

TensorFlow in 100 Seconds

TensorFlow is a tool for machine learning capable of building deep neural networks with high-level Python code. It provides developer-friendly APIs that help software engineers train, analyze, and dep

0 0 30