- vừa được xem lúc

[paper explain] Meta Pseudo Labels: khi semi-supervised lên ngôi

0 0 13

Người đăng: Nguyễn Văn Quân

Theo Viblo Asia

1. Mở đầu

Như chúng ta đã biết thì các phương pháp semi-supervised learning đã góp công không nhỏ trong việc cải thiện hơn nữa các model state-of-the-art trong rất nhiều computer vision tasks như image classification, object detection, và semantic segmentation. Các phương pháp như Pseudo Labels hay self-training chắc cũng khá quen thuộc với những người từng làm về semi-supervised learning. Hôm nay mình sẽ giới thiệu đến các bạn một phiên bản nâng cấp của Pseudo Labels, giúp đưa semi-supervised learning lên đỉnh của Imagenet. Nói sơ qua thì cách hoạt động của Pseudo Label khá đơn giản: chúng ta cần 2 model, một gọi là teacher và một là student. Đầu tiên, ta cần huấn luyện teacher model với dữ liệu có nhãn, sau đó sử dụng model teacher để predict ra nhãn giả - pseudo label của dữ liệu chưa có nhãn, từ bây giờ ta sẽ gọi dữ liệu có nhãn chuẩn là labeled data và dữ liệu được sinh nhãn giả là pseudo data cho ngắn gọn nhé. Pseudo data sẽ được kết hợp với labeled data để huấn luyện cho model student, nhờ có sự bổ sung này mà student có thể mang lại kết quả tốt hơn so với teacher.

Mặc dù phương pháp kể trên khá hiệu quả nhưng vẫn tồn tại một nhược điểm lớn là sẽ có những pseudo label mà teacher sinh ra không chính xác, kéo theo student cũng sẽ học từ dữ liệu sai lệch đấy và kết quả là performance của student bị giảm sút. Điểm yếu này được gọi là confirmation bias trong pseudo-labeling.

Paper mà hôm nay mình muốn thảo luận với mọi người là một phiên bản nâng cấp xịn xò của Pseudo Label - Meta Pseudo Labels. Những gì mà Meta Pseudo Label muốn làm là cải thiện nhược điểm kể trên của teacher thông qua việc quan sát pseudo label mà nó sinh ra sẽ có ảnh hưởng gì đến student, nghĩa là nó sẽ nhận lại feedback của student sau khi học từ pseudo label và tự chỉnh sửa lại bản thân để cho ra những phiên bản pseudo label tốt hơn. Vì phần chứng minh của paper khá nhiều toán nên mình sẽ cố gắng đi chậm, nếu các bạn phát hiện ra mình sai thì đừng ngần ngại góp ý nhé 😁

2. Meta Pseudo Labels

Trái: Pseudo Labels, teacher được cố định sau khi train với labeled data sau đó sinh pseudo label cho student học. Phải: Meta Pseudo Labels, teacher được train song song với student.

Ký hiệu:

  • T,ST, S : mô hình teacher và student
  • θt,θs\theta_t,\theta_s : tham số của teacher và student
  • θsPL\theta_s^{PL}: tham số của mô hình student được train với pseudo label tạo bởi teacher
  • T,S|T|, |S| : dimension của student và teacher
  • (xL,yL)(x_L,y_L) : labeled data gồm image và label
  • xUx_U : unlabeled data chỉ gồm image
  • T(xU,θT)T(x_U,\theta_T): soft prediction của teacher với unlabeled data
  • S(xU,θS),S(xL,θS)S(x_U,\theta_S), S(x_L,\theta_S): soft prediction của student với xUx_UxLx_L
  • CE(q,p)\operatorname{CE}(q,p): cross-entropy loss giữa 2 phân phối q, p với q là label
  • Ex[f]\mathbb{E}_x[f] : giá trị kỳ vọng của phương trình f với biến ngẫu nhiên x.
  • \nabla : gradient

2.1 Revisit Pseudo Label

Trước khi nói về Meta Pseudo Label, ta sẽ mở đầu với việc ôn lại 1 chút về Pseudo Label nhé. Như đã giới thiệu ở phần mở đầu, Pseudo Label huấn luyện student model với unlabeled data để tối thiểu hóa hàm cross-entropy:

