利用 Softmax 回归识别手写数字

本例中使用到的图片来自开源数据集 — MNIST. MNIST 是一个入门级的机器视觉数据集, 在机器学习领域里, 我们经常用 MNIST 数据集来实验各种模型. 打开官方网站, 可以看到如下介绍信息

MNIST 数据集

TensorFlow 提供了一个库, 可以自动将 MNIST 数据集下载到本地, 代码如下

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

上面的代码将自动下载数据文件, 并保存在当前代码所在目录下的 MNIST_data 文件夹下. 如果文件夹中已经有了下载好的文件, 那么不会再次下载. 有时候, 直接使用上面的代码下载数据文件, 可能会因网络不佳而出错, 这时, 可以提前下载好放在 MNIST_data 文件夹下.

代码中的 one_hot=True, 表示将样本标签转化为 one_hot 编码. 我们先来了解一下数据的基本形态

print('Training data size: {}'.format(mnist.train.num_examples))
print('Testing data size: {}'.format(mnist.test.num_examples))
print('Validating data size: {}'.format(mnist.validation.num_examples))
print('Training image shape: {}'.format(mnist.train.images.shape))
print('Training label shape: {}'.format(mnist.train.labels.shape))

输出

Training data size: 55000
Testing data size: 10000
Validating data size: 5000
Training image shape: (55000, 784)
Training label shape: (55000, 10)

样本包含三部分: 训练数据集、测试数据集、验证数据集, 分别有 55000、10000、5000 条记录, 分别用于模型训练, 模型训练过程中的效果评估, 以及最终模型的效果评估.

MNIST 数据集的图片是 28×2828\times28 Pixel, 因此一幅图就是 1 行 784 列. 训练集包含 55000 个样本, 因此 mnist.train.images 是一个形状为 (55000, 784) 的张量, 张量的每个元素, 表示图片的像素点, 值介于 0~255 之间.

查看一个样本

import matplotlib.pyplot as plt
sample_image = mnist.train.images[1]
sample_label = mnist.train.labels[1]
print('Sample label: {}'.format(sample_label))
im = sample_image.reshape(-1, 28)
plt.imshow(im)
plt.show()

输出 Sample label: [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.], 这时 one_hot 编码, 表示 3, 输出图片如下.

MNIST 中的手写数字样例

完整代码

# encoding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from sklearn.metrics import confusion_matrix, classification_report
from tensorflow.examples.tutorials.mnist import input_data


def define_graph(learning_rate=0.001):
    g = tf.Graph()
    with g.as_default():
        x = tf.placeholder(tf.float32, shape=(None, 784), name='image')
        y = tf.placeholder(tf.float32, shape=(None, 10), name='label')

        with tf.name_scope('regression'):
            w = tf.Variable(tf.truncated_normal([784, 10], stddev=0.1))
            b = tf.Variable(tf.constant(0.1, shape=[10]))
            y_ = tf.add(tf.matmul(x, w), b)

        with tf.name_scope('loss'):
            cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=y_)
            loss = tf.reduce_mean(cross_entropy)

        with tf.name_scope('train'):
            optimizer = tf.train.AdamOptimizer(learning_rate)
            train = optimizer.minimize(loss)

        with tf.name_scope('predict'):
            predict = tf.argmax(tf.nn.softmax(y_), 1)

        with tf.name_scope('accuracy'):
            label = tf.argmax(y, 1)
            correct = tf.equal(predict, label)
            accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
    return [g, x, y, loss, train, predict, accuracy]


def main(argv):
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
    [g, x, y, loss, train, predict, accuracy] = define_graph(FLAGS.learning_rate)
    test_feed = {x: mnist.test.images, y: mnist.test.labels}
    template = 'Batch: {}, loss: {:.5f}, accuracy: {:.2f}%'
    with tf.Session(graph=g) as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        for i in range(1, FLAGS.train_step + 1):
            xs, ys = mnist.train.next_batch(FLAGS.batch_size)
            sess.run(train, feed_dict={x: xs, y: ys})
            if i % FLAGS.display_step == 0:
                [curr_loss, curr_accuracy] = sess.run([loss, accuracy], feed_dict=test_feed)
                print(template.format(i, curr_loss, 100 * curr_accuracy))
        curr_test = np.argmax(mnist.test.labels, 1)
        curr_predict = sess.run(predict, feed_dict=test_feed)
        print(classification_report(curr_test, curr_predict, labels=range(10)))
        print(confusion_matrix(curr_test, curr_predict, labels=range(10)))


if __name__ == '__main__':
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.flags.DEFINE_string('data_dir', 'MNIST_data', 'input data directory')
    tf.flags.DEFINE_integer('batch_size', 50, 'batch size')
    tf.flags.DEFINE_integer('train_step', 20000, 'training steps')
    tf.flags.DEFINE_integer('display_step', 100, 'display step')
    tf.flags.DEFINE_float('learning_rate', 1e-4, 'learning rate')
    FLAGS = tf.flags.FLAGS
    tf.app.run(main)

计算图

计算图

执行程序输出

Batch: 100, loss: 2.08542, accuracy: 26.50%
Batch: 200, loss: 1.80991, accuracy: 42.29%
Batch: 300, loss: 1.59328, accuracy: 54.29%
Batch: 400, loss: 1.41938, accuracy: 63.38%
Batch: 500, loss: 1.27804, accuracy: 69.38%
......
Batch: 19600, loss: 0.28959, accuracy: 91.88%
Batch: 19700, loss: 0.28959, accuracy: 91.86%
Batch: 19800, loss: 0.28898, accuracy: 92.03%
Batch: 19900, loss: 0.28897, accuracy: 92.02%
Batch: 20000, loss: 0.28862, accuracy: 92.01%
             precision    recall  f1-score   support

          0       0.95      0.98      0.97       980
          1       0.97      0.98      0.97      1135
          2       0.93      0.88      0.91      1032
          3       0.90      0.90      0.90      1010
          4       0.92      0.93      0.93       982
          5       0.90      0.85      0.88       892
          6       0.94      0.96      0.95       958
          7       0.93      0.91      0.92      1028
          8       0.87      0.88      0.87       974
          9       0.89      0.90      0.90      1009

avg / total       0.92      0.92      0.92     10000

[[ 964    0    2    1    0    3    7    1    2    0]
 [   0 1114    2    2    0    3    4    1    9    0]
 [   9    7  913   19   12    1   11   12   42    6]
 [   3    2   18  913    0   29    3   12   21    9]
 [   2    2    6    0  914    1    9    2   10   36]
 [  10    3    4   38    9  762   17    5   34   10]
 [  10    3    6    1    8   10  916    2    2    0]
 [   1    9   23    8    9    0    0  937    3   38]
 [   8    7    9   24    9   27   11   13  857    9]
 [  10    6    2   11   33    8    0   20    8  911]]

从以上输出可以看出, 仅用一个神经元, 就能取得较好的效果, 整体准确率在 92.01%. 可是如果还想取得更好的效果, 这种浅层的神经网络就无能为力了, 下一篇, 构建卷积神经网络 (CNN) 识别手写数字, 用于这个识别任务, 将取得高达 99.31% 的准确率.

参考文献