- vừa được xem lúc

Quantization với Pytorch (Phần 2)

0 0 34

Người đăng: Bui Quang Manh

Theo Viblo Asia

3. Giải thuật quantization (Tiếp theo)

Tiếp tục phần giới thiệu giải thiệu quantization với pytorch, ta đến thuật toán đạt hiệu quả cao nhất trong ba phương pháp mà mình có đề cập trong bài Quantization với Pytorch (Phần 1): Quantize Aware Training.

3.3. Quantize Aware Training (QAT)

QAT mô hình hóa những ảnh hưởng của quantization trong suốt quá trình huấn luyện và hiệu chỉnh nó nhờ đó giúp cho phương pháp này đạt hiệu quả cao hơn so với các phương pháp quantization khác.

QAT hoạt động bằng cách chèn những lớp Fake Quantization vào trong mô hình. Gọi chúng là các lớp Fake Quantization bởi vì chúng mô hình việc quantization số nguyên nhưng tính toán bằng các phép floating point.

Qfake(x)=D(Q(x))Q_{fake}(x) = D(Q(x))

trong đó:

  • Q là một hàm quantization sẽ ánh xạ các giá trị ở số phẩy động về dạng số nguyên.
  • D là hàm dequantization sẽ ánh xạ ngược các giá trị đã được quantize bằng hàm Q về dạng số phẩy động.

Ví dụ ta có một lớp Fake Quantization có công thức hoạt động như sau:

Qfake(x,s,b)=s2b1clamp(round(2b1xs),2b1,2b11)Q_{fake}(x, s, b) = \frac{s}{2^{b - 1}}clamp(round(\frac{2^{b - 1}x}{s}), -2^{b - 1}, 2^{b - 1} - 1)

trong đó:

  • x là input tensor ở dạng số phẩy động
  • s là một scale factos
  • b là số bit để quantize
  • round: hàm làm tròn
  • clamp(x, min, max): hàm giới hạn x trong khoảng [min, max]

Ở đây ta nhận thấy phần tử s2b1\frac{s}{2^{b - 1}} đóng vai trò như hàm dequantization trong khi phần còn lại là hàm quantization.

Như đã giải thích bên trên, trong quá trình huấn luyện, phương pháp này vẫn sử dụng các tensor ở dạng float point như bình thường tuy nhiên các lớp Fake Quantize sẽ mô hình hóa ảnh hưởng của quantization bằng cách nhân inputs với một số (scale factor) để biểu diễn số ở floating point sang tập số hữu hạn mới và làm tròn. Quá trình này diễn ra trong cả hai quá trình forward và và backpropagation. Vì vậy mô hình có thể tự tiến hành tối ưu chính nó nếu nó nhận thức được (aware) những ảnh hưởng này. Quantize dựa trên việc nhận thức được ảnh hưởng khi chuyển từ float sang int cũng là lý do phương pháp này có tên là Quantize Aware Training.

Ở phần bên dưới bài viết, chúng ta cùng đi vào phần thực hành sử dụng phương pháp này với thư viện vietocr. Nhưng tạm thời chúng ta sẽ gác lại để lướt qua một vài điểm cần lưu ý khi sử dụng quantization với pytorch.

4. Một số lưu ý

Phần này mình có thấy bài viết A developer-friendly guide to model quantization with PyTorch khá đầy đủ, mình tham khảo và bổ sung chi tiết hơn theo ý hiểu của mình. Các bạn có thể đọc bài viết gốc bằng cách vào trực tiếp đường dẫn bên trên.

1. Quantzation chỉ là phương pháp dùng khi inference.

Ảnh minh họa forward and backpropagation (Nguồn Internet)

Như chúng ta đã biết các số dấu phẩy động có khả năng biểu diễn chính xác hơn nhiều so với các số nguyên (int8). Do đó int8 không thể dùng trong quá trình lan truyền ngược (backpropagation) vì quá trình này rất nhạy cảm với biểu diễn không chính xác của weight và dẫn tới mô hình bị phân kỳ.

