大家好。本公众号覆盖数据分析、可视化、大模型等多个方向实用技巧分享,欢迎大家关注,具体可查看下面文章:
如果对你有帮助,还请关注下面公众号点赞关注转发,你的支持就是我创作的最大动力~
上次给大家分享深度学习-工具篇: 万字长文-PyTorch建模常用代码实战指南;今天给大家分享一个好用的库torchkeras, 大家都知道torch在模型训练和推理需要写大量的繁琐的Trainer代码,该作者仿照keras的风格来对模型训练和推理二次封装,简单易用。
该库作者同时提供大量的notebook案例覆盖强化学习、CV、NLP、大模型等多个任务,该库实操起来非常方便。下面进入我们今天的主题~
背景介绍
torchkeras库介绍
使用教程篇:利用torchkeras来轻松训练个人深度学习任务
安装导入库
定义torch网络模型
打印网络模型的结构参数
模型训练:可视化每轮训练结果
模型训练: 输出每轮训练和推理明细结果
模型验证:对模型性能进行验证
应用篇:利用torchkeras库来实现强化学习、CV、NLP、大模型等任务
强化学习任务
计算机视觉任务
NLP任务
大模型任务
参考文档
torchkeras(其github仓库:https://github.com/lyhue1991/torchkeras)是一个通用的pytorch模型训练模版工具。完全基于torch库,仿照keras的风格对torch模型训练和推理进行二次封装,具有以下特点:
下面是我对对torchkeras库的使用方法介绍,具体使用可参考其对应的github,里面提供大量的案例介绍。
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
import pytorch_lightning as pl
import torchmetrics
import accelerate
import torchkeras
print("accelerate version:", accelerate.__version__) # torchmetrics version: 1.3.2
print("torchmetrics version:", torchmetrics.__version__) # torchmetrics version: 1.3.2
print("pytorch_lightning version:", pl.__version__) # pytorch_lightning version: 2.2.1
print("torchkeras version:", torchkeras.__version__) # torchkeras version: 3.9.6
print("torch version:", torch.__version__) # torch version: 2.1.2
def create_net():
net = nn.Sequential()
net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3))
net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("conv2",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5))
net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("dropout",nn.Dropout2d(p = 0.1))
net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
net.add_module("flatten",nn.Flatten())
net.add_module("linear1",nn.Linear(64,32))
net.add_module("relu",nn.ReLU())
net.add_module("linear2",nn.Linear(32,10))
return net
net = create_net()
net = create_net()
from torchkeras import summary
summary(net,input_data=features); # features:[128, 1, 28, 28]
网络结构如下:
--------------------------------------------------------------------------
Layer (type) Output Shape Param #
==========================================================================
Conv2d-1 [-1, 32, 26, 26] 320
MaxPool2d-2 [-1, 32, 13, 13] 0
Conv2d-3 [-1, 64, 9, 9] 51,264
MaxPool2d-4 [-1, 64, 4, 4] 0
Dropout2d-5 [-1, 64, 4, 4] 0
AdaptiveMaxPool2d-6 [-1, 64, 1, 1] 0
Flatten-7 [-1, 64] 0
Linear-8 [-1, 32] 2,080
ReLU-9 [-1, 32] 0
Linear-10 [-1, 10] 330
==========================================================================
Total params: 53,994
Trainable params: 53,994
Non-trainable params: 0
--------------------------------------------------------------------------
Input size (MB): 0.000076
Forward/backward pass size (MB): 0.263016
Params size (MB): 0.205971
Estimated Total Size (MB): 0.469063
--------------------------------------------------------------------------
model = torchkeras.KerasModel(net,
loss_fn = nn.CrossEntropyLoss(),
optimizer= torch.optim.Adam(net.parameters(),lr=5e-3),
metrics_dict = {"acc":Accuracy()}
)
dfhistory=model.fit(train_data=dl_train,
val_data=dl_val,
epochs=30,
patience=3,
monitor="val_acc",
mode="max",
ckpt_path='checkpoint.pt',
plot=True,
wandb=False
)
训练效果展示dfhistory为一个dataframe的对象:
model = torchkeras.KerasModel(net,
loss_fn = nn.CrossEntropyLoss(),
optimizer= torch.optim.Adam(net.parameters(),lr=5e-3),
metrics_dict = {"acc":Accuracy()}
)
dfhistory=model.fit(train_data=dl_train,
val_data=dl_val,
epochs=30,
patience=3,
monitor="val_acc",
mode="max",
ckpt_path='checkpoint.pt',
plot=False,
wandb=False
)
输出结果:
<<<<<< ⚡️ cuda is used >>>>>>
================================================================================2024-04-21 04:02:17
Epoch 1 / 30
100%|█████████████████| 24/24 [00:00<00:00, 57.35it/s, lr=0.005, train_acc=0.971, train_loss=0.0797]
100%|██████████████████████████████████| 4/4 [00:00<00:00, 24.80it/s, val_acc=0.928, val_loss=0.264]
<<<<<< reach best val_acc : 0.9279999732971191 >>>>>>
================================================================================2024-04-21 04:02:17
Epoch 2 / 30
100%|█████████████████| 24/24 [00:00<00:00, 47.63it/s, lr=0.005, train_acc=0.976, train_loss=0.0651]
100%|████████████████████████████████████| 4/4 [00:00<00:00, 21.21it/s, val_acc=0.934, val_loss=0.2]
<<<<<< reach best val_acc : 0.9340000152587891 >>>>>>
================================================================================2024-04-21 04:02:18
Epoch 3 / 30
100%|██████████████████| 24/24 [00:00<00:00, 45.93it/s, lr=0.005, train_acc=0.98, train_loss=0.0525]
100%|███████████████████████████████████| 4/4 [00:00<00:00, 21.11it/s, val_acc=0.94, val_loss=0.229]
<<<<<< reach best val_acc : 0.9399999976158142 >>>>>>
================================================================================2024-04-21 04:02:19
Epoch 4 / 30
100%|█████████████████| 24/24 [00:00<00:00, 62.65it/s, lr=0.005, train_acc=0.989, train_loss=0.0271]
100%|██████████████████████████████████| 4/4 [00:00<00:00, 26.42it/s, val_acc=0.956, val_loss=0.202]
<<<<<< reach best val_acc : 0.9559999704360962 >>>>>>
================================================================================2024-04-21 04:02:19
Epoch 5 / 30
100%|█████████████████| 24/24 [00:00<00:00, 59.41it/s, lr=0.005, train_acc=0.986, train_loss=0.0362]
100%|██████████████████████████████████| 4/4 [00:00<00:00, 25.56it/s, val_acc=0.952, val_loss=0.202]
================================================================================2024-04-21 04:02:20
Epoch 6 / 30
100%|█████████████████| 24/24 [00:00<00:00, 60.27it/s, lr=0.005, train_acc=0.992, train_loss=0.0239]
100%|██████████████████████████████████| 4/4 [00:00<00:00, 27.19it/s, val_acc=0.948, val_loss=0.211]
================================================================================2024-04-21 04:02:20
Epoch 7 / 30
100%|█████████████████| 24/24 [00:00<00:00, 60.93it/s, lr=0.005, train_acc=0.991, train_loss=0.0253]
100%|██████████████████████████████████| 4/4 [00:00<00:00, 24.58it/s, val_acc=0.936, val_loss=0.248]
<<<<<< val_acc without improvement in 3 epoch,early stopping >>>>>>
# used the saved model parameters
net_clone = create_net()
model_clone = torchkeras.KerasModel(net_clone,loss_fn = nn.CrossEntropyLoss(),
optimizer= torch.optim.Adam(net_clone.parameters(),lr = 0.001),
metrics_dict={"acc":Accuracy()})
model_clone.net.load_state_dict(torch.load("checkpoint.pt"))
val_metrics = model_clone.evaluate(dl_val)
print(val_metrics)
100%|████████████████████████████████████| 4/4 [00:00<00:00, 24.12it/s, val_acc=0.97, val_loss=0.16]
{'val_loss': 0.16028987243771553, 'val_acc': 0.9700000286102295}
torchkeras作者还提供大量深度学习的案例和代码notebook,下面是部分notebook介绍。
今天给大家分享一个好用深度学习训练torchkeras库,非常适合torch开发用户,该库提供强大好用的功能,具体的文档详细请参考其github仓库。
如果本文对你有帮助,还请你点赞在看转发。你的支持就是我创作的最大动力,关注下面公众号不迷路~
👉往期文章精选