θSPL=argminθSExu[CE(T(xu;θT),S(xu;θS))]:=Lu(θT,θS)(1)\theta_{S}^{\mathrm{PL}}=\underset{\theta_{S}}{\operatorname{argmin}} \underbrace{\mathbb{E}_{x_{u}}\left[\operatorname{CE}\left(T\left(x_{u} ; \theta_{T}\right), S\left(x_{u} ; \theta_{S}\right)\right)\right]}_{:=\mathcal{L}_{u}\left(\theta_{T}, \theta_{S}\right)} \tag{1}

  • Lu(θT,θS)\mathcal{L}_{u}\left(\theta_{T}, \theta_{S}\right) : loss của student khi train với pseudo label tạo bởi teacher trên unlabeled data.

Giả sử ta đã có model teacher được train tốt với tập labeled, mục tiêu của pseudo label là tạo ra θsPL\theta_s^{PL} tối ưu trên tập labeled data:

Exl,yl[CE(yl,S(xl;θSPL))]:=Ll(θSPL)(2)\mathbb{E}_{x_{l}, y_{l}}\left[\operatorname{CE}\left(y_{l}, S\left(x_{l} ; \theta_{S}^{\mathrm{PL}}\right)\right)\right]:=\mathcal{L}_{l}\left(\theta_{S}^{\mathrm{PL}}\right) \tag{2}

2.2 Solution for confirmation-bias

Với Pseudo Labels, muốn cho student θsPL\theta_s^{PL} tối ưu thì bắt buộc phải phụ thuộc vào teacher θT\theta_T thông qua pseudo label T(xU,θT)T(x_U,\theta_T). Để miêu tả sự phụ thuộc này ta sẽ dùng ký hiệu θSPL(θT)\theta_{S}^{\mathrm{PL}}(\theta_{T}). Như vậy hàm loss của student trên labeled data có thể được viết gọn lại như sau: Ll(θSPL(θT))\mathcal{L}_{l}(\theta_{S}^{\mathrm{PL}}(\theta_{T})) và tất nhiên nhiệm vụ của hàm này sẽ là tối ưu 2 tham số θSPL\theta_{S}^{\mathrm{PL}}θT\theta_{T}. Từ đó ta có thể tối ưu hóa Ll\mathcal{L}_{l} theo θT\theta_T như sau:

minθTLl(θSPL(θT)), trong đoˊ θSPL(θT)=argminθS Lu(θT,θS)(3)\begin{aligned}\min _{\theta_{T}} & \mathcal{L}_{l}\left(\theta_{S}^{\mathrm{PL}}\left(\theta_{T}\right)\right), \\\text { trong đó } & \theta_{S}^{\mathrm{PL}}\left(\theta_{T}\right)=\underset{\theta_{S}}{\operatorname{argmin}}\; \mathcal{L}_{u}\left(\theta_{T}, \theta_{S}\right)\end{aligned} \tag{3}

Theo như công thức trên thì ta có thể tối ưu hóa teacher thông qua biểu hiện của student, từ đó pseudo label dùng để train student cũng sẽ dần được cải thiện. Tuy nhiên do mối phụ thuộc θSPL(θT)\theta_{S}^{\mathrm{PL}}(\theta_{T})θT\theta_T là vô cùng phức tạp nên việc tính gradient θT(θSPL(θT))\nabla_{\theta_{T}}(\theta_{S}^{\mathrm{PL}}(\theta_{T})) nếu muốn diễn ra thì bắt buộc phải thay đổi toàn bộ quá trình training của student.

Để đơn giản hóa việc này, ta sẽ áp dụng ý tưởng của meta-learning : xấp xỉ argminθS\underset{\theta_{S}}{\operatorname{argmin}} bằng cách update từng bước gradient của θT\theta_T:

θSPL(θT)θSηSθSLu(θT,θS)(4)\theta_{S}^{\mathrm{PL}}\left(\theta_{T}\right) \approx \theta_{S}-\eta_{S} \cdot \nabla_{\theta_{S}} \mathcal{L}_{u}\left(\theta_{T}, \theta_{S}\right)\tag{4}

 với ηS laˋ learning rate của student \text{ với } \eta_S \text{ là learning rate của student }

Thay biểu thức trên vào phương trình (3) ta sẽ có hàm tối ưu của teacher trong Meta Pseudo Labels:

minθTLl(θSηSθSLu(θT,θS))(5)\min _{\theta_{T}} \quad \mathcal{L}_{l}\left(\theta_{S}-\eta_{S} \cdot \nabla_{\theta_{S}} \mathcal{L}_{u}\left(\theta_{T}, \theta_{S}\right)\right) \tag{5}

