- vừa được xem lúc

Paper reading | FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization

0 0 2

Người đăng: Viblo AI

Theo Viblo Asia

Đóng góp của bài báo

Bài báo giới thiệu mô hình FastViT có kiến trúc hybrid vision transformer, mô hình đạt kết quả SOTA trong việc cân bằng giữa latency-accuracy. FastViT nhanh hơn gấp 3.5 lần so với mô hình CMT (một mô hình hybrid transformer SOTA gần đây). FastViT cũng nhanh gấp 4.9 lần so với EfficientNet và gấp 1.9 lần so với ConvNeXt trên thiết bị di động với cùng độ chính xác trên tập dữ liệu ImageNet. Ngoài ra, FastViT cũng có hiệu suất tốt và nhất quán cho nhiều task khác nhau như image classification, detection, segmentation và 3D mesh regression với latency giảm đáng kể trên cả thiết bị di động và GPU.

Kiến trúc mô hình

Tổng quan

image.png

FastViT là mô hình hybrid transformer và có 4 stage như mô tả trong hình trên. Tác giả cũng cung cấp thông tin các biến thể của FastViT trong bảng dưới:

image.png

Quan sát mô hình tổng quan của FastViT ta thấy có một số điểm đáng chú ý. Đầu tiên, tại phần (a) của mô hình, nhóm tác giả thay các dense k×k convolution thường thấy trong các layer stem và patch embedding bằng một phiên bản factorized của nó sử dụng train-time overparameterization. Thứ hai, đó là tại phần (d), mô hình sử dụng RepMixer, đây là một token mixer sử dụng để tham số hóa một skip connection, giúp giảm chi phí truy cập bộ nhớ. Thứ ba, trong phần (c), nhóm tác giả sử dụng large kernel convolution với mục tiêu cải thiện receptive field ngay tại các stage đầu tiên của mô hình.

Nhóm tác giả sử dụng model PoolFormer làm baseline để thiết kế các biến thể của FastViT, từ đó đánh giá chất lượng của từng kiến trúc được thiết kế.

image.png

FastViT

Tham số hóa skip connection

Trong phần này, ta sẽ cùng tìm hiểu cách mô hình tại stage (d) thực hiện tham số hóa skip connection. Cho môt input tensor XX, mixing block trong layer được cài đặt như sau:

image.png

trong đó σ\sigma là một hàm activation phi tuyến tính và BN là Batch Normalization layer. DWConv là depthwise convolution layer. Mặc dù block này được chứng minh là có hiệu quả nhưng trong RepMixer, nhóm tác giả sắp xếp lại các thao tác và loại bỏ hàm activation phi tuyến tính, từ đó ta có cách cài đặt mới như sau:

image.png

Ưu điểm chính của cách thiết kế này là nó có thể được tham số hóa thành một layer depwise convolution duy nhất khi thực hiện inference như trong hình trên phần (d)

image.png

Để kiểm nghiệm lại ưu điểm của việc tham số hóa skip connection, nhóm tác giả so sánh kết quả khi sử dụng RepMixer và Pooling và thu được kết quả trong biểu đồ dưới. Nhận thấy rằng RepMixer có latency thấp hơn đáng kể so với Pooling, kể cả khi ở ảnh có độ phân giải cao.

image.png

Linear Train-time Overparameterization

Để cải thiện hiệu suất (số lượng tham số, FLOPs và latency) hơn nữa, nhóm tác thay thế tất cả các dense k×k convolution bằng phiên bản được phân tách của nó, tức là depthwise convolution k×k , sau đó là pointwise convolution 1×1. Tuy nhiên, việc giảm số lượng tham số từ việc phân tách có thể làm giảm chất lượng của mô hình. Để tăng chất lượng của các layer đã được phân tách, nhóm tác giả thực hiện Linear Train-time Overparameterization. Việc sử dụng overparameterization trong các layer stem, patch embedding và projection giúp tăng hiệu suất mô hình. Kết quả khi sử dụng Train-time Overparam được thể hiện trong bảng dưới:

image.png

