- vừa được xem lúc

Exponential Moving Average trong Deep Learning

0 0 13

Người đăng: Hoang Thuy Ha

Theo Viblo Asia

Chất lượng của mô hình học sâu (deep learning) có liên quan chặt chẽ đến quá trình huấn luyện chúng. Để huấn luyện được mô hình tốt, việc giảm nhiễu (noise) từ quá trình cập nhật ngẫu nhiên (stochastic updates) là cần thiết. Cách chính quy, đã được chứng minh bằng toán học, dùng để giảm nhiễu trong tối ưu hàm lồi phải kể đến (tail) average. Ứng dụng trong học sâu (deep learning), để có được mô hình tốt hơn, kỹ thuật lấy trung bình các bộ trọng số (weights) được tạo ra trong quá trình train là lựa chọn không tồi.

Trong bài viết này, kỹ thuật lấy trung bình Exponential Moving Average (EMA) sẽ được đề cập. EMA được sử dụng như là một trick để cải thiện độ chính xác của model. Ví dụ một cách sử dụng, thay vì sử dụng bộ trọng số (weight) từ training iteration cuối cùng, thì EMA của các bộ trọng số (weights) trong quá trình huấn luyện sẽ được dùng để dự đoán.

Giới thiệu về noise

Trong học sâu (deep learning), cho tập dữ liệu dùng để huấn luyện DD bao gồm NN điểm dữ liệu, thông thường mục tiêu muốn đạt được là tối ưu bộ trọng số Θ\Theta của mạng nơ-ron để tối thiểu hóa hàm mất mát (loss) L(D,Θ)L(D, \Theta). Gọi bộ trọng số (weight) sau khi được tối ưu là Θ\Theta^{\\*}.

Trong thực tế, giá chị của NN thường rất lớn, nên việc tối ưu trọng số (weight) mà dùng toàn bộ dữ liệu trong tập huấn luyện là vô cùng "đắt đỏ", đặc biệt là khi dùng các kỹ thuật như gradient descent. Chính vì vậy, thay vì dùng cả bộ dữ liệu, một lượng nhỏ điểm dữ liệu sẽ được dùng để huấn luyện một lúc, gọi là mini-batch, và chúng được lấy một cách ngẫu nhiên từ tập huấn luyện. Điều này sẽ dẫn đến nhiễu (noise). Nhiễu trong quá trình huấn luyện (training noise) có cả mặt tốt, có cả mặt xấu.

Trong gradient descent, đạo hàm (gradient) để cập nhật bộ trọng số (weight) được tính bằng tất cả dữ liệu trong tập huấn luyện. Tuy nhiên như đã đề cập, việc cùng một lúc tính đạo hàm mà dùng tât cả dữ liệu huấn luyện sẽ rất "đắt đỏ", nên nó được ước tính (estimate) bằng đạo hàm (gradient) của mini-batch. Nhờ đó, đạo hàm (gradient) được tính toán nhanh hơn, cũng đi kèm với việc giảm thiểu độ chính xác. Nhưng, việc sử dụng đạo hàm ước tính (estimated gradient), hay (noisy gradient) đôi khi hữu ích trong việc tối ưu, kết quả là đạt được cực trị địa phương tốt hơn so với khi huấn luyện dùng toàn bộ dữ liệu (để tính đạo hàm cùng một lúc). Đôi khi, quá trình tối ưu bộ trọng số (weight) chuẩn bị hội tụ, nhiễu (noise) của đạo hàm (gradient) khi dùng mini-batch có thể làm cho Θ\Theta^{\\*} không phải là cực tiểu, mà Θ\Theta dao động xung quang cực tiểu.

Kĩ thuật giảm nhiễu (noise) như EMA được dùng để khắc phục tình trạng đó.

Giới thiệu về Moving Average

Moving average là kĩ thuật làm mịn (smoothing technique) thường được dùng để giảm nhiễu và biến động trong dữ liệu chuỗi thời gian (time-series data).

Dạng đơn giản nhất của moving average là simple moving average (SMA), được tính bằng giá trị trung bình của kk điểm dữ liệu trước đó. Cụ thể, cho chuỗi dữ liệu p1,p2,p3,...,pt,...p_1, p_2, p_3, ..., p_t, ..., SMA trên kk điểm dữ liệu gần nhất so với điểm dữ liệu ptp_t, được ký hiệu là SMAt{SMA}_{t} :

SMAt=1ki=tk+1tpi \begin{equation} {SMA}_t = \frac { 1 } { k } \sum _{ i = t-k+1 } ^ { t } p _ { i } \end{equation}

SMAt{SMA}_t có thể tính bằng SMAt1{SMA}_{t-1}. Điều này hữu ích khi dùng SMA để tính toán với streaming data:

SMAt=SMAt1+1k(ptptk) \begin{equation} {SMA}_t = {SMA}_{t-1} + \frac{1}{k} (p_t - p_{t-k}) \end{equation}

Về mặt toán học, moving average là một phép toán, dùng để tính trung bình của một số lượng nhất định các điểm dữ liệu trước đó.

Giới thiệu về Exponential Moving Average (EMA)

Trong học sâu (deep learning), Exponential Moving Average (EMA) tính trung bình có trọng số của các bộ trọng số (weights). Cụ thể, cho dãy các trọng số thu được trong quá trình huấn luyện Θ1,Θ2,Θ3,...,Θt,...\Theta_1, \Theta_2, \Theta_3, ..., \Theta_t, ..., trong đó Θt\Theta_t là bộ trọng số (weight) thu được sau training iteration thứ tt, EMA của trọng số (weight) tại thời điểm thứ tt (training iteration thứ tt) ký hiệu là EMAt{EMA}_t.