Về cơ bản thì quá trình training của student vẫn phụ thuộc vào phương trình (1) của Pseudo Labels, ngoại trừ việc tham số của teacher sẽ không còn cố định mà thay đổi dần dựa vào student. Từ đó chúng ta sẽ rút ra được quá trình tối ưu hóa song song teacher - student:

  • Student: sử dụng pseudo label từ teacher - T(xU,θT)T(x_U,\theta_T) để tối ưu hóa hàm mục tiêu với SGD:

θS=θSηSθSLu(θT,θS)(6)\theta_S^{\prime}=\theta_S-\eta_{S} \cdot\nabla_{\theta_{S}} \mathcal{L}_{u}\left(\theta_{T}, \theta_{S}\right)\tag{6}

  • Teacher: sử dụng labeled data kết hợp với feedback của student để cải thiện pseudo label và tối ưu hóa hàm mục tiêu với SGD:

θT=θTηTθTLl(θSθSLu(θT,θS))(7)\theta_T^{\prime}=\theta_T-\eta_{T} \cdot\nabla_{\theta_{T}} \mathcal{L}_{l}\left(\theta_S-\nabla_{\theta_{S}} \mathcal{L}_{u}\left(\theta_{T}, \theta_{S}\right)\right)\tag{7}

2.3 Teacher's auxiliary losses

Các tác giả thấy rằng Meta Pseudo Labels tự thân nó đã khá tốt rồi, tuy nhiên nếu thêm một task phụ vào quá trình training của teacher thì performance sẽ còn tốt hơn. Do đó khi train teacher với labeled data, ta có thể thêm một auxiliary task dạng self-supervised để tận dụng unlabeled data giúp tăng độ generalization của model teacher. Auxiliary task này được thực hiện theo paper UDA (Unsupervised Data Augmentation for Consistency Training) với tổng quan như sau:

Ta có thể mô tả 1 cách đơn giản về UDA như sau:

  • B1 : Với labeled data (x,y)(x,y), ta để model predict label y^=Pθ(yx)\hat{y} = P_{\theta}(y|x) và tính supervised loss Lsup=CE(y,y^)L_{sup}= \operatorname{CE}(y, \hat{y})
  • B2 : Với unlabeled data (x)(x), ta tiến hành augment xx để có x^\hat{x}, sau đó để model predict label cho xxx^\hat{x} : Pθ(yx)P_{\theta}(y|x)Pθ(yx^)P_{\theta}(y|\hat{x}) rồi tính unsupervised loss với 2 label trên : Lunsup(Pθ(yx),Pθ(yx^))L_{unsup}(P_{\theta}(y|x), P_{\theta}(y|\hat{x}))
  • B3: tính loss tổng : Lfinal=Lsup+αLunsupL_{final}=L_{sup} + \alpha\cdot L_{unsup} và optimize model dựa trên loss tổng

2.4 Derivation of the Teacher’s Update Rule

Nhắc lại một số ký hiệu toán học:

  • cho hàm khả vi f:RmRn,xf(x),xRmf : R ^ { m } \rightarrow R^ {n}, x \mapsto f(x) , x \in R ^ { m }, ta sẽ tìm được ma trận jacobi của ff dựa trên đạo hàm từng phần hàm ff với xx:

xf=gradf=dfdx=[f(x)x1f(x)x2f(x)xn]\nabla _ { x } f = \operatorname { g r a d } f = \frac { d f } { d x } = \left [ \frac { \partial f \left ( x \right ) } { \partial x _ { 1 } } \quad \frac { \partial f \left ( x \right ) } { \partial x _ { 2 } } \quad \ldots \quad \frac { \partial f \left ( x \right ) } { \partial x _ { n } } \right ]

Giờ ta sẽ vào món chính: tính gradient cho quá trình cập nhật teacher. Giả sử với một batch unlabeled data xux_u, teacher sẽ sinh pseudo label y^uT(xu;θT)\hat { y } _ { u } \sim T ( x _ { u } ; \theta _ { T } ), sau đó student sử dụng (xu,y^u)(x_u,\hat{y}_u) để cập nhật tham số θS\theta_S của nó. Chúng ta kỳ vọng tham số mới của student sẽ có dạng Ey^uT(xu;θT)[θSηSθSCE(y^u,S(xu;θS))]\mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[\theta_{S}-\eta_{S} \nabla_{\theta_{S}} \mathbf{C E}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)\right]. Ta sẽ cập nhật tham số của teacher trên tập labeled data thông qua cross-entropy của sự thay đổi giữa tham số của student cũ và student mới:

