- vừa được xem lúc

Paper reading | ViViT: A Video Vision Transformer

0 0 3

Người đăng: Viblo AI

Theo Viblo Asia

Đóng góp của bài báo

Bài báo đề xuất một mô hình transformer thuần có tên là ViViT sử dụng cho bài toán video classification. Thao tác chính vẫn là sử dụng self attention cho chuỗi các spatio-temporal (không gian-thời gian) token được trích xuất từ video đầu vào 😀 Để xử lý hiệu quả một lượng lớn các spatio-temporal token có thể gặp trong video, nhóm tác giả đề xuất một số phương pháp phân tích mô hình theo chiều spatial và temporal nhằm tăng độ hiệu quả và khả năng mở rộng. Ngoài ra, để train model hiệu quả trên tập dữ liệu nhỏ, nhóm tác giả thực hiện reguliarise model trong suốt quá trình training và tất nhiên sẽ tận dụng thêm các pretrained image model nữa 😀 Tổng quan mô hình trong hình dưới đây:

image.png

Mô hình đạt kết quả SOTA trên các benchmark cho bài toán video classification như Kinetics 400, Kinetics 600, Epic Kitchens 100, Something-Something v2 và Moments in Time.

Phương pháp

Tổng quan về Vision Transformer

Để hiểu về mô hình Transformer dùng cho video, ta sẽ ngó lại chút xem mô hình Transformer dùng cho hình ảnh (model ViT) như nào 😀

Mô hình ViT nhận đầu vào là NN patch ảnh không chồng nhau, kí hiệu là xiRh×wx_i \in \mathbb{R}^{h \times w}. Các patch ảnh này sẽ đi qua một trainable linear projection, đầu ra sẽ là các patch embedding. Các patch embedding này được concat thành các token z\mathbf{z}. Cụ thể z\mathbf{z} được xác định như sau:

image.png

zclsz_{cls} là learnable token được thêm vào chuỗi các embedded patch. Ngoài ra, chuỗi các embedded patch được cộng với learned positional embedding pRN×d\mathbf{p} \in \mathbb{R}^{N \times d} với mục tiêu duy trì thông tin vị trí.

Các token z\mathbf{z} tiếp tục được đưa vào một encoder bao gồm một chuỗi LL các transformer layer. Mỗi layer ll chữa các thành phần Multi-Headed Self-Attention, layer normalisation (LN) và các MLP block như sau:

image.png

MLP bao gồm 2 lớp linear projection, ở giữa 2 lớp này là một hàm activation phi tuyến tính GELU. Chiều của các token là dd được giữ cố định qua các layer.

Embedding video clips

Bài toán bây giờ đặt ra là làm cách nào để có thể embed được video và đưa vào mô hình😀

Dựa vào ý tưởng cơ bản từ mô hình ViT, nhóm tác giả thực hiện nghiên cứu 2 method để map từ một video VRT×H×W×C\mathbf{V} \in \mathbb{R}^{T \times H \times W \times C} sang chuỗi các token z~Rnt×nh×nw×d\tilde{\mathbf{z}} \in \mathbb{R}^{n_t \times n_h \times n_w \times d}. Sau đó ta sẽ cộng thêm positional embedding và reshape thành RN×d\mathbb{R}^{N \times d} để thu được z\mathbf{z}z\mathbf{z} sẽ là đầu vào của transformer.

Cách đầu tiên là Uniform frame sampling được mô tả trong hình dưới: image.png Đây là cách làm khá đơn giản và tường minh, ta thực hiện lấy mẫu ntn_t frame, sau đó thực hiện embed các frame này như trong mô hình ViT và sau đó concat các token lại với nhau. Theo cách này, ta sẽ có tổng cộng ntnhnwn_t \cdot n_h \cdot n_w token được đưa vào transformer encoder.

Cách embed 2D patch được cài đặt sử dụng Jax như sau:

