- vừa được xem lúc

Style Transfer - explain and code

0 0 23

Người đăng: Trung Thành Nguyễn

Theo Viblo Asia

Bạn nghĩ sao về một bức ảnh chụp (máy ảnh) Hà Nội nhưng lại mang phong cách tranh thiên tài Picasso. Với sự ra đời của thuật toán Style Transfer, chuyện đó là hoàn toàn có thể."

1. Thuật toán.

Dưới đây là hình minh họa cho thuật toán. Chúng ta có 3 ảnh gồm:

  • input_image: Được khởi tạo random, lúc đầu là ảnh nhiễu bất kì, sau quá trình update, tối ưu thành kết quả ta muốn (output image)
  • content_image: chứa nội dung mà ảnh output_image sẽ chứa.
  • style_image: chứa style (phong cách) mà output_image sẽ chứa sau quá trình update sẽ thành ảnh mong muốn.

Ý tưởng thuật toán rất đơn giản: cả 3 ảnh cùng đưa vào 1 pretrained-CNN để trích xuất ra các feature_map. Những Feature map này là những thông tin đặc trưng, chứa đựng thông tin về nội dung, đường nét, màu sắc của ảnh, hay còn gọi là content_feature. Style của một họa sĩ, 1 bức tranh thực tế chính là mối quan hệ giữa các đường nét, màu sắc trong tranh. Như vậy, bằng một phép biến đổi (gram_matrix), ta có thể tính ra được style_feature dựa trên content_feature. Thuật toán cụ thể như sau:

  • Gọi 3 ảnh input_image, content_image, style_image lần lượt là A, B, C
  • Đưa cả 3 ảnh vào cùng 1 pretrained_CNN để trích xuất ra các feature cho 3 ảnh, lần lượt là: A_content, B_content, C_content. Các feature này chính là feature đặc trưng về nội dung (content)
  • Để biến đổi content_feature sang style_feature, ta dùng 1 thuật toán gọi là hàm GramMatrix()
  • A_style = GramMatrix(A_content), C_style = GramMatrix(C_content)
  • Ta có ContentLoss = MSE(A_content, B_content), StyleLoss = MSE(A_style, C_style). Với MSE là hàm MeanSquareError được dùng phổ biến trong Machine learning.
  • Tính CombineLoss = w1ContentLoss + w2 StyleLoss
  • Tính đạo hàm input_image theo CombineLoss. Update input_image với thuật toán gradient descent. Chú ý, thứ chúng ta tối ưu ở đây là input_image chứ không phải các params của pretrained_model.
  • Quy trình sẽ lặp đi lặp lại nhằm tối ưu giá trị CombineLoss. Ảnh input_image dần được update sao cho ContentLoss và StyleLoss giảm dần. Tức có nội dung giống với content_image, phong cách giống với style_image. Dừng thuật toán khi input_image đủ tốt

1.2 Style_feature

Mình sẽ nói kĩ hơn về Style_feature. Như bạn đã biết, content_feature đã bao hàm thông tin về đường nét, nội dung, hình ảnh của 1 ảnh. style_feature chính là mối quan hệ tương quan giữa các thông tin này với nhau. Trong Đại số, ta có khái niệm GramMatrix. GramMatrix chính là thứ chúng ta cần.

GramMatrix(A)=AATGramMatrix(A) = A * A^T

Phép nhân trong công thức này là phép nhân ma trận (dot product). Thuật toán về cơ bản là thế, bắt tay vào code bạn sẽ dễ hiểu hơn.

2. Thực hành code

2.1 Chuẩn bị

Để tiện cho việc code và train, mình sẽ code trên Google Colab. Hãy download và đọc code ở link sau:

Torch khá tiện khi tính gradient của 1 biến và update biến đó theo các thuật toán optimizer. Vì vậy, mình quyết định dùng Torch thay vì tensorflow. Cũng bài này, 2 năm trước mình code bài này bằng Tensorflow thuần khá vất vả nên mình quyết định đổi sang torch.

2.2 Gõ code ?

Chắc hẳn đây là phần được mong đợi nhất. Mình sẽ tiến hành code.

Trước hết, do cả project bao gồm code và dataset được lưu trên drive, ta cần mount drive vào môi trường ảo của colab. Bước này có thể bỏ qua nếu bạn run trên máy local

## mount drive to notebook (dataset is saved in mydrive)
from google.colab import drive
drive.mount('/content/drive')
%cd '/content/drive/MyDrive/Project by me/Colab/Code in AI Blog/Style_transfer_1'

