构建卷积神经网络 (CNN) 识别手写数字
接上一篇利用 Softmax 回归识别手写数字, 本文一步步构建一个卷积神经网络, 同样用于识别手写数字, 将取得惊人效果, 准确率高达 99.31%.
完整代码
# encoding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from sklearn.metrics import confusion_matrix, classification_report
from tensorflow.examples.tutorials.mnist import input_data
def define_cnn(learning_rate=0.001):
g = tf.Graph()
with g.as_default():
x = tf.placeholder(tf.float32, shape=(None, 784), name='image')
y = tf.placeholder(tf.float32, shape=(None, 10), name='label')
with tf.name_scope('reshape'):
x_image = tf.reshape(x, [-1, 28, 28, 1])
with tf.name_scope('conv1'):
w_conv1 = tf.Variable(tf.truncated_normal([5, 5, 1, 32], stddev=0.1))
b_conv1 = tf.Variable(tf.constant(0.1, shape=[32]))
h_conv1 = tf.nn.relu(tf.nn.conv2d(x_image, w_conv1, strides=[1, 1, 1, 1], padding='SAME') + b_conv1)
with tf.name_scope('pool1'):
h_pool1 = tf.nn.max_pool(h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
with tf.name_scope('conv2'):
w_conv2 = tf.Variable(tf.truncated_normal([5, 5, 32, 64], stddev=0.1))
b_conv2 = tf.Variable(tf.constant(0.1, shape=[64]))
h_conv2 = tf.nn.relu(tf.nn.conv2d(h_pool1, w_conv2, strides=[1, 1, 1, 1], padding='SAME') + b_conv2)
with tf.name_scope('pool2'):
h_pool2 = tf.nn.max_pool(h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
with tf.name_scope('fc1'):
w_fc1 = tf.Variable(tf.truncated_normal([7 * 7 * 64, 1024], stddev=0.1))
b_fc1 = tf.Variable(tf.constant(0.1, shape=[1024]))
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)
with tf.name_scope('dropout'):
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
with tf.name_scope('fc2'):
w_fc2 = tf.Variable(tf.truncated_normal([1024, 10], stddev=0.1))
b_fc2 = tf.Variable(tf.constant(0.1, shape=[10]))
y_ = tf.matmul(h_fc1_drop, w_fc2) + b_fc2
with tf.name_scope('loss'):
cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=y_)
loss = tf.reduce_mean(cross_entropy)
with tf.name_scope('train'):
optimizer = tf.train.AdamOptimizer(learning_rate)
train = optimizer.minimize(loss)
with tf.name_scope('predict'):
predict = tf.argmax(tf.nn.softmax(y_), 1)
with tf.name_scope('accuracy'):
label = tf.argmax(y, 1)
correct = tf.equal(predict, label)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
return [g, x, y, keep_prob, loss, train, predict, accuracy]
def main(argv):
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
[g, x, y, keep_prob, loss, train, predict, accuracy] = define_cnn(FLAGS.learning_rate)
test_feed = {x: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0}
template = 'Batch: {}, loss: {:.5f}, accuracy: {:.2f}%'
with tf.Session(graph=g) as sess:
init = tf.global_variables_initializer()
sess.run(init)
for i in range(1, FLAGS.train_step + 1):
xs, ys = mnist.train.next_batch(FLAGS.batch_size)
sess.run(train, feed_dict={x: xs, y: ys, keep_prob: 0.5})
if i % FLAGS.display_step == 0:
[curr_loss, curr_accuracy] = sess.run([loss, accuracy], feed_dict=test_feed)
print(template.format(i, curr_loss, 100 * curr_accuracy))
curr_test = np.argmax(mnist.test.labels, 1)
curr_predict = sess.run(predict, feed_dict=test_feed)
print(classification_report(curr_test, curr_predict, labels=range(10)))
print(confusion_matrix(curr_test, curr_predict, labels=range(10)))
if __name__ == '__main__':
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.logging.set_verbosity(tf.logging.INFO)
tf.flags.DEFINE_string('data_dir', 'MNIST_data', 'input data directory')
tf.flags.DEFINE_integer('batch_size', 50, 'batch size')
tf.flags.DEFINE_integer('train_step', 20000, 'training steps')
tf.flags.DEFINE_integer('display_step', 100, 'display step')
tf.flags.DEFINE_float('learning_rate', 1e-4, 'learning rate')
FLAGS = tf.flags.FLAGS
tf.app.run(main)
计算图
执行程序输出
Batch: 100, loss: 0.52098, accuracy: 83.99%
Batch: 200, loss: 0.32583, accuracy: 90.25%
Batch: 300, loss: 0.25579, accuracy: 92.39%
Batch: 400, loss: 0.21144, accuracy: 93.84%
Batch: 500, loss: 0.18473, accuracy: 94.77%
......
Batch: 19600, loss: 0.02193, accuracy: 99.27%
Batch: 19700, loss: 0.02318, accuracy: 99.22%
Batch: 19800, loss: 0.02183, accuracy: 99.25%
Batch: 19900, loss: 0.02433, accuracy: 99.24%
Batch: 20000, loss: 0.02342, accuracy: 99.31%
precision recall f1-score support
0 0.99 1.00 0.99 980
1 1.00 1.00 1.00 1135
2 1.00 1.00 1.00 1032
3 0.99 1.00 0.99 1010
4 0.99 1.00 0.99 982
5 0.99 0.99 0.99 892
6 1.00 0.99 0.99 958
7 0.99 0.99 0.99 1028
8 0.99 0.99 0.99 974
9 0.99 0.99 0.99 1009
avg / total 0.99 0.99 0.99 10000
[[ 977 0 0 0 0 0 1 1 1 0]
[ 0 1130 1 2 0 0 0 1 1 0]
[ 1 0 1029 0 0 0 0 2 0 0]
[ 0 0 0 1006 0 2 0 2 0 0]
[ 0 0 0 0 978 0 0 0 0 4]
[ 1 0 0 7 0 881 1 0 1 1]
[ 3 2 0 0 1 1 950 0 1 0]
[ 0 0 2 1 0 0 0 1021 1 3]
[ 2 0 2 2 0 1 0 2 962 3]
[ 1 0 0 1 5 3 0 1 1 997]]