2. Độ chính xác sẽ giảm sau khi quantization ?

Quantization thường làm giảm độ chính xác của mô hình. Đây là vấn đề tradeoff giữa độ chính xác và thời gian xử lý. Tuy nhiên, việc chúng ta đánh đổi bao nhiêu độ chính xác để giảm thời gian xử lý phụ thuộc vào rất nhiều yếu tố như kích thước mô hình ban đầu, kĩ thuật quantization hay việc chúng ta quantize bao nhiêu lớp trong mô hình và lớp đó có ảnh hưởng như thế nào đến toàn bộ mô hình,.... Ví dụ một mô hình có kích thước lớn thường có nhiều kết nối dư thừa hay mô hình vẫn biếu diễn tốt với ít kết nối hơn do đó quantize sẽ không gây ảnh hưởng quá nhiều. Những yếu tố này đều được cần nghiên cứu kĩ càng để chúng ta có thể thực hiện tối ưu mô hình một cách tốt nhất.

3. Không cần thực hiện quantization đối với toàn bộ mô hình.

Chúng ta hoàn toàn có thể quantize một phần mô hình và xác định lớp nào được quantize hay không. Để thực hiện điều này, Pytorch cung cấp cho chúng ta hai cách để thực hiện như sau:

  • Tắt / bật chế độ quantization của từng lớp bằng gán các giá trị .qconfig của các lớp với một giá trị qconfig_dict cụ thể. Ví dụ conv1.qconfig = None nghĩa là conv1 không được quantize hoặc conv1.qconfig = custom_qconfig có nghĩa là sử dụng custom_qconfig thay cho config mà ta đã chỉ định sẵn.
  • Dùng QuantStub và DeQuantSub.
import torch # define a floating point model where some layers could be statically quantized
class M(torch.nn.Module): def __init__(self): super(M, self).__init__() # QuantStub converts tensors from floating point to quantized self.quant = torch.quantization.QuantStub() self.conv = torch.nn.Conv2d(1, 1, 1) self.relu = torch.nn.ReLU() # DeQuantStub converts tensors from quantized to floating point self.dequant = torch.quantization.DeQuantStub() def forward(self, x): # manually specify where tensors will be converted from floating # point to quantized in the quantized model x = self.quant(x) x = self.conv(x) x = self.relu(x) # manually specify where tensors will be converted from quantized # to floating point in the quantized model x = self.dequant(x) return x

4. Pytorch chỉ hỗ trợ quantization với CPU

Bạn có thể vô tư thực hiện huấn luyện với Quantize Aware Training ở trên các thiết bị GPU tuy nhiên khi thực hiện inference sử dụng quantization bắt buộc bạn phải sử dụng cpu hoặc trên mobie.

5. Thực hành quantize mô hình VietOCR

Mọi người chắc hẳn đã quen thuộc với thư viện VietOCR - một thư viện OCR cho tiếng Việt. Ở bài trước, mình cũng đã có bài Chuyển đổi mô hình học sâu về ONNX hướng dẫn mọi người chuyển mô hình VietOCR qua dạng ONNX - một định dạng được Pytorch hỗ trợ tối ưu cũng như dễ dàng trong triển khai mô hình. Ở trong bài viết hôm nay, mình cũng sẽ giới thiệu phương pháp quantization giúp cho mô hình VietOCR chạy nhanh hơn trên những thiết bị CPU hoặc edge device. Các bạn có thể xem toàn bộ phần mã nguồn ở đây nhé. Mình cùng bắt tay vào làm nào ?

5.1. Định nghĩa cấu hình huấn luyện

Mình sẽ định nghĩa các tham số dùng cho lúc huấn luyện mô hình ở đây.

config = Cfg.load_config_from_name('vgg_seq2seq')
dataset_params = { 'name':'hw', 'data_root':'./data_line/', 'train_annotation':'train_line_annotation.txt', 'valid_annotation':'test_line_annotation.txt'
} params = { 'print_every':200, 'valid_every':15*200, 'iters':20000, 'checkpoint':'./weights/transformerocr.pth', 'export':'./weights/transformerocr.pth', 'metrics': 10000 } config['trainer'].update(params)
config['dataset'].update(dataset_params)
config['device'] = 'cuda:1'
config['cnn']['pretrained']=False
config['weights'] = "./weights/transformerocr.pth"

