- vừa được xem lúc

Paper reading | Tóm tắt mô hình ResNeSt: Split-Attention Networks

0 0 13

Người đăng: Viblo AI

Theo Viblo Asia

Đóng góp của bài báo

Bài báo giới thiệu một kiến trúc mô hình đơn giản có tên ResNeSt sử dụng channel-wise attention trên các nhánh của mạng với mục tiêu tận dụng sức mạnh capture thông tin tương tác giữa các đặc trưng (cross-feature interaction) và học đa dạng các biểu diễn. Mô hình ResNeSt vượt qua mô hình EfficientNet trên khía cạnh đánh đổi độ chính xác và độ trễ (accuracy and latency trade-off) trên task image classification.

image.png

Split-Attention Networks

Toàn bộ ý tưởng hay ho của ResNeSt nằm trong Split-Attention block. Split-Attention block bao gồm 2 thành phần là featuremap group và các split attention.

image.png

Featuremap Group. Tại Featuremap Group, feature được chia thành các nhóm, ta có thể đặt số lượng Featuremap group bằng một cardinality hyperparameter KK. Featuremap group có thể gọi là Cardinal group (xem hình trên). Trong bài báo, nóm tác giả cũng giới thiệu một hyperparameter nữa là RR (radix) thể hiện số lượng split trong cardinal group. Do đó, số lượng feature group là G=KRG = KR. Tại mỗi feature group, ta thực hiện trích xuất feature sử dụng các layer Conv. Đầu ra của các layer này sẽ được đưa vào Split Attention.

image.png

Split Attention trong Cardinal Group. Đầu ra của các split được tổng hợp thông qua phép toán tính tổng element-wise tất cả các split trong cardinal group. Biểu diễn của cardinal group thứ kkU^k=j=R(k1)+1RkUj\hat{U}^k=\sum_{j=R(k-1)+1}^{R k} U_j trong đó UjU_j là biểu diễn đầu ra của từng split. Các thông tin ngữ cảnh toàn cục sau đó được tổng hợp thông qua một layer global average pooling theo chiều không gian. Thành phần thứ cc được tính như sau:

image.png

trong đó skRC/Ks^k \in \mathbb{R}^{C / K}.

Sau đó, mỗi featuremap của channel cc được tính toán như sau:

image.png

trong đó aik(c)a_i^k(c) là trọng số được tính như sau:

image.png

Gic\mathcal{G}_i^c có vai trò xác định trọng số của mỗi split cho channel cc dựa vào biểu diễn ngữ cảnh toàn cục sks^k.

ResNeSt Block. Các biểu diễn của cardinal group sau đó được concat theo chiều channel V=Concat(V1,V2,...,VK)V = Concat(V^1, V^2,..., V^K). Giống như block trong model ResNet, ta sử dụng một kết nối tắt: Y=V+XY = V + X nếu input và output featuremap có cùng kích thước. Nếu kích thước khác nhaum ta có thể sử dụng thêm một lớp convolution hoặc kết hợp convolution với pooling. Khi đó ta có Y=V+T(X)Y = V + \mathcal{T}(X).

Coding

Ta xây dựng các layer của model ResNeSt như sau:

import torch
import torch.nn as nn
import torch.nn.functional as F class GlobalAvgPool2d(nn.Module): ''' global average pooling 2D class ''' def __init__(self): super(GlobalAvgPool2d, self).__init__() def forward(self, x): return F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1) class ConvBlock(nn.Module): ''' convolution 2D -> batch normalization -> ReLU ''' def __init__(self, in_channels, out_channels, kernel_size, stride, padding ): super(ConvBlock, self).__init__() self.block = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False, ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): x = self.block(x) return x '''
Split Attention
''' class rSoftMax(nn.Module): ''' (radix-majorize) softmax class input is cardinal-major shaped tensor. transpose to radix-major ''' def __init__(self, groups=1, radix=2 ): super(rSoftMax, self).__init__() self.groups = groups self.radix = radix def forward(self, x): B = x.size(0) # transpose to radix-major x = x.view(B, self.groups, self.radix, -1).transpose(1, 2) x = F.softmax(x, dim=1) x = x.view(B, -1, 1, 1) return x class SplitAttention(nn.Module): def __init__(self, in_channels, channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, radix=2, reduction_factor=4 ): super(SplitAttention, self).__init__() self.radix = radix self.radix_conv = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=channels*radix, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups*radix, bias=bias ), nn.BatchNorm2d(channels*radix), nn.ReLU(inplace=True) ) inter_channels = max(32, in_channels*radix//reduction_factor) self.attention = nn.Sequential( nn.Conv2d( in_channels=channels, out_channels=inter_channels, kernel_size=1, groups=groups ), nn.BatchNorm2d(inter_channels), nn.ReLU(inplace=True), nn.Conv2d( in_channels=inter_channels, out_channels=channels*radix, kernel_size=1, groups=groups ) ) self.rsoftmax = rSoftMax( groups=groups, radix=radix ) def forward(self, x): ''' input : | in_channels | ''' ''' radix_conv : | radix 0 | radix 1 | ... | radix r | | group 0 | group 1 | ... | group k | group 0 | group 1 | ... | group k | ... | group 0 | group 1 | ... | group k | ''' x = self.radix_conv(x) ''' split : [ | group 0 | group 1 | ... | group k |, | group 0 | group 1 | ... | group k |, ... ] sum : | group 0 | group 1 | ...| group k | ''' B, rC = x.size()[:2] splits = torch.split(x, rC // self.radix, dim=1) gap = sum(splits) ''' !! becomes cardinal-major !! attention : | group 0 | group 1 | ... | group k | | radix 0 | radix 1| ... | radix r | radix 0 | radix 1| ... | radix r | ... | radix 0 | radix 1| ... | radix r | ''' att_map = self.attention(gap) ''' !! transposed to radix-major in rSoftMax !! rsoftmax : same as radix_conv ''' att_map = self.rsoftmax(att_map) ''' split : same as split sum : same as sum ''' att_maps = torch.split(att_map, rC // self.radix, dim=1) out = sum([att_map*split for att_map, split in zip(att_maps, splits)]) ''' output : | group 0 | group 1 | ...| group k | concatenated tensors of all groups, which split attention is applied ''' return out.contiguous() '''
Bottleneck Block
''' class BottleneckBlock(nn.Module): expansion = 4 def __init__(self, in_channels, channels, stride=1, dilation=1, downsample=None, radix=2, groups=1, bottleneck_width=64, is_first=False ): super(BottleneckBlock, self).__init__() group_width = int(channels * (bottleneck_width / 64.)) * groups layers = [ ConvBlock( in_channels=in_channels, out_channels=group_width, kernel_size=1, stride=1, padding=0 ), SplitAttention( in_channels=group_width, channels=group_width, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, groups=groups, bias=False, radix=radix ) ] if stride > 1 or is_first: layers.append( nn.AvgPool2d( kernel_size=3, stride=stride, padding=1 ) ) layers += [ nn.Conv2d( group_width, channels*4, kernel_size=1, bias=False ), nn.BatchNorm2d(channels*4) ] self.block = nn.Sequential(*layers) self.downsample = downsample def forward(self, x): residual = x if self.downsample: residual = self.downsample(x) out = self.block(x) out += residual return F.relu(out) if __name__ == "__main__": m = BottleneckBlock(256, 64) x = torch.randn(3, 256, 4, 4) print(m(x).size())

Xây dựng model ResNeSt từ các layer trên như sau:

'''
ResNeSt
'''
import torch
import torch.nn as nn from layers import ConvBlock
from layers import GlobalAvgPool2d
from layers import BottleneckBlock class ResNeSt(nn.Module): ''' ResNeSt [1] class [1] ResNeSt : Split-Attention Networks, Hang Zhang, Chongruo Wu, Zhongyue Zhang, Yi Zhu, Zhi Zhang, Haibin Lin, Yue Sun, Tong He, Jonas Mueller, R. Manmatha, Mu Li, Alexander Smola, https://arxiv.org/abs/2004.08955 official implementation : https://github.com/zhanghang1989/ResNeSt ''' def __init__(self, layers, radix=2, groups=1, bottleneck_width=64, n_classes=1000, stem_width=64 ): super(ResNeSt, self).__init__() self.radix = radix self.groups = groups self.bottleneck_width = bottleneck_width self.deep_stem = nn.Sequential( ConvBlock( in_channels=3, out_channels=stem_width, kernel_size=3, stride=2, padding=1 ), ConvBlock( in_channels=stem_width, out_channels=stem_width, kernel_size=3, stride=1, padding=1 ), ConvBlock( in_channels=stem_width, out_channels=stem_width*2, kernel_size=3, stride=1, padding=1 ), nn.MaxPool2d( kernel_size=3, stride=2, padding=1 ) ) self.in_channels = stem_width*2 self.layer1 = self._make_layers( channels=64, blocks=layers[0], stride=1, is_first=False ) self.layer2 = self._make_layers( channels=128, blocks=layers[1], stride=2 ) self.layer3 = self._make_layers( channels=256, blocks=layers[2], stride=2 ) self.layer4 = self._make_layers( channels=512, blocks=layers[3], stride=2 ) self.classifier = nn.Sequential( GlobalAvgPool2d(), nn.Linear( in_features=512*BottleneckBlock.expansion, out_features=n_classes ) ) def _make_layers(self, channels, blocks, stride=1, is_first=True ): down_layers = None if not stride ==1 or not self.in_channels == channels * BottleneckBlock.expansion: down_layers = nn.Sequential( nn.AvgPool2d( kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False ), nn.Conv2d( in_channels=self.in_channels, out_channels=channels*BottleneckBlock.expansion, kernel_size=1, stride=stride, bias=False ), nn.BatchNorm2d(channels*BottleneckBlock.expansion) ) layers = [] layers.append( BottleneckBlock( in_channels=self.in_channels, channels=channels, stride=stride, downsample=down_layers, radix=self.radix, groups=self.groups, bottleneck_width=self.bottleneck_width, is_first=is_first ) ) self.in_channels = channels * BottleneckBlock.expansion for _ in range(1, blocks): layers.append( BottleneckBlock( in_channels=self.in_channels, channels=channels, radix=self.radix, groups=self.groups, bottleneck_width=self.bottleneck_width ) ) return nn.Sequential(*layers) def forward(self, img): x = self.deep_stem(img) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.classifier(x) return x if __name__ == "__main__": m = ResNeSt( [3, 4, 6, 3] ) img = torch.randn(3, 3, 224, 224) print(m(img).size())

