tensorflow1.xでNeural Network5

tensorflow1.xでNeural Network5

最近リアルが忙しい怠惰人間です。時間が随分と空いてしまい申し訳無いです。

さて、第2回目の記事ではtensorflowの学習について行いました。さて、今回はNeural Networkで多項分類を行おうと思います。

Neural Networkで分類を行う

前回までのプログラムで、Neural Networkで関数の近似を行えることがわかったと思います。しかし、現実におけるDeep Learningは不良品の分類や人の顔から個人名を挙げるなどといった分類に当たることを行っていると思います。今までのNeural Networkでは、可能なのは関数の近似のみで分類はできそうにありません。そこで、今回の記事では今までのNeuralNetworkに新要素を足すことで分類を行うための準備をしたいと思います。

では、早速新要素について見ていきましょう。

ソフトマックス関数

ソフトマックス関数は、一言で言うならばニューラルネットワークの出力の比を確立に直す関数です。第一回目の記事におけるニューラルネットワークの式での\(\sigma\)に相当します。では、式を見ていきましょう。

$$
f(x)
=
\displaystyle \frac{1}{\displaystyle \sum_{i=0}^{n} exp(x_i)}
\left(
\begin{array}{c}
exp(x_1) \\
exp(x_2) \\
\vdots \\
exp(x_n)
\end{array}
\right)
$$

この式における\(x_p\)はベクトルxのp番目の要素という意味です。

かなりわかりずらいですね。わかりやすくなるように説明をしていきます。まずは全体像を見てみましょう。ソフトマックス関数は任意の長さのベクトルを入力とし、入力と同じ長さのベクトルを出力とする関数です。出力のベクトルのk番目の要素は、以下の式で計算されます。

$$
y_k
=
\displaystyle \frac{exp(x_k)}{\displaystyle \sum_{i=0}^{n} exp(x_i)}
$$

では、この関数は何をしているのでしょうか?わかりやすく理解するには、関数exp(x)が単調増加の関数であることを思い出す必要があります。図1にexp(x)のグラフを示しておきます。

図1 exp(x)のグラフ

ソフトマックス関数がニューラルネットワークの出力の比を確立に直す関数であることは先ほど述べました。それを踏まえるとexp(x)は単調増加の関数であるがゆえに、削除しても式の意味に大きな影響を与えないことがわかります。では、各要素の計算式からexp(x) を消してみましょう。

$$
y_k
=
\displaystyle \frac{x_k}{\displaystyle \sum_{i=0}^{n} x_i}
$$

どこかで見たことがあるような式ですね。そう、ただの確立を算出する式です。ソフトマックス関数がかなり簡単な構造をしていることがわかると思います。

また、注意事項として、ソフトマックス関数は分類を行うための関数なので今までのプログラム例で扱ってきたようなスカラー値を出力とするニューラルネットワークでは使用できません。

では、ソフトマックス関数を使うことでどのようなメリットがあるのかを示します。

  • 確率で出力されるのでわかりやすい
  • ニューラルネットワークの学習がやりやすくなる。

特に、2つ目のニューラルネットワークの学習がやりやすくなると言うのが重要です。この特性があるためにソフトマックス関数は一般的に使用されます。これはどうしてかと言うと、誤差関数を作成する際に確率であれば、正しい分類がされる確率を最大化するのがやりやすいためです。(Aが正しい分類結果の時に、Aと認識されている確率が大きいほど良い結果とすると言うこと)

ソフトマックス関数はtensorflowでは以下のメソットで使用できます。

  • tf.nn.softmax(値)

多項分類機の作成

ここでは、MNISTというデータセットを使用しての分類を行なっていきたいと思います。MNISTというのは手書きの数字のデータセットのことです。数字なので、0~9の10種類の文字が存在します。MNISTについての詳しい説明は様々なサイトや本に記載されているので、今回は割愛します。では、実際にソースコードを見ていきましょう。

#インポート
import tensorflow as tf

#mnist画像読み込み用
from tensorflow.examples.tutorials.mnist import input_data

#関数の定義
# 入力の定義
x=tf.placeholder(tf.float32,(None,784))

# 重みの定義
weight=tf.Variable(tf.zeros([784,10],tf.float32))

# バイアスの定義(バッチ対応のためサイズは[1,10]に=[None,10]みたいな)
bias=tf.Variable(tf.zeros([1,10],tf.float32))

# 式の設定(weightとxのmatmulの順番に注意)
y=tf.add(tf.matmul(x,weight),bias)

# 教師データの定義
t=tf.placeholder(tf.float32,(None,10))

# ソフトマックス
softmax=tf.nn.softmax(y)

# 誤差関数の定義
loss=0.5*tf.reduce_sum(tf.square(t-softmax))

# 学習方法の定義(SGD)
train=tf.train.GradientDescentOptimizer(0.01).minimize(loss)

#学習精度の定義
#tf.equal(x,y):x=yならばTrue
#tf.argmax(x,y):xのy軸方向の最大値を返す(y=0=縦、y=1=横)
correct_prediction=tf.equal(tf.argmax(softmax,1),tf.argmax(t,1))

#tf.cast(x,y):xの型をyにキャストする
#tf.reduce_mean(x):xの要素すべての平均を返す
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

#式の定義終了

#セッションの開始
with tf.Session() as s:
#./MNISTにmnistのデータをダウンロード
mnist=input_data.read_data_sets("./MNIST/",one_hot=True)

s.run(tf.global_variables_initializer()) #変数の初期化

for i in range(0,10000,1): #10000回学習を行う

#教師データ(ミニバッチ法)の読み込み
input_pct,teacher_pct=mnist.train.next_batch(100)

#重みとバイアスの遷移を確認
if i%1000==0:
acc=s.run(accuracy,feed_dict={x:input_pct,t:teacher_pct})
print("Step:",i)
print(" accuracy=",acc,"\n")

#学習の実行
s.run(train,feed_dict={x:input_pct,t:teacher_pct})

#結果の出力
acc=s.run(accuracy,feed_dict={x:input_pct,t:teacher_pct})
print("Step:",i+1)
print(" accuracy=",acc,"\n")
<実行結果>
 Step: 0
accuracy= 0.09
Step: 1000
accuracy= 0.9
Step: 2000
accuracy= 0.89
Step: 3000
accuracy= 0.93
Step: 4000
accuracy= 0.95
Step: 5000
accuracy= 0.95
Step: 6000
accuracy= 0.91
Step: 7000
accuracy= 0.96
Step: 8000
accuracy= 0.97
Step: 9000
accuracy= 0.92
Step: 10000
accuracy= 0.92

どうでしたか、比較的単純なソースコードにしてはなかなかな精度が出ていると思います。次回以降はこのソースコードを改造して精度を上げていこうと思います。

二項分類機について

二項分類機を作成する際には、ソフトマックス関数を使用する必要はありません。何故ならば、次の章で使用する活性化関数にシグモイド関数を使用すれば良いだけだからです。ただし、ソフトマックス関数を使用しても十分に可能です。

ハードマックスについて

今回の記事で紹介したソフトマックスは結果を確率的に分類しました。しかし、確率的に分類しなくても良い場合もこの世の中には存在します。確率的に分類を行わずに、断定的に分類を行うことをハードマックスと呼ぶことがあります。

まとめ

  • ソフトマックス関数:tf.nn.softmax(値)

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です