Tuy nhiên, việc sử dụng train-time overparameterization làm cho tăng thời gian training. Trong FastViT, nhóm tác giả chỉ thực hiện overparameterize các layer mà thay thế dense k×k convolution bằng phiên bản được phân tách của nó. Các layer này ở trong các layer stem, patch embedding và projection. Chi phí tính toán phát sinh trong các layer này thấp hơn so với phần còn lại của mô hình, do đó việc overparameterizing các lớp này không làm tăng đáng kể thời gian training.

Coding

File code mô hình FastViT tham khảo từ https://github.com/apple/

import os
import copy
from functools import partial
from typing import List, Tuple, Optional, Union import torch
import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model from models.modules.mobileone import MobileOneBlock
from models.modules.replknet import ReparamLargeKernelConv try: from mmseg.models.builder import BACKBONES as seg_BACKBONES from mmseg.utils import get_root_logger from mmcv.runner import _load_checkpoint has_mmseg = True
except ImportError: print("If for semantic segmentation, please install mmsegmentation first") has_mmseg = False try: from mmdet.models.builder import BACKBONES as det_BACKBONES from mmdet.utils import get_root_logger from mmcv.runner import _load_checkpoint has_mmdet = True
except ImportError: print("If for detection, please install mmdetection first") has_mmdet = False def _cfg(url="", **kwargs): return { "url": url, "num_classes": 1000, "input_size": (3, 256, 256), "pool_size": None, "crop_pct": 0.95, "interpolation": "bicubic", "mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD, "classifier": "head", **kwargs, } default_cfgs = { "fastvit_t": _cfg(crop_pct=0.9), "fastvit_s": _cfg(crop_pct=0.9), "fastvit_m": _cfg(crop_pct=0.95),
} def convolutional_stem( in_channels: int, out_channels: int, inference_mode: bool = False
) -> nn.Sequential: """Build convolutional stem with MobileOne blocks. Args: in_channels: Number of input channels. out_channels: Number of output channels. inference_mode: Flag to instantiate model in inference mode. Default: ``False`` Returns: nn.Sequential object with stem elements. """ return nn.Sequential( MobileOneBlock( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, groups=1, inference_mode=inference_mode, use_se=False, num_conv_branches=1, ), MobileOneBlock( in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, groups=out_channels, inference_mode=inference_mode, use_se=False, num_conv_branches=1, ), MobileOneBlock( in_channels=out_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, groups=1, inference_mode=inference_mode, use_se=False, num_conv_branches=1, ), ) class MHSA(nn.Module): """Multi-headed Self Attention module. Source modified from: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ def __init__( self, dim: int, head_dim: int = 32, qkv_bias: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0, ) -> None: """Build MHSA module that can handle 3D or 4D input tensors. Args: dim: Number of embedding dimensions. head_dim: Number of hidden dimensions per head. Default: ``32`` qkv_bias: Use bias or not. Default: ``False`` attn_drop: Dropout rate for attention tensor. proj_drop: Dropout rate for projection tensor. """ super().__init__() assert dim % head_dim == 0, "dim should be divisible by head_dim" self.head_dim = head_dim self.num_heads = dim // head_dim self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor) -> torch.Tensor: shape = x.shape B, C, H, W = shape N = H * W if len(shape) == 4: x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C) qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, self.head_dim) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) # trick here to make q@k.t more stable attn = (q * self.scale) @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) if len(shape) == 4: x = x.transpose(-2, -1).reshape(B, C, H, W) return x class PatchEmbed(nn.Module): """Convolutional patch embedding layer.""" def __init__( self, patch_size: int, stride: int, in_channels: int, embed_dim: int, inference_mode: bool = False, ) -> None: """Build patch embedding layer. Args: patch_size: Patch size for embedding computation. stride: Stride for convolutional embedding layer. in_channels: Number of channels of input tensor. embed_dim: Number of embedding dimensions. inference_mode: Flag to instantiate model in inference mode. Default: ``False`` """ super().__init__() block = list() block.append( ReparamLargeKernelConv( in_channels=in_channels, out_channels=embed_dim, kernel_size=patch_size, stride=stride, groups=in_channels, small_kernel=3, inference_mode=inference_mode, ) ) block.append( MobileOneBlock( in_channels=embed_dim, out_channels=embed_dim, kernel_size=1, stride=1, padding=0, groups=1, inference_mode=inference_mode, use_se=False, num_conv_branches=1, ) ) self.proj = nn.Sequential(*block) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) return x class RepMixer(nn.Module): """Reparameterizable token mixer. For more details, please refer to our paper: `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_ """ def __init__( self, dim, kernel_size=3, use_layer_scale=True, layer_scale_init_value=1e-5, inference_mode: bool = False, ): """Build RepMixer Module. Args: dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`. kernel_size: Kernel size for spatial mixing. Default: 3 use_layer_scale: If True, learnable layer scale is used. Default: ``True`` layer_scale_init_value: Initial value for layer scale. Default: 1e-5 inference_mode: If True, instantiates model in inference mode. Default: ``False`` """ super().__init__() self.dim = dim self.kernel_size = kernel_size self.inference_mode = inference_mode if inference_mode: self.reparam_conv = nn.Conv2d( in_channels=self.dim, out_channels=self.dim, kernel_size=self.kernel_size, stride=1, padding=self.kernel_size // 2, groups=self.dim, bias=True, ) else: self.norm = MobileOneBlock( dim, dim, kernel_size, padding=kernel_size // 2, groups=dim, use_act=False, use_scale_branch=False, num_conv_branches=0, ) self.mixer = MobileOneBlock( dim, dim, kernel_size, padding=kernel_size // 2, groups=dim, use_act=False, ) self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True ) def forward(self, x: torch.Tensor) -> torch.Tensor: if hasattr(self, "reparam_conv"): x = self.reparam_conv(x) return x else: if self.use_layer_scale: x = x + self.layer_scale * (self.mixer(x) - self.norm(x)) else: x = x + self.mixer(x) - self.norm(x) return x def reparameterize(self) -> None: """Reparameterize mixer and norm into a single convolutional layer for efficient inference. """ if self.inference_mode: return self.mixer.reparameterize() self.norm.reparameterize() if self.use_layer_scale: w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * ( self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight ) b = torch.squeeze(self.layer_scale) * ( self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias ) else: w = ( self.mixer.id_tensor + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight ) b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias self.reparam_conv = nn.Conv2d( in_channels=self.dim, out_channels=self.dim, kernel_size=self.kernel_size, stride=1, padding=self.kernel_size // 2, groups=self.dim, bias=True, ) self.reparam_conv.weight.data = w self.reparam_conv.bias.data = b for para in self.parameters(): para.detach_() self.__delattr__("mixer") self.__delattr__("norm") if self.use_layer_scale: self.__delattr__("layer_scale") class ConvFFN(nn.Module): """Convolutional FFN Module.""" def __init__( self, in_channels: int, hidden_channels: Optional[int] = None, out_channels: Optional[int] = None, act_layer: nn.Module = nn.GELU, drop: float = 0.0, ) -> None: """Build convolutional FFN module. Args: in_channels: Number of input channels. hidden_channels: Number of channels after expansion. Default: None out_channels: Number of output channels. Default: None act_layer: Activation layer. Default: ``GELU`` drop: Dropout rate. Default: ``0.0``. """ super().__init__() out_channels = out_channels or in_channels hidden_channels = hidden_channels or in_channels self.conv = nn.Sequential() self.conv.add_module( "conv", nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=7, padding=3, groups=in_channels, bias=False, ), ) self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_channels)) self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1) self.drop = nn.Dropout(drop) self.apply(self._init_weights) def _init_weights(self, m: nn.Module) -> None: if isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class RepCPE(nn.Module): """Implementation of conditional positional encoding. For more details refer to paper: `Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_ In our implementation, we can reparameterize this module to eliminate a skip connection. """ def __init__( self, in_channels: int, embed_dim: int = 768, spatial_shape: Union[int, Tuple[int, int]] = (7, 7), inference_mode=False, ) -> None: """Build reparameterizable conditional positional encoding Args: in_channels: Number of input channels. embed_dim: Number of embedding dimensions. Default: 768 spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7) inference_mode: Flag to instantiate block in inference mode. Default: ``False`` """ super(RepCPE, self).__init__() if isinstance(spatial_shape, int): spatial_shape = tuple([spatial_shape] * 2) assert isinstance(spatial_shape, Tuple), ( f'"spatial_shape" must by a sequence or int, ' f"get {type(spatial_shape)} instead." ) assert len(spatial_shape) == 2, ( f'Length of "spatial_shape" should be 2, ' f"got {len(spatial_shape)} instead." ) self.spatial_shape = spatial_shape self.embed_dim = embed_dim self.in_channels = in_channels self.groups = embed_dim if inference_mode: self.reparam_conv = nn.Conv2d( in_channels=self.in_channels, out_channels=self.embed_dim, kernel_size=self.spatial_shape, stride=1, padding=int(self.spatial_shape[0] // 2), groups=self.embed_dim, bias=True, ) else: self.pe = nn.Conv2d( in_channels, embed_dim, spatial_shape, 1, int(spatial_shape[0] // 2), bias=True, groups=embed_dim, ) def forward(self, x: torch.Tensor) -> torch.Tensor: if hasattr(self, "reparam_conv"): x = self.reparam_conv(x) return x else: x = self.pe(x) + x return x def reparameterize(self) -> None: # Build equivalent Id tensor input_dim = self.in_channels // self.groups kernel_value = torch.zeros( ( self.in_channels, input_dim, self.spatial_shape[0], self.spatial_shape[1], ), dtype=self.pe.weight.dtype, device=self.pe.weight.device, ) for i in range(self.in_channels): kernel_value[ i, i % input_dim, self.spatial_shape[0] // 2, self.spatial_shape[1] // 2, ] = 1 id_tensor = kernel_value # Reparameterize Id tensor and conv w_final = id_tensor + self.pe.weight b_final = self.pe.bias # Introduce reparam conv self.reparam_conv = nn.Conv2d( in_channels=self.in_channels, out_channels=self.embed_dim, kernel_size=self.spatial_shape, stride=1, padding=int(self.spatial_shape[0] // 2), groups=self.embed_dim, bias=True, ) self.reparam_conv.weight.data = w_final self.reparam_conv.bias.data = b_final for para in self.parameters(): para.detach_() self.__delattr__("pe") class RepMixerBlock(nn.Module): """Implementation of Metaformer block with RepMixer as token mixer. For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_ """ def __init__( self, dim: int, kernel_size: int = 3, mlp_ratio: float = 4.0, act_layer: nn.Module = nn.GELU, drop: float = 0.0, drop_path: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, inference_mode: bool = False, ): """Build RepMixer Block. Args: dim: Number of embedding dimensions. kernel_size: Kernel size for repmixer. Default: 3 mlp_ratio: MLP expansion ratio. Default: 4.0 act_layer: Activation layer. Default: ``nn.GELU`` drop: Dropout rate. Default: 0.0 drop_path: Drop path rate. Default: 0.0 use_layer_scale: Flag to turn on layer scale. Default: ``True`` layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 inference_mode: Flag to instantiate block in inference mode. Default: ``False`` """ super().__init__() self.token_mixer = RepMixer( dim, kernel_size=kernel_size, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, ) assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( mlp_ratio ) mlp_hidden_dim = int(dim * mlp_ratio) self.convffn = ConvFFN( in_channels=dim, hidden_channels=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) # Drop Path self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() # Layer Scale self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True ) def forward(self, x): if self.use_layer_scale: x = self.token_mixer(x) x = x + self.drop_path(self.layer_scale * self.convffn(x)) else: x = self.token_mixer(x) x = x + self.drop_path(self.convffn(x)) return x class AttentionBlock(nn.Module): """Implementation of metaformer block with MHSA as token mixer. For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_ """ def __init__( self, dim: int, mlp_ratio: float = 4.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.BatchNorm2d, drop: float = 0.0, drop_path: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, ): """Build Attention Block. Args: dim: Number of embedding dimensions. mlp_ratio: MLP expansion ratio. Default: 4.0 act_layer: Activation layer. Default: ``nn.GELU`` norm_layer: Normalization layer. Default: ``nn.BatchNorm2d`` drop: Dropout rate. Default: 0.0 drop_path: Drop path rate. Default: 0.0 use_layer_scale: Flag to turn on layer scale. Default: ``True`` layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 """ super().__init__() self.norm = norm_layer(dim) self.token_mixer = MHSA(dim=dim) assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( mlp_ratio ) mlp_hidden_dim = int(dim * mlp_ratio) self.convffn = ConvFFN( in_channels=dim, hidden_channels=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) # Drop path self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() # Layer Scale self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale_1 = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True ) self.layer_scale_2 = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True ) def forward(self, x): if self.use_layer_scale: x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x))) x = x + self.drop_path(self.layer_scale_2 * self.convffn(x)) else: x = x + self.drop_path(self.token_mixer(self.norm(x))) x = x + self.drop_path(self.convffn(x)) return x def basic_blocks( dim: int, block_index: int, num_blocks: List[int], token_mixer_type: str, kernel_size: int = 3, mlp_ratio: float = 4.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.BatchNorm2d, drop_rate: float = 0.0, drop_path_rate: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, inference_mode=False,
) -> nn.Sequential: """Build FastViT blocks within a stage. Args: dim: Number of embedding dimensions. block_index: block index. num_blocks: List containing number of blocks per stage. token_mixer_type: Token mixer type. kernel_size: Kernel size for repmixer. mlp_ratio: MLP expansion ratio. act_layer: Activation layer. norm_layer: Normalization layer. drop_rate: Dropout rate. drop_path_rate: Drop path rate. use_layer_scale: Flag to turn on layer scale regularization. layer_scale_init_value: Layer scale value at initialization. inference_mode: Flag to instantiate block in inference mode. Returns: nn.Sequential object of all the blocks within the stage. """ blocks = [] for block_idx in range(num_blocks[block_index]): block_dpr = ( drop_path_rate * (block_idx + sum(num_blocks[:block_index])) / (sum(num_blocks) - 1) ) if token_mixer_type == "repmixer": blocks.append( RepMixerBlock( dim, kernel_size=kernel_size, mlp_ratio=mlp_ratio, act_layer=act_layer, drop=drop_rate, drop_path=block_dpr, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, ) ) elif token_mixer_type == "attention": blocks.append( AttentionBlock( dim, mlp_ratio=mlp_ratio, act_layer=act_layer, norm_layer=norm_layer, drop=drop_rate, drop_path=block_dpr, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, ) ) else: raise ValueError( "Token mixer type: {} not supported".format(token_mixer_type) ) blocks = nn.Sequential(*blocks) return blocks class FastViT(nn.Module): """ This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_ """ def __init__( self, layers, token_mixers: Tuple[str, ...], embed_dims=None, mlp_ratios=None, downsamples=None, repmixer_kernel_size=3, norm_layer: nn.Module = nn.BatchNorm2d, act_layer: nn.Module = nn.GELU, num_classes=1000, pos_embs=None, down_patch_size=7, down_stride=2, drop_rate=0.0, drop_path_rate=0.0, use_layer_scale=True, layer_scale_init_value=1e-5, fork_feat=False, init_cfg=None, pretrained=None, cls_ratio=2.0, inference_mode=False, **kwargs, ) -> None: super().__init__() if not fork_feat: self.num_classes = num_classes self.fork_feat = fork_feat if pos_embs is None: pos_embs = [None] * len(layers) # Convolutional stem self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode) # Build the main stages of the network architecture network = [] for i in range(len(layers)): # Add position embeddings if requested if pos_embs[i] is not None: network.append( pos_embs[i]( embed_dims[i], embed_dims[i], inference_mode=inference_mode ) ) stage = basic_blocks( embed_dims[i], i, layers, token_mixer_type=token_mixers[i], kernel_size=repmixer_kernel_size, mlp_ratio=mlp_ratios[i], act_layer=act_layer, norm_layer=norm_layer, drop_rate=drop_rate, drop_path_rate=drop_path_rate, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, ) network.append(stage) if i >= len(layers) - 1: break # Patch merging/downsampling between stages. if downsamples[i] or embed_dims[i] != embed_dims[i + 1]: network.append( PatchEmbed( patch_size=down_patch_size, stride=down_stride, in_channels=embed_dims[i], embed_dim=embed_dims[i + 1], inference_mode=inference_mode, ) ) self.network = nn.ModuleList(network) # For segmentation and detection, extract intermediate output if self.fork_feat: # add a norm layer for each output self.out_indices = [0, 2, 4, 6] for i_emb, i_layer in enumerate(self.out_indices): if i_emb == 0 and os.environ.get("FORK_LAST3", None): """For RetinaNet, `start_level=1`. The first norm layer will not used. cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...` """ layer = nn.Identity() else: layer = norm_layer(embed_dims[i_emb]) layer_name = f"norm{i_layer}" self.add_module(layer_name, layer) else: # Classifier head self.gap = nn.AdaptiveAvgPool2d(output_size=1) self.conv_exp = MobileOneBlock( in_channels=embed_dims[-1], out_channels=int(embed_dims[-1] * cls_ratio), kernel_size=3, stride=1, padding=1, groups=embed_dims[-1], inference_mode=inference_mode, use_se=True, num_conv_branches=1, ) self.head = ( nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes) if num_classes > 0 else nn.Identity() ) self.apply(self.cls_init_weights) self.init_cfg = copy.deepcopy(init_cfg) # load pre-trained model if self.fork_feat and (self.init_cfg is not None or pretrained is not None): self.init_weights() def cls_init_weights(self, m: nn.Module) -> None: """Init. for classification""" if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) @staticmethod def _scrub_checkpoint(checkpoint, model): sterile_dict = {} for k1, v1 in checkpoint.items(): if k1 not in model.state_dict(): continue if v1.shape == model.state_dict()[k1].shape: sterile_dict[k1] = v1 return sterile_dict def init_weights(self, pretrained: str = None) -> None: """Init. for mmdetection or mmsegmentation by loading ImageNet pre-trained weights. """ logger = get_root_logger() if self.init_cfg is None and pretrained is None: logger.warning( f"No pre-trained weights for " f"{self.__class__.__name__}, " f"training start from scratch" ) pass else: assert "checkpoint" in self.init_cfg, ( f"Only support " f"specify `Pretrained` in " f"`init_cfg` in " f"{self.__class__.__name__} " ) if self.init_cfg is not None: ckpt_path = self.init_cfg["checkpoint"] elif pretrained is not None: ckpt_path = pretrained ckpt = _load_checkpoint(ckpt_path, logger=logger, map_location="cpu") if "state_dict" in ckpt: _state_dict = ckpt["state_dict"] elif "model" in ckpt: _state_dict = ckpt["model"] else: _state_dict = ckpt sterile_dict = FastViT._scrub_checkpoint(_state_dict, self) state_dict = sterile_dict missing_keys, unexpected_keys = self.load_state_dict(state_dict, False) def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) return x def forward_tokens(self, x: torch.Tensor) -> torch.Tensor: outs = [] for idx, block in enumerate(self.network): x = block(x) if self.fork_feat and idx in self.out_indices: norm_layer = getattr(self, f"norm{idx}") x_out = norm_layer(x) outs.append(x_out) if self.fork_feat: # output the features of four stages for dense prediction return outs # output only the features of last layer for image classification return x def forward(self, x: torch.Tensor) -> torch.Tensor: # input embedding x = self.forward_embeddings(x) # through backbone x = self.forward_tokens(x) if self.fork_feat: # output features of four stages for dense prediction return x # for image classification x = self.conv_exp(x) x = self.gap(x) x = x.view(x.size(0), -1) cls_out = self.head(x) return cls_out @register_model
def fastvit_t8(pretrained=False, **kwargs): """Instantiate FastViT-T8 model variant.""" layers = [2, 2, 4, 2] embed_dims = [48, 96, 192, 384] mlp_ratios = [3, 3, 3, 3] downsamples = [True, True, True, True] token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer") model = FastViT( layers, token_mixers=token_mixers, embed_dims=embed_dims, mlp_ratios=mlp_ratios, downsamples=downsamples, **kwargs, ) model.default_cfg = default_cfgs["fastvit_t"] if pretrained: raise ValueError("Functionality not implemented.") return model @register_model
def fastvit_t12(pretrained=False, **kwargs): """Instantiate FastViT-T12 model variant.""" layers = [2, 2, 6, 2] embed_dims = [64, 128, 256, 512] mlp_ratios = [3, 3, 3, 3] downsamples = [True, True, True, True] token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer") model = FastViT( layers, token_mixers=token_mixers, embed_dims=embed_dims, mlp_ratios=mlp_ratios, downsamples=downsamples, **kwargs, ) model.default_cfg = default_cfgs["fastvit_t"] if pretrained: raise ValueError("Functionality not implemented.") return model @register_model
def fastvit_s12(pretrained=False, **kwargs): """Instantiate FastViT-S12 model variant.""" layers = [2, 2, 6, 2] embed_dims = [64, 128, 256, 512] mlp_ratios = [4, 4, 4, 4] downsamples = [True, True, True, True] token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer") model = FastViT( layers, token_mixers=token_mixers, embed_dims=embed_dims, mlp_ratios=mlp_ratios, downsamples=downsamples, **kwargs, ) model.default_cfg = default_cfgs["fastvit_s"] if pretrained: raise ValueError("Functionality not implemented.") return model @register_model
def fastvit_sa12(pretrained=False, **kwargs): """Instantiate FastViT-SA12 model variant.""" layers = [2, 2, 6, 2] embed_dims = [64, 128, 256, 512] mlp_ratios = [4, 4, 4, 4] downsamples = [True, True, True, True] pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] token_mixers = ("repmixer", "repmixer", "repmixer", "attention") model = FastViT( layers, token_mixers=token_mixers, embed_dims=embed_dims, pos_embs=pos_embs, mlp_ratios=mlp_ratios, downsamples=downsamples, **kwargs, ) model.default_cfg = default_cfgs["fastvit_s"] if pretrained: raise ValueError("Functionality not implemented.") return model @register_model
def fastvit_sa24(pretrained=False, **kwargs): """Instantiate FastViT-SA24 model variant.""" layers = [4, 4, 12, 4] embed_dims = [64, 128, 256, 512] mlp_ratios = [4, 4, 4, 4] downsamples = [True, True, True, True] pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] token_mixers = ("repmixer", "repmixer", "repmixer", "attention") model = FastViT( layers, token_mixers=token_mixers, embed_dims=embed_dims, pos_embs=pos_embs, mlp_ratios=mlp_ratios, downsamples=downsamples, **kwargs, ) model.default_cfg = default_cfgs["fastvit_s"] if pretrained: raise ValueError("Functionality not implemented.") return model @register_model
def fastvit_sa36(pretrained=False, **kwargs): """Instantiate FastViT-SA36 model variant.""" layers = [6, 6, 18, 6] embed_dims = [64, 128, 256, 512] mlp_ratios = [4, 4, 4, 4] downsamples = [True, True, True, True] pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] token_mixers = ("repmixer", "repmixer", "repmixer", "attention") model = FastViT( layers, embed_dims=embed_dims, token_mixers=token_mixers, pos_embs=pos_embs, mlp_ratios=mlp_ratios, downsamples=downsamples, layer_scale_init_value=1e-6, **kwargs, ) model.default_cfg = default_cfgs["fastvit_m"] if pretrained: raise ValueError("Functionality not implemented.") return model @register_model
def fastvit_ma36(pretrained=False, **kwargs): """Instantiate FastViT-MA36 model variant.""" layers = [6, 6, 18, 6] embed_dims = [76, 152, 304, 608] mlp_ratios = [4, 4, 4, 4] downsamples = [True, True, True, True] pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))] token_mixers = ("repmixer", "repmixer", "repmixer", "attention") model = FastViT( layers, embed_dims=embed_dims, token_mixers=token_mixers, pos_embs=pos_embs, mlp_ratios=mlp_ratios, downsamples=downsamples, layer_scale_init_value=1e-6, **kwargs, ) model.default_cfg = default_cfgs["fastvit_m"] if pretrained: raise ValueError("Functionality not implemented.") return model

Thực nghiệm

Bảng dưới so sánh kết quả của các model SOTA trên tập dữ liệu ImageNet-1k. Có thể thấy rằng Fast-ViT đạt được sự hiệu quả cả về độ chính xác và latency trên thiết bị di động và GPU.

image.png

Bảng dưới so sánh kết quả của các model SOTA cho task classification trên tập dữ liệu ImageNet-1k khi được train sử dụng distillation objective.

image.png

Bảng dưới so sánh kết quả các model SOTA được nhóm dựa vào FLOPs trên các bộ dữ liệu khó hơn là ImageNet-A, ImageNet-B, ImageNet-C, ImageNet-Sketch. Với ImageNet-C là đó mean corruption error, tức là càng nhỏ thì càng tốt, với các dataset còn lại thì giá trị càng lớn thì càng tốt.

image.png

Bảng dưới là kết quả trên tập dữ liệu FreiHAND test.

image.png

Bảng dưới là kết quả cho ADE20K semantic segmentation task.

image.png

Bảng dưới là kết quả cho task object detection và instance segmentation trên tập MS-COCO val2017.

image.png

Tham khảo

[1] FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization

[2] https://github.com/apple/

[3] MobileOne: An Improved One millisecond Mobile Backbone

Bình luận

Bài viết tương tự

- vừa được xem lúc

Tấn công và phòng thủ bậc nhất cực mạnh cho các mô hình học máy

tấn công bậc nhất cực mạnh = universal first-order adversary. Update: Bleeding edge của CleverHans đã lên từ 3.1.0 đến 4.

0 0 45

- vừa được xem lúc

[Deep Learning] Key Information Extraction from document using Graph Convolution Network - Bài toán trích rút thông tin từ hóa đơn với Graph Convolution Network

Các nội dung sẽ được đề cập trong bài blog lần này. . Tổng quan về GNN, GCN. Bài toán Key Information Extraction, trích rút thông tin trong văn bản từ ảnh.

0 0 232

- vừa được xem lúc

Trích xuất thông tin bảng biểu cực đơn giản với OpenCV

Trong thời điểm nhà nước đang thúc đẩy mạnh mẽ quá trình chuyển đổi số như hiện nay, Document Understanding nói chung cũng như Table Extraction nói riêng đang trở thành một trong những lĩnh vực được quan tâm phát triển và chú trọng hàng đầu. Vậy Table Extraction là gì? Document Understanding là cái

0 0 240

- vừa được xem lúc

Con đường AI của tôi

Gần đây, khá nhiều bạn nhắn tin hỏi mình những câu hỏi đại loại như: có nên học AI, bắt đầu học AI như nào, làm sao tự học cho đúng, cho nhanh, học không bị nản, lộ trình học AI như nào... Sau nhiều lần trả lời, mình nghĩ rằng nên viết hẳn một bài để trả lời chi tiết hơn, cũng như để các bạn sau này

0 0 166

- vừa được xem lúc

[B5'] Smooth Adversarial Training

Đây là một bài trong series Báo khoa học trong vòng 5 phút. Được viết bởi Xie et. al, John Hopkins University, trong khi đang intern tại Google. Hiện vẫn là preprint do bị reject tại ICLR 2021.

0 0 49

- vừa được xem lúc

Deep Learning với Java - Tại sao không?

Muốn tìm hiểu về Machine Learning / Deep Learning nhưng với background là Java thì sẽ như thế nào và bắt đầu từ đâu? Để tìm được câu trả lời, hãy đọc bài viết này - có thể kỹ năng Java vốn có sẽ giúp bạn có những chuyến phiêu lưu thú vị. DJL là tên viết tắt của Deep Java Library - một thư viện mã ng

0 0 158