勾配降下法を用いてTensorflowで単回帰を実装する

前回の記事に引き続き、単回帰の勉強です。

今回用いたコードはGithub上にもあげています。

 

目標

 

損失関数の理解を深め、最小二乗法以外の方法を用いて単回帰分析をする。

 

RMSE

 

今回は損失関数に平均自乗平方根誤差を用いています。

 

データの数:n
真の値: y
予測した値: f

 

RMSE = \sqrt{\frac{1}{n}\sum_{k=1}^{n}(f_{i} - y_{i})}

 

標準偏差とほとんど同じような式ですね。

 

他によく使われるものに平均自乗誤差というのもあり、平方根を取るかどうかの差があります。
平方根を取ることによって単位を元の統計値と同じにすることができます。

 

勾配降下法とは

 

機械学習では、目的関数を定義して、その値が最小になるようにパラメータを更新して最適なパラメータを決定します。
このパラメータの更新時に使われる一つの手段に勾配降下法というのがあります。

 

適当な初期値を決め、そこから手探りで目的関数を最も小さくするパラメータを探ります。
訓練データから一つ選んで微分してもっとも傾きが急な方向に進む、また一つ選んで傾きが急な方向に進むというのを繰り返します。

 

一口に勾配降下法と言っても主に使われるものに

  • 最急降下法
  • 確率的勾配降下法
  • ミニバッチ確率的勾配降下法

などがあります。

超わかりやすいので詳しくは以下のサイトを参考にして下さい

 

【参考】
確率的勾配降下法とは何か、をPythonで動かして解説する – Qiita
勾配降下法ってなんだろう – 白猫のメモ帳
ロジスティック回帰 (勾配降下法 / 確率的勾配降下法) を可視化する – StatsFragments

 

大まかな手順は、以下のような感じになります。

  1. 適当に初期点を選ぶ
  2. 現在地における最急降下方向を計算する
  3. その方向に進む
  4. 2,3を繰り返す

 

最急降下法

 

全ての出力を用いてパラメータの更新度を決めます。

 

メリット

 

  • 単純

デメリット

 

  • すべての誤差の合計を取ってから更新するので、学習データが多いと計算コストが大きくなる
  • 1つのデータが増えるたびにすべてのデータを計算し直さないといけない(オンライン学習ができない。)
  • 最小化する目的関数が常に同じなので、間違った局所解に1度入り込んだが最後、二度とそこから抜け出せない

 

確率的勾配降下法(SGD)

 

注意としては重回帰などのように変数が複数ある時は、
変数によって取る値域が異なると、学習がうまくいかないのでそんなときはスケーリングをすると良いらしいです。

 

【参考】
Coursera Machine Learning (2): 重回帰分析、スケーリング、正規方程式 – Qiita

 

メリット

 

  • 1つの学習データだけで更新するので、学習データが増えても耐える
  • 最適解にたどりつきやすい
  • オンライン学習ができる
  • そこそこの結果が欲しければ速い

 

デメリット

 

  • 最適解にたどりつくまでに時間がかかることがある
  • 例外データに引っ張られやすい(前処理が重要!)
  • 厳密な最適解が欲しい場合は時間がかかる

 

ミニバッチ確率的勾配降下法

 

上2つの良いとこ取りで、一部の出力を使ってパラメータの更新度合いを決めます。
これによってより多くのデータを使いながらもオンライン学習をすることができます。

 

ただ、ミニバッチのサイズを大きくしすぎると、確率的勾配降下法の良さを損なうことになるので注意です。

 

実装

 

コードの一部を掲載します。

 

 

結果

 

ハイパーパラメータを以下のように設定して実行してみました。
全然学習が上手く行っていないことがわかります
learning_rate = 0.05
batch_size = 100
epoch = 1000

 

y = 2.48720455 * x + 19.91344643
Loss = 55.4958

 

次に、ハイパーパラメータを以下のように設定して実行してみました。
割とうまく行っています。
learning_rate = 0.01
batch_size = 200
epoch = 100000

 

y = -2.45794249 * x + 134.38780212
Loss = 23.769

 

さいごに

 

今回はTensorFlowの勉強の為に、わざわざTensorFlowで実装してみましたが、
やはりこの辺のシンプルな機械学習だとscikit-learnで実装するのと比べるとちょっと手間がありますね。

コメントを残す