def embed_2d_patch(x, patches, embedding_dim): """Standard ViT method of embedding input patches.""" n, h, w, c = x.shape assert patches.get('size') is not None, ('patches.size is now the only way' 'to define the patches') fh, fw = patches.size gh, gw = h // fh, w // fw if embedding_dim: x = nn.Conv( embedding_dim, (fh, fw), strides=(fh, fw), padding='VALID', name='embedding')(x) else: # This path often results in excessive padding: b/165788633 x = jnp.reshape(x, [n, gh, fh, gw, fw, c]) x = jnp.transpose(x, [0, 1, 3, 2, 4, 5]) x = jnp.reshape(x, [n, gh, gw, -1]) return x

Cách thứ 2 là Tubelet embedding được mô tả trong hình dưới:

image.png

Thay vì trích xuất các frame riêng lẻ, ta sẽ trích xuất các "ống" (tube) spatio-temporal từ input và sau đó đi qua lớp linear projection để chuyển về chiều Rd\mathbb{R}^d. Đây chính là phiên bản mở rộng của ViT embedding lên 3D và do đó sẽ sử dụng 3D convolution. Một tube sẽ có 3 chiều là t×h×wt \times h \times w, mỗi token được trích xuất theo chiều temporal, height và width tương ứng như sau nt=Tt,nh=Hh,nw=Wwn_t = \left\lfloor\frac{T}{t}\right\rfloor, n_h = \left\lfloor\frac{H}{h}\right\rfloor, n_w = \left\lfloor\frac{W}{w}\right\rfloor. Chiều của các tube càng nhỏ thì càng nhiều token và làm tăng chi phí tính toán. Cách embedding video này khác với cách đầu tiên ở chỗ, cách Uniform frame sampling là lấy các frame rời rạc và tổng hợp thông tin temporal trong transformer encoder còn cách Tubelet embedding sẽ tổng hợp được luôn thông tin spatio-temporal 😀

Cách cài đặt 3D patch sử dụng Jax như sau:

def embed_3d_patch(x, patches, embedding_dim, kernel_init_method, name='embedding'): """Embed 3D input patches into tokens.""" assert patches.get('size') is not None, 'patches.size must be defined' assert len(patches.size) == 3, 'patches.size must have 3 elements' assert embedding_dim, 'embedding_dim must be specified' fh, fw, ft = patches.size x = nn.Conv( embedding_dim, (ft, fh, fw), strides=(ft, fh, fw), padding='VALID', name=name, kernel_init=kernel_initializer)( x) return x

Mô hình Transformer cho video

Nhóm tác giả đề xuất các kiến trúc transformer-based cho video bao gồm 4 model.

Model 1: Spatio-temporal attention

Đầu tiên ta thực hiện tokenize video sử dụng Tubelet embedding như đã trình bày ở phần trước. Sau khi tokenize, ta thực hiện cộng các patch embedding với positional embedding. Cuối cùng, các spatio-temporal token được truyền vào transformer encoder. Mỗi lớp transformer sẽ mô hình hóa tất cả tương tác giữa các spatio-temporal token và do đó sẽ mô hình được các tương tác dài hạn trong video. Tuy nhiên, Multi-Headed Self Attention (MSA) có độ phức tạp bậc 2 tương ứng với số lượng token. Vấn đề này đặc biệt khoai đối với dữ liệu đầu vào là video. Khi số frame trong video tăng thì độ phức tạp cũng ngày càng lớn hơn 😀

Một encoder block cơ bản được cài đặt như sau:

