- vừa được xem lúc

Paper reading | Scaling Language-Image Pre-training via Masking

0 0 10

Người đăng: Viblo AI

Theo Viblo Asia

Động lực và đóng góp

Các model Language-supervised visual pre-training điển hình như CLIP thể hiện được sự mạnh mẽ trong việc học các biểu diễn chung giữa hình ảnh và ngôn ngữ tự nhiên. Mặt khác, CLIP cũng tận dụng được việc sử dụng các pretrained encoder, điều này này cải thiện đáng kể hiệu suất cho các task multimodel và cả unimodel.

Tuy nhiên, do sự phức tạp của việc training multimodel hình ảnh và ngôn ngữ, đặc biệt nếu thực hiện training từ đầu, ta phải cần một lượng lớn dữ liệu được train trong nhiều giờ và tiêu tốn nhiều tài nguyên. Ví dụ như model CLIP gốc được train trên 400 triệu dữ liệu (image,text) trong 32 epoch, cần rất nhiều GPU khỏe. Ngay cả khi sử dụng cơ sở hạ tầng cao cấp, thời gian training vẫn là một nút cổ chai lớn cản trở việc thực hiện training.

Do đó, nhóm tác giả đề xuất một phương pháp có thể training CLIP hiệu quả hơn có tên Fast Language-Image Pre-training (FLIP) giúp tối ưu thời gian training và tài nguyên sử dụng. Ý tưởng cơ bản của phương pháp này là thực hiện xóa ngẫu nhiên một lượng lớn image patch trong quá trình training. Ý tưởng ở đây là sự đánh đổi giữa việc "sự cẩn thận khi xét 1 mẫu (image - text) nào đó" và "số mẫu ta có thể xử lý". Việc sử dụng masking cho ta một số lợi ích. Thứ nhất, vẫn cùng một thời gian training nhưng cho ta train được nhiều mẫu dữ liệu hơn (vì mỗi mẫu dữ liệu giống như việc ta nhìn lướt qua vậy 😄 từ đó có thể nhìn được thêm nhiều mẫu khác). Thứ hai, vẫn cùng lượng memory footprint, ta có thể so sánh đối chiếu nhiều mẫu tại mỗi step hơn (tức là có thể sử dụng batch size lớn hơn).

Quan sát hình dưới ta thấy rằng với thời gian training ít hơn nhưng ta vẫn thu được kết quả training tương tự, thậm chí còn tốt hơn so với phương pháp ban đầu của CLIP.

Nhóm tác giả nhận thấy rằng FLIP hiệu suất cạnh tranh so với CLIP tại nhiều downstream task. Đặc biệt, do việc training diễn ra nhanh hơn, ta có thể tính đến việc scale FLIP. Nhóm tác giả thực hiện nghiên cứu theo 3 hướng:

  • Scale model size
  • Scale dataset size
  • Scale thời gian training

Nhóm tác giả quan sát việc scale model size và dataset có thể làm tăng độ chính xác mà không làm tăng training cost. Điều này là động lực cho việc nghiên cứu scale model vision-language trong tương lai.

Phương pháp

Nhóm tác giả đề xuất 4 ý tưởng chính trong FLIP.

Đầu tiên là image masking. Như model CLIP, nhóm tác giả sử dụng model ViT làm image encoder. Image ban đầu cũng được chia thành các patch không overlap nhau. Sau đó, nhóm tác giả thực hiện mask các patch (với tỉ lệ 50%, 75%). Image encoder ViT chỉ được train trên các patch không bị mask. Việc sử dụng tỉ lệ mask là 50% hoặc 75% giảm độ phức tạp thời gian của image encoding đi 1/2 (hoặc 1/4) và đồng thời giúp tăng batch size lên gấp 2 hoặc 4 lần mà vẫn giữ nguyên cost tài nguyên sử dụng.

Tiếp theo là text masking. Bước này thì có hay không là tùy ý 😄 Cách thực hiện giống như image masking, ta cũng thực hiện mask các text token và sử dụng text encoder cho các text token không bị mask. Điều này khác với BERT là ta sẽ cho encoder học cả mask token. Với cách masking và chỉ học trên text token không bị che nên ta có thể giảm cost cho việc training text encoder. Tuy nhiên, việc training text encoder nhanh hơn (do giảm lượng token) có thể làm cho hiệu suất giảm.

Objective của bài toán này là tối thiểu hóa contrastive loss. Negative sample ở đây là các sample còn lại ở cùng batch. Số lượng negative sample lớn là rất cần thiết cho self-supervised contrastive learning trên ảnh.

Mặc dù encoder được pretrain trên các ảnh bị mask, nhưng nó có thể được train trực tiếp trên các ảnh nguyên vẹn mà không cần thay đổi. Để thu hẹp khoảng cách phân phối do masking, ta có thể đặt tỷ lệ masking là 0% và tiếp tục pretraining cho các step nhỏ đầu tiên. Chiến lược unmasking này làm tăng hiệu quả cho việc đánh đổi độ chính xác và thời gian training.

Thực nghiệm

Trong bảng dưới, nhóm tác giả thực hiện đánh giá tác động của các tham số lên hiệu suất mô hình.

Zero-shot accuracy trên dataset ImageNet-1K, so với các baseline CLIP khác nhau. Kích thước ảnh là 224. Các mục được đánh dấu bằng màu xám được pretrain trên một tập dữ liệu khác. Các mô hình FLIP sử dụng batchsize 64.000, tỷ lệ masking 50% và sử dụng chiến lược unmasked tuning.

