今回はJulia & ScikitLearn で有名なアヤメのデータを用いて分類をやってみます。
前回と同じく Julia v1.2.0です。
Julia でニューラルネットワーク
まずは必要なパッケージを追加、インポートします。
In [ ]: ] add ScikitLearn
In [ ]: ] add Plots StatsPlots
In [ ]: ] add DataFrames RDatasets
In [ ]: using Plots, StatsPlots using RDatasets using ScikitLearn using ScikitLearn: fit!, predict using DataFrames @sk_import linear_model: LinearRegression @sk_import neural_network: MLPClassifier @sk_import model_selection: train_test_split @sk_import metrics: accuracy_score
次にアヤメのデータを読み込みます。
PythonのScikitLearnでは、 sklearn.datasets.load_iris
から読み込んで利用することが多いかと思いますが、今回は DataFrame形式で読み込める RDatasets のデータを読み込みます。
In [ ]: iris = dataset("datasets", "iris") println(iris)
Out [ ]: 150×5 DataFrame │ Row │ SepalLength │ SepalWidth │ PetalLength │ PetalWidth │ Species │ │ │ [90mFloat64[39m │ [90mFloat64[39m │ [90mFloat64[39m │ [90mFloat64[39m │ [90mCategorical…[39m │ ├─────┼─────────────┼────────────┼─────────────┼────────────┼──────────────┤ │ 1 │ 5.1 │ 3.5 │ 1.4 │ 0.2 │ setosa │ │ 2 │ 4.9 │ 3.0 │ 1.4 │ 0.2 │ setosa │ │ 3 │ 4.7 │ 3.2 │ 1.3 │ 0.2 │ setosa │ │ 4 │ 4.6 │ 3.1 │ 1.5 │ 0.2 │ setosa │ │ 5 │ 5.0 │ 3.6 │ 1.4 │ 0.2 │ setosa │ ... │ 145 │ 6.7 │ 3.3 │ 5.7 │ 2.5 │ virginica │ │ 146 │ 6.7 │ 3.0 │ 5.2 │ 2.3 │ virginica │ │ 147 │ 6.3 │ 2.5 │ 5.0 │ 1.9 │ virginica │ │ 148 │ 6.5 │ 3.0 │ 5.2 │ 2.0 │ virginica │ │ 149 │ 6.2 │ 3.4 │ 5.4 │ 2.3 │ virginica │ │ 150 │ 5.9 │ 3.0 │ 5.1 │ 1.8 │ virginica │
In [ ]: @df iris corrplot(cols(1:4), group=:Species, grid = false, marker=:auto, bg=RGB(0.95,0.95,0.95))
実際に分類する前に、学習用データとテストデータに分けます。
In [ ]: X_train, X_test, y_train, y_test = train_test_split(convert(Matrix, iris[!, [:SepalWidth, :PetalWidth]]), iris[!, :Species], test_size=0.3, random_state=0)
Out [ ]: 4-element Array{Array,1}: [2.0 1.0; 3.0 1.8; … ; 3.8 2.2; 3.2 0.2] [2.8 2.4; 2.2 1.0; … ; 3.5 0.6; 3.7 0.2] ["versicolor", "virginica", "virginica", "virginica", "virginica", "versicolor", "virginica", "versicolor", "versicolor", "virginica" … "versicolor", "versicolor", "versicolor", "setosa", "setosa", "setosa", "virginica", "versicolor", "virginica", "setosa"] ["virginica", "versicolor", "setosa", "virginica", "setosa", "virginica", "setosa", "versicolor", "versicolor", "versicolor" … "versicolor", "setosa", "versicolor", "versicolor", "versicolor", "virginica", "setosa", "virginica", "setosa", "setosa"]
学習とテストデータの分類をして、評価をしてみます。
In [ ]: clf = MLPClassifier(hidden_layer_sizes=[15, 15], solver="lbfgs", random_state=0, max_iter=10000, alpha=0.1) fit!(clf, X_train, y_train) y_pred = predict(clf, X_test) println(y_pred) accuracy_score(y_pred, y_test)
Out [ ]: Any["virginica", "versicolor", "setosa", "virginica", "setosa", "virginica", "setosa", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "setosa", "versicolor", "versicolor", "setosa", "setosa", "virginica", "versicolor", "setosa", "setosa", "virginica", "setosa", "setosa", "versicolor", "versicolor", "setosa", "virginica", "versicolor", "setosa", "virginica", "virginica", "versicolor", "setosa", "versicolor", "versicolor", "versicolor", "virginica", "setosa", "virginica", "setosa", "setosa"] 0.9777777777777777