RθT1×T=θTCE(yl,S(xl;Ey^uT(xu;θT)[θSηSθSCE(y^u,S(xu;θS))]))(8)\underbrace{\frac{\partial R}{\partial \theta_{T}}}_{1 \times|T|}=\frac{\partial}{\partial \theta_{T}} \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[\theta_{S}-\eta_{S} \nabla_{\theta_{S}} \operatorname{CE}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)\right]\right)\right)\tag{8}

Đặt:

θˉSS×1=Ey^uT(xu;θT)[θSηSθSCE(y^u,S(xu;θS))](9)\underbrace{\bar{\theta}_{S}^{\prime}}_{|S| \times 1}=\mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[\theta_{S}-\eta_{S} \nabla_{\theta_{S}} \mathbf{C E}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)\right]\tag{9}

phương trình (6) trở thành:

RθT1×T=θTCE(yl,S(xl;θS))(8’)\underbrace{\frac{\partial R}{\partial \theta_{T}}}_{1 \times|T|}=\frac{\partial}{\partial \theta_{T}} \operatorname{CE}(y_l, S(x_l;\theta_{S}^{\prime}))\tag{8'}

Áp dụng quy tắc đạo hàm của hàm hợp với RθT\frac{\partial R}{\partial \theta_{T}}:

RθT=RθS×θSθT=CE(yl,S(xl;θˉS))θSθS=θˉS)1×SθˉSθTS×T=A×B(10)\begin{aligned}\frac{\partial R}{\partial \theta_{T}}&=\frac{\partial R}{\partial \theta_{S}}\times\frac{\partial \theta_S}{\partial \theta_{T}}\\&=\underbrace{\left.\frac{\partial \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \bar{\theta}_{S}^{\prime}\right)\right)}{\partial \theta_{S}}\right|_{\left.\theta_{S} =\bar{\theta}_{S}^{\prime}\right)}}_{1\times| S \mid} \cdot \underbrace{\frac{\partial \bar{\theta}_{S}^{\prime}}{\partial \theta_{T}}}_{|S| \times|T|} \\&=\qquad\qquad\qquad A \qquad\qquad\times \quad B\end{aligned}\tag{10}

Xét phương trình (8), phần A chính là quá trình train student θS\theta_S^{\prime} với labeled data sau khi đã train θS\theta_S với pseudo data để có θS\theta_S^{\prime}, phần này hoàn toàn có thể tính thông qua backprop thông thường.

Ta xét tiếp phần B:

B=θˉSθTS×T=θTEy^uT(xu;θT)[θSηSθSCE(y^u,S(xu;θS))]=θTEy^uT(xu;θT)[θSηS(CE(y^u,S(xu;θS))θSθS=θS)](11)\begin{aligned}B=\underbrace{\frac{\partial \bar{\theta}_{S}^{\prime}}{\partial \theta_{T}}}_{|S| \times|T|} &=\frac{\partial}{\partial \theta_{T}} \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[\theta_{S}-\eta_{S} \nabla_{\theta_{S}} \operatorname{CE}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)\right] \\&=\frac{\partial}{\partial \theta_{T}} \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[\theta_{S}-\eta_{S} \cdot\left(\left.\frac{\partial \operatorname{CE}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)}{\partial \theta_{S}}\right|_{\theta_{S}=\theta_{S}}\right)^{\top}\right]\end{aligned}\tag{11}

Chú ý : với pt (11), jacobian của CE(y^u,S(xu;θS))\operatorname{CE}(\widehat{y}_{u}, S(x_{u} ; \theta_{S}))dim=1×S\text{dim}=1\times |S| cần được chuyển vị để khớp với dimθS=S×1\text{dim}_{\theta_S}=|S|\times1.

Vậy thì tại sao θS\theta_Sdim=S×1dim=|S|\times1θSCE\nabla_{\theta_S}\operatorname{CE}dim=1×Sdim=1\times|S| ? Ở đây S|S| chính là số lượng tham số có trong student. Với θS\theta_S là tham số của student nên dĩ nhiên dim của nó sẽ là S|S| và mỗi tham số trong student là duy nhất nên dimθS=S×1\text{dim}_{\theta_S}=|S|\times1. Còn θSCE\nabla_{\theta_S}\operatorname{CE} là gradient của hàm loss với biến là θS\theta_S và chỉ có 1 θS\theta_S được xét đến, trong θS\theta_SS|S| lượng tham số nên dimθSCE=1×S\text{dim}_{\nabla_{\theta_S}\operatorname{CE}}=1\times|S|.

