学习 Keras 3

Keras 是一个用 Python 编写的深度学习框架,能够运行在 JAX、TensorFlow 或 PyTorch 之上,即支持多后端。它具有以下主要特点:

作为一个支持多后端的深度学习框架,Keras 具有众多优势,这也是我们选择学习并将其作为深度学习算法实践框架的原因。

从今天开始,我们将系统学习 Keras 3,这是自 2023 年 11 月发布的一个大版本更新。我们的演示代码大多以 TensorFlow 作为运行后端,可以直接通过安装 TensorFlow 的方式安装 Keras,并以tf.keras的方式使用 Keras。我们先来看看演示代码运行的版本环境。

In [1]:
import platform
In [2]:
print(f"Python 的版本:{platform.python_version()}")
Python 的版本:3.12.8
In [3]:
import tensorflow as tf
2025-01-06 00:22:44.461122: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
In [4]:
print(f"TensorFlow 的版本:{tf.__version__}")
print(f"Keras 的版本:{tf.keras.__version__}")
TensorFlow 的版本:2.18.0
Keras 的版本:3.7.0

初次上手

Keras 最核心的数据结构是 (Layer) 和模型 (Model)。最简单的模型是顺序模型 (Sequential model),即层顺序线性堆叠而成的模型。对于更复杂的模型结构,可以使用 Keras 的函数式 API,它允许构建任意的模型结构。或者通过子类继承的方式,从头编写模型。这里也就罗列出了使用 Keras 的常见三种方式,本节先以简单的顺序模型作为演示。

In [5]:
model = tf.keras.Sequential()

以上代码初始化了一个顺序模型,接下来,通过 .add() 成员方法,就可以顺序堆叠层。

In [6]:
model.add(tf.keras.layers.Input(shape=(784,)))
model.add(tf.keras.layers.Dense(units=64, activation="relu"))
model.add(tf.keras.layers.Dense(units=10, activation="softmax"))

模型构建完成后,可以通过 .compile() 方法来配置学习过程。

In [7]:
model.compile(loss="categorical_crossentropy", optimizer="sgd", metrics=["accuracy"])

接下来,使用 .summary() 方法输出模型的概要信息,并通过 plot_model() 工具方法画出模型的结构图。

In [8]:
model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param# ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense (Dense)                   │ (None, 64)             │        50,240 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_1 (Dense)                 │ (None, 10)             │           650 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 50,890 (198.79 KB)
 Trainable params: 50,890 (198.79 KB)
 Non-trainable params: 0 (0.00 B)
In [9]:
tf.keras.utils.plot_model(
    model,
    show_shapes=True,
    show_dtype=True,
    show_layer_names=True,
    show_layer_activations=True,
    show_trainable=True,
)
Out[9]:
No description has been provided for this image

这个模型共有 50890 个待训练的 float32 参数,所需存储空间约为 KB。接下来,你可以通过.fit()方法在训练集上训练模型,然后通过.evaluate()方法在测试集上评估模型效果,最后通过.predict()方法在新数据 (验证集) 上生成预测结果。这部分的演示将在下一节进行。

重构优化

本节最后,我们演示一种更简单的构建顺序模型的方法。

In [10]:
from tensorflow.keras import (
    Sequential,
    activations,
    layers,
    losses,
    metrics,
    optimizers,
    utils,
)
In [11]:
model = Sequential(
    [
        layers.Input(shape=(784,)),
        layers.Dense(units=64, activation=activations.relu),
        layers.Dense(units=10, activation=activations.softmax),
    ]
)

如果需要对学习过程进行进一步配置,可以参照以下代码方式。总之,Keras 的理念是让简单的事情保持简单,同时允许用户在需要时进行完全控制。(终极控制方式是使用子类继承,使整个过程源码化。)

In [12]:
model.compile(
    loss=losses.sparse_categorical_crossentropy,
    optimizer=optimizers.SGD(learning_rate=0.01, momentum=0.9, nesterov=True),
    metrics=[metrics.SparseCategoricalAccuracy()],
)
In [13]:
model.summary()
Model: "sequential_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param# ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense_2 (Dense)                 │ (None, 64)             │        50,240 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_3 (Dense)                 │ (None, 10)             │           650 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 50,890 (198.79 KB)
 Trainable params: 50,890 (198.79 KB)
 Non-trainable params: 0 (0.00 B)
In [14]:
utils.plot_model(
    model,
    show_shapes=True,
    show_dtype=True,
    show_layer_names=True,
    show_layer_activations=True,
    show_trainable=True,
)
Out[14]:
No description has been provided for this image

相关推荐