5.2. Chuẩn bị mô hình cho quantize aware training.

Khởi tạo mô hình và load dữ liệu từ weight có sẵn.

# get pretrained model
model, vocab = build_model(config)
weights = config['weights']
model.load_state_dict(torch.load(weights, map_location=torch.device(device)))

Mô hình bên dưới sẽ giúp chúng ta quantize một phần nhỏ trong toàn bộ mô hình

class QuantizedCNN(nn.Module): def __init__(self, model_fp32): super(QuantizedCNN, self).__init__() # QuantStub converts tensors from floating point to quantized. # This will only be used for inputs. self.quant = torch.quantization.QuantStub() # DeQuantStub converts tensors from quantized to floating point. # This will only be used for outputs. self.dequant = torch.quantization.DeQuantStub() # FP32 model self.model_fp32 = model_fp32 def forward(self, x): # manually specify where tensors will be converted from floating # point to quantized in the quantized model x = self.quant(x) x = self.model_fp32(x) # manually specify where tensors will be converted from quantized # to floating point in the quantized model x = self.dequant(x) return x

Thực hiện fuse layer. Fuse layer là kỹ thuật gộp các layer riêng lẻ như Conv + Bathcnorm + Relu, Conv + Relu, Conv + BatchNorm, Linear + Relu vào một nhóm nhờ đó có thể tính toán trong một lần qua đó cải thiện hiệu suất và tăng tốc độ tính toán.

model = model.train()
for m in model.cnn.model.modules(): if type(m) == nn.Sequential: for n, layer in enumerate(m): if type(layer) == nn.Conv2d: torch.quantization.fuse_modules(m, [str(n), str(n + 1), str(n + 2)], inplace=True)

Trong Pytorch, quantization chỉ hỗ trợ cho một số hàm do đó phụ thuộc vào phương pháp mà mình sử dụng hoặc thiết bị backend mà chúng ta định sử dụng là cpu hay mobie nên chúng ta cần phải chọn cấu hình cho phù hợp.

quantized_cnn = QuantizedCNN(model_fp32=model.cnn)
quantized_cnn.qconfig = torch.quantization.get_default_qconfig("fbgemm") # Print quantization configurations
print(quantized_cnn.qconfig) # the prepare() is used in post training quantization to prepares your model for the calibration step
quantized_cnn = torch.quantization.prepare_qat(quantized_cnn, inplace=True) model.cnn = quantized_cnn

5.3. Huấn luyện mô hình

model.train()
model = model.to(device)
trainer = Trainer(qmodel=model, config=config, pretrained=False)
trainer.train()

Và chúng ta thu được kết quả là kích thước mô hình đã giảm từ 85MB xuống còn 29MB. Phụ thuộc vào bộ dữ liệu sử dụng huấn luyện sẽ dẫn đến kết quả khác nhau. Trong bài hướng dẫn này, mình sử dụng tạm thời bộ dữ liệu mẫu do thư viện VietOCR cung cấp.

5.4. Inference

Ở bước này, chúng ta sẽ sử dụng mô hình đã được quantize để dự đoán.

# define config for inference mode
config = Cfg.load_config_from_name('vgg_seq2seq')
# Pytorch support only cpu device
config['device'] = 'cpu'
config['cnn']['pretrained']=False
config['weights'] = "./weights/quantize_transformerocr.pth" # create quantized model
qmodel, vocab = build_model(config) ## fuse layer
qmodel = model.train()
for m in qmodel.cnn.model.modules(): if type(m) == nn.Sequential: for n, layer in enumerate(m): if type(layer) == nn.Conv2d: torch.quantization.fuse_modules(m, [str(n), str(n + 1), str(n + 2)], inplace=True) # prepare model for quantize aware training
quantized_cnn = QuantizedCNN(model_fp32=qmodel.cnn)
quantized_cnn.qconfig = torch.quantization.get_default_qconfig("fbgemm") # Print quantization configurations
print(quantized_cnn.qconfig) # the prepare() is used in post training quantization to prepares your model for the calibration step
quantized_cnn = torch.quantization.prepare_qat(quantized_cnn, inplace=True)
quantized_cnn = quantized_cnn.to(torch.device('cpu'))
qmodel.cnn = torch.quantization.convert(quantized_cnn, inplace=True) # create detector
detector = Predictor(config, qmodel=qmodel)