Xét phương trình (11), để đơn giản thì ta đặt gSg_{S} là ký hiệu gradient của student:

gS(y^u)S×1=(CE(y^u,S(xu;θS))θSθS=θS)(12)\underbrace{g_{S}\left(\widehat{y}_{u}\right)}_{|S| \times|1|}=\left(\left.\frac{\partial \operatorname{CE}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)}{\partial \theta_{S}}\right|_{\theta_{S}=\theta_{S}}\right)^{\top}\tag{12}

Do θS\theta_S không phụ thuộc vào θT\theta_T, nên θS\theta_S ở phương trình (11) sẽ bị triệt tiêu khi đạo hàm theo θT\theta_T, do đó phương trình (11) sẽ trở thành:

θˉSθTS×T=ηSθTEy^uT(xu;θT)[gS(y^u)S×1](13)\underbrace{\frac{\partial \bar{\theta}_{S}^{\prime}}{\partial \theta_{T}}}_{|S| \times|T|}=-\eta_{S} \cdot \frac{\partial}{\partial \theta_{T}} \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}[\underbrace{g_{S}\left(\widehat{y}_{u}\right)}_{|S| \times 1}]\tag{13}

Bây giờ chúng ta sẽ đi giải quyết "củ khoai" này nhé, theo như paper thì đúng ra sẽ dùng REINFORCE algorithm, nhưng mà mình có đọc qua paper gốc được viết năm 1992 thì thấy khó nuốt quá nên có thử tự giải theo cách "dễ nhai" hơn. Mọi người xem qua và cho ý kiến về cách giải của mình nhé.

Với phương trình (13) thì ta sẽ đi giải quyết đạo hàm của hàm kỳ vọng Ey^u[gS(y^u)]\mathbb{E}_{\widehat{y}_{u}}[g_{S}\left(\widehat{y}_{u}\right)]. Một cách tổng quát thì kỳ vọng của hàm f(x)f(x) với biến ngẫu nhiên rời rạc xx sẽ có dạng:

E[f(x)]=xP(x)f(x)\mathbb{E}[f(x)]=\sum_{x}{P(x)f(x)}

Áp dụng công thức trên vào phương trình (13) với Ey^u[gS(y^u)]\mathbb{E}_{\widehat{y}_{u}}[g_{S}\left(\widehat{y}_{u}\right)]:

θTEy^uT(xu;θT)[gs(y^u)]=θTy^up(y^uxu;θT)gs(y^u)=y^uθTp(y^uxu;θT)gs(y^u)(14)\frac{\partial }{\partial \theta_T} \mathbb{E}_{\hat{y}_u \sim T(x_u;\theta_T)}[g_s(\hat{y}_u)] \\ = \frac{\partial}{\partial \theta_T} \sum_{\hat{y}_u} p(\hat{y}_u|x_u;\theta_T)g_s(\hat{y}_u) \\ = \sum_{\hat{y}_u} \frac{\partial}{\partial \theta_T}p(\hat{y}_u|x_u;\theta_T)g_s(\hat{y}_u) \tag{14}

Để tính đạo hàm của pp , ta sẽ cần công thức (*) ở dưới. Tuy nhiên để có công thức (*) ta cần thực hiện một số bước. Tính đạo hàm của y=log(f(x))y = log(f(x)):

  • Đặt y=log(f(x))y = log(f(x))u=f(x)u=f(x)

y=log(u) => dydu=log(u)d u=1u vaˋ dudx=f(x)y=log(u) \text{ => }\frac{dy}{du}=\frac{\text{d }log(u)}{d\;u}=\frac{1}{u}\text{ và }\frac{du}{dx}=f^{\prime}(x)

=> dydx=dydududx=1f(x)f(x)\text{=> } \frac { d y } { d x } = \frac { d y } { d u }\cdot \frac { d u } { d x } = \frac{1}{f(x)}\cdot f^{\prime}(x)

Viết lại công thức tổng quát:

log(x)x=1f(x)f(x)x\frac{\partial{log(x)}}{\partial{x}}=\frac{1}{f(x)}\cdot \frac{\partial{f(x)}}{\partial{x}}

