本例中使用到的图片来自开源数据集 — MNIST. MNIST 是一个入门级的机器视觉数据集, 在机器学习领域里, 我们经常用 MNIST 数据集来实验各种模型. 打开官方网站, 可以看到如下介绍信息
TensorFlow 提供了一个库, 可以自动将 MNIST 数据集下载到本地, 代码如下
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
上面的代码将自动下载数据文件, 并保存在当前代码所在目录下的 MNIST_data
文件夹下. 如果文件夹中已经有了下载好的文件, 那么不会再次下载. 有时候, 直接使用上面的代码下载数据文件, 可能会因网络不佳而出错, 这时, 可以提前下载好放在 MNIST_data
文件夹下.
代码中的 one_hot=True
, 表示将样本标签转化为 one_hot 编码. 我们先来了解一下数据的基本形态
print('Training data size: {}'.format(mnist.train.num_examples))
print('Testing data size: {}'.format(mnist.test.num_examples))
print('Validating data size: {}'.format(mnist.validation.num_examples))
print('Training image shape: {}'.format(mnist.train.images.shape))
print('Training label shape: {}'.format(mnist.train.labels.shape))
输出
Training data size: 55000
Testing data size: 10000
Validating data size: 5000
Training image shape: (55000, 784)
Training label shape: (55000, 10)
样本包含三部分: 训练数据集、测试数据集、验证数据集, 分别有 55000、10000、5000 条记录, 分别用于模型训练, 模型训练过程中的效果评估, 以及最终模型的效果评估.
MNIST 数据集的图片是 Pixel, 因此一幅图就是 1 行 784 列. 训练集包含 55000 个样本, 因此 mnist.train.images
是一个形状为 (55000, 784) 的张量, 张量的每个元素, 表示图片的像素点, 值介于 0~255 之间.
查看一个样本
import matplotlib.pyplot as plt
sample_image = mnist.train.images[1]
sample_label = mnist.train.labels[1]
print('Sample label: {}'.format(sample_label))
im = sample_image.reshape(-1, 28)
plt.imshow(im)
plt.show()
输出 Sample label: [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
, 这时 one_hot 编码, 表示 3, 输出图片如下.
完整代码
# encoding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from sklearn.metrics import confusion_matrix, classification_report
from tensorflow.examples.tutorials.mnist import input_data
def define_graph(learning_rate=0.001):
g = tf.Graph()
with g.as_default():
x = tf.placeholder(tf.float32, shape=(None, 784), name='image')
y = tf.placeholder(tf.float32, shape=(None, 10), name='label')
with tf.name_scope('regression'):
w = tf.Variable(tf.truncated_normal([784, 10], stddev=0.1))
b = tf.Variable(tf.constant(0.1, shape=[10]))
y_ = tf.add(tf.matmul(x, w), b)
with tf.name_scope('loss'):
cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=y_)
loss = tf.reduce_mean(cross_entropy)
with tf.name_scope('train'):
optimizer = tf.train.AdamOptimizer(learning_rate)
train = optimizer.minimize(loss)
with tf.name_scope('predict'):
predict = tf.argmax(tf.nn.softmax(y_), 1)
with tf.name_scope('accuracy'):
label = tf.argmax(y, 1)
correct = tf.equal(predict, label)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
return [g, x, y, loss, train, predict, accuracy]
def main(argv):
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
[g, x, y, loss, train, predict, accuracy] = define_graph(FLAGS.learning_rate)
test_feed = {x: mnist.test.images, y: mnist.test.labels}
template = 'Batch: {}, loss: {:.5f}, accuracy: {:.2f}%'
with tf.Session(graph=g) as sess:
init = tf.global_variables_initializer()
sess.run(init)
for i in range(1, FLAGS.train_step + 1):
xs, ys = mnist.train.next_batch(FLAGS.batch_size)
sess.run(train, feed_dict={x: xs, y: ys})
if i % FLAGS.display_step == 0:
[curr_loss, curr_accuracy] = sess.run([loss, accuracy], feed_dict=test_feed)
print(template.format(i, curr_loss, 100 * curr_accuracy))
curr_test = np.argmax(mnist.test.labels, 1)
curr_predict = sess.run(predict, feed_dict=test_feed)
print(classification_report(curr_test, curr_predict, labels=range(10)))
print(confusion_matrix(curr_test, curr_predict, labels=range(10)))
if __name__ == '__main__':
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.logging.set_verbosity(tf.logging.INFO)
tf.flags.DEFINE_string('data_dir', 'MNIST_data', 'input data directory')
tf.flags.DEFINE_integer('batch_size', 50, 'batch size')
tf.flags.DEFINE_integer('train_step', 20000, 'training steps')
tf.flags.DEFINE_integer('display_step', 100, 'display step')
tf.flags.DEFINE_float('learning_rate', 1e-4, 'learning rate')
FLAGS = tf.flags.FLAGS
tf.app.run(main)
计算图
执行程序输出
Batch: 100, loss: 2.08542, accuracy: 26.50%
Batch: 200, loss: 1.80991, accuracy: 42.29%
Batch: 300, loss: 1.59328, accuracy: 54.29%
Batch: 400, loss: 1.41938, accuracy: 63.38%
Batch: 500, loss: 1.27804, accuracy: 69.38%
......
Batch: 19600, loss: 0.28959, accuracy: 91.88%
Batch: 19700, loss: 0.28959, accuracy: 91.86%
Batch: 19800, loss: 0.28898, accuracy: 92.03%
Batch: 19900, loss: 0.28897, accuracy: 92.02%
Batch: 20000, loss: 0.28862, accuracy: 92.01%
precision recall f1-score support
0 0.95 0.98 0.97 980
1 0.97 0.98 0.97 1135
2 0.93 0.88 0.91 1032
3 0.90 0.90 0.90 1010
4 0.92 0.93 0.93 982
5 0.90 0.85 0.88 892
6 0.94 0.96 0.95 958
7 0.93 0.91 0.92 1028
8 0.87 0.88 0.87 974
9 0.89 0.90 0.90 1009
avg / total 0.92 0.92 0.92 10000
[[ 964 0 2 1 0 3 7 1 2 0]
[ 0 1114 2 2 0 3 4 1 9 0]
[ 9 7 913 19 12 1 11 12 42 6]
[ 3 2 18 913 0 29 3 12 21 9]
[ 2 2 6 0 914 1 9 2 10 36]
[ 10 3 4 38 9 762 17 5 34 10]
[ 10 3 6 1 8 10 916 2 2 0]
[ 1 9 23 8 9 0 0 937 3 38]
[ 8 7 9 24 9 27 11 13 857 9]
[ 10 6 2 11 33 8 0 20 8 911]]
从以上输出可以看出, 仅用一个神经元, 就能取得较好的效果, 整体准确率在 92.01%. 可是如果还想取得更好的效果, 这种浅层的神经网络就无能为力了, 下一篇, 构建卷积神经网络 (CNN) 识别手写数字, 用于这个识别任务, 将取得高达 99.31% 的准确率.