What is it, naokirin?

Juliaで機械学習入門(2)〜ニューラルネットワーク〜

今回はJulia & ScikitLearn で有名なアヤメのデータを用いて分類をやってみます。

前回と同じく Julia v1.2.0です。

Julia でニューラルネットワーク

まずは必要なパッケージを追加、インポートします。

In [ ]: ] add ScikitLearn
In [ ]: ] add Plots StatsPlots
In [ ]: ] add DataFrames RDatasets
In [ ]: using Plots, StatsPlots
        using RDatasets
        using ScikitLearn
        using ScikitLearn: fit!, predict
        using DataFrames
        @sk_import linear_model: LinearRegression
        @sk_import neural_network: MLPClassifier
        @sk_import model_selection: train_test_split
        @sk_import metrics: accuracy_score

次にアヤメのデータを読み込みます。

PythonのScikitLearnでは、 sklearn.datasets.load_iris から読み込んで利用することが多いかと思いますが、今回は DataFrame形式で読み込める RDatasets のデータを読み込みます。

In [ ]: iris = dataset("datasets", "iris")
        println(iris)
Out [ ]: 150×5 DataFrame
         │ Row │ SepalLength │ SepalWidth │ PetalLength │ PetalWidth │ Species      │
         │     │ [90mFloat64[39m     │ [90mFloat64[39m    │ [90mFloat64[39m     │ [90mFloat64[39m    │ [90mCategorical…[39m │
         ├─────┼─────────────┼────────────┼─────────────┼────────────┼──────────────┤
         │ 1   │ 5.1         │ 3.5        │ 1.4         │ 0.2        │ setosa       │
         │ 2   │ 4.9         │ 3.0        │ 1.4         │ 0.2        │ setosa       │
         │ 3   │ 4.7         │ 3.2        │ 1.3         │ 0.2        │ setosa       │
         │ 4   │ 4.6         │ 3.1        │ 1.5         │ 0.2        │ setosa       │
         │ 5   │ 5.0         │ 3.6        │ 1.4         │ 0.2        │ setosa       │
...
         │ 145 │ 6.7         │ 3.3        │ 5.7         │ 2.5        │ virginica    │
         │ 146 │ 6.7         │ 3.0        │ 5.2         │ 2.3        │ virginica    │
         │ 147 │ 6.3         │ 2.5        │ 5.0         │ 1.9        │ virginica    │
         │ 148 │ 6.5         │ 3.0        │ 5.2         │ 2.0        │ virginica    │
         │ 149 │ 6.2         │ 3.4        │ 5.4         │ 2.3        │ virginica    │
         │ 150 │ 5.9         │ 3.0        │ 5.1         │ 1.8        │ virginica    │
In [ ]: @df iris corrplot(cols(1:4), group=:Species, grid = false, marker=:auto, bg=RGB(0.95,0.95,0.95))

f:id:naokirin:20191006101843p:plain:w600

実際に分類する前に、学習用データとテストデータに分けます。

In [ ]: X_train, X_test, y_train, y_test = train_test_split(convert(Matrix, iris[!, [:SepalWidth, :PetalWidth]]), iris[!, :Species], test_size=0.3, random_state=0)
Out [ ]: 4-element Array{Array,1}:
         [2.0 1.0; 3.0 1.8; … ; 3.8 2.2; 3.2 0.2]                                                                                                                                                                                                                           
         [2.8 2.4; 2.2 1.0; … ; 3.5 0.6; 3.7 0.2]                                                                                                                                                                                                                           
         ["versicolor", "virginica", "virginica", "virginica", "virginica", "versicolor", "virginica", "versicolor", "versicolor", "virginica"  …  "versicolor", "versicolor", "versicolor", "setosa", "setosa", "setosa", "virginica", "versicolor", "virginica", "setosa"]
         ["virginica", "versicolor", "setosa", "virginica", "setosa", "virginica", "setosa", "versicolor", "versicolor", "versicolor"  …  "versicolor", "setosa", "versicolor", "versicolor", "versicolor", "virginica", "setosa", "virginica", "setosa", "setosa"]

学習とテストデータの分類をして、評価をしてみます。

In [ ]: clf = MLPClassifier(hidden_layer_sizes=[15, 15], solver="lbfgs", random_state=0, max_iter=10000, alpha=0.1)
        fit!(clf, X_train, y_train)
        y_pred = predict(clf, X_test)
        println(y_pred)
        accuracy_score(y_pred, y_test)
Out [ ]: Any["virginica", "versicolor", "setosa", "virginica", "setosa", "virginica", "setosa", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "versicolor", "setosa", "versicolor", "versicolor", "setosa", "setosa", "virginica", "versicolor", "setosa", "setosa", "virginica", "setosa", "setosa", "versicolor", "versicolor", "setosa", "virginica", "versicolor", "setosa", "virginica", "virginica", "versicolor", "setosa", "versicolor", "versicolor", "versicolor", "virginica", "setosa", "virginica", "setosa", "setosa"]

         0.9777777777777777