- vừa được xem lúc

[Paper Explained] PSPNet - Mô hình Deep Learning kinh điển cho bài toán Semantic Segmentation

0 0 37

Người đăng: Tung

Theo Viblo Asia

1. Giới thiệu

Bài toán Semantic segmentation (Phân vùng ngữ nghĩa ảnh) là một trong những bài toán cơ bản trong lĩnh vực Thị giác máy tính, nhiệm vụ của bài toán là phân loại chính xác tới từng pixel trong ảnh. Hình ảnh dưới đây mô tả kết quả phân vùng với tập dữ liệu PASCAL VOC (theo thứ tự từ trái qua phải, lần lượt là ảnh đầu vào, ảnh kết quả và ảnh dự đoán).

Dễ thấy, kết quả của bài toán là một ảnh có cùng kích thước với ảnh đầu vào, trong ảnh kết quả thì các đối tượng trong ảnh nếu có cùng class sẽ được phân vùng thành cùng màu, hay nói cách khác đây chính là phân loại class cho từng pixel trong ảnh. Gần đây, các mô hình Transformer đang đạt hiệu năng rất cao cho bài toán Semantic segmentation do sức mạnh của các pretrained backbone như ViT, PVT, Swin, ... hay mô hình chuyên dụng như SegFormer, ... .Tuy nhiên, để hiểu hơn về bài toán semantic segmentation này, hãy cùng "back to basic" với PSPNet - một mô hình dựa trên kiến trúc CNN nổi tiếng, kinh điển cho bài toán Semantic segmentation.

2. Mô hình

image.png

Nguồn: https://www.researchgate.net/figure/The-receptive-field-of-each-convolution-layer-with-a-3-3-kernel-The-green-area-marks_fig4_316950618
Mô hình PSPNet được trình bày trong bài báo [Pyramid Scene Parsing Network](https://arxiv.org/pdf/1612.01105.pdf), trọng tâm của mô hình là tìm cách mở rộng Receptive field. Vậy khái niệm Receptive field là gì? Receptive field trong Deep learning có thể hiểu là "kích thước của vùng đầu vào để tính toán Feature map đầu ra hoặc một điểm trên Feature map đầu ra".

image.png

Nguồn: https://developer.nvidia.com/blog/image-segmentation-using-digits-5/
Vậy Receptive field quan trọng như thế nào trong các mô hình Deep learning, ta sẽ cùng nhìn vào hình minh họa trên. Ở hình này, ta dễ thấy vùng đầu vào màu vàng khá lớn, bao trọn được chiếc xe ô tô, dẫn đến kết quả dự đoán của điểm ảnh với vùng đầu vào này sẽ dễ dàng được phân loại vào class xe ô tô. Nhưng nếu như vùng đầu vào là vùng màu xanh, khá nhỏ, điểm ảnh với vùng đầu vào này không bao quát hết xe ô tô, và nếu chỉ "giới hạn tầm nhìn" ở trong khoảng ô màu xanh đó thì ngay thậm chí con người cùng khó nhận ra đó là phần đuôi của xe ô tô. Do đó, việc mở rộng Receptive field là vấn đề vô cùng quan trọng với các mô hình Deep learning trong không chỉ là bài toán Semantic segmentation mà còn trong các bài toán khác của lĩnh vực Thị giác máy tính.

Screenshot from 2023-02-23 09-40-07.png

Nguồn: Paper PSPNet

Hình trên mô tả kiến trúc tổng quát của mô hình, Backbone là ResNet với kỹ thuật Dilated convolution, Feature map cuối cùng sẽ có kích thước HW là 1/8 và đưa qua Pyramid Pooling Module (PPM). Sau đó, các Feature map từ PPM được kết hợp lại và đưa ra kết quả phân vùng cuối cùng.

2.1. Backbone

Backbone được trình bày trong PSPNet là ResNet với kỹ thuật Dilated Convolution ở các layer 3 và 4, như trong các mô hình DeepLab. Việc sử dụng kỹ thuật này sẽ giúp ouput tại feature map cuối cùng của backbone có kích thước HW là 1/8 so với ảnh gốc (thay vì 1/32 như các mô hình CNN thông thường khác), do vậy phần nào trách được việc suy hao thông tin về mặt không gian HW khi truyền qua mạng, cũng như kỹ thuật dilated convolution có thể giúp mở rộng thêm Receptive field. Trong lập trình, chỉ cần đơn giản duyệt các tham số trong mạng Backbone ResNet và sửa thuộc tính của các Convolutional kernel.

for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1)
for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1)