Import lib and init some variable

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from PIL import Image
from torchvision import transforms as T
from torchvision import models
from matplotlib import pyplot as plt img_size = 512
device = torch.device('cuda')
img_transforms = T.Compose([ T.Resize((img_size, img_size)), T.ToTensor()
])

Một vài hàm phụ trợ:

def load_img(img_fn): global device global img_transforms image = Image.open(img_fn) image = img_transforms(image) # to batch image = image.unsqueeze(0).to(device, torch.float) return image def to_image(img_tensor): image = img_tensor.cpu().clone() image = image.squeeze(0) image = T.ToPILImage()(image) return image def plot_imgs(imgs, cols=3, size=7, title=""): rows = len(imgs)//cols + 1 fig = plt.figure(figsize=(cols*size, rows*size)) for i, img in enumerate(imgs): image = to_image(img) if isinstance(img, torch.Tensor) else img fig.add_subplot(rows, cols, i+1) plt.imshow(image) plt.suptitle(title) plt.show()

Các function và module quan trọng, bao gồm: ContentLoss, StyleLoss, GramMatrix, Normalization

def compute_gram_matrix(feature_map): batchsize, c, w, h = feature_map.size() feature = feature_map.view(batchsize*c, w*h) gram_matrix = torch.mm(feature, feature.T) gram_matrix = gram_matrix/(batchsize*c*w*h) #normalize return gram_matrix # ContentLoss (perceptual loss) is MSE between input_feature and target feature
class ContentLoss(nn.Module): def __init__(self, target_feature): super(ContentLoss, self).__init__() # target_feature is an constant, not variable  self.target_feature = target_feature.detach() def forward(self, input_feature): self.loss = F.mse_loss(input_feature, self.target_feature) return input_feature # styleLoss is MSE between gram_matrix of input_feature and target feature
class StyleLoss(nn.Module): def __init__(self, target_feature): super(StyleLoss, self).__init__() gram_matrix = compute_gram_matrix(target_feature) self.target_gram_matrix = gram_matrix.detach() def forward(self, input_feature): input_gram_matrix = compute_gram_matrix(input_feature) self.loss = F.mse_loss(input_gram_matrix, self.target_gram_matrix) return input_feature # Additionally, VGG networks are trained on images with
# each channel normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
# We will use them to normalize the image before sending it into the network.
class Normalization(nn.Module): def __init__(self): super(Normalization, self).__init__() self.mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1).to(device) self.std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1).to(device) def forward(self, image_tensor): return (image_tensor - self.mean) / self.std

Load pretrained-CNN model, build Style Transfer model


def build_ST_model( cnn_model, normalization, content_layer_names, style_layer_names, content_img, style_img): """Build a style_transfer model with specify style and content""" cnn_model.eval() ST_model = nn.Sequential(normalization) all_content_loss = [] all_style_loss = [] i = 0 for layer in cnn_model.children(): if isinstance(layer, nn.Conv2d): i += 1 name = 'conv_{}'.format(i) elif isinstance(layer, nn.ReLU): name = 'relu_{}'.format(i) layer = nn.ReLU(inplace=False) elif isinstance(layer, nn.MaxPool2d): name = 'pool_{}'.format(i) elif isinstance(layer, nn.BatchNorm2d): name = 'bn_{}'.format(i) ST_model.add_module(name=name, module=layer) if name in content_layer_names: target_feature = ST_model(content_img) c_loss = ContentLoss(target_feature=target_feature) ST_model.add_module(f'content_loss_{i}', module=c_loss) all_content_loss.append(c_loss) if name in style_layer_names: target_feature = ST_model(style_img) s_loss = StyleLoss(target_feature=target_feature) ST_model.add_module(f'style_loss_{i}', module=s_loss) all_style_loss.append(s_loss) return ST_model, all_content_loss, all_style_loss # load pretrain-model and build StyleTransfer model
cnn_model = models.vgg19(pretrained=True).features.to(device).eval()
normalization = Normalization()
content_layers = ['conv_4']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
ST_model, all_content_loss, all_style_loss = build_ST_model( cnn_model, normalization, content_layers, style_layers, content_img, style_img)

Function chính, chứa trương trình chính để transfer style từ một ảnh sang ảnh khác.

