Có bao giờ bạn tự hỏi tại sao model machine learning có thể hoạt động cực kỳ tốt trong quá khứ, nhưng bỗng dưng lại trở nên "kém vui" trong hiện tại không? Đó là lúc mà khái niệm data drift được ra đời trong MLOps concept, khi mà việc train model được thực hiện lặp đi lặp lại, vậy khi nào thì ta phát hiện ra dữ liệu bị thay đổi và nên cảnh báo các Data Scientist, rằng tính chất dữ liệu khác rồi anh ơi?
"Sông có lúc, data có khúc" - câu này mình chế ra
Data drift là gì nhỉ?
Vạn vật luôn đổi thay và dữ liệu của ta cũng thế, trong một MLOps pipeline, mình hiểu đơn giản là quá trình lấy data mới và huấn luyện các mô hình học máy một cách tự động. Và vì bộ siêu tham số (hyper parameters) đã được tinh chỉnh để fit với một tập data duy nhất (thường các anh DS hay thực hiện trong file csv import lên notebook), cho nên, khi có data mới được cào liên tục hàng năm/hàng tháng/hàng ngày, không có gì đảm bảo rằng tính chất toán học của tập dữ liệu mới này sẽ toàn vẹn được như cũ?
Ví dụ cho việc thay đổi:
- Dữ liệu kinh doanh bị ảnh hưởng bởi các sự kiện đột ngột (như covid19,...)
- Dữ liệu về chiều cao của thanh thiếu nên Việt nam năm 2000 với năm 2020 sẽ khác nhau, vì bây giờ các bạn được ăn uống đầy đủ hơn lúc trước, nên chiều cao cũng phát triển hơn.
- ... rất nhiều trường hợp khác dữ liệu bị thay đổi về các tính chất.
Và sự thay đổi tính chất đó gọi là... data drift.
Hiểu theo mặt thống kê, data drift đơn giản là phân phối của dữ liệu bị trượt, hoặc bị thay đổi theo thời gian, ví dụ như mô hình bạn train năm 2018 hội tụ và generalize tốt với dữ liệu 2018, không có nghĩa dữ liệu sẽ có cùng ý nghĩa thống kê với của năm 2021. Và mục tiêu của mình là phải có một hệ thống tự động nhận biết được thời điểm 2021, tức là thời điểm dữ liệu bị thay đổi tính chất, và thực hiện việc cảnh báo (alert) cho các MLOps Engineer và Data Scientist.
Sau khi biết data bị drift thì người ta thường làm gì nhỉ, thật ra thì có nhiều cách lắm, nhưng xoay quanh vẫn là tìm cách để huấn luyện (train) lại mô hình (model):
- Train mô hình với data gộp giữa cũ và mới.
- Train mô hình với chỉ mỗi data mới.
- Train mô hình với bộ siêu tham số mới (cần được tuning lại bởi data scientist).
- ....
- Và best practice thì thường là không có, mà sẽ phụ thuộc vào quyết định của các data scientist, là những người hiểu rõ nhất về tính chất của dữ liệu mà họ đang làm.
Phương pháp tối ưu để phát hiện data drift
Có rất nhiều phương pháp để phát hiện data drift, nhưng phương pháp vừa có ý nghĩa về mặt thống kê, vừa có kết quả thực tế, hiệu quả và giải thích được đó chính là sử dụng một kiểm định tên là kiểm định Kolmogorov-Smirnov (K-S Test) , có người còn gọi là Goodness-of-Fit Test.
Nghe lạ quá nhỉ? Nhưng về ý nghĩa toán học - thống kê nó không có gì cao siêu cả, chỉ đơn giản là kiểm định xem hai phân phối của hai mẫu có khác biệt tính chất với nhau không.
Nó có một metric để đo lường sự khác biệt biểu diễn như sau:
Trong đó:
- là hàm phân phối tích lũy của tập data trong quá khứ.
- là hàm phân phối tích lũy của tập data mới.
Nếu giá trị càng lớn, càng xảy ra data drift.
Kiểm định : (cho 2 mẫu nha)
Bác bỏ khi của kiểm định Kolmogorov-Smirnov (K-S) nhỏ hơn mức ý nghĩa đã chọn (ví dụ chẳng hạn). Tính liên hệ như nào với thì mình sẽ không đi sâu ở bài viết này để tránh việc dài dòng. Vậy thì... thực hành thôi!
Sử dụng Python để detect data drift
- Thư viện mà mình cần dùng ở đây hết sức quen thuộc, đó là
scipy
, khỏi phải giới thiệu nhiều vì package này quá quen thuộc với các bạn data science rồi, cho những ai chưa biết thì đây là một thư viện chứa rất nhiều hàm thống kê của Python. Cách cài đặt đơn giản bằngpip
thôi:
pip install scipy
- Sử dụng
ks_2samp
nằm trong modulescipy.stats
như sau:
from scipy import stats
test = stats.ks_2samp(df[column], df_new[column])
- Ở đây mình đang thực hiện việc test ở hai dataframe là
df
(dữ liệu quá khứ) vàdf_new
(dữ liệu nóng hổi vừa thu thập được). - Để kiểm định xem 1 feature / column có đảm bao tính chất hay không, mình sẽ thực hiện như trên, kết quả của function
ks_2samp
trả về sẽ được gán vào biếntest
test[1]
chính là của kiểm định, và như mọi người đã thấy, mình đang dùng mức ý nghĩa:
Nếu , tức là sẽ bác bỏ , đồng nghĩa với việc data bị drift.
if test[1] < 0.05: print("Data drift at column: ", column)
- Chỉ vậy thôi đó, ta có thể đóng function lại code trên để tiện dùng trong MLOps pipeline, và tiếp đến phần sau, mình sẽ ví dụ việc tích hợp detect tự động như nào nhé!
Tích hợp detect vào MLOps pipeline
Ở đây mình sẽ ví dụ như là 1 task của Airflow nhé.
- Trong task này, mình sẽ duyệt qua từng column, nếu column vào bị drift, thì ngay lập tức mail cho data scientist, để thực hiện tune parameter nếu cần thiết, hoặc train model lại với data mới.
- Đồng thời cũng sẽ return True nếu gặp column drift, và return False nếu không gặp column drift nào.
from airflow.decorators import task
import pandas as pd
from scipy import stats @task.python( show_return_value_in_logs=True,
)
def detect_drift()-> bool: # NOTE: Download new data from S3 (bước fetch df and df_new) for column in df.columns: test = stats.ks_2samp(df[column], df_new[column]) if test[1] < 0.05: print("Data drift at column: ", column) mail_to_data_scientist() return True return False
Kết luận
Vậy đó, giờ thì ta đã hiểu thêm về data drift và cách kiểm tra nó để giữ cho các mô hình của ta luôn hiệu quả! Đừng quên rằng trong thế giới mà data thay đổi liên tục, "sông có lúc, data có khúc" – hãy luôn sẵn sàng để tuning model 24/7 và optimize model của bạn. 😀