Tải bộ dữ liệu mẫu do thư viện VietOCR cung cấp

# Download sample image
! gdown --id 1uMVd6EBjY4Q0G2IkU5iMOQ34X0bysm0b
! unzip -qq -o sample.zip

Tiến hành dự đoán kết quả

img = './sample/031189003299.jpeg'
img = Image.open(img)
plt.imshow(img)
s = detector.predict(img)
s

6. Lời kết

Đến đây nhiều bạn chắc chắn sẽ có thắc mắc tại sao mình mới quantize phần CNN còn phần encoder và decoder thì sao ? Bởi vì QAT chỉ tốt nhất cho những kiến trúc convolution. Còn đối với kiến trúc như LSTM, GRU, Transformer, chúng ta thường sử dụng phương pháp dynamic quantization. Cách này tương đối đơn giản. Các bạn có thể xem lại bài viết trước để nắm rõ thêm. Cảm ơn các bạn đã theo dõi bài viết của mình và đừng quên upvote cho mình. Nếu có bất cứ thắc mắc nào về bài viết, các bạn hãy comment xuống bên dưới để được giải đáp nhé!

Tham khảo.

  1. A developer-friendly guide to model quantization with PyTorch
  2. Aspects and best practices of quantization aware training for custom network accelerators

Bình luận

Bài viết tương tự

- vừa được xem lúc

TorchServe, công cụ hỗ trợ triển khai mô hình PyTorch

Lời mở đầu. Hôm nay tôi sẽ giới thiệu sơ qua cho các bạn công cụ triển khai mô hình dành riêng cho mô hình PyTorch.

0 0 37

- vừa được xem lúc

Bacteria classification bằng thư viện fastai

Giới thiệu. fastai là 1 thư viện deep learning hiện đại, cung cấp API bậc cao để giúp các lập trình viên AI cài đặt các mô hình deep learning cho các bài toán như classification, segmentation... và nhanh chóng đạt được kết quả tốt chỉ bằng vài dòng code. Bên cạnh đó, nhờ được phát triển trên nền tản

0 0 37

- vừa được xem lúc

Pytorch - Một số tips hay, tối ưu cho quá trình huấn luyện model của bạn

Xin chào các bạn, cũng lâu rồi mình mới quay trở lại ngồi viết mấy bài chia sẻ trên viblo. Chẹp, dạo này làm remote nên lười vận động, lười cả viết bài hẳn.

0 0 276

- vừa được xem lúc

Hướng dẫn tất tần tật về Pytorch để làm các bài toán về AI

Giới thiệu về pytorch. Pytorch là framework được phát triển bởi Facebook.

0 0 181

- vừa được xem lúc

Nhận diện khuôn mặt với mạng MTCNN và FaceNet (Phần 2)

Chào mừng các bạn đã quay lại với series "Nhận diện khuôn mặt với mạng MTCNN và FaceNet" của mình. Ở phần 1, mình đã giải thích qua về lý thuyết và nền tảng của 2 mạng là MTCNN và FaceNet.

0 0 733

- vừa được xem lúc

Video Understanding: Tổng quan

"Thợ lặn" hơi lâu, sau sự kiện MayFest thì đến bây giờ cũng là 3 tháng rồi mình không viết thêm bài mới. Thế nên là, hôm nay mình lại ngoi lên, đầu tiên là để luyện lại văn viết một chút, tiếp theo cũ

0 0 97