f(x)x=f(x)log(x)x(*)\frac{\partial{f(x)}}{\partial{x}}=f(x)\cdot \frac{\partial{log(x)}}{\partial{x}}\tag{*}

Áp dụng (*) vào (14):

θTEy^uT(xu;θT)[gs(y^u)]=y^uθTp(y^uxu;θT)gs(y^u)=y^up(y^uxu;θT)θTlog(p(y^uxu;θT)gs(y^u)=Ey^uT(xu;θT)[gs(y^u)θTlog(p(y^uxu;θT)](15)\frac{\partial }{\partial \theta_T} \mathbb{E}_{\hat{y}_u \sim T(x_u;\theta_T)}[g_s(\hat{y}_u)]\\ = \sum_{\hat{y}_u} \frac{\partial}{\partial \theta_T}p(\hat{y}_u|x_u;\theta_T)g_s(\hat{y}_u)\\=\sum_{\hat{y}_u} p(\hat{y}_u|x_u;\theta_T) \frac{\partial}{\partial \theta_T} log(p(\hat{y}_u|x_u;\theta_T)g_s(\hat{y}_u) \\ =\mathbb{E}_{\hat{y}_u \sim T(x_u;\theta_T)} [g_s(\hat{y}_u) \frac{\partial}{\partial \theta_T} log(p(\hat{y}_u|x_u;\theta_T)] \tag{15}

Cuối cùng ta có thể diễn giải θˉSθT\frac{\partial \bar{\theta}_{S}^{\prime}}{\partial \theta_{T}} như sau:

θˉS(t+1)θTS×T=ηSθTEy^uT(xu;θT)[gS(y^u)]=ηSEy^uT(xu;θT)[gS(y^u)S×1logP(y^uxu;θT)θT1×T]=ηSEy^uT(xu;θT)[gS(y^u)S×1CE(y^u,T(xu;θT))θT1×T](17)\begin{aligned}\underbrace{\frac{\partial \bar{\theta}_{S}^{(t+1)}}{\partial \theta_{T}}}_{|S| \times|T|} &=-\eta_{S} \cdot \frac{\partial}{\partial \theta_{T}} \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[g_{S}\left(\widehat{y}_{u}\right)\right] \\&=-\eta_{S} \cdot \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}[\underbrace{g_{S}\left(\widehat{y}_{u}\right)}_{|S| \times 1} \underbrace{\cdot \underbrace{\frac{\partial \log P\left(\widehat{y}_{u} \mid x_{u} ; \theta_{T}\right)}{\partial \theta_{T}}}_{1 \times|T|}]}\\&=\eta_{S} \cdot \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}[\underbrace{g_{S}\left(\widehat{y}_{u}\right)}_{|S| \times 1} \cdot \underbrace{\frac{\partial \operatorname{CE}\left(\widehat{y}_{u}, T\left(x_{u} ; \theta_{T}\right)\right)}{\partial \theta_{T}}}_{1 \times|T|}]\end{aligned}\tag{17}

Đến đây thì ta đã có thể sử dụng đạo hàm của cross-entropy để tính brackrop như thông thường. Thay phương trình (17) vào (10):

RθT1×T=CE(yl,S(xl;θˉS))θSθS=θˉS1×SθˉSθTS×T=ηSCE(yl,S(xl;θˉS))θSθS=θˉS1×SEy^uT(xu;θT)[gS(y^u)S×1CE(y^u,T(xu;θT))θT1×T](18)\begin{aligned}\underbrace{\frac{\partial R}{\partial \theta_{T}}}_{1 \times|T|} &=\underbrace{\left.\frac{\partial \mathbf{C E}\left(y_{l}, S\left(x_{l} ; \bar{\theta}_{S}^{\prime}\right)\right)}{\partial \theta_{S}}\right|_{\theta_{S}=\bar{\theta}_{S}^{\prime}}}_{1 \times|S|} \underbrace{\frac{\partial \bar{\theta}_{S}^{\prime}}{\partial \theta_{T}}}_{|S| \times|T|} \\&=\eta_{S} \cdot \underbrace{\left.\frac{\partial \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \bar{\theta}_{S}^{\prime}\right)\right)}{\partial \theta_{S}}\right|_{\theta_{S}=\bar{\theta}_{S}^{\prime}}}_{1 \times|S|} \cdot \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}[\underbrace{g_{S}\left(\widehat{y}_{u}\right)}_{|S| \times 1} \cdot \underbrace{\frac{\partial \operatorname{CE}\left(\widehat{y}_{u}, T\left(x_{u} ; \theta_{T}\right)\right)}{\partial \theta_{T}}}_{1 \times|T|}]\end{aligned} \qquad\tag{18}

