- vừa được xem lúc

[Paper Explain] MetaFormer: Khi Attention is NOT all you need cho bài toán phân loại ảnh

0 0 26

Người đăng: Nguyen Mai

Theo Viblo Asia

Yêu cầu nhỏ

Hiểu về các lớp Norm khác nhau hoạt động như nào: BatchNorm (BN), GroupNorm (GN), LayerNorm (LN) và biết cách sử dụng Pytorch

Hình 0. Cách hoạt động của các lớp Norm khác nhau

Mở đầu

Từ khi Transformer được áp dụng cho bài toán phân loại ảnh qua ViT, đã có rất nhiều models mới tập trung vào cải thiện phép Self-Attention. Một khối encoder của Transformer có kiến trúc chung như ở hình 1, bao gồm 2 thành phần. Phần đầu là nơi chứa Attention module dùng để "mix" thông tin của tokens với nhau, do đó ta đặt nó là "Token Mixer". Phần còn lại là các module nhỏ kiểu như Channel MLP và Skip-connection. Với sự thành công của Transformer, người ta cho rằng đó là do sự vượt bậc của phép Self-Attention. Do đó, có rất nhiều models mới tập trung vào cải thiện phép Self-Attention. Tuy nhiên, một số nghiên cứu lại chỉ đơn thuần sử dụng phép Spatial MLP làm Token Mixer và đạt được hiệu năng tương đương với phép Self-Attention. Hay thậm chí, có một nghiên cứu khác sử dụng Fourier Transform thay thế cho phép Attention và đạt được hiệu năng tới 97% một model ViT thuần.

Tác giả của paper MetaFormer nhận ra một điểm chung giữa các models được phát triển bây giờ, đó là các khối của chúng đều có kiến trúc 2 phần tương tự nhau. Phần đầu chứa phép Token Mixer, phần sau là Channel MLP và các skip-connection. Và tác giả cho rằng, thay vì tập trung quá nhiều vài cải tiến Token Mixer, kiến trúc tổng thể của một khối, cũng góp phần vô cùng lớn trong việc tạo ra một model có hiệu năng mạnh mẽ. Tác giả đóng gói kiến trúc tổng thể của một khối này lại, và đặt tên cho nó là MetaFormer (Hình 1). Tác giả đã tạo ra một hướng nghiên cứu mới, thay vì tập trung vào cải thiện Token Mixer, thì ta sẽ cải thiện kiến trúc tổng thể của nó, tức cải thiện MetaFormer.

Hình 1. Kiến trúc Transformer (giữa), MLP-like (phải) và kiến trúc đóng gói MetaFormer (trái)

Chứng minh

Ừ thì tác giả đề ra rằng kiến trúc tổng thể cũng vô cùng quan trọng, nếu không có gì để backup cho lời nói đó thì cũng chỉ là kết luận xàm xí thôi. Đầu tiên, tác giả sẽ chứng minh cho câu nói đó.

Một khối MetaFormer được chia ra làm 2 phần (Hình 1), phần đầu có Token Mixer và phần sau cho Channel MLP:

X=X+TokenMixer(Norm1(X))X' = X + TokenMixer(Norm_1(X))

X=X+ChannelMLP(Norm2(X))X'' = X' + ChannelMLP(Norm_2(X'))

Token Mixer

Để chứng minh rằng MetaFormer là thứ cực kì quan trọng, thì tất nhiên ta phải sử dụng một Token Mixer cực yếu để xem rằng hiệu năng của toàn bộ model với Token Mixer cực yếu này có còn tốt?
Vì vậy, Token Mixer được chọn sẽ là phép Pooling. cụ thể là Average Pooling. Cụ thể, phép Pooling sẽ được định nghĩa như sau:

Thuật toán 1. Phép Pooling được sử dụng

Modified Layer Norm

Thông thường trong kiến trúc họ Transformer, Norm được chọn để sử dụng là LayerNorm (LN). Tuy nhiên, tác giả nhận ra Norm này có vẻ không phù hợp cho lắm.

