読者です 読者をやめる 読者になる 読者になる

Pythonと機械学習

Pythonも機械学習も初心者ですが、頑張ってこのブログで勉強してこうと思います。

ランダムフォレスト

目次

はじめに

ランダムフォレストは複数の決定木学習による多数決で学習結果を決定するアルゴリズムです。

複数の学習機で学習させ、それぞれの学習機の予測結果を多数決で決定する手法を、アンサンブル学習といい、一般的によく使われているみたいです。

ランダムフォレストはアンサンブル学習の一種です。

ランダムに選択したトレーニングサンプルと特徴量で個々の決定木を学習させてやります。

ランダム選択+多数決によりトレーニングサンプル中のノイズを除去する効果があり、過学習を防ぐ効果があります。

各決定木で学習を実施する時は、トレーニングサンプルを全て使わずにランダムに選択した少し少ないデータを使ってやります。

ブートストラップサンプリング

トレーニングサンプルが{m}個ある時、重複を許して{m}個をランダムに選択する方法をブートストラップサンプリングと言います。

ランダムフォレストでは各決定木の学習に、このブートストラップサンプリングという方法でトレーニングサンプルをランダムに選択します。

もちろんこの様な選択の仕方では、全く同じデータがいくつかダブってしまいますがそのままにしておきます。

例えばクラスラベルがA、B、Cという3個のデータを含むノードがあった時Aの割合は1/3ですが、Aがダブってがあったとすると(つまりA、A、Bの様な感じです)Aの割合は2/3になります。

つまりデータがダブってあると、ある特徴量でノードを2つに分けた時に得られる情報利得が変化します。

また、ランダムフォレストでは、特徴量もランダムに選択します。{n}個ある特徴量からそれよりも少ない{d}個の特徴量をランダムに選択します。

一般的に{d=\sqrt{n}}にすると良いらしいです。(普通は{\sqrt{n}}は整数にはならないので自作スクリプト中では小数点以下を四捨五入してやることにします。)

特徴量の重要度

scikit-learnのランダムフォレストでは特徴量の重要度も出力できます。

scikit-learnのソースコードで確認しましたが、各決定木で特徴量の重要度を求めて、それらの平均値として算出しているみたいです。自作スクリプト版ではこちらも考慮しようと思います。

自作スクリプト

決定木ができてしまうと、ランダムフォレストをコーディングするのはそれほど難しくありません。

以下自作スクリプトです。scikit-learnのランダムフォレストと結果を比較しています。

RandomForestの引数、 n_estimatorsで決定木の個数を指定し、random_stateで乱数のシードを指定します。

スコアと特徴量の重要度比較

--------------------------------------------------
my random forest score:0.977777777778
sklearn random forest score:0.977777777778

--------------------------------------------------
my random forest feature importances:
     petal length (cm) : 0.540643113442
     petal width (cm) : 0.459356886558
sklearn random forest feature importances:
     petal length (cm) : 0.546606134925
     petal width (cm) : 0.453393865075

決定領域

自作スクリプト

f:id:darden:20170105224645p:plain:w600

scikit-learn版

f:id:darden:20170105224709p:plain:w600

いくつかブートストラップサンプリングの乱数シードを振ってみましたが、どうも自作版とscikit-learn版で結果が異なっています。

違いの理由が分からず、結構悩んでしまいました。

scikit-learnのソースコードを確認したのですが、scikit-learnでは各決定木のノード分割の時に、ブートストラップサンプリングでダブったトレーニングサンプルの個数で重み付けして情報利得を計算している様な感じでした。

結局複雑すぎてよくわかりませんでしたが、あまり悩んでもしょうがないのでその内またコード探索してみたいと思います。

もしどなたか分かる人いれば教えていただけるとありがたいです。

おわりに

最近はPythonに慣れて来たのでscikit-learnのソースコードも少し読めるようになって来ました。

Jupyterもいいですが、PyCharmはCtrlを押しながら関数をクリックすると、その関数の定義に飛べるのでソースコードを読むときにすごく便利です。

前回の記事で自作版とscikit-learn版で決定木の学習速度が随分違うなーと思っていたら、scikit-learnではコアな部分では爆速のCythonを使っているみたいですね。

爆速Cythonもその内勉強したいですね。