Ở bài trước, chúng ta đã tìm hiểu về một mô hình thời gian liên tục sử dụng SDE. Nếu chúng ta bỏ đi hệ số diffusion, phương trình này sẽ trở thành phương trình vi phân toàn phần theo thời gian . Lúc này, việc thay đổi trạng thái sẽ trở nên tất định, do đó chúng ta có thể mô hình sự thay đổi của xác suất trạng thái theo thời gian, từ đó có thể mô hình một phiên bản tương tự của normalizing flow theo thời gian liên tục. Không chỉ vậy, cách làm này còn có thể sử dụng tương tự ResNet với kiến trúc bất kì.
ResNet và phương trình vi phân
Một mô hình ResNet về cơ bản có dạng sau
với , là đầu vào và biến đổi ở lớp thứ . Nếu chúng ta coi là chuỗi số thực , chúng ta có thể viết lại thành
Đây chính là cách xấp xỉ một phương trình vi phân bằng phương pháp Euler. Cụ thể hơn, khi , cách làm này xấp xỉ phương trình sau
Từ góc nhìn này, ta có thể xem mạng neural như một quá trình thay đổi của một trạng thái theo thời gian, biểu diễn bởi phương trình vi phân (ordinary differential equation - ODE) như trên thay vì mô hình theo từng lớp như truyền thống. Đầu ra của mô hình sẽ là trạng thái tại thời điểm , được tìm bằng cách giải ODE với điều kiện đầu là đầu vào . Mô hình này có thể sử dụng để thay thế bất kì mô hình ResNet nào. Hàm ở đây có thể là một kiến trúc tùy ý, nhận trạng thái và thời gian , trả về vector cùng chiều với .
Một tính chất quan trọng của ODE là liệu từ phương trình này có thể xác định được không. Định lý Picard–Lindelöf chỉ ra rằng trong trường hợp là Lipschitz theo , tồn tại sao cho tồn tại và xác định duy nhất quanh . Như vậy, để ODE định nghĩa tốt, chúng ta cần mô hình thỏa mãn tính chất Lipschitz.
Giải ODE
Với ODE với điều kiện đầu bên trên, trạng thái tại thời điểm sẽ được tính như sau
Mục tiêu của chúng ta sẽ là xấp xỉ tích phân trên. Cách đơn giản nhất là phương pháp Euler: Với mỗi chuỗi , chúng ta tính lần lượt giá trị tại những thời điểm trên như sau:
Như đã nói ở trên, cách làm này giống với mô hình ResNet quen thuộc.
def odeint_euler(f, y0, t): def step(state, t): y_prev, t_prev = state dt = t - t_prev y = y_prev + dt * f(t_prev, y_prev) return y, t t_curr = t[0] y_curr = y0 ys = [] for i in t[1:]: y_curr, t_curr = step((y_curr, t_curr), i) ys.append(y_curr) return torch.stack(ys)
Một cách xấp xỉ phổ biến khác có sai số thấp hơn là phương pháp Runge-Kutta, xấp xỉ sai khác giữa các thời điểm bởi 4 giá trị
def odeint_rk4(f, y0, t): def step(state, t): y_prev, t_prev = state dt = t - t_prev k1 = dt * f(t_prev, y_prev) k2 = dt * f(t_prev + dt/2., y_prev + k1/2.) k3 = dt * f(t_prev + dt/2., y_prev + k2/2.) k4 = dt * f(t + dt, y_prev + k3) y = y_prev + (k1+ 2 * k2 + 2 * k3 + k4) / 6 return y, t t_curr = t[0] y_curr = y0 ys = [] for i in t[1:]: y_curr, t_curr = step((y_curr, t_curr), i) ys.append(y_curr) return torch.stack(ys)
Tính thử một ví dụ với ODE sau
ODE có nghiệm là . Dùng bước để xấp xỉ tích phân để tính , hai cách tính trên cho kết quả như bên dưới
Ta có thể thấy phương pháp Euler cho kết quả không chính xác. Điều này thể hiện khoảng cách giữa các bước ảnh hưởng đến độ chính xác của phương pháp xấp xỉ. Do đó ta có thể xấp xỉ ODE chính xác hơn bằng cách chọn độ dài mỗi bước sao cho ước lượng lỗi tối ưu (việc này yêu cầu một cách để ước lượng lỗi, ví dụ như dùng một phương pháp khác để xấp xỉ, rồi tính sai khác giữa kết quả của hai phương pháp). Tuy nhiên điều này nảy sinh một vấn đề sau: Trong trường hợp ta muốn dùng minibatch, sai số giữa các ODE trong batch là khác nhau, do đó thời gian giữa các ODE sẽ khác nhau, việc xử lý toàn batch sẽ không giống như mạng neural thông thường. Một cách giải quyết là gộp chung toàn batch thành 1 ODE, các mốc thời gian sẽ dùng chung, tuy nhiên có thể tăng sai số. Đối với jax, ta có thể dùng vmap để tính song song các ODE trong batch (gần đây torch cũng có cài đặt vmap).
Cập nhật tham số
Ở bài trước, chúng ta đã làm quen với một mô hình thời gian liên tục với SDE bằng mô hình trực tiếp score theo thời gian. Tuy nhiên, đối với neural ODE, ta đang mô hình sự thay đổi của trạng thái theo thời gian. Do đó việc cập nhật gradient trở nên không hiển nhiên, yêu cầu tham số hóa lại đối với tham số của mô hình.
Phần này sẽ trình bày cách cập nhật gradient cho hai cách cài đặt automatic differentiation là tích vector-Jacobian (VJP) và tích Jacobian-vector (JVP). Chi tiết về hai cách cài đặt này có thể xem ở tài liệu tham khảo của thư viện jax.
Tính với tích vector-Jacobian (reverse mode)
Để cho thuận tiện, chúng ta sẽ viết lại phương trình vi phân dưới dạng sau
Giả sử hàm mục tiêu được tính tại trạng thái cuối tại thời điểm thông qua hàm , từ định lí tồn tại duy nhất hàm này cũng có thể được tính từ trạng thái thông qua hàm .
Mục tiêu của chúng ta là đạo hàm đối với trạng thái ban đầu và tham số , nói cách khác là tính đạo hàm riêng và .
Đặt
chúng ta đã biết và cần tính . Như vậy, chúng ta có thể mô hình sự thay đổi của hàm theo thời gian , từ đó tính ra bằng cách tích phân theo thời gian từ về .
Do ODE có nghiệm duy nhất xung quanh lân cận của , ta có thể lấy đạo hàm riêng theo tại hai vế
Đổi thứ tự đạo hàm riêng và áp dụng chain rule ta có
Quay lại với hàm mục tiêu, áp dụng chain rule ta được
Từ hai điều trên, ta có thể mô hình sự thay đổi của theo thời gian như sau
Lúc này có thể tính bởi
Để tính , chúng ta sẽ dùng vector-Jacobian với đầu vào là . Trạng thái này có thể được tính lại bằng ODE ban đầu, hoặc có thể sử dụng lại chính trạng thái đã tính trong quá trình forward nếu sử dụng cũng một cách để xấp xỉ.
Tiếp theo chúng ta sẽ tính đạo hàm riêng với tham số của mô hình, áp dụng chain rule ta được
Tương tự như trên, nếu chúng ta có thể mô hình được sự thay đổi của theo thời gian, có thể tính bằng cách tích phân từ trạng thái .
Lấy đạo hàm theo ở hai vế, ta có
Tương tự như trạng thái đầu , ta có thể giả sử ODE thỏa mãn quanh lân cận của và lấy đạo hàm theo ở hai vế, sau đó đổi thứ tự đạo hàm và áp dụng chain rule
Thay và , ta được
Suy ra
Một câu hỏi nữa là giá trị của điều kiện đầu là gì. Chúng ta có thể nhận ra hàm mất mát được tính dựa trên trạng thái cuối mà không cần đến tham số của quá trình, do đó .
Từ đây ta có thể tính được
Tổng hợp lại, để tìm đạo hàm riêng theo trạng thái ban đầu và tham số của mô hình, ta sẽ giải hệ phương trình vi phân sau
với trạng thái ban đầu là
Tính với tích Jacobian-vector (forward mode)
Đối với cách cài đặt này, ta quan tâm đến vi phân của khi biết vi phân của và . Ta có
với mọi ( kí hiệu vector tiếp tuyến). Tương tự phần trên, ta nghĩ đến việc tìm sự thay đổi của theo thời gian.
Đặt . Ở phần trên chúng ta đã có
Do đó
Việc còn lại là tìm điều kiện đầu. Tại thời điểm , , do vậy . Lúc này việc tìm vi phân tại tương đương với việc giải ODE
với điều kiện đầu .
Ghi chú: Với cả hai cách cài đặt, ta đều phải tích phân ngược theo thời gian. Điều này yêu cầu phương pháp xấp xỉ ODE phải thỏa mãn tính chất thời gian khả nghịch, cụ thể hơn khi giải ODE theo chiều thuận rồi từ đó giải theo chiều nghịch, ta được chính xác điều kiện đầu. Các phương pháp giải ODE bậc nhất (bao gồm phương pháp Euler, Runge-Kutta) không thoả mãn tính chất này.
Ví dụ
Trong phần này mình sẽ minh họa với pytorch, sử dụng hàm vjp và jvp. Hai hàm này nhận vào một hàm bất kì có đầu vào và đầu ra là tensor, rồi tính VJP/JVP tại đầu vào theo một vector tiếp tuyến nào đó.
Đối với VJP/JVP theo tham số của mô hình, chúng ta có thể xóa attribute rồi đặt lại để đưa tham số vào đối số của hàm forward, xem cụ thể tại đây
def del_attr(obj, names): if len(names) == 1: delattr(obj, names[0]) else: del_attr(getattr(obj, names[0]), names[1:])
def set_attr(obj, names, val): if len(names) == 1: setattr(obj, names[0], val) else: set_attr(getattr(obj, names[0]), names[1:], val) def make_functional(mod): orig_params = tuple(mod.parameters()) names = [] for name, p in list(mod.named_parameters()): del_attr(mod, name.split(".")) names.append(name) return orig_params, names def load_weights(mod, names, *params): for name, p in zip(names, params): set_attr(mod, name.split("."), p) def del_weights(mod): for name, p in list(mod.named_parameters()): del_attr(mod, name.split(".")) class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.module = nn.Sequential(nn.Linear(4, 5), nn.LeakyReLU(), nn.Linear(5,3),nn.Tanh()) def get_params(self): self.params, self.names = make_functional(self) def forward(self, t, state, *args): if len(args) == 0: load_weights(self, self.names, *self.params) elif len(args) > 0: del_weights(self) load_weights(self, self.names, *args) return self.module(torch.cat([t.view(1), state])) model = Model()
model.get_params()
Khi tính JVP/VJP, chúng ta cần giải hệ ODE, do đó thuật toán cần được chỉnh sửa một chút
def odeint_rk4_system(f, y0, t): """ y0 : list of states f : func returns list of states """ def step(state, t): y_prev, t_prev = state dt = t - t_prev k1 = [dt * i for i in f(t_prev, y_prev)] k2 = [dt * i for i in f(t_prev + dt/2., [y + j1/2. for y, j1 in zip(y_prev, k1)])] k3 = [dt * i for i in f(t_prev + dt/2., [y + j2/2. for y, j2 in zip(y_prev, k2)])] k4 = [dt * i for i in f(t + dt, [y + j3 for y, j3 in zip(y_prev, k3)])] y = [i + (j1+ 2 * j2 + 2 * j3 + j4) / 6 for i, j1, j2, j3, j4 in zip(y_prev, k1, k2, k3, k4)] return y, t t_curr = t[0] y_curr = y0 ys = [] for i in t[1:]: y_curr, t_curr = step((y_curr, t_curr), i) ys.append(y_curr) return ys
Chúng ta sẽ mô hình đạo hàm theo thời gian của vị trí 1 điểm trong với phương pháp Runge-Kutta bậc 4, được kết quả như hình dưới
Với vector tiếp tuyến tại điều kiện đầu, pushforward theo thời gian được vector tiếp tuyến tại từng thời điểm như sau
Với vector tiếp tuyến tại , chúng ta kéo lùi lại và . Áp dụng JVP với được kết quả như hình
Áp dụng JVP với được kết quả sau
Code sử dụng trong bài có thể xem ở đây.
Trong bài tiếp theo, chúng ta sẽ tìm hiểu về mô hình continuous normalizing flow với neural ODE, và liên hệ với SDE ở bài trước.