Pythonと機械学習

Pythonも機械学習も初心者ですが、頑張ってこのブログで勉強してこうと思います。

多クラス分類ロジスティック回帰

目次

はじめに

最近は忙しくて、全然ブログを更新してませんでした。久しぶりの更新です。

多クラス分類のロジスティック回帰のお勉強ログです。

過去に2クラス分類のロジスティック回帰の記事を書きましたが、その拡張版です。

それにしても、過去の記事を読み返してみると、つくづく文章が下手だなと思ってしまいます。

冗長的な表現が多く、自分で読み返してもよくわからない箇所が多いなと。。

今回は必要なことのみを箇条書きで書くようにしてみました。

モデルの概要

  • 特徴量ベクトル  \mathbf{x} = [x_{1}, \cdots, x_{m} ] から、多クラス分類を実施するロジスティック回帰を考えていきます。

  • ロジスティック回帰の出力がクラス数 {c} 個に分類される場合を想定し、以下のようなモデルを考えます。


\underset{(1 \times c)}{\mathbf{z}} = \underset{(1 \times m)}{\mathbf{x}}\underset{(m \times c)}{\mathbf{W}} + \underset{(1 \times c)}{\mathbf{b}}
\tag{1}


\underset{(1 \times c)}{\mathbf{\phi}} = f(\underset{(1 \times c)}{\mathbf{z}})
\tag{2}


L = g(\underset{(1 \times c)}{\mathbf{\phi}})
\tag{3}

  • \mathbf{W},~\mathbf{b} は重みと閾値を表し、多クラスの場合は重みは行列、閾値はベクトルとなります。

  • 行列やベクトルの下に書いてある {()} の中には、わかりやすいように次数を表記しておきました。

  • f が活性化関数、 g が損失関数、 L が損失値を表しています。

  • 多クラス分類の場合は、活性化関数にソフトマックス関数、損失関数に交差エントロピーを用います。

One-hot表現

  • 多クラス分類では、教師データとしてOne-hot表現のベクトルを用います。

  • One-hot表現とは、成分の一つだけが1でその他は0であるベクトルで表現する方法です。

  • 例えばクラスラベルが c 個ある分類では、以下のように表現してやります。

    • クラスラベルが1のものは  [\underbrace{1,~0,\cdots,~0}_{c個}]
    • クラスラベルが2のものは  [\underbrace{0,~1,\cdots,~0}_{c個}]

    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\vdots

    • クラスラベルがcのものは  [\underbrace{0,~0,\cdots,~1}_{c個}]

ソフトマックス関数

  • ソフトマックス関数の入力を \mathbf{z}=[z_1,\cdots,z_c] 、 出力を \mathbf{\phi}=[\varphi_1,\cdots,\varphi_c] としたとき、ソフトマックス関数は以下のように定義されます。

\displaystyle
 \left[\varphi_1,\cdots,\varphi_c \right]  = \frac{1}{\displaystyle \sum_{i=1}^{c} e^{z_i}} \left[e^{z_1},\cdots,e^{z_c} \right]
\tag{4}

  • \mathbf{\epsilon}=[\underbrace{1,~1,\cdots,~1}_{c個}] として、ベクトル演算で無理やり書くと、


\displaystyle
\mathbf{\phi} = \frac{1}{e^{\mathbf{z}}\mathbf{\epsilon}^T}e^{\mathbf{z}}
\tag{5}

  • このソフトマックス関数の出力の各成分は、全部足すと1になっており、それぞれの成分 \varphi_i は、クラスラベルが i である確率を表していると考えます。

  • 例えば、クラス数が3の場合、ソフトマックス関数の出力が [0.2, ~0.7, ~0.1] であったとしましょう。これが意味することは、クラスラベルが1である確率が20%、クラスラベルが2である確率が70%、クラスラベルが3である確率が10%であり、つまりクラスラベルが2である確率が一番高いという事になります。

  • 教師データのOne-hot表現についても同様に考えることができます。

  • 例えば [1,~0,\cdots,~0] という教師データがあったときは、クラスラベルが1である確率が100%で、その他のクラスラベルについては0%という意味です。

ソフトマックス関数の微分

  • (4)式の \varphi_jz_i微分することを考えます。 A=\displaystyle \sum^{c}_{i=1}{e^{z_i}} とおくと、

  • j\neq i の時、