Cuối cùng, ta sẽ sử dụng phép xấp xỉ Monte-Carlo cho mọi biểu thức trong pt(18) với y^u\hat{y}_u đã tính từ trước. Cụ thể hơn thì ta sẽ tính xấp xỉ θˉS\bar{\theta}_{S}^{\prime} với θS\theta_S bằng cách cập nhật tham số student với (xu,yu)(x_u,y_u): θS=θSηSθSCE(y^u,S(xu;θs))\theta_S^{\prime}=\theta_S-\eta_{S} \cdot\nabla_{\theta_S}\operatorname{CE}(\hat{y}_u, S(x_u;\theta_s)). Đồng thời ước lượng E\mathbb{E} cũng với y^u\hat{y}_u. Với kết quả ước lượng vừa rồi, ta sẽ tính được gradient của θTLu(θT,θS)\nabla_{\theta_T}\mathcal{L}_u(\theta_T, \theta_S).

Pt(18) là dạng tổng quát cho 1 batch dữ liệu. Để tường minh hơn ta sẽ lấy 1 mẫu ngẫu nhiên trong batch để tính gradient:

θTLl=ηSCE(yl,S(xl;θS))θS1×S(CE(y^u,S(xu;θS))θSθS=θS)S×1CE(y^u,T(xu;θT))θT1×T=ηS((θSCE(yl,S(xl;θS))θSCE(y^u,S(xu;θS)))A scalar :=hθTCE(y^u,T(xu;θT))(19)\begin{aligned}\nabla_{\theta_{T}} \mathcal{L}_{l} &=\eta_{S} \cdot \underbrace{\frac{\partial \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \theta_{S}^{\prime}\right)\right)}{\partial \theta_{S}}}_{1 \times|S|} \cdot \underbrace{\left(\left.\frac{\partial \mathbf{C E}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)}{\partial \theta_{S}}\right|_{\theta_{S}=\theta_{S}}\right)^{\top}}_{|S| \times 1} \cdot \underbrace{\frac{\partial \operatorname{CE}\left(\widehat{y}_{u}, T\left(x_{u} ; \theta_{T}\right)\right)}{\partial \theta_{T}}}_{1 \times|T|} \\&=\underbrace{\eta_{S} \cdot\left(\left(\nabla_{\theta_{S}^{\prime}} \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \theta_{S}^{\prime}\right)\right)^{\top} \cdot \nabla_{\theta_{S}} \operatorname{CE}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)\right)\right.}_{\text {A scalar }:=h} \cdot \nabla_{\theta_{T}} \mathbf{C E}\left(\widehat{y}_{u}, T\left(x_{u} ; \theta_{T}\right)\right)\end{aligned} \tag{19}

Đến đây là hết phần diễn giải cách cập nhật của teacher dựa trên gradient của student rồi nhỉ, các bạn thấy scalar hh bên trên chứ ? Đấy chính là thứ mà chúng ta mong muốn từ đầu đến giờ : feedback của student để teacher cải thiện performance. Khi các bạn xem phần pseudo code với UDA bên dưới thì sẽ thấy một hh tương tự:

