- vừa được xem lúc

CycleGAN cho bài toán image-to-image translation

0 0 27

Người đăng: Hieu Bui

Theo Viblo Asia

Introduction

Image-to-image translation là một lớp bài toán computer vision mà mục tiêu là học một ánh xạ giữa ảnh input và ảnh output. Bài toán này có thể áp dụng vào một số lĩnh vực như style transfer, tô màu ảnh, làm nét ảnh, sinh dữ liệu cho segmentation, face filter,...

Thông thường để huấn luyện một mô hình Image-to-image translation, ta sẽ cần một lượng lớn các cặp ảnh input và label. Ví dụ như: ảnh màu và ảnh grayscale tương ứng với nó, ảnh mờ và ảnh đã được làm nét, ....Tuy nhiên, việc chuấn bị dataset theo kiểu này có thể khá tốn kém trong một số trường hợp như: style transfer ảnh từ mùa hè sang mùa đông (kiếm được ảnh phong cảnh trong các điều kiện khác nhau), biến ngựa thường thành ngựa vằn (khó mà kiếm được ảnh của 1 con ngựa thường và ảnh của nó nhưng là ngựa vằn ?).

Do các bộ dataset theo cặp gần như là không tồn tại nên mới nảy sinh như cầu phát triển một mô hình có khả năng học từ dữ liệu unpaired. Cụ thể hơn là có thể sử dụng bất kỳ hai tập ảnh không liên quan và các đặc trưng chung được trích xuất từ mỗi bộ sưu tập và sử dụng trong quá trình image translation. Đây được gọi là bài toán unpaired image-to-image translation.

Một cách tiếp cận thành công cho unpaired image-to-image translation là CycleGAN.

CycleGAN architecture

CycleGAN được thiết kế dựa trên Generative Adversarial Network (GAN). Kiến trúc GAN là một cách tiếp cận để huấn luyện một mô hình sinh ảnh bao gồm hai mạng neural: một mạng generator và một mạng discriminator. Generator sử dụng một vector ngẫu nhiên lấy từ latent space làm đầu vào và tạo ra hình ảnh mới và Discriminator lấy một bức ảnh làm đầu vào và dự đoán xem nó là thật (lấy từ dataset) hay giả (được tạo ra bởi generator). Cả hai mô hình sẽ thi đấu với nhau, Generator sẽ được huấn luyện để sinh ảnh có thể đánh lừa Discriminator và Discriminator sẽ được huấn luyện để phân biệt tốt hơn hình ảnh được tạo.

CycleGAN là một mở rộng của kiến trúc GAN cổ điển bao gồm 2 Generator và 2 Discriminator. Generator đầu tiên gọi là G, nhận đầu vào là ảnh từ domain X (ngựa vằn) và convert nó sang domain Y (ngựa thường). Generator còn lại gọi là Y, có nhiệm vụ convert ảnh từ domain Y sang X. Mỗi mạng Generator có 1 Discriminator tương ứng với nó

  • DYD_Y: phân biệt ảnh lấy từ domain Y và ảnh được translate G(x).
  • DXD_X: phân biệt ảnh lấy từ domain X và ảnh được translate F(y).

Generator

Generator của CycleGAN dựa trên được lấy từ paper này, bao gồm 3 thành phần: encoder, transformer và decoder

Phần encoder bao gồm 3 lớp tích chập, 2 lớp sau có stride = 2 để làm giảm kích thước đầu vào của ảnh và tăng số channel. Output của encoder được sử dụng làm đầu vào cho transformer bao gồm 6 khối residual như trong resnet. Lớp batch normalization trong khối residual được thay bằng instance normalization. Cuối cùng phần decoder bao gồm 3 lớp transposed convolution sẽ biến đổi ảnh về kích thước ban đầu và số channel phụ thuộc vào domain đầu ra.

Discriminator

Discriminator sử dụng kiến trúc PatchGAN. Thông thường trong bài toán classification, output của mạng sẽ là một giá trị scalar - xác suất thuộc class nào đó. Trong mô hình CycleGAN, tác giả thiết kế Discriminator sao cho output của nó là một feature map N×N×1N\times N\times1. Có thể xem là Discriminator sẽ chia ảnh đầu vào thành 1 lưới N×NN \times N và giá trị tại mỗi vùng trên lưới sẽ là xác suất để vùng tương ứng trên ảnh là thật hay giả.

Loss function

Adversarial loss

Trong quá trình huấn luyện, generator G cố gắng tối thiểu hóa hàm adversarial loss bằng cách translate ra ảnh G(x) (với x là ảnh lấy từ domain X) sao cho giống với ảnh từ domain Y nhất, ngược lại Discriminator DYD_Y cố gắng cực đại hàm adversarial loss bằng cách phân biệt ảnh G(x) và ảnh thật y từ domain