2.2. Pyramid Pooling Module

Pyramid Pooling Module (PPM) là đặc trưng của mạng PSPNet, Module này sử dụng phép Global average pooling với nhiều tỷ lệ bin khác nhau để đưa kích thước HW của Feature map sau khi đã pooling về 1x1, 2x2, 3x3, 6x6. Như đã biết, phép Global average pooling là một cách tốt để giảm kích thước Feature map, tăng Receptive field và phép toán này thường sử dụng nhiều trong cách bài toán về Image classification. Tuy nhiên, nếu chỉ sử dụng phép Global average pooling 1 lần thì thông tin đặc trưng sẽ không được đa dạng, hơn nữa nếu chỉ Pooling về kích thước 1x1 sẽ mất đi rất nhiều đặc trưng theo chiều không gian HW của Feature map và làm kém kết quả phân vùng. Với việc sử dụng các phép Global average pooling cùng các tham số khác nhau, PSPNet sẽ học được đặc trưng toàn cục đa dạng hơn, từ đó cải thiện hiệu năng của kết quả phân vùng. Sau khi Pooling, các feature map được đưa qua lớp Convolution, sau đó phóng to về cùng kích thước trước khi Pooling (tức là bằng 1/8 kích thước HW của ảnh gốc) rồi sử dụng phép Cat các Feature map lại theo chiều Channel. Pyramid Pooling Module được lập trình như sau:

class PPM(nn.Module): def __init__(self, in_dim, reduction_dim, bins): super(PPM, self).__init__() self.features = [] for bin in bins: self.features.append(nn.Sequential( nn.AdaptiveAvgPool2d(bin), nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), nn.BatchNorm2d(reduction_dim), nn.ReLU(inplace=True) )) self.features = nn.ModuleList(self.features) def forward(self, x): x_size = x.size() out = [x] for f in self.features: out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True)) return torch.cat(out, 1)

Ngoài ra, trong quá trình training, PSPNet còn sử dụng thêm kỹ thuật Auxiliary loss để giúp Feature map tại giữa mạng đã phải học tốt kết quả phân vùng cũng như nhằm tăng cường Gradient khi lan truyền ngược tới những layer đầu trong mạng. Screenshot from 2023-02-23 10-17-35.png

Lập trình PSPNet như sau:

class PPM(nn.Module): def __init__(self, in_dim, reduction_dim, bins): super(PPM, self).__init__() self.features = [] for bin in bins: self.features.append(nn.Sequential( nn.AdaptiveAvgPool2d(bin), nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), nn.BatchNorm2d(reduction_dim), nn.ReLU(inplace=True) )) self.features = nn.ModuleList(self.features) def forward(self, x): x_size = x.size() out = [x] for f in self.features: out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True)) return torch.cat(out, 1) class PSPNet(nn.Module): def __init__(self, layers=50, bins=(1, 2, 3, 6), dropout=0.1, classes=2, zoom_factor=8, use_ppm=True, criterion=nn.CrossEntropyLoss(ignore_index=255), pretrained=True): super(PSPNet, self).__init__() assert layers in [50, 101, 152] assert 2048 % len(bins) == 0 assert classes > 1 assert zoom_factor in [1, 2, 4, 8] self.zoom_factor = zoom_factor self.use_ppm = use_ppm self.criterion = criterion if layers == 50: resnet = models.resnet50(pretrained=pretrained) elif layers == 101: resnet = models.resnet101(pretrained=pretrained) else: resnet = models.resnet152(pretrained=pretrained) self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.conv2, resnet.bn2, resnet.relu, resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool) self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) fea_dim = 2048 if use_ppm: self.ppm = PPM(fea_dim, int(fea_dim/len(bins)), bins) fea_dim *= 2 self.cls = nn.Sequential( nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Dropout2d(p=dropout), nn.Conv2d(512, classes, kernel_size=1) ) if self.training: self.aux = nn.Sequential( nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout2d(p=dropout), nn.Conv2d(256, classes, kernel_size=1) ) def forward(self, x, y=None): x_size = x.size() assert (x_size[2]-1) % 8 == 0 and (x_size[3]-1) % 8 == 0 h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1) w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1) x = self.layer0(x) x = self.layer1(x) x = self.layer2(x) x_tmp = self.layer3(x) x = self.layer4(x_tmp) if self.use_ppm: x = self.ppm(x) x = self.cls(x) if self.zoom_factor != 1: x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) if self.training: aux = self.aux(x_tmp) if self.zoom_factor != 1: aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True) main_loss = self.criterion(x, y) aux_loss = self.criterion(aux, y) return x.max(1)[1], main_loss, aux_loss else: return x

