多クラス分類ロジスティック回帰の実装
目次
はじめに
前回、多クラス分類ロジスティック回帰の記事を書いたのでそのpython実装です。
やはり、実際にプログラムを書いて確認したくなりますね。
活性化関数・損失関数の実装
ソフトマックス関数
前回記事で、(22)式で書いていた部分です。
np.ndarray
同士の割り算は、同じインデックスの成分同士の割り算になるので、結局一行でかけました。
Z = np.dot(X, self.W) + self.b def softmax(self, Z): return np.exp(Z)/np.sum(np.exp(Z), axis=1)[:, np.newaxis]
np.sum(np.exp(Z), axis=1)
の引数でaxis=1
を指定すると、各行でのsum値を横ベクトルで返してくれます。[:, np.newaxis]
で縦ベクトルにしてます。
交差エントロピー
前回記事(23)式です。
交差エントロピーの算出を無理やり行列の積で書いていましたが、やりたいことは教師データ行列と、ソフトマックス関数の出力の行列のlogをとって同じ成分同士をかけて、全部の要素を足しこめばいいので(前回記事(15)式)、
np.ndarray
の掛け算、axis引数なしのnp.sum()
で書けました。
def cross_entropy(self, X, Y): Phi = self.softmax(np.dot(X, self.W) + self.b) n_samples = X.shape[0] return -np.sum(Y*np.log(Phi))/n_samples
交差エントロピーの重み勾配
- こちらは前回記事(24)、(25)式通り。交差エントロピーをサンプル数で割ることにしたので、その重み勾配もサンプル数で割っておきます。
def grad_loss(self, X, Y): Phi = self.softmax(np.dot(X, self.W) + self.b) n_samples = X.shape[0] dW = -np.dot(X.T, Y - Phi)/n_samples db = -np.dot(np.ones([1, n_samples]), Y - Phi)/n_samples return dW, db
多クラス分類ロジスティック回帰クラス
学習実施の
fit
関数内では、ミニバッチ学習機能を入れました。その他、予測値を出力(ソフトマックス関数の出力を量子化)する
predict
関数、正答率を出力するaccuracy_score
関数を作ってます。
分類の実施
例によってアヤメのデータを使います。4つの特徴量をもつアヤメを、3つの品種クラスに分類します。
アヤメデータのロード、教師データのonehotエンコーディング、データのシャッフル、トレーニング・テストデータ分割、特徴量の標準化は全部scikit-learnの関数を使ってます。便利です。
- エポック数は500回に設定しており、学習が終了すると損失値と正答率のログをプロットします。うまく学習できてますね。
- 最終エポックのトレーニングデータとテストデータの正答率です。テストデータは96%です。中々です。
acc_train: 1.0 acc_test: 0.955555555556
scikit-learnの多クラス分類結果
- 確認のためscikit-learnでも同じことをやってみました。
scikit-learnでは、多クラス分類でも教師データをonehotエンコーディングしなくていいみたいです。スクリプト上でまた元の表記に戻してます。
LogisticRegression()
のsolver
引数で最適化ソルバーを選べるのですが、ヘルプ上には確率的勾配降下法(sgd)が見当たらなかったので、一番それっぽいのsolver="sag"
を指定しました。Stochastic Average Gradient descent solver(確率的平均勾配降下法?) sgdを改良したものみたいです。正則化パラメータ
C
はデフォルトで1.0に設定されているみたいでして、私のモデルでは正則化は考慮してないので、1e10を入れて効かないようにしています。また
tol
は、各エポックで損失値の変化がこの値よりも小さくなったら学習を終了するというもので、最大エポックまで回したかったのですごく小さな値1e-10を入れました。最後に学習結果の正答率をトレーニングデータとテストデータを出力させてます。scikit-learnに正答率を求める
accuracy_score
関数があるのでそれを使ってます。
acc_train: 1.0 acc_test: 0.955555555556
- 同じ結果になるとうれしいですね~🎵 それにしてもscikit-learnは速いです。