Tuy nhiên, khi xem code của Meta Pseudo Label thì các bạn sẽ thấy hh được tính như thế này: . Biến dot_product chính là công thức tính hh lằng nhằng phía trên đấy :v Nếu viết lại theo công thức toán học dựa trên đoạn code thì hh sẽ được tính như sau: h=L(θS)L(θS)h = L(\theta_S) - L(\theta_S'). Vậy tại sao từ h dài dòng lại có thể biến đổi thành phép trừ 2 hàm loss đơn giản như vậy? Thử chứng minh 1 chút nhé:

θS=θSηSθSCE(y^u,S(xu;θS))\theta'_S = \theta_S - \eta_S \nabla_{\theta_S}CE(\hat{y}_u, S(x_u;\theta_S))

Đặt η=ηSθSCE(y^u,S(xu;θS)) ta coˊ\text{Đặt }\eta = \eta_S \nabla_{\theta_S}CE(\hat{y}_u, S(x_u;\theta_S)) \text{ ta có: }

θS=θSη\theta'_S = \theta_S - \eta

Áp dụng công thức xấp xỉ taylor: f(x+h)=f(x)+hf(x)f(x+h)=f(x) + hf'(x)

L(θS)=L(θSη)L(θS)ηθSL(θS)=L(θS)ηSθSCE(y^u,S(xu;θS))θSL(θS)=L(θS)ηSθSCE(y^u,S(xu;θS))θSCE(yl,S(x;θS))=L(θS)hL(\theta_S') = L(\theta_S - \eta) \approx L(\theta_S) - \eta \nabla_{\theta_S}L(\theta_S) \\ =L(\theta_S) - \eta_S \nabla_{\theta_S}CE(\hat{y}_u, S(x_u;\theta_S)) \nabla_{\theta_S}L(\theta_S) \\ =L(\theta_S) - \eta_S \nabla_{\theta_S}CE(\hat{y}_u, S(x_u;\theta_S)) \nabla_{\theta_S}CE(y_l, S(x;\theta_S)) \\ = L(\theta_S) - h

h=L(θS)L(θS)h = L(\theta_S) - L(\theta_S')

Done!!

Dưới đây là toàn bộ quá trình train teacher với UDA và feedback từ student:

Và kết quả SoTA của MPL với EfficientNet-L2:

Lời kết

Bài viết của mình đến đây là đã hoàn thành mục đích ban đầu: cố gắng thử thách bản thân với một paper kinh điển do các idol người Việt viết và mang paper này đến với mọi người một cách dễ hiểu nhất. Nếu có thắc mắc thì các bạn có thể comment bên dưới, mình sẽ cố gắng trả lời trong tầm kiến thức của bản thân. Hoặc nếu các bạn phát hiện lỗi sai thì cứ thẳng thắn góp ý nhé. Cảm ơn các bạn đã đọc bài.

References

  1. Meta Pseudo Labels
  2. Unsupervised Data Augmentation for Consistency Training

Bình luận

Bài viết tương tự

- vừa được xem lúc

Tấn công và phòng thủ bậc nhất cực mạnh cho các mô hình học máy

tấn công bậc nhất cực mạnh = universal first-order adversary. Update: Bleeding edge của CleverHans đã lên từ 3.1.0 đến 4.

0 0 42

- vừa được xem lúc

[Deep Learning] Key Information Extraction from document using Graph Convolution Network - Bài toán trích rút thông tin từ hóa đơn với Graph Convolution Network

Các nội dung sẽ được đề cập trong bài blog lần này. . Tổng quan về GNN, GCN. Bài toán Key Information Extraction, trích rút thông tin trong văn bản từ ảnh.

0 0 219

- vừa được xem lúc

Trích xuất thông tin bảng biểu cực đơn giản với OpenCV

Trong thời điểm nhà nước đang thúc đẩy mạnh mẽ quá trình chuyển đổi số như hiện nay, Document Understanding nói chung cũng như Table Extraction nói riêng đang trở thành một trong những lĩnh vực được quan tâm phát triển và chú trọng hàng đầu. Vậy Table Extraction là gì? Document Understanding là cái

0 0 230

- vừa được xem lúc

Con đường AI của tôi

Gần đây, khá nhiều bạn nhắn tin hỏi mình những câu hỏi đại loại như: có nên học AI, bắt đầu học AI như nào, làm sao tự học cho đúng, cho nhanh, học không bị nản, lộ trình học AI như nào... Sau nhiều lần trả lời, mình nghĩ rằng nên viết hẳn một bài để trả lời chi tiết hơn, cũng như để các bạn sau này

0 0 157

- vừa được xem lúc

[B5'] Smooth Adversarial Training

Đây là một bài trong series Báo khoa học trong vòng 5 phút. Được viết bởi Xie et. al, John Hopkins University, trong khi đang intern tại Google. Hiện vẫn là preprint do bị reject tại ICLR 2021.

0 0 45

- vừa được xem lúc

Deep Learning với Java - Tại sao không?

Muốn tìm hiểu về Machine Learning / Deep Learning nhưng với background là Java thì sẽ như thế nào và bắt đầu từ đâu? Để tìm được câu trả lời, hãy đọc bài viết này - có thể kỹ năng Java vốn có sẽ giúp bạn có những chuyến phiêu lưu thú vị. DJL là tên viết tắt của Deep Java Library - một thư viện mã ng

0 0 139