class EncoderBlock(nn.Module): """Transformer encoder block. Attributes: mlp_dim: Dimension of the mlp on top of attention block. num_heads: Number of heads. attention_axis: Axis over which we run attention. dropout_rate: Dropout rate. attention_dropout_rate: Dropout for attention heads. droplayer_p: Probability of dropping a layer. attention_kernel_initializer: Initializer to use for attention layers. deterministic: Deterministic or not (to apply dropout). attention_fn: dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` dtype: The dtype of the computation (default: float32). Returns: Output after transformer encoder block. """ mlp_dim: int num_heads: int dtype: jnp.dtype = jnp.float32 dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 attention_kernel_initializer: Initializer = nn.initializers.xavier_uniform() attention_fn: Any = nn.dot_product_attention droplayer_p: float = 0.0 def get_drop_pattern(self, x, deterministic): if not deterministic and self.droplayer_p: shape = (x.shape[0],) + (1,) * (x.ndim - 1) return jax.random.bernoulli( self.make_rng('dropout'), self.droplayer_p, shape).astype('float32') else: return 0.0 @nn.compact def __call__(self, inputs: jnp.ndarray, deterministic: bool) -> jnp.ndarray: """Applies Encoder1DBlock module.""" # Attention block. x = nn.LayerNorm(dtype=self.dtype)(inputs) x = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=self.attention_kernel_initializer, broadcast_dropout=False, dropout_rate=self.attention_dropout_rate, attention_fn=self.attention_fn, dtype=self.dtype)( x, x, deterministic=deterministic) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic) drop_pattern = self.get_drop_pattern(x, deterministic) x = x * (1.0 - drop_pattern) + inputs # MLP block. y = nn.LayerNorm(dtype=self.dtype)(x) y = attention_layers.MlpBlock( mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate, activation_fn=nn.gelu, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6))( y, deterministic=deterministic) drop_pattern = self.get_drop_pattern(x, deterministic) return y * (1.0 - drop_pattern) + x

Model 2: Factorised encoder

image.png

Như trong hình trên, mô hình gồm 2 loại transformer encoder phân biệt. Sở dĩ mô hình được gọi là Factorised encoder bởi vì ta không dùng cùng 1 encoder cho tất cả các token trong video, thay vào đó ta tách ra làm 2 encoder phân biệt.

Transformer encoder đầu tiên là spatial encoder có cấu trúc khá tương tự như một Transformers encoder chuẩn thông thường. Spatial encoder có nhiệm vụ chỉ mô hình hóa tương tác giữa các token được trích xuất từ cùng temporal index (cùng temporal index và khác spatial index). Cụ thể, video được chia thành các clip. Mỗi clip được chia thành các tube và mỗi tube sẽ như là một token. Ta sẽ đưa các tube và token CLS vào spatial transformer. Đầu ra của spatial encoder là global average pooling từ các encoded token hoặc chỉ là CLS token được trích xuất từ clip (nếu như token CLS được thêm vào input ban đầu).

Các biểu diễn này sau đó được concat và đưa vào temporal encoder bao gồm LtL_t lớp transformer để mô hình hóa tương tác giữa các token từ các temporal index khác nhau. Temporal encoder cũng giống như một Transformers encoder chuẩn. Đầu ra của temporal encoder được đưa vào một MLP head và trả về label cho video.

Mặc dù mô hình này có nhiều lớp transformer hơn Model 1 (và tất nhiên sẽ nhiều tham số hơn) nhưng nó yêu cầu ít FLOP hơn vì 2 khối transformer có độ phức tạp tổng là O((nhnw)2+nt2)\mathcal{O}\left(\left(n_h \cdot n_w\right)^2+n_t^2\right), tốt hơn nhiều so với O((ntnhnw)2)\mathcal{O}\left(\left(n_t \cdot n_h \cdot n_w\right)^2\right) của model 1.

Code cài đặt như sau:

class SpaceTimeViViT(nn.Module): """ViT model for Video with factorized space-time attention.""" spatial_mlp_dim: int spatial_num_layers: int spatial_num_heads: int temporal_mlp_dim: int temporal_num_layers: int temporal_num_heads: int num_classes: int patches: ml_collections.ConfigDict hidden_size: int temporal_encoding_config: ml_collections.ConfigDict attention_config: ml_collections.ConfigDict representation_size: Optional[int] = None dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 stochastic_droplayer_rate: float = 0. classifier: str = 'gap' return_prelogits: bool = False dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, x: jnp.ndarray, *, train: bool, debug: bool = False): del debug x, _ = temporal_encode( x, self.temporal_encoding_config, self.patches, self.hidden_size, return_1d=False) bs, t, h, w, c = x.shape x = x.reshape(bs, t, h * w, c) def vit_body(x, mlp_dim, num_layers, num_heads, encoder_name='Transformer'): # If we want to add a class token, add it here. if self.classifier in ['token']: n, _, c = x.shape cls = self.param(f'cls_{encoder_name}', nn.initializers.zeros, (1, 1, c), x.dtype) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = Encoder( temporal_dims=None, # This is unused for Factorised-Encoder mlp_dim=mlp_dim, num_layers=num_layers, num_heads=num_heads, attention_config=self.attention_config, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, stochastic_droplayer_rate=self.stochastic_droplayer_rate, dtype=self.dtype, name=encoder_name)(x, train=train) if self.classifier in ['token', '0']: x = x[:, 0] elif self.classifier in ('gap', 'gmp', 'gsp'): fn = {'gap': jnp.mean, 'gmp': jnp.max, 'gsp': jnp.sum}[self.classifier] x = fn(x, axis=list(range(1, x.ndim - 1))) return x # run attention across spacec, per frame x = jax.vmap( functools.partial( vit_body, mlp_dim=self.spatial_mlp_dim, num_layers=self.spatial_num_layers, num_heads=self.spatial_num_heads, encoder_name='SpatialTransformer'), in_axes=1, out_axes=1, axis_name='time')( x) assert x.ndim == 3 and x.shape[:2] == (bs, t) # run attention across time, over all frames if not self.attention_config.get('spatial_only_baseline', False): x = vit_body( x, mlp_dim=self.temporal_mlp_dim, num_layers=self.temporal_num_layers, num_heads=self.temporal_num_heads, encoder_name='TemporalTransformer') else: # Do global average pooling instead, as method of combining temporal info. x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) if self.representation_size is not None: x = nn.Dense(self.representation_size, name='pre_logits')(x) x = nn.tanh(x) else: x = nn_layers.IdentityLayer(name='pre_logits')(x) if self.return_prelogits: return x else: x = nn.Dense( self.num_classes, kernel_init=nn.initializers.zeros, name='output_projection')(x) return x

Model 3: Factorised self-attention

Model 3 thì lại tương tự như model 1 😀 điểm khác biệt là transformer encoder block được sử dụng không phải là block transformer chuẩn thông thường. Thay vì sử dụng multi-headed self-attention (MSA) cho tất cả các cặp token, lớp MSA được tách ra làm 2 phần. Phần MSA layer đầu tiên sử dụng để tính toán attention giữa các token được trích xuất có cùng temporal index và MSA layer thứ 2 dùng để tính toán attention giữa các token được trích xuất có temporal index khác nhau. Hình dưới mô tả trực quan 2 layer này.

image.png

Nhóm tác giả cũng nhận thấy rằng thứ tự spatial đến temporal hay temporal đến spatial self-attention cũng không mang lại nhiều sự khác biệt. Số lượng tham số của model 3 lớn hơn model 1 do có thêm một self-attention layer.

Code cài đặt như sau:

class EncoderFactorizedSelfAttentionBlock(nn.Module): """Encoder with facctorized self attention block. Attributes: mlp_dim: Dimension of the mlp on top of attention block. num_heads: Number of heads. temporal_dims: Number of temporal dimensions in the flattened input attention_kernel_initializer: Initializer to use for attention layers. dropout_rate: Dropout rate. attention_dropout_rate: Dropout for attention heads. droplayer_p: Probability of dropping a layer. attention_order: The order to do the attention. Choice of {time_space, space_time}. dtype: the dtype of the computation (default: float32). """ mlp_dim: int num_heads: int temporal_dims: int attention_kernel_initializer: Initializer dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 droplayer_p: Optional[float] = None attention_order: str = 'time_space' dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, inputs: jnp.ndarray, *, deterministic: bool): """Applies Encoder1DBlock module.""" b, thw, d = inputs.shape inputs = _reshape_to_time_space(inputs, self.temporal_dims) self_attention = functools.partial( nn.SelfAttention, num_heads=self.num_heads, kernel_init=self.attention_kernel_initializer, broadcast_dropout=False, dropout_rate=self.attention_dropout_rate, dtype=self.dtype) if self.attention_order == 'time_space': attention_axes = (1, 2) elif self.attention_order == 'space_time': attention_axes = (2, 1) else: raise ValueError(f'Invalid attention order {self.attention_order}.') def _run_attention_on_axis(inputs, axis, two_d_shape): """Reshapes the input and run attention on the given axis.""" inputs = model_utils.reshape_to_1d_factorized(inputs, axis=axis) x = nn.LayerNorm( dtype=self.dtype, name='LayerNorm_{}'.format(_AXIS_TO_NAME[axis]))( inputs) x = self_attention( name='MultiHeadDotProductAttention_{}'.format(_AXIS_TO_NAME[axis]))( x, deterministic=deterministic) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic) x = x + inputs return model_utils.reshape_to_2d_factorized( x, axis=axis, two_d_shape=two_d_shape) x = inputs two_d_shape = inputs.shape for axis in attention_axes: x = _run_attention_on_axis(x, axis, two_d_shape) # MLP block. x = jnp.reshape(x, [b, thw, d]) y = nn.LayerNorm(dtype=self.dtype, name='LayerNorm_mlp')(x) y = attention_layers.MlpBlock( mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate, activation_fn=nn.gelu, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), name='MlpBlock')( y, deterministic=deterministic) return x + y

Model 4: Factorised dot-product attention

Model 4 này có độ phức tạp giống như model 2, 3 trong khi vẫn giữ lượng tham số như model 1 😀 Về cơ bản model 4 giống như model 1 về mặt kiến trúc nhưng khác ở MSA layer trong transformer block (xem hình dưới)

image.png

Trong MSA, nhóm tác giả tách ra làm 2 phần là spatial head và temporal head. Spatial head dùng để tính toán dot product self-attention giữa các token được trích xuất có cùng temporal index và Temporal head dùng để ùng để tính toán dot product self-attention giữa các token được trích xuất có cùng spatial index.

Dễ thấy, trong model 3 nhóm tác giả sử dụng toàn bộ head để tính toán lần lượt spatial và temporal attention, trong khi ở model 4, nhóm tác giả sử dụng các head khác nhau trong cùng MSA layer để tính toán temporal và spatial self-attention.

Code cài đặt như sau:

def factorized_dot_product_attention( query: jnp.ndarray, key: jnp.ndarray, value: jnp.ndarray, bias: Optional[jnp.ndarray] = None, broadcast_dropout: bool = True, dropout_rng: Optional[Any] = None, dropout_rate: float = 0.1, deterministic: bool = False, dtype: jnp.dtype = jnp.float32, precision: Optional[jax.lax.Precision] = None,
) -> jnp.ndarray: """Applies head-factorized qkv dot-product attention. This factorizes the dot-product attention by assigning different heads to run attention on different axes. Args: query: Queries for calculating attention with shape of `[batch..., num_heads, qk_depth_per_head]`. key: Keys for calculating attention with shape of `[batch..., num_heads, qk_depth_per_head]`. value: Values to be used in attention with shape of `[batch..., num_heads, v_depth_per_head]`. bias: Bias for the attention weights. This should be broadcastable to the shape: `[batch...]`. This can be used for incorporating causal masks, padding masks, proximity bias, etc. Default is None, which means no bias is applied on attention matrix. broadcast_dropout: Use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey to be used for dropout. dropout_rate: Dropout rate. deterministic: Deterministic or not (to apply dropout). dtype: The dtype of the computation (default: float32). precision: Numerical precision of the computation see `jax.lax.Precision` for details. Returns: Output of shape `[bs, ..., num_heads, features]`. """ if query.shape != key.shape: raise ValueError('Axial dot product attention only supports ' 'query and key with the same shape.') if bias is not None: raise ValueError('Bias is not supported in ' 'factorized_dot_product_attention.') # Normalize the query with the square of its depth. query = query / jnp.sqrt(query.shape[-1]).astype(dtype) # Shape of query, key, and value: [bs, t, hw, h, c]. prefix_str = 'ab' # Split heads for each axial attention dimension. num_attn_dimensions = query.ndim - 3 # all dims but bs, heads, and channel. if query.shape[-2] % num_attn_dimensions != 0: raise ValueError(f'In head-axial dot-product attention, number of ' f'heads ({query.shape[-2]}) should be divisible by number ' f'of attention dimensions ({num_attn_dimensions})!') queries = jnp.split(query, num_attn_dimensions, axis=-2) keys = jnp.split(key, num_attn_dimensions, axis=-2) values = jnp.split(value, num_attn_dimensions, axis=-2) # queries, keys, and values are each a list with two arrays (sinec # we have two dims, t and hw) that are made by spliting heads: # [(bs, t, hw, h//2, c), (bs, t, hw, h//2, c)]. outputs = [] for i, (query, key, value) in enumerate(zip(queries, keys, values)): # Shape of query, key, and value: [bs, t, hw, h//2, c]. axis = i + 1 # to account for the batch dim batch_dims = prefix_str[:axis] einsum_str = f'{batch_dims}x...z,{batch_dims}y...z->{batch_dims}x...y' # For axis=1 einsum_str (q,k->a): ax...z,ay...z->ax...y # For axis=2 einsum_str (q,k->a): abx...z,aby...z->abx...y attn_logits = jnp.einsum(einsum_str, query, key, precision=precision) # For axis=1 (attention over t): attn_logits.shape: [bs, t, hw, h//2, t] # For axis=2 (attention over hw): attn_logits.shape: [bs, t, hw, h//2, hw] attn_weights = jax.nn.softmax(attn_logits, axis=-1) return jnp.concatenate(outputs, axis=-2)

Thực nghiệm

Bảng dưới là kết quả so sánh các phương pháp input encoding sử dụng ViViT-B và spatio-temporal attention trên tập dữ liệu Kinetics.

image.png

Bảng dưới là kết quả so sánh các kiến trúc model sử dụng ViViT-B là backbone. Kích thước tube là 16×216 \times 2. Kết quả trong bảng là Top-1 accuracy trên tập Kinetic 400 (K400) và accuracy trên tập Epic Kitchens (EK). TPU-v3 được sử dụng trong quá trình inference.

image.png

Nhóm tác giả cũng thực hiện nghiên cứu sự ảnh hưởng của số lượng temporal transformer LtL_t trong model 2. Kết quả là Top-1 acccuracy trên tập Kinetics 400. Lt=0L_t = 0 nghĩa là sử dụng “average pooling baseline".

image.png

So sánh với các method khác trên nhiều tập dữ liệu khác nhau.

image.png

Tham khảo

[1] ViViT: A Video Vision Transformer

[2] https://github.com/google-research/scenic/tree/main/scenic/projects/vivit

[3] AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

Bình luận

Bài viết tương tự

- vừa được xem lúc

Paper reading | ACTION-Net: Multipath Excitation for Action Recognition

Đóng góp của bài báo. Trong các bài toán liên quan tới video, thông tin cần nắm bắt không chỉ là thông tin về mặt không gian (spatial information) giống như hình ảnh mà còn là thông tin về thời gian (

0 0 10

- vừa được xem lúc

Paper reading | Masked Motion Encoding for Self-Supervised Video Representation Learning

Đóng góp của bài báo. Bài báo đóng góp một phương pháp để có thể học biểu diễn video (Video representation learning) một cách hiệu quả.

0 0 2