Trước tiên, ta phải hiểu về cách hoạt động của một số lớp Norm thường sử dụng và Norm của Pytorch API đã. Khi sử dụng kiến trúc họ Transformer, Feature maps sẽ có dạng (B,tokens,C)(B, tokens, C) với tokenstokensH×WH \times W. Khi thực hiện LN, ta sẽ norm ở trên chiều CC, và weight của LN sẽ có chiều CC luôn. Tuy nhiên, việc không sử dụng phép Self-Attention, tức feature maps sẽ có dạng (B,C,H,W)(B, C, H, W). Và nếu lúc này thực hiện LN thông qua nn.LayerNorm của Pytorch thì ta sẽ thực hiện norm ở trên chiều [C,H,W][C, H, W] thay vì chỉ CC, và weight của LN cũng sẽ có chiều là [C,H,W][C, H, W]. Nếu ta muốn áp dụng được LN chuẩn mực, tức là chỉ Norm trên chiều CC kể cả là feature maps có dạng (B,C,H,W)(B, C, H, W), thì vẫn hoàn toàn có thể, chỉ là giờ chúng ta sẽ phải tự viết code chứ không còn áp dụng được nn.LayerNorm của Pytorch nữa mà thôi.

Tóm cái váy lại thì:

  • LayerNorm chỉ trên chiều channel như Transformer: Phải tự viết code lại, chỉ thực hiện norm trên chiều channel và weight cũng có chiều channel
  • nn.LayerNorm: thực hiện norm trên chiều [C,H,W][C, H, W] và weight có chiều [C,H,W][C, H, W]

Oke thế giờ thì 2 Norm kể trên LayerNorm và nn.LayerNorm có gì chưa tốt nếu sử dụng với feature maps có dạng (B,C,H,W)(B, C, H, W). Nếu ta sử dụng nn.LayerNorm, ta sẽ phải khai báo một lớp LN như sau:

self.layer_norm = nn.LayerNorm([C, H, W])

Điều này yêu cầu nn.LayerNorm của ta phải có HHWW cố định thì mới có thể khai báo và forward được. Điều này khiến cho mạng chỉ nhận đầu vào là một kích cỡ ảnh cố định \rightarrow không khả thi cho các downstream task.

Còn sử dụng LayerNorm thì hiệu năng yếu.

Vì vậy, tác giả đã tạo ra Modified Layer Norm (MLN). MLN sẽ thực hiện Norm trên chiều [C,H,W][C, H, W], tuy nhiên, weight của MLN sẽ có chiều CC. Nó là sự kết hợp giữa nn.LayerNorm và LayerNorm. Khi khai báo, ta sẽ chỉ phải khai báo như sau:

self.modified_layer_norm = ModifiedLayerNorm(C)

Tức là giờ ta không còn phải cố định chiều HHWW nữa.

Dưới đây là bảng kết quả so sánh MLN với LN và BN:

Bảng 1. So sánh MLN (Baseline) với LN và BN

Kiến trúc của mạng

Hình 2. a) Kiến trúc của toàn bộ model PoolFormer. b) kiến trúc của một PoolFormer block. Gọi là PoolFormer vì Token Mixer được chọn là phép Pooling

Bảng kết quả

Phần đáng mong đợi nhất đây. Đây là điều khiến cho câu nói của tác giả không trở thành câu chém gió vô căn cứ.

Bảng 2. PoolFormer so sánh với những model khác

❓ Mình có một câu hỏi thú vị muốn đặt ra cho các bạn đọc bài này. Phía trên CNN cùi bắp thì không nói, nhưng cùng là MetaFormer, tại sao PoolFormer lại mạnh hơn PVT, ViT hay Swin với việc chỉ sử dụng phép Pooling làm Token Mixer?

Cải thiện MetaFormer

Như đã nói ở trên, thay vì tập trung sáng tạo ra một Token Mixer mới, thì ta sẽ sử dụng những Token Mixer đơn giản, và cải thiện những thứ còn lại của một MetaFormer Block.

StarReLU

Trong paper Transformer, ReLU được chọn làm activation function:

ReLU(x)=max(x,0)ReLU(x) = max(x, 0)

Activation function này có độ nặng tính toán là 1 FLOP. Sau đó, GELU được sử dụng làm activation function chính cho các model họ Transformer:

GELU(x)=xΦ(x)0.5×x(1+tanh(2/π(x+0.044715×x3)))GELU(x) = x \Phi(x) \approx 0.5 \times x(1+tanh(\sqrt{2/\pi}(x+0.044715 \times x^{3})))

Phép tính GELU(x)GELU(x) tốn mất 14 FLOPs, lớn hơn gấp 14 lần ReLU. Có một paper đã tìm ra phép thay thế gần đúng của GELU, gọi là SquaredReLU như sau:

SquaredReLU=xReLU(x)=(ReLU(x))2SquaredReLU = xReLU(x) = (ReLU(x))^2

SquaredReLU chỉ tốn có 2 FLOPs, tuy nhiên, hiệu năng của SquaredReLU vẫn không thể sánh ngang với GELU trên bài toán phân loại ảnh. Nhóm tác giả của MetaFormer cho rằng, việc hiệu năng tụt giảm có thể là do distribution shift (sự dịch chuyển phân phối) trên output của phép tính. Giả dụ xx tuân theo phân phối chuẩn với mean 0 và variance 1, ~N(0,1)N(0, 1), ta có:

E((ReLU(x)2)=0.5,Var((ReLU(x)2)=1.25E((ReLU(x)^2) = 0.5, \qquad Var((ReLU(x)^2) = 1.25

với E()E(\cdot)Var()Var(\cdot) lần lượt là expectation và variance. Do đó, nhóm tác giả tạo ra StarReLU để có thể giải quyết distribution shift như sau:

StarReLU(x)=(ReLU(x))2E((ReLU(x))2)Var((ReLU(x))2)=(ReLU(x))20.51.250.8944(ReLU(x))20.4472StarReLU(x) = \frac{(ReLU(x))^2 - E((ReLU(x))^2)}{\sqrt{Var((ReLU(x))^2)}} = \frac{(ReLU(x))^2 - 0.5}{\sqrt{1.25}} \approx 0.8944 \cdot (ReLU(x))^2 − 0.4472

Tuy nhiên, đấy là giả định khi sử dụng input là normal distribution. Để StarReLU phù hợp với nhiều distribution hơn thì, thay vì sử dụng 2 hệ số cố định 0.89440.89440.4472-0.4472, ta sẽ để cho nó tự học. StarReLU bản mở rộng như sau:

StarReLU(x)=s(ReLU(x))2+bStarReLU(x) = s \cdot (ReLU(x))^2 + b

với sRs \in \RbRb \in \R là learnable params (như BN). StarReLU lúc này chỉ tốn có 4 FLOPs

Các thay đổi khác

Output của một phần trong một khối MetaFormer được tính như sau:

Y=X+F(Norm(X))Y = X + F(Norm(X))

với XX là input, YY là output, NormNorm là lớp Norm và FF là Token Mixer hoặc channel MLP Scaling branch output. Giống như implicit knowledge dạng nhân, ta có 3 cách thêm scaling branch như sau:

  • Residual scale: thêm hệ số scale vào phần Residual:

Y=λrX+F(Norm(X))Y = \lambda_r \odot X + F(Norm(X))

  • Layer Scale: thêm hệ số scale vào phần tính toán

Y=X+λlF(Norm(X))Y = X + \lambda_l \odot F(Norm(X))

  • Branch Scale: Kết hợp cả Layer Scale và Residual Scale

Y=λrX+λlF(Norm(X))Y = \lambda_r \odot X + \lambda_l \odot F(Norm(X))

Hiệu năng của Residual Scale, theo những thử nghiệm của tác giả cho bài toán phân loại ảnh, đang là tốt nhất. Tuy nhiên, mọi người có thể thử các cách scale khác cho phù hợp với bài toán

IdentityFormer và RandFormer

Để chứng minh hiệu năng của việc cải thiện MetaFormer thay vì Token Mixer, tác giả của MetaFormer đã tạo ra IdentityFormer và RandFormer

IdentityFormer. Loại bỏ hoàn toàn Token Mixer ra khỏi model, hay nói cách khác là sử dụng Identity Mapping làm Token Mixer, ta có:

IdentityMapping(X)=XIdentityMapping(X) = X

hay:

X=X+Norm(X)X' = X + Norm(X)

RandFormer. Sử dụng một ma trận ngẫu nhiên (Random Matrix) làm Token Mixer, và Random Matrix này sẽ không được cập nhật trong quá trình backprop. Tức là ma trận được khởi tạo ngẫu nhiên ra sao thì giữ nguyên như thế đến cuối

RandomMixing(X)=XWRRandomMixing(X) = XW_R

XRN×CX \in \R^{N \times C}NN là số token (H×WH \times W), CC là số channels; WRRN×NW_R \in \R^{N \times N}

Bảng 2. Kết quả của IdentityFormer, RandFormer và PoolFormerV2 so với ResNet trong ResNet strikes back

ConvFormer và CAFormer

Với việc đã chứng minh được cải tiến MetaFormer thay vì Token Mixer có thể đem lại hiệu năng không tưởng, ta sẽ sử dụng Token Mixer đơn giảncó sẵn để tạo ra một model mạnh thay vì phải đau đầu suy nghĩ ra một Token Mixer mớiphức tạp ConvFormer. Sử dụng Depthwise Separable Convolution trong MobileNetV2 làm TokenMixer:

Conv(X)=Convpw2(Convdw(σ(Convpw1(X))))Conv(X) = Conv_{pw_{2}}(Conv_{dw}(\sigma(Conv_{pw_{1}}(X))))

với ConvpwConv_{pw}1×11 \times 1 DWConv (hay còn gọi là Point wise Conv), ConvdwConv_{dw}k×kk \times k DWConv, trong paper k=7k=7, và σ\sigma là một hàm phi tuyến tính.

CAFormer. Sử dụng nửa đầu giống như ConvFormer, còn nửa sau của model sử dụng Self-Attention.

❓ Các bạn có thể thử suy nghĩ tại sao nửa sau của CAFormer lại sử dụng Self-Attention mà không sử dụng từ đầu nhé

Bảng 3. Hiệu năng của ConvFormer và CAFormer. Đáng chú ý là CAFormer là model SOTA

TL;DR

[W]hat

  • Một hướng đi mới trong việc nghiên cứu model cho bài toán phân loại ảnh, có thể mở rộng thành các backbone chung chung

[W]hy

  • Mình cũng không có đánh giá gì về phần này vì nó không phải cải tiến những cái gì chưa tốt từ những model cũ mà mở ra một hướng nghiên cứu mới khá là thú vị

Ho[W]

  • Tập trung vào cải tiến khối tổng thể (MetaFormer) thay vì chỉ tập trung vào cải tiến một module (Token Mixer)

Reference

PoolFormer: https://arxiv.org/abs/2111.11418

MetaFormer: https://arxiv.org/abs/2210.13452

Bình luận

Bài viết tương tự

- vừa được xem lúc

Tấn công và phòng thủ bậc nhất cực mạnh cho các mô hình học máy

tấn công bậc nhất cực mạnh = universal first-order adversary. Update: Bleeding edge của CleverHans đã lên từ 3.1.0 đến 4.

0 0 42

- vừa được xem lúc

[Deep Learning] Key Information Extraction from document using Graph Convolution Network - Bài toán trích rút thông tin từ hóa đơn với Graph Convolution Network

Các nội dung sẽ được đề cập trong bài blog lần này. . Tổng quan về GNN, GCN. Bài toán Key Information Extraction, trích rút thông tin trong văn bản từ ảnh.

0 0 222

- vừa được xem lúc

Trích xuất thông tin bảng biểu cực đơn giản với OpenCV

Trong thời điểm nhà nước đang thúc đẩy mạnh mẽ quá trình chuyển đổi số như hiện nay, Document Understanding nói chung cũng như Table Extraction nói riêng đang trở thành một trong những lĩnh vực được quan tâm phát triển và chú trọng hàng đầu. Vậy Table Extraction là gì? Document Understanding là cái

0 0 231

- vừa được xem lúc

Con đường AI của tôi

Gần đây, khá nhiều bạn nhắn tin hỏi mình những câu hỏi đại loại như: có nên học AI, bắt đầu học AI như nào, làm sao tự học cho đúng, cho nhanh, học không bị nản, lộ trình học AI như nào... Sau nhiều lần trả lời, mình nghĩ rằng nên viết hẳn một bài để trả lời chi tiết hơn, cũng như để các bạn sau này

0 0 157

- vừa được xem lúc

[B5'] Smooth Adversarial Training

Đây là một bài trong series Báo khoa học trong vòng 5 phút. Được viết bởi Xie et. al, John Hopkins University, trong khi đang intern tại Google. Hiện vẫn là preprint do bị reject tại ICLR 2021.

0 0 45

- vừa được xem lúc

Deep Learning với Java - Tại sao không?

Muốn tìm hiểu về Machine Learning / Deep Learning nhưng với background là Java thì sẽ như thế nào và bắt đầu từ đâu? Để tìm được câu trả lời, hãy đọc bài viết này - có thể kỹ năng Java vốn có sẽ giúp bạn có những chuyến phiêu lưu thú vị. DJL là tên viết tắt của Deep Java Library - một thư viện mã ng

0 0 139