多クラス分類ロジスティック回帰
目次
はじめに
最近は忙しくて、全然ブログを更新してませんでした。久しぶりの更新です。
多クラス分類のロジスティック回帰のお勉強ログです。
過去に2クラス分類のロジスティック回帰の記事を書きましたが、その拡張版です。
それにしても、過去の記事を読み返してみると、つくづく文章が下手だなと思ってしまいます。
冗長的な表現が多く、自分で読み返してもよくわからない箇所が多いなと。。
今回は必要なことのみを箇条書きで書くようにしてみました。
モデルの概要
特徴量ベクトル ] から、多クラス分類を実施するロジスティック回帰を考えていきます。
ロジスティック回帰の出力がクラス数 個に分類される場合を想定し、以下のようなモデルを考えます。
行列やベクトルの下に書いてある の中には、わかりやすいように次数を表記しておきました。
が活性化関数、 が損失関数、 が損失値を表しています。
多クラス分類の場合は、活性化関数にソフトマックス関数、損失関数に交差エントロピーを用います。
One-hot表現
多クラス分類では、教師データとしてOne-hot表現のベクトルを用います。
One-hot表現とは、成分の一つだけが1でその他は0であるベクトルで表現する方法です。
例えばクラスラベルが 個ある分類では、以下のように表現してやります。
- クラスラベルが1のものは ]
- クラスラベルが2のものは ]
- クラスラベルがcのものは ]
ソフトマックス関数
- ソフトマックス関数の入力を ] 、 出力を ] としたとき、ソフトマックス関数は以下のように定義されます。
- ] として、ベクトル演算で無理やり書くと、
このソフトマックス関数の出力の各成分は、全部足すと1になっており、それぞれの成分 は、クラスラベルが である確率を表していると考えます。
例えば、クラス数が3の場合、ソフトマックス関数の出力が ] であったとしましょう。これが意味することは、クラスラベルが1である確率が20%、クラスラベルが2である確率が70%、クラスラベルが3である確率が10%であり、つまりクラスラベルが2である確率が一番高いという事になります。
教師データのOne-hot表現についても同様に考えることができます。
例えば ] という教師データがあったときは、クラスラベルが1である確率が100%で、その他のクラスラベルについては0%という意味です。
ソフトマックス関数の微分
(4)式の を で微分することを考えます。 とおくと、
の時、
- の時、
- クロネッカーのデルタ を使って一本の式に書くと、
- 行列形式で並べて書いたらどうなるのかな?( を縦に を横に並べてみます。)
- う〜ん。これ以上まとめられない。。
交差エントロピー
- クラスラベルが である教師データ をOne-hot表現で表すと、 番目の要素が1でそれ以外が0のベクトルとして表せます。
- 一つの特徴量ベクトル からそのクラスラベルの確率ベクトル ] が与えられた時、クラスラベルが である条件付確率 は、
- これをクラスラベルがどの場合でもいいように、 ] を使ってより一般的に表すと、以下のように書くことができます。多クラスの場合のベルヌーイ分布ですね。
更にここから、この確率 の確からしさを表す尤度を計算したいのですが、尤度を計算するためには、いくつかデータサンプルが必要なので、更にインデックスを増やす形になります。
データサンプルのインデックスを とし 個のデータサンプルがあるとします。
(12)式にデータサンプルのインデックスを追加してやります。
- 尤度 は、(13)式の確率を各データサンプルで掛け合わせたものなので、
- 対数尤度に-1をかけたものを交差エントロピー(以下 )と呼び、最小化する損失関数として利用します。
交差エントロピーを重みで微分
学習時に重みを更新するために、損失関数の重み勾配を求める必要があります。
を考えます。
- 同様に は、
行列による表記
- 本来は行列ではなくて、テンソルとして考えるべきなんでしょうが、プログラムで書いたときにnumpyで書きやすいようになるべく行列、ベクトル形式で書いていこうと思います。
特徴量行列
- 行は各データサンプル、列は各特徴量を表します。
特徴量行列×重み行列+閾値
活性化関数(ソフトマックス関数)の出力
- これを更に分解するのは難しいのでやめときます。
損失関数(交差エントロピー)の出力
は、行が各データサンプル、列が各クラスラベルを表す教師データの行列です。
はアダマール積を表し、行列の各要素を掛け算する記号です。
損失関数(交差エントロピー)の重み微分
- について、(16)式の を縦に を横に並べます。
- について、(17)式の を横に並べます。
重みの更新
- 学習率を とし、(24)、(25)式で求めた を使って重みと閾値を更新します。