\displaystyle
\begin{eqnarray*}
\frac{\partial \varphi_j}{\partial z_i}
&=& \frac{ - e^{z_j}e^{z_i}}{A^2} \\
&=& - \varphi_j \varphi_i
\end{eqnarray*}
\tag{6}

  • j=i の時、

\displaystyle
\begin{eqnarray*}
\frac{\partial \varphi_j}{\partial z_i}
&=& \frac{e^{z_i}A - e^{z_i}e^{z_i}}{A^2} \\
&=& \varphi_i(1 - \varphi_i)
\end{eqnarray*}
\tag{7}


\displaystyle
\frac{\partial \varphi_j}{\partial z_i}
= \varphi_i(\delta_{ij} - \varphi_j)\\
\begin{align}
    where,~~
    \delta_{ij} =
    \begin{cases}
        1 & (i=j) \\
        0 & (i \neq j )
    \end{cases}
\end{align}
\tag{8}

  • 行列形式で並べて書いたらどうなるのかな?( i を縦に j を横に並べてみます。)

\displaystyle
\begin{eqnarray*}
\frac{\partial \varphi_j}{\partial z_i}
&=&
    \left[
        \begin{array}{ccc}
            \varphi_1(1-\varphi_1) & \cdots & -\varphi_1\varphi_c \\
            \vdots & \ddots & \vdots \\
            -\varphi_c\varphi_1 & \cdots & \varphi_c(1-\varphi_c) \\
        \end{array}
    \right] \\
&=&
    \left[
        \begin{array}{ccc}
            -\varphi_1\varphi_1 & \cdots & -\varphi_1\varphi_c \\
            \vdots & \ddots & \vdots \\
            -\varphi_c\varphi_1 & \cdots & -\varphi_c\varphi_c \\
        \end{array}
    \right]+
    \left[
        \begin{array}{ccc}
            \varphi_1 & \cdots & 0 \\
            \vdots & \ddots & \vdots \\
            0 & \cdots & \varphi_c \\
        \end{array}
    \right] \\
&=&
    -\left[
        \begin{array}{c}
            \varphi_1\\
            \vdots\\
            \varphi_c\\
        \end{array}
    \right]
    \left[
        \begin{array}{ccc}
            \varphi_1,&\cdots,&\varphi_c
        \end{array}
    \right]
    +
    \left[
        \begin{array}{ccc}
            \varphi_1 & \cdots & 0 \\
            \vdots & \ddots & \vdots \\
            0 & \cdots & \varphi_c \\
        \end{array}
    \right] \\
&=&
    -\mathbf{\phi}^T\mathbf{\phi} +
    \left[
        \begin{array}{ccc}
            \varphi_1 & \cdots & 0 \\
            \vdots & \ddots & \vdots \\
            0 & \cdots & \varphi_c \\
        \end{array}
    \right] \\
\end{eqnarray*}
\tag{9}

  • う〜ん。これ以上まとめられない。。

交差エントロピー

  • クラスラベルが k である教師データ \mathbf{y}_{class=k} をOne-hot表現で表すと、 k 番目の要素が1でそれ以外が0のベクトルとして表せます。

\displaystyle
\begin{eqnarray*}
\mathbf{y}_{class=k}
&=& \left[y_1,\cdots,~y_k,\cdots,~y_c\right] \\
&=& \left[~0~,\cdots,~\underset{k番目}{1},\cdots,~0\right]  \\
\end{eqnarray*}
\tag{10}

  • 一つの特徴量ベクトル \mathbf{x} からそのクラスラベルの確率ベクトル \mathbf{\phi}=[\varphi_1,\cdots,\varphi_c] が与えられた時、クラスラベルが k である条件付確率 p(\mathbf{y}=\mathbf{y}_{class=k}|\phi) は、

\displaystyle
p(\mathbf{y}=\mathbf{y}_{class=k}|\phi) = \varphi_k
\tag{11}

  • これをクラスラベルがどの場合でもいいように、 \mathbf{y}=[y_1,\cdots,y_c] を使ってより一般的に表すと、以下のように書くことができます。多クラスの場合のベルヌーイ分布ですね。