Thực nghiệm

Bảng dưới là hiệu suất của các cải tiến từ model ResNet trên tập dữ liệu ImageNet.

image.png

Bảng dưới là hiệu suất của model ResNeSt với các setting khác nhau. Ví dụ, 2s2x40d là radix = 2, cardinality = 2 và width = 40.

image.png

Bảng dưới so sánh độ chính xác và tốc độ inference của các SOTA model trên tập dữ liệu ImageNet. ResNeSt thể hiện sự cân bằng giữa độ chính xác và tốc độ inference một cách tối ưu nhất.

image.png

Kết quả trên task Object Detection với tập dữ liệu MS-COCO.

image.png

Bảng dưới so sánh kết quả trên task Instance Segmentation với tập dữ liệu MS-COCO.

image.png

Tương tự với tập dữ liệu ADE20K, ta có kết quả sau:

image.png

Với bộ dữ liệu Citscapes, ResNeSt vẫn thể hiện sự vượt trội với các model SOTA trước đó.

image.png

Không chỉ với các task hình ảnh đơn thuần, bảng dưới thể hiện kết quả trên task Pose estimation với tập MS-COCO.

image.png

Tham khảo

[1] ResNeSt: Split-Attention Networks

[2] Amazon Introduces ResNeSt: Strong, Split-Attention Networks

[3] https://github.com/zhanghang1989/ResNeSt/tree/master

[4] https://paperswithcode.com/method/channel-attention-module

Bình luận

Bài viết tương tự

- vừa được xem lúc

Tấn công và phòng thủ bậc nhất cực mạnh cho các mô hình học máy

tấn công bậc nhất cực mạnh = universal first-order adversary. Update: Bleeding edge của CleverHans đã lên từ 3.1.0 đến 4.

0 0 42

- vừa được xem lúc

[Deep Learning] Key Information Extraction from document using Graph Convolution Network - Bài toán trích rút thông tin từ hóa đơn với Graph Convolution Network

Các nội dung sẽ được đề cập trong bài blog lần này. . Tổng quan về GNN, GCN. Bài toán Key Information Extraction, trích rút thông tin trong văn bản từ ảnh.

0 0 219

- vừa được xem lúc

Trích xuất thông tin bảng biểu cực đơn giản với OpenCV

Trong thời điểm nhà nước đang thúc đẩy mạnh mẽ quá trình chuyển đổi số như hiện nay, Document Understanding nói chung cũng như Table Extraction nói riêng đang trở thành một trong những lĩnh vực được quan tâm phát triển và chú trọng hàng đầu. Vậy Table Extraction là gì? Document Understanding là cái

0 0 230

- vừa được xem lúc

Con đường AI của tôi

Gần đây, khá nhiều bạn nhắn tin hỏi mình những câu hỏi đại loại như: có nên học AI, bắt đầu học AI như nào, làm sao tự học cho đúng, cho nhanh, học không bị nản, lộ trình học AI như nào... Sau nhiều lần trả lời, mình nghĩ rằng nên viết hẳn một bài để trả lời chi tiết hơn, cũng như để các bạn sau này

0 0 157

- vừa được xem lúc

[B5'] Smooth Adversarial Training

Đây là một bài trong series Báo khoa học trong vòng 5 phút. Được viết bởi Xie et. al, John Hopkins University, trong khi đang intern tại Google. Hiện vẫn là preprint do bị reject tại ICLR 2021.

0 0 45

- vừa được xem lúc

Deep Learning với Java - Tại sao không?

Muốn tìm hiểu về Machine Learning / Deep Learning nhưng với background là Java thì sẽ như thế nào và bắt đầu từ đâu? Để tìm được câu trả lời, hãy đọc bài viết này - có thể kỹ năng Java vốn có sẽ giúp bạn có những chuyến phiêu lưu thú vị. DJL là tên viết tắt của Deep Java Library - một thư viện mã ng

0 0 139