Công thức của EMA thường được tính như sau:

EMAt={Θ1if iteration=1αEMAt1+(1α)Θtif iteration=t \begin{equation} EMA_{t} = \begin{cases} \Theta_1 & \text{if $iteration = 1$}\\ \alpha * {EMA}_{t-1} + (1-\alpha) * \Theta_t & \text{if $iteration = t$} \end{cases} \end{equation}

Trong đó, α[0,1)\alpha \in[0, 1), được gọi là EMA decay. Thông giá trị của α\alpha được chọn là 0.99,0.999,...0.99, 0.999, ...

Để hiểu rõ hơn ý nghĩa phía sau của EMA, cùng nhau phân tích (3)(3).

EMAt=(1α)[Θt+αΘt1+α2Θt2+...+αt2Θ2]+αt1Θ1(1α)[Θt+αΘt1+α2Θt2+...+αt2Θ2+αt1Θ1]=(1α)i=1tαtiΘi \begin{equation} \begin{aligned} EMA_{t} & = (1-\alpha) [\Theta_t + \alpha\Theta_{t-1} + \alpha^2\Theta_{t-2} + ... + \alpha^{t-2}\Theta_2] + \alpha^{t-1}\Theta_1 \\ & \approx (1-\alpha) [\Theta_t + \alpha\Theta_{t-1} + \alpha^2\Theta_{t-2} + ... + \alpha^{t-2}\Theta_2 + \alpha^{t-1}\Theta_1] \\ &= (1-\alpha) \sum_{i=1}^t \alpha^{t-i} \Theta_{i} \end{aligned} \end{equation}

Trong đó, tổng của các hệ số xấp xỉ 1 khi tt \rightarrow \infty

Một vài nhận xét của bản thân tác giả: Giá trị của EMAtEMA_t chịu ảnh hưởng nhiều hơn bởi các weights gần với training iteration tt hơn.

Code EMA trong deep learning

Mẫu code tham khảo về cách EMA được dùng trong deep learning.

class ModelEmaV2(nn.Module): def __init__(self, model, decay=0.9999, device=None): super(ModelEmaV2, self).__init__() # make a copy of the model for accumulating moving average of weights self.module = deepcopy(model) self.module.eval() self.decay = decay self.device = device # perform ema on different device from model if set if self.device is not None: self.module.to(device=device) def _update(self, model, update_fn): with torch.no_grad(): for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): if self.device is not None: model_v = model_v.to(device=self.device) ema_v.copy_(update_fn(ema_v, model_v)) def update(self, model): self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) def set(self, model): self._update(model, update_fn=lambda e, m: m)

Tài liệu tham khảo

Exponential Moving Average of Weights in Deep Learning: Dynamics and Benefits

Exponential-Moving-Average

Acceleration of Stochastic Approximation by Averaging

Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results

EMA code sample

Bình luận

Bài viết tương tự

- vừa được xem lúc

Tấn công và phòng thủ bậc nhất cực mạnh cho các mô hình học máy

tấn công bậc nhất cực mạnh = universal first-order adversary. Update: Bleeding edge của CleverHans đã lên từ 3.1.0 đến 4.

0 0 42

- vừa được xem lúc

[Deep Learning] Key Information Extraction from document using Graph Convolution Network - Bài toán trích rút thông tin từ hóa đơn với Graph Convolution Network

Các nội dung sẽ được đề cập trong bài blog lần này. . Tổng quan về GNN, GCN. Bài toán Key Information Extraction, trích rút thông tin trong văn bản từ ảnh.

0 0 219

- vừa được xem lúc

Trích xuất thông tin bảng biểu cực đơn giản với OpenCV

Trong thời điểm nhà nước đang thúc đẩy mạnh mẽ quá trình chuyển đổi số như hiện nay, Document Understanding nói chung cũng như Table Extraction nói riêng đang trở thành một trong những lĩnh vực được quan tâm phát triển và chú trọng hàng đầu. Vậy Table Extraction là gì? Document Understanding là cái

0 0 230

- vừa được xem lúc

Con đường AI của tôi

Gần đây, khá nhiều bạn nhắn tin hỏi mình những câu hỏi đại loại như: có nên học AI, bắt đầu học AI như nào, làm sao tự học cho đúng, cho nhanh, học không bị nản, lộ trình học AI như nào... Sau nhiều lần trả lời, mình nghĩ rằng nên viết hẳn một bài để trả lời chi tiết hơn, cũng như để các bạn sau này

0 0 157

- vừa được xem lúc

[B5'] Smooth Adversarial Training

Đây là một bài trong series Báo khoa học trong vòng 5 phút. Được viết bởi Xie et. al, John Hopkins University, trong khi đang intern tại Google. Hiện vẫn là preprint do bị reject tại ICLR 2021.

0 0 45

- vừa được xem lúc

Deep Learning với Java - Tại sao không?

Muốn tìm hiểu về Machine Learning / Deep Learning nhưng với background là Java thì sẽ như thế nào và bắt đầu từ đâu? Để tìm được câu trả lời, hãy đọc bài viết này - có thể kỹ năng Java vốn có sẽ giúp bạn có những chuyến phiêu lưu thú vị. DJL là tên viết tắt của Deep Java Library - một thư viện mã ng

0 0 139