Skip to content

3.5 图像分类数据集 (Fashion-MNIST)

在介绍 softmax 回归的实现前我们先引入一个多类图像分类数据集。它将在后面的章节中被多次使用,以方便我们观察比较算法之间在模型精度和计算效率上的区别。图像分类数据集中最常用的是手写数字识别数据集 MNIST[1]。但大部分模型在 MNIST 上的分类精度都超过了 95%。为了更直观地观察算法之间的差异,我们将使用一个图像内容更加复杂的数据集 Fashion-MNIST[2](这个数据集也比较小,只有几十 M,没有 GPU 的电脑也能吃得消)。

本节我们将使用 torchvision 包,它是服务于 PyTorch 深度学习框架的,主要用来构建计算机视觉模型。torchvision 主要由以下几部分构成:

  1. torchvision.datasets:一些加载数据的函数及常用的数据集接口;
  2. torchvision.models:包含常用的模型结构 (含预训练模型),例如 AlexNet、VGG、ResNet 等;
  3. torchvision.transforms:常用的图片变换,例如裁剪、旋转等;
  4. torchvision.utils:其他的一些有用的方法。

3.5.1 获取数据集

首先导入本节需要的包或模块。

python
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append("..") # 为了导入上层目录的 d2lzh_pytorch
import d2lzh_pytorch as d2l

下面,我们通过 torchvision 的 torchvision.datasets 来下载这个数据集。第一次调用时会自动从网上获取数据。我们通过参数 train 来指定获取训练数据集或测试数据集 (testing data set)。测试数据集也叫测试集 (testing set),只用来评价模型的表现,并不用来训练模型。

另外我们还指定了参数 transform = transforms.ToTensor() 使所有数据转换为 Tensor,如果不进行转换则返回的是 PIL 图片。transforms.ToTensor() 将尺寸为 (H x W x C) 且数据位于 [0, 255] 的 PIL 图片或者数据类型为 np.uint8 的 NumPy 数组转换为尺寸为 (C x H x W) 且数据类型为 torch.float32 且位于 [0.0, 1.0] 的 Tensor

注意: 由于像素值为 0 到 255 的整数,所以刚好是 uint8 所能表示的范围,包括 transforms.ToTensor() 在内的一些关于图片的函数就默认输入的是 uint8 型,若不是,可能不会报错但可能得不到想要的结果。所以,如果用像素值 (0-255 整数) 表示图片数据,那么一律将其类型设置成 uint8,避免不必要的 bug。 本人就被这点坑过,详见 我的这个博客 2.2.4 节

python
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

上面的 mnist_trainmnist_test 都是 torch.utils.data.Dataset 的子类,所以我们可以用 len() 来获取该数据集的大小,还可以用下标来获取具体的一个样本。训练集中和测试集中的每个类别的图像数分别为 6,000 和 1,000。因为有 10 个类别,所以训练集和测试集的样本数分别为 60,000 和 10,000。

python
print(type(mnist_train))
print(len(mnist_train), len(mnist_test))

输出:

<class 'torchvision.datasets.mnist.FashionMNIST'>
60000 10000

我们可以通过下标来访问任意一个样本:

python
feature, label = mnist_train[0]
print(feature.shape, label)  # Channel x Height x Width

输出:

torch.Size([1, 28, 28]) tensor(9)

变量 feature 对应高和宽均为 28 像素的图像。由于我们使用了 transforms.ToTensor(),所以每个像素的数值为 [0.0, 1.0] 的 32 位浮点数。需要注意的是,feature 的尺寸是 (C x H x W) 的,而不是 (H x W x C)。第一维是通道数,因为数据集中是灰度图像,所以通道数为 1。后面两维分别是图像的高和宽。

Fashion-MNIST 中一共包括了 10 个类别,分别为 t-shirt(T 恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包) 和 ankle boot(短靴)。以下函数可以将数值标签转成相应的文本标签。

python
# 本函数已保存在 d2lzh 包中方便以后使用
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

下面定义一个可以在一行里画出多张图像和对应标签的函数。

python
# 本函数已保存在 d2lzh 包中方便以后使用
def show_fashion_mnist(images, labels):
    d2l.use_svg_display()
    # 这里的 _ 表示我们忽略 (不使用) 的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()

现在,我们看一下训练数据集中前 10 个样本的图像内容和文本标签。

python
X, y = [], []
for i in range(10):
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

3.5.2 读取小批量

我们将在训练数据集上训练模型,并将训练好的模型在测试数据集上评价模型的表现。前面说过,mnist_traintorch.utils.data.Dataset 的子类,所以我们可以将其传入 torch.utils.data.DataLoader 来创建一个读取小批量数据样本的 DataLoader 实例。

在实践中,数据读取经常是训练的性能瓶颈,特别当模型较简单或者计算硬件性能较高时。PyTorch 的 DataLoader 中一个很方便的功能是允许使用多进程来加速数据读取。这里我们通过参数 num_workers 来设置 4 个进程读取数据。

python
batch_size = 256
if sys.platform.startswith('win'):
    num_workers = 0  # 0 表示不用额外的进程来加速读取数据
else:
    num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

我们将获取并读取 Fashion-MNIST 数据集的逻辑封装在 d2lzh_pytorch.load_data_fashion_mnist 函数中供后面章节调用。该函数将返回 train_itertest_iter 两个变量。随着本书内容的不断深入,我们会进一步改进该函数。它的完整实现将在 5.6 节中描述。

最后我们查看读取一遍训练数据需要的时间。

python
start = time.time()
for X, y in train_iter:
    continue
print('%.2f sec' % (time.time() - start))

输出:

1.57 sec

小结

  • Fashion-MNIST 是一个 10 类服饰分类数据集,之后章节里将使用它来检验不同算法的表现。
  • 我们将高和宽分别为 hhww 像素的图像的形状记为 h×wh \times w(h,w)

参考文献

[1] LeCun, Y., Cortes, C., & Burges, C. http://yann.lecun.com/exdb/mnist/

[2] Xiao, H., Rasul, K., & Vollgraf, R. (2017). Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747.


注:本节除了代码之外与原书基本相同,原书传送门