- vừa được xem lúc

Sử dụng Mô hình học sâu Keras với Scikit-Learn

0 0 9

Người đăng: Chung Hoàng Ngọc Kỳ An

Theo Viblo Asia

Tổng quan

Keras là một thư viện phổ biến để học sâu trong Python. Trên thực tế, Keras hướng đến sự tối giản, chỉ tập trung vào những gì cần để xác định và xây dựng các mô hình học sâu một cách nhanh chóng và đơn giản. Thư viện scikit-learning trong Python được xây dựng dựa trên ngăn xếp SciPy để tính toán số hiệu quả. Nó là một thư viện đầy đủ tính năng dành cho học máy nói chung và cung cấp nhiều tiện ích hữu ích trong việc phát triển các mô hình học sâu. Một trong số đó là:

  • Đánh giá các mô hình bằng các phương pháp lấy mẫu như k-fold cross validation
  • Tìm kiếm và đánh giá hiệu quả các siêu tham số mô hình

Có một trình bao bọc trong thư viện TensorFlow/Keras để tạo các mô hình học sâu được sử dụng làm công cụ ước tính phân loại hoặc hồi quy trong scikit-learning. Nhưng gần đây, trình bao bọc này đã được gỡ bỏ để trở thành một mô-đun Python độc lập.

Trong ví dụ này, chúng ta sẽ làm việc với tệp dữ liệu Pima Indians onset of diabetes

Trước hết chúng ta sẽ cài đặt một vài thư viện cần thiết: !pip install tensorflow scikeras scikit-learn

Đánh giá mô hình Deep Learning với cross validation

Các lớp KerasClassifier và KerasRegressor trong SciKeras lấy một mô hình đối số là tên của hàm cần gọi để lấy mô hình.

Chúng ta cần định nghĩa một hàm để xác định mô hình, biên dịch và trả về hàm đó.

Trong ví dụ dưới đây, chúng ta định nghĩa một hàm create_model() để tạo một mạng nơ-ron nhiều lớp đơn giản cho bài toán.

Chuyển tên hàm này cho lớp KerasClassifier bằng đối số mô hình. Hoặc chuyển vào các đối số bổ sung của nb_epoch=150 và batch_size=10. Chúng được tự động nhóm lại và chuyển đến hàm fit(), được gọi bởi lớp KerasClassifier.

Trong ví dụ này, chúng ta sẽ sử dụng StratifiedKFold scikit-learning để thực hiện cross validation phân tầng 10 lần. Đây là một kỹ thuật lấy mẫu lại có thể cung cấp ước tính chính xác về hiệu suất của mô hình máy học trên dữ liệu không nhìn thấy được.

Tiếp theo, sử dụng hàm scikit-learning cross_val_score() để đánh giá mô hình bằng sơ đồ cross validation và in kết quả.

from numpy.random.mtrand import random
import tensorflow as tf import numpy as np
import pandas as pd from tensorflow import keras from tensorflow.keras import layers from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from scikeras.wrappers import KerasClassifier
from sklearn.model_selection import train_test_split from sklearn.model_selection import cross_val_score
from sklearn.model_selection import StratifiedKFold def create_model(): model = Sequential( [ layers.Dense(12, input_dim =8, activation = "relu"), layers.Dense(8, activation = "relu"), layers.Dense(1, activation = "sigmoid") ] ) model.compile( loss = "binary_crossentropy", optimizer = "adam", metrics = ["accuracy"] ) return model seed = 7
np.random.seed(seed) dataset = np.loadtxt("https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv", delimiter = ",") X = dataset[:, 0:8]
Y = dataset[:, 8] model= KerasClassifier(model = create_model, epochs = 150, batch_size = 10, verbose = 0)
kfold = StratifiedKFold(n_splits = 10, shuffle = True, random_state = seed)
results = cross_val_score(model, X, Y, cv = kfold)
print(results.mean())

Lưu ý: Kết quả có thể thay đổi do tính chất ngẫu nhiên của thuật toán hoặc quy trình đánh giá hoặc sự khác biệt về độ chính xác của các con số. Có thể chạy nhiều lần và so sánh kết quả trung bình.

Tham số mô hình Deep Learning Grid Search

Ví dụ trước cho thấy việc đóng gói mô hình học sâu từ Keras và sử dụng nó trong các chức năng từ thư viện scikit-learning dễ dàng như thế nào.

Trong ví dụ này, chúng ta sẽ tiến thêm một bước. Hàm chỉ định cho đối số mô hình khi tạo trình bao bọc KerasClassifier có thể nhận đối số. Chúng ta có thể sử dụng các đối số này để tùy chỉnh thêm việc xây dựng mô hình. Ngoài ra, cần phải biết rằng có thể cung cấp đối số cho hàm fit().