\displaystyle
\begin{eqnarray*}
p\left(\mathbf{y}|\phi\right)
&=& \varphi_1^{y_1}\varphi_2^{y_2}\cdots\varphi_c^{y_c} \\
&=& \prod_{k=1}^{c}\varphi_k^{y_k}
\end{eqnarray*}
\tag{12}

  • 更にここから、この確率 p(\mathbf{y}|\mathbf{\phi}) の確からしさを表す尤度を計算したいのですが、尤度を計算するためには、いくつかデータサンプルが必要なので、更にインデックスを増やす形になります。

  • データサンプルのインデックスを i とし n 個のデータサンプルがあるとします。

  • (12)式にデータサンプルのインデックスを追加してやります。


\displaystyle
p\left(\mathbf{y}_i|\phi_i\right) =
\prod_{k=1}^{c}\varphi_{ik}^{y_{ik}}
\tag{13}

  • 尤度 l は、(13)式の確率を各データサンプルで掛け合わせたものなので、

\displaystyle
\begin{eqnarray*}
l
&=& \prod_{i=1}^{n}p\left(\mathbf{y}_i|\phi_i\right) \\
&=& \prod_{i=1}^{n}\prod_{k=1}^{c}\varphi_{ik}^{y_{ik}}
\end{eqnarray*}
\tag{14}

  • 対数尤度に-1をかけたものを交差エントロピー(以下 L )と呼び、最小化する損失関数として利用します。

\displaystyle
\begin{eqnarray*}
L
&=& -\log\left(l\right) \\
&=& -\log\left(\prod_{i=1}^{n}\prod_{k=1}^{c}\varphi_{ik}^{y_{ik}}\right) \\
&=& -\sum_{i=1}^{n}\sum_{k=1}^{c}y_{ik}\log\varphi_{ik}
\end{eqnarray*}
\tag{15}

交差エントロピーを重みで微分

  • 学習時に重みを更新するために、損失関数の重み勾配を求める必要があります。

  • \frac{\partial L}{\partial w_{jl}} を考えます。


\displaystyle
\begin{eqnarray*}
\frac{\partial L}{\partial w_{jl}}
&=& -\sum_{i=1}^{n}\sum_{k=1}^{c} \frac{\partial}{\partial w_{jl}}\left(y_{ik}\log\varphi_{ik}\right) \\
&=& -\sum_{i=1}^{n}\sum_{k=1}^{c} y_{ik}\frac{\partial}{\partial w_{jl}}\log\varphi_{ik} \\
&=& -\sum_{i=1}^{n}\sum_{k=1}^{c} y_{ik}\frac{\partial \log\varphi_{ik}}{\partial \varphi_{ik}}\frac{\partial \varphi_{ik}}{\partial w_{jl}} \\
&=& -\sum_{i=1}^{n}\sum_{k=1}^{c} y_{ik}\frac{\partial \log\varphi_{ik}}{\partial \varphi_{ik}}\sum_{p=1}^{c}\frac{\partial \varphi_{ik}}{\partial z_{ip}}\frac{\partial z_{ip}}{\partial w_{jl}} \\
&=& -\sum_{i=1}^{n}\sum_{k=1}^{c} y_{ik}\frac{\partial \log\varphi_{ik}}{\partial \varphi_{ik}} \sum_{p=1}^{c}\frac{\partial \varphi_{ik}}{\partial z_{ip}} \sum_{q=1}^{m}\frac{\partial \left(x_{iq}w_{qp}+b_{p}\right)}{\partial w_{jl}} \\
&=& -\sum_{i=1}^{n}\sum_{k=1}^{c} y_{ik}\frac{\partial \log\varphi_{ik}}{\partial \varphi_{ik}}\frac{\partial \varphi_{ik}}{\partial z_{il}}x_{ij} ~~~~\small{(\because q=j,p=lの項のみが残る)} \\
&=& -\sum_{i=1}^{n}\sum_{k=1}^{c} y_{ik}\frac{1}{\varphi_{ik}}\frac{\partial \varphi_{ik}}{\partial z_{il}}x_{ij} \\
&=& -\sum_{i=1}^{n}\sum_{k=1}^{c} y_{ik}\frac{1}{\varphi_{ik}}\varphi_{ik}\left(\delta_{kl}-\varphi_{il} \right)x_{ij}  ~~~~\small{(\because (8)式より)} \\
&=& -\sum_{i=1}^{n}\sum_{k=1}^{c} y_{ik}\left(\delta_{kl}-\varphi_{il} \right)x_{ij} \\
&=& -\sum_{i=1}^{n}\left(\sum_{k=1}^{c} y_{ik}\delta_{kl}- \sum_{k=1}^{c}y_{ik}\varphi_{il}\right)x_{ij} \\
&=& -\sum_{i=1}^{n}\left(y_{il} - \varphi_{il}\right)x_{ij} ~~~~\small{(\because k=lの項のみが残る)}\\
\end{eqnarray*}
\tag{16}

  • 同様に \frac{\partial L}{\partial b_{l}} は、