4 bảng dưới là so sánh kết quả mô hình với nhiều task và các bộ dữ liệu khác nhau

Cài đặt

Hàm random mask được nhóm tác giả implement sử framework JAX như sau:

def gather_by_einsum(x, ids): """ x: [N, L, ...] ids: [N, K] N: Số lượng mẫu trong batch K: Số lượng các chỉ mục được lấy từ ids L: Chiều dài của mỗi mẫu trong batch (trong trường hợp này, đây là chiều thứ hai của x). """ # Tạo một vector one hot có chiều [N, K, L] mat = jax.nn.one_hot(ids, x.shape[1]) # [N, K, L] # Khởi tạo giá trị x là tensor chứa giá trị của các phần tử được giữ lại sau khi masked x = jnp.einsum("nl...,nkl->nk...", x, mat) return x def random_mask(rng, x, mask_ratio, bias=None): """ x: [N, L, C] input bias: [N, L], an additional map to the noise map (small is keep, large is remove) """ # Khởi tạo giá trị N, L. Trong đó N là số lượng các mẫu trong batch và L là độ dài mỗi mẫu hay số lượng các vector trong mỗi mẫu N, L, _ = x.shape # batch, length, dim # Khởi tạo len_keep là số phần tử được giữ lại sau khi mask len_keep = int(L * (1 - mask_ratio)) # Khởi tạo noise theo phân phối uniform có shape bằng N và L noise = random.uniform(rng, shape=x.shape[:2]) # Thêm bias vào noise if bias is not None: noise += bias # Khởi tạo biến ids_shuffle là tensor có giá trị là cần phần tử của noise được sắp xếp theo thứ tự tăng dần ids_shuffle = jnp.argsort(noise, axis=1) # ascend: small is keep, large is remove # Lưu lại chỉ số của các phần tử sau khi sắp xếp  ids_restore = jnp.argsort(ids_shuffle, axis=1) # Giữ lại một số chỉ số bằng len_keep ids_keep = ids_shuffle[:, :len_keep] x_masked = gather_by_einsum(x, ids_keep) x_masked = t5x.layers.with_sharding_constraint( x_masked, ("batch", "length", "embed") ) # generate the binary mask: 0 is keep, 1 is remove # Khởi tạo numpy array có shape là N, L mask = jnp.ones([N, L]) mask = t5x.layers.with_sharding_constraint(mask, ("batch", "length")) # Set các giá trị tại axis = 1 có chỉ số nhỏ hơn len_keep bằng 0  mask = mask.at[:, :len_keep].set(0) # Tạo binary mask  mask = gather_by_einsum(mask, ids_restore) mask = t5x.layers.with_sharding_constraint(mask, ("batch", "length")) return x_masked, mask, ids_restore

Tham khảo

[1] Learning Transferable Visual Models From Natural Language Supervision

[2] Scaling Language-Image Pre-training via Masking

[3] https://github.com/facebookresearch/flip

Bình luận

Bài viết tương tự

- vừa được xem lúc

Tấn công và phòng thủ bậc nhất cực mạnh cho các mô hình học máy

tấn công bậc nhất cực mạnh = universal first-order adversary. Update: Bleeding edge của CleverHans đã lên từ 3.1.0 đến 4.

0 0 42

- vừa được xem lúc

[Deep Learning] Key Information Extraction from document using Graph Convolution Network - Bài toán trích rút thông tin từ hóa đơn với Graph Convolution Network

Các nội dung sẽ được đề cập trong bài blog lần này. . Tổng quan về GNN, GCN. Bài toán Key Information Extraction, trích rút thông tin trong văn bản từ ảnh.

0 0 223

- vừa được xem lúc

Trích xuất thông tin bảng biểu cực đơn giản với OpenCV

Trong thời điểm nhà nước đang thúc đẩy mạnh mẽ quá trình chuyển đổi số như hiện nay, Document Understanding nói chung cũng như Table Extraction nói riêng đang trở thành một trong những lĩnh vực được quan tâm phát triển và chú trọng hàng đầu. Vậy Table Extraction là gì? Document Understanding là cái

0 0 231

- vừa được xem lúc

Con đường AI của tôi

Gần đây, khá nhiều bạn nhắn tin hỏi mình những câu hỏi đại loại như: có nên học AI, bắt đầu học AI như nào, làm sao tự học cho đúng, cho nhanh, học không bị nản, lộ trình học AI như nào... Sau nhiều lần trả lời, mình nghĩ rằng nên viết hẳn một bài để trả lời chi tiết hơn, cũng như để các bạn sau này

0 0 157

- vừa được xem lúc

[B5'] Smooth Adversarial Training

Đây là một bài trong series Báo khoa học trong vòng 5 phút. Được viết bởi Xie et. al, John Hopkins University, trong khi đang intern tại Google. Hiện vẫn là preprint do bị reject tại ICLR 2021.

0 0 45

- vừa được xem lúc

Deep Learning với Java - Tại sao không?

Muốn tìm hiểu về Machine Learning / Deep Learning nhưng với background là Java thì sẽ như thế nào và bắt đầu từ đâu? Để tìm được câu trả lời, hãy đọc bài viết này - có thể kỹ năng Java vốn có sẽ giúp bạn có những chuyến phiêu lưu thú vị. DJL là tên viết tắt của Deep Java Library - một thư viện mã ng

0 0 139