Trong ví dụ này, chúng ta sẽ sử dụng tìm kiếm dạng lưới (grid search) để đánh giá các cấu hình khác nhau cho mô hình mạng thần kinh và báo cáo về sự kết hợp mang lại hiệu suất ước tính tốt nhất.

Hàm create_model() được xác định để nhận hai đối số, optimizer và init, cả hai đều phải có giá trị mặc định. Điều này sẽ cho phép đánh giá hiệu quả của việc sử dụng các thuật toán tối ưu hóa khác nhau và sơ đồ khởi tạo trọng số.

Sau khi tạo mô hình, hãy xác định các mảng giá trị cho tham số muốn tìm kiếm, cụ thể:

  • Optimizers: tìm kiếm các giá trị trọng lượng khác nhau
  • Initializers: chuẩn bị trọng số mạng bằng các lược đồ khác nhau
  • Epochs : các giai đoạn huấn luyện mô hình cho số lần tiếp xúc khác nhau với tập dữ liệu huấn luyện
  • Batches: thay đổi số lượng mẫu trước khi cập nhật trọng lượng

Các tùy chọn được chỉ định trong một dictionary và được chuyển đến cấu hình của lớp scikit-learning GridSearchCV. Lớp này sẽ đánh giá một phiên bản mô hình mạng thần kinh của bạn cho từng tổ hợp tham số (2 x 3 x 3 x 3 đối với tổ hợp optimizers, initializations, epochs, và batches). Sau đó, mỗi kết hợp được đánh giá bằng cách sử dụng cros validation phân tầng 3 lần mặc định.

import numpy as np
import tensorflow as tf from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
from scikeras.wrappers import KerasClassifier
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV def create_model(optimizer = "rmsprop", init = "glorot_uniform"): model = Sequential( [ layers.Dense(12, input_dim = 8, kernel_initializer = init, activation = "relu"), layers.Dense(8, kernel_initializer = init, activation = "relu"), layers.Dense(1, kernel_initializer = init, activation = "sigmoid") ] ) model.compile( loss = "binary_crossentropy", optimizer = "adam", metrics = ["accuracy"] ) return model seed = 7
np.random.seed(seed) dataset = np.loadtxt("https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv", delimiter = ",") X = dataset[:, 0:8]
Y = dataset[:, 8] model = KerasClassifier(model = create_model, verbose = 0)
print(model.get_params().keys()) optimizers = ['rmsprop', 'adam']
init = ['glorot_uniform', 'normal', 'uniform']
epochs = [50, 100, 150]
batches = [5, 10, 20]
param_grid = dict(optimizer = optimizers, epochs = epochs, batch_size = batches, model__init = init)
grid = GridSearchCV(estimator = model, param_grid = param_grid)
grid_result = grid.fit(X, Y) print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
means = grid_result.cv_results_["mean_test_score"]
stds = grid_result.cv_results_["params"]
params = grid_result.cv_results_['params']
for mean, stdev, param in zip(means, stds, params): print("%f (%f) with: %r" % (mean, stdev, param))

Lưu ý: Kết quả có thể thay đổi do tính chất ngẫu nhiên của thuật toán hoặc quy trình đánh giá hoặc sự khác biệt về độ chính xác của các con số. Có thể chạy nhiều lần và so sánh kết quả trung bình.

Tham khảo thêm: machinelearningmastery.com

Bình luận

Bài viết tương tự

- vừa được xem lúc

Một số thủ thuật hay ho với Linux (1).

1. Ctrl + x + e. Giữ CTRL, nhấn phím x rồi nhấn phím e. Thao tác này sẽ mở ra editor mặc định (echo $EDITOR | $VISUAL để kiểm tra) chứa sẵn.

0 0 45

- vừa được xem lúc

How to deploy Amplication app to DigitalOcean

This article shows you the way to deploy an app generated by Amplication to DigitalOcean. Amplication provides the dockerfile to use containers for deployment, but this blog explains how to do it manu

0 0 53

- vừa được xem lúc

Có gì mới trong Laravel 9.0?

Laravel v9 là phiên bản LTS tiếp theo của Laravel và ra mắt vào tháng 2 năm 2022. Trong bài viết này, mình xin giới thiệu một vài tính năng mới trong Laravel trong Laravel 9.

0 0 78

- vừa được xem lúc

Xây dựng trang web tra cứu ảnh sử dụng phân cụm Spectral Clustering

1. Tổng quan tra cứu ảnh. 1.1.

0 0 45

- vừa được xem lúc

Scanning network 1 - quét mạng như một hacker

Chào mọi người mình là Tuntun. Một năm qua là một năm khá bận rộn nhỉ.

0 0 46

- vừa được xem lúc

Interpreter Design Pattern - Trợ thủ đắc lực của Developers

1. Giới thiệu. . Interpreter là một mẫu thiết kế thuộc nhóm hành vi (Behavioral Pattern).

0 0 43