Ladv(G,DY,X,Y)=1n[logDY(y)]+1n[log(1DY(G(x))]L_{adv}(G, D_Y, X, Y) = \frac{1}{n}[ logD_{Y}(y)] + \frac{1}{n}[log(1- D_Y(G(x))]

Adversarial loss được áp dụng tương tự đối với generator F và Discriminator

Ladv(F,DX,Y,X)=1n[logDX(x)]+1n[log(1DX(F(y))]L_{adv}(F, D_X, Y, X ) = \frac{1}{n}[ logD_{X}(x)] + \frac{1}{n}[log(1- D_X(F(y))]

Cycle consistency loss

Chỉ riêng adversarial loss là không đủ để mô hình cho ra kết quả tốt. Nó sẽ lai generator theo hướng tạo ra được ảnh output bất kỳ trong domain mục tiêu chứ không phải output mong muốn. Ví dụ với bài toán biến ngựa vằn thành ngựa thường, generator có thể biến con ngựa vằn thành 1 con ngựa thường rất đẹp nhưng lại không có đặc điểm nào liên quan tới con ngựa vằn ban đầu.

Để giải quyết vấn đề này, cycle consistency loss được giới thiệu. Trong paper, tác giả cho rằng nếu ảnh x từ domain X được translate sang domain Y và sau đó translate ngược lại về domain Y lần lượt bằng 2 generator G, F thì ta sẽ được ảnh x ban đầu: xG(x)F(G(x))xx\rightarrow G(x) \rightarrow F(G(x)) \approx x

Lcycle(G,F)=1nF(G(xi))xi+G(F(yi))yiL_{cycle}(G, F) = \frac{1}{n}\sum|F(G(x_i)) - x_i|+|G(F(y_i)) - y_i|

Full loss

L=Ladv(G,DY,X,Y)+Ladv(F,DX,Y,X)+λLcycle(G,F)L = L_{adv}(G, D_Y, X, Y) + L_{adv}(F, D_X, Y, X) + \lambda L_{cycle}(G, F)

trong đó λ\lambda là siêu tham số và được chọn là 10.

Một số kết quả

Style transfer tranh vẽ sang ảnh chụp

Ngựa vằn sang ngựa thường

Táo thành cam

Mặt người thành búp bê

References

Bình luận

Bài viết tương tự

- vừa được xem lúc

Tấn công và phòng thủ bậc nhất cực mạnh cho các mô hình học máy

tấn công bậc nhất cực mạnh = universal first-order adversary. Update: Bleeding edge của CleverHans đã lên từ 3.1.0 đến 4.

0 0 28

- vừa được xem lúc

[Deep Learning] Key Information Extraction from document using Graph Convolution Network - Bài toán trích rút thông tin từ hóa đơn với Graph Convolution Network

Các nội dung sẽ được đề cập trong bài blog lần này. . Tổng quan về GNN, GCN. Bài toán Key Information Extraction, trích rút thông tin trong văn bản từ ảnh.

0 0 198

- vừa được xem lúc

Trích xuất thông tin bảng biểu cực đơn giản với OpenCV

Trong thời điểm nhà nước đang thúc đẩy mạnh mẽ quá trình chuyển đổi số như hiện nay, Document Understanding nói chung cũng như Table Extraction nói riêng đang trở thành một trong những lĩnh vực được quan tâm phát triển và chú trọng hàng đầu. Vậy Table Extraction là gì? Document Understanding là cái

0 0 213

- vừa được xem lúc

Con đường AI của tôi

Gần đây, khá nhiều bạn nhắn tin hỏi mình những câu hỏi đại loại như: có nên học AI, bắt đầu học AI như nào, làm sao tự học cho đúng, cho nhanh, học không bị nản, lộ trình học AI như nào... Sau nhiều lần trả lời, mình nghĩ rằng nên viết hẳn một bài để trả lời chi tiết hơn, cũng như để các bạn sau này

0 0 134

- vừa được xem lúc

[B5'] Smooth Adversarial Training

Đây là một bài trong series Báo khoa học trong vòng 5 phút. Được viết bởi Xie et. al, John Hopkins University, trong khi đang intern tại Google. Hiện vẫn là preprint do bị reject tại ICLR 2021.

0 0 32

- vừa được xem lúc

Deep Learning với Java - Tại sao không?

Muốn tìm hiểu về Machine Learning / Deep Learning nhưng với background là Java thì sẽ như thế nào và bắt đầu từ đâu? Để tìm được câu trả lời, hãy đọc bài viết này - có thể kỹ năng Java vốn có sẽ giúp bạn có những chuyến phiêu lưu thú vị. DJL là tên viết tắt của Deep Java Library - một thư viện mã ng

0 0 122