def run_style_transfer( cnn_model, normalization, content_layers, style_layers, content_img, style_img, num_steps=100, style_weight=10000, content_weight=1): """Run main program to transfer style from style_img to content_img""" ST_model, content_losses, style_losses = build_ST_model( cnn_model, normalization, content_layers, style_layers, content_img, style_img) input_image = content_img.clone() input_image.requires_grad_() optimizer = optim.RMSprop([input_image]) for i in range(num_steps): input_image.data.clamp_(0,1) optimizer.zero_grad() content_score, style_score = 0, 0 ST_model(input_image) for c_loss in content_losses: content_score += c_loss.loss for s_loss in style_losses: style_score += s_loss.loss total_loss = content_score*content_weight + style_score*style_weight total_loss.backward() optimizer.step() if (i+1)%50 == 0: print("content_loss: ", content_score.item(), "style_loss: ", style_score.item()) input_image.data.clamp_(0, 1) return input_image

Thay vì khởi tạo random ảnh input_image, mình quyết định gán ảnh content_image cho input_image. Như thế input_image sẽ dễ dàng và học nhanh hơn, không tốn thời gian trong việc khôi phục content. Cũng bởi vậy, trong số của StyleLoss được chọn lớn hơn rất nhiều trọng số của ContentLoss.

style_img = load_img('dataset/style-3.jpg')
content_img = load_img('dataset/content_1.jpg') output_img = run_style_transfer( cnn_model, normalization, content_layers, style_layers, content_img, style_img, num_steps=700, style_weight=1000000, content_weight=1) plot_imgs([content_img, style_img, output_img], size=10)

3 Kết luận

Như bạn thấy, thuật toán StyleTransfer này thực sự không khó, ý tưởng và cách làm rất đơn giản. Trong đề tài StyleTransfer, người ta còn áp dụng cả GAN. Trong thời gian tới, mình sẽ cố gắng viết bài về StyleTransfer với GAN. Hiện tại mình mới lập blog cá nhân tại https://trungthanhnguyen0502.github.io . Bạn có thể follow mình trên Viblo hoặc đón đọc bài trực tiếp từ Blog cá nhân. Cảm ơn bạn đã đọc

Tham khảo

Bài viết cách đây 2 năm của chính mình: https://forum.machinelearningcoban.com/t/style-transfer-tutorial/4026

https://pytorch.org/tutorials/advanced/neural_style_tutorial.html

https://www.tensorflow.org/tutorials/generative/style_transfer

Bình luận

Bài viết tương tự

- vừa được xem lúc

Tấn công và phòng thủ bậc nhất cực mạnh cho các mô hình học máy

tấn công bậc nhất cực mạnh = universal first-order adversary. Update: Bleeding edge của CleverHans đã lên từ 3.1.0 đến 4.

0 0 42

- vừa được xem lúc

[Deep Learning] Key Information Extraction from document using Graph Convolution Network - Bài toán trích rút thông tin từ hóa đơn với Graph Convolution Network

Các nội dung sẽ được đề cập trong bài blog lần này. . Tổng quan về GNN, GCN. Bài toán Key Information Extraction, trích rút thông tin trong văn bản từ ảnh.

0 0 219

- vừa được xem lúc

Trích xuất thông tin bảng biểu cực đơn giản với OpenCV

Trong thời điểm nhà nước đang thúc đẩy mạnh mẽ quá trình chuyển đổi số như hiện nay, Document Understanding nói chung cũng như Table Extraction nói riêng đang trở thành một trong những lĩnh vực được quan tâm phát triển và chú trọng hàng đầu. Vậy Table Extraction là gì? Document Understanding là cái

0 0 230

- vừa được xem lúc

Con đường AI của tôi

Gần đây, khá nhiều bạn nhắn tin hỏi mình những câu hỏi đại loại như: có nên học AI, bắt đầu học AI như nào, làm sao tự học cho đúng, cho nhanh, học không bị nản, lộ trình học AI như nào... Sau nhiều lần trả lời, mình nghĩ rằng nên viết hẳn một bài để trả lời chi tiết hơn, cũng như để các bạn sau này

0 0 157

- vừa được xem lúc

[B5'] Smooth Adversarial Training

Đây là một bài trong series Báo khoa học trong vòng 5 phút. Được viết bởi Xie et. al, John Hopkins University, trong khi đang intern tại Google. Hiện vẫn là preprint do bị reject tại ICLR 2021.

0 0 45

- vừa được xem lúc

Deep Learning với Java - Tại sao không?

Muốn tìm hiểu về Machine Learning / Deep Learning nhưng với background là Java thì sẽ như thế nào và bắt đầu từ đâu? Để tìm được câu trả lời, hãy đọc bài viết này - có thể kỹ năng Java vốn có sẽ giúp bạn có những chuyến phiêu lưu thú vị. DJL là tên viết tắt của Deep Java Library - một thư viện mã ng

0 0 139