\displaystyle
\begin{eqnarray*}
\frac{\partial L}{\partial b_{l}}
&=& -\sum_{i=1}^{n}\sum_{k=1}^{c} \frac{\partial}{\partial b_{l}}\left(y_{ik}\log\varphi_{ik}\right) \\
&=& -\sum_{i=1}^{n}\sum_{k=1}^{c} y_{ik}\frac{\partial \log\varphi_{ik}}{\partial \varphi_{ik}} \sum_{p=1}^{c}\frac{\partial \varphi_{ik}}{\partial z_{ip}} \sum_{q=1}^{m}\frac{\partial \left(x_{iq}w_{qp}+b_{p}\right)}{\partial b_{l}} \\
&=& -\sum_{i=1}^{n}\sum_{k=1}^{c} y_{ik}\frac{\partial \log\varphi_{ik}}{\partial \varphi_{ik}}\frac{\partial \varphi_{ik}}{\partial z_{il}}  ~~~~\small{(\because p=lの項のみが残る)} \\
&=& -\sum_{i=1}^{n}\sum_{k=1}^{c} y_{ik}\frac{1}{\varphi_{ik}}\varphi_{ik}\left(\delta_{kl}-\varphi_{il} \right) \\
&=& -\sum_{i=1}^{n}\sum_{k=1}^{c} y_{ik}\left(\delta_{kl}-\varphi_{il} \right)   ~~~~\small{(\because (8)式より)} \\
&=& -\sum_{i=1}^{n}\left(\sum_{k=1}^{c} y_{ik}\delta_{kl}- \sum_{k=1}^{c}y_{ik}\varphi_{il} \right) \\
&=& -\sum_{i=1}^{n}\left(y_{il} - \varphi_{il} \right)  ~~~~\small{(\because k=lの項のみが残る)}\\
\end{eqnarray*}
\tag{17}

行列による表記

  • 本来は行列ではなくて、テンソルとして考えるべきなんでしょうが、プログラムで書いたときにnumpyで書きやすいようになるべく行列、ベクトル形式で書いていこうと思います。

特徴量行列

  • 行は各データサンプル、列は各特徴量を表します。

\displaystyle
\mathbf{X}=
\left[
    \begin{array}{c}
        \mathbf{x}_1 \\
        \vdots \\
        \mathbf{x}_n \\
    \end{array}
\right]
=\left[
    \begin{array}{ccc}
        x_{11} & \cdots & x_{1m} \\
        \vdots & \ddots & \vdots \\
        x_{n1} & \cdots & x_{nm} \\
    \end{array}
\right] \\
\tag{18}

特徴量行列×重み行列+閾値


\displaystyle
\underset{(n \times c)}{\mathbf{Z}} =\underset{(n \times m)}{\mathbf{X}}\underset{(m \times c)}{\mathbf{W}}+\underset{(1 \times c)}{\mathbf{b}}
\tag{19}



\displaystyle
\underset{(n \times c)}{
        \left[
            \begin{array}{ccc}
                z_{11} & \cdots & z_{1c} \\
                \vdots & \ddots & \vdots \\
                z_{nc} & \cdots & z_{nc} \\
            \end{array}
        \right]
    } =
\underset{(n \times m)}{
        \left[
            \begin{array}{ccc}
                x_{11} & \cdots & x_{1m} \\
                \vdots & \ddots & \vdots \\
                x_{n1} & \cdots & x_{nm} \\
            \end{array}
        \right]
    }