3. Kết luận

Tóm tắt lại, PSPNet là một mô hình ra đời từ rất lâu, mang tính chất kinh điển cho bài toán Semantic segmentation. Các điểm đáng chú ý của mô hình là việc sử dụng kỹ thuật Dilated convolution, sử dụng Pyramid Pooling Module cũng như thêm kỹ thuật Auxiliary loss. Mình có tham khảo code gốc của tác giả và viết lại đoạn code training cho PSPNet trên Google Colab, các bạn có thể tham khảo tại đây: https://github.com/tungbt-k62/PSPNet_pytorch_colab

Tham khảo

Bình luận

Bài viết tương tự

- vừa được xem lúc

Tấn công và phòng thủ bậc nhất cực mạnh cho các mô hình học máy

tấn công bậc nhất cực mạnh = universal first-order adversary. Update: Bleeding edge của CleverHans đã lên từ 3.1.0 đến 4.

0 0 42

- vừa được xem lúc

[Deep Learning] Key Information Extraction from document using Graph Convolution Network - Bài toán trích rút thông tin từ hóa đơn với Graph Convolution Network

Các nội dung sẽ được đề cập trong bài blog lần này. . Tổng quan về GNN, GCN. Bài toán Key Information Extraction, trích rút thông tin trong văn bản từ ảnh.

0 0 219

- vừa được xem lúc

Trích xuất thông tin bảng biểu cực đơn giản với OpenCV

Trong thời điểm nhà nước đang thúc đẩy mạnh mẽ quá trình chuyển đổi số như hiện nay, Document Understanding nói chung cũng như Table Extraction nói riêng đang trở thành một trong những lĩnh vực được quan tâm phát triển và chú trọng hàng đầu. Vậy Table Extraction là gì? Document Understanding là cái

0 0 230

- vừa được xem lúc

Con đường AI của tôi

Gần đây, khá nhiều bạn nhắn tin hỏi mình những câu hỏi đại loại như: có nên học AI, bắt đầu học AI như nào, làm sao tự học cho đúng, cho nhanh, học không bị nản, lộ trình học AI như nào... Sau nhiều lần trả lời, mình nghĩ rằng nên viết hẳn một bài để trả lời chi tiết hơn, cũng như để các bạn sau này

0 0 157

- vừa được xem lúc

[B5'] Smooth Adversarial Training

Đây là một bài trong series Báo khoa học trong vòng 5 phút. Được viết bởi Xie et. al, John Hopkins University, trong khi đang intern tại Google. Hiện vẫn là preprint do bị reject tại ICLR 2021.

0 0 45

- vừa được xem lúc

Deep Learning với Java - Tại sao không?

Muốn tìm hiểu về Machine Learning / Deep Learning nhưng với background là Java thì sẽ như thế nào và bắt đầu từ đâu? Để tìm được câu trả lời, hãy đọc bài viết này - có thể kỹ năng Java vốn có sẽ giúp bạn có những chuyến phiêu lưu thú vị. DJL là tên viết tắt của Deep Java Library - một thư viện mã ng

0 0 139