\underset{(m \times c)}{
    \left[
        \begin{array}{ccc}
            w_{11} & \cdots & x_{1c} \\
            \vdots & \ddots & \vdots \\
            w_{m1} & \cdots & x_{mc} \\
        \end{array}
    \right]
    }
+\underset{(1 \times c)}{
    \left[
        \begin{array}{ccc}
            b_{1} & \cdots & b_{c}
        \end{array}
    \right]
    }
\tag{20}

活性化関数(ソフトマックス関数)の出力


\displaystyle
\begin{eqnarray*}
\underset{(n \times c)}{\mathbf{\Phi}}
&=& f(\underset{(n \times c)}{\mathbf{Z}} ) \\
\end{eqnarray*}
\tag{21}




\displaystyle
\underset{(n \times c)}{
        \left[
            \begin{array}{ccc}
                \varphi_{11} & \cdots & \varphi_{1c} \\
                \vdots & \ddots & \vdots \\
                \varphi_{n1} & \cdots & \varphi_{nc} \\
            \end{array}
        \right]
    } =
\underset{(n \times c)}{
        \left[
            \begin{array}{ccc}
                \frac{e^{z_{11}}}{\sum_{k=1}^{c}e^{z_{1k}}} & \cdots & \frac{e^{z_{1c}}}{\sum_{k=1}^{c}e^{z_{1k}}} \\
                \vdots & \ddots & \vdots \\
                \frac{e^{z_{n1}}}{\sum_{k=1}^{c}e^{z_{nk}}} & \cdots & \frac{e^{z_{nc}}}{\sum_{k=1}^{c}e^{z_{nk}}} \\
            \end{array}
        \right]
    }
\tag{22}

  • これを更に分解するのは難しいのでやめときます。

損失関数(交差エントロピー)の出力


\displaystyle
\begin{eqnarray*}
\underset{(1 \times 1)}{L}
&=& g(\underset{(n \times c)}{\mathbf{\Phi}} ) \\
&=& -\underset{(1 \times n)}{
    \left[
        \begin{array}{c}
            1 & \cdots & 1
        \end{array}
    \right]
    }
\underset{(n \times 1)}{
    \left[
        \begin{array}{ccc}
            \sum_{k=1}^{c}y_{1k}\log\varphi_{1k} \\
            \vdots \\
            \sum_{k=1}^{c}y_{nk}\log\varphi_{nk}
        \end{array}
    \right]
    } \\
&=& -
\underset{(1 \times n)}{
    \left[
        \begin{array}{c}
            1 & \cdots & 1
        \end{array}
    \right]
    }
\underset{(n \times k)}{
    \left[
        \begin{array}{ccc}
            y_{11}\log\varphi_{11} & \cdots & y_{1c}\log\varphi_{1c} \\
            \vdots & \ddots & \vdots \\
            y_{n1}\log\varphi_{n1} & \cdots & y_{nc}\log\varphi_{nc}\\
        \end{array}
    \right]
    } 
    \underset{(k \times 1)}{
        \left[
            \begin{array}{ccc}
                1 \\
                \vdots\\
                1\\
            \end{array}
        \right]
        } \\
&=& -
\underset{(1 \times n)}{
    \left[
        \begin{array}{c}
            1 & \cdots & 1
        \end{array}
    \right]
    }
\left(
\underset{(n \times k)}{
    \left[
        \begin{array}{ccc}
            y_{11} & \cdots & y_{1c} \\
            \vdots & \ddots & \vdots \\
            y_{n1} & \cdots & y_{nc}\\
        \end{array}
    \right]
    }
\circ
\underset{(n \times k)}{
    \left[
        \begin{array}{ccc}
            \log\varphi_{11} & \cdots & \log\varphi_{1c} \\
            \vdots & \ddots & \vdots \\
            \log\varphi_{n1} & \cdots & \log\varphi_{nc}\\
        \end{array}
    \right]
    }
\right)
\underset{(k \times 1)}{
    \left[
        \begin{array}{ccc}
            1 \\
            \vdots\\
            1\\
        \end{array}
    \right]
    } \\
&=& -
\underset{(1 \times n)}{
    \left[
        \begin{array}{c}
            1 & \cdots & 1
        \end{array}
    \right]
    }
\left(
\underset{(n \times k)}{
    \mathbf{Y}
    }
\circ
\underset{(n \times k)}{
    \log\mathbf{\Phi}
    }
\right)
\underset{(k \times 1)}{
    \left[
        \begin{array}{ccc}
            1 \\
            \vdots\\
            1\\
        \end{array}
    \right]
    }
\end{eqnarray*}
\tag{23}

  • \mathbf{Y} は、行が各データサンプル、列が各クラスラベルを表す教師データの行列です。

  •  \circ アダマール積を表し、行列の各要素を掛け算する記号です。

損失関数(交差エントロピー)の重み微分

  • \frac{\partial L}{\partial \mathbf{W}} について、(16)式の j を縦に l を横に並べます。

\displaystyle
\begin{eqnarray*}
\frac{\partial L}{\partial \mathbf{W}}
&=& -
\underset{(m \times c)}{
    \left[
        \begin{array}{ccc}
            \sum_{i=1}^{n}(y_{i1}-\varphi_{i1})x_{i1} & \cdots & \sum_{i=1}^{n}(y_{ic}-\varphi_{ic})x_{i1} \\
            \vdots & \ddots & \vdots \\
            \sum_{i=1}^{n}(y_{i1}-\varphi_{i1})x_{im} & \cdots & \sum_{i=1}^{n}(y_{ic}-\varphi_{ic})x_{im}
        \end{array}
    \right]
} \\
&=& -
\underset{(m \times n)}{
    \left[
        \begin{array}{ccc}
            x_{11} & \cdots & x_{n1} \\
            \vdots & \ddots & \vdots \\
            x_{1m} & \cdots & x_{nm}
        \end{array}
    \right]
}
\underset{(n \times c)}{
    \left[
        \begin{array}{ccc}
            y_{11}-\varphi_{11} & \cdots & y_{1c}-\varphi_{1c} \\
            \vdots & \ddots & \vdots \\
            y_{n1}-\varphi_{n1} & \cdots & y_{nc}-\varphi_{nc}
        \end{array}
    \right]
} \\
&=& -
\underset{(m \times n)}{
    \mathbf{X}^T
}
\left(
\underset{(n \times c)}{
    \mathbf{Y}
}
-
\underset{(n \times c)}{
    \mathbf{\Phi}
}
\right)
\end{eqnarray*}
\tag{24}

  • \frac{\partial L}{\partial \mathbf{b}} について、(17)式の l を横に並べます。

\displaystyle
\begin{eqnarray*}
\frac{\partial L}{\partial \mathbf{b}}
&=& -
\underset{(1 \times c)}{
    \left[
        \begin{array}{ccc}
            \sum_{i=1}^{n}(y_{i1}-\varphi_{i1}) & \cdots & \sum_{i=1}^{n}(y_{ic}-\varphi_{ic}) \\
        \end{array}
    \right]
} \\
&=& -
\underset{(1 \times n)}{
    \left[
        \begin{array}{ccc}
            1 & \cdots & 1 \\
        \end{array}
    \right]
}
\underset{(n \times c)}{
    \left[
        \begin{array}{ccc}
            y_{11}-\varphi_{11} & \cdots & y_{1c}-\varphi_{1c} \\
            \vdots & \ddots & \vdots \\
            y_{n1}-\varphi_{n1} & \cdots & y_{nc}-\varphi_{nc}
        \end{array}
    \right]
} \\
&=& -
\underset{(1 \times n)}{
    \left[
        \begin{array}{ccc}
            1 & \cdots & 1 \\
        \end{array}
    \right]
}
\left(
\underset{(n \times c)}{
    \mathbf{Y}
}
-
\underset{(n \times c)}{
    \mathbf{\Phi}
}
\right)
\end{eqnarray*}
\tag{25}

重みの更新

  • 学習率を \mu とし、(24)、(25)式で求めた \frac{\partial L}{\partial \mathbf{W}},\frac{\partial L}{\partial \mathbf{b}} を使って重みと閾値を更新します。

\mathbf{W}:=\mathbf{W}-\mu\frac{\partial L}{\partial \mathbf{W}}
\tag{26}



\mathbf{b}:=\mathbf{b}-\mu\frac{\partial L}{\partial \mathbf{b}}
\tag{27}