导读作者:Sovit Rath
编译:ronghuaiyang
本文对TrOCR模型在弯曲和模糊文本数据集上进行了微调,并分析了每一步的代码和训练结果。
TrOCR (Transformer based Optical Character Recognition)模型是最好的 OCR 模型之一,在之前的文章中,我们分析了这个模型在单行打印文本和手写文本上的效果有多好。但是,和其他深度学习模型一样,它也有局限性。它在弯曲文本上的效果不好。本文会对 TrOCR 在弯曲文本数据集上进行微调。
我们知道,之前的 TrOCR 无法识别弯曲文本和竖向文本图像。这些图像在 SCUT-CTW1500 中有,我们会在这个数据集上训练 TrOCR 并分析结果,我们会知道 TrOCR 模型在不同的用例上的边界是哪里。
我们使用 Hugging Face 的训练 API 来训练模型,为了完成整个操作,我们需要安装下面的步骤进行:
准备和分析弯曲文本图像数据集
加载 Hugging Face 的 TrOCR 小打印文本模型
初始化 Hugging Face 的 Sequence to Sequence 训练 API
定义评估度量
训练模型并跑推理
SCUT-CTW1500 数据集(后面称为 CTW1500)包含了几千张现实场景中的弯曲的文本图像。
原始数据集在官方仓库中有:https://github.com/Yuliang-Liu/Curve-Text-Detector,包含训练集和测试集,我们把训练集划分为训练集和验证集。
最终的数据集包含 6052 个训练样本和 1651 个验证样本。图像的标签存在一个文本文件中,数据集中的图像和标签如下:
从上图中可以明显看出一些事情。除了弯曲的文本图像外,数据集还包含模糊和朦胧的图像。这种真实世界的图像变化给深度学习模型带来了挑战。了解如此多样化的数据集中图像和文本的特征对于任何 OCR 模型的最新性能至关重要。这给 TrOCR 模型带来了一个有趣的挑战,当然,经过训练,它会在此类图像上表现得更好。
让我们跳到本文的技术方面。从这里开始,我们将详细讨论 TrOCR 训练过程的代码。
第一步是安装所有需要的库。
!pip install -q transformers
!pip install -q sentencepiece
!pip install -q jiwer
!pip install -q datasets
!pip install -q evaluate
!pip install -q -U accelerate
!pip install -q matplotlib
!pip install -q protobuf==3.20.1
!pip install -q tensorboar
在这些里面,有一些十分非常重要的:
transformers
: 这是 Hugging Face 的 transformers
库,我们可以通过它获取几百个基于 transformer 的模型,包括 TrOCR 模型。
sentencepiece
: 这是 sentencepiece
tokenizer 库,可以将词转化为 tokens 和数字,这也是 Hugging Face 的一部分。
jiwer
: jiwer
库中有一些语音和语言识别的度量,包括 WER (Word Error Rate) 和**CER (Character Error Rate)**。我们会使用 CER 度量来评估模型训练结果。
然后,我们导入需要的库和包。
import os
import os
import torch
import evaluate
import numpy as np
import pandas as pd
import glob as glob
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from PIL import Image
from zipfile import ZipFile
from tqdm.notebook import tqdm
from dataclasses import dataclass
from torch.utils.data import Dataset
from urllib.request import urlretrieve
from transformers import (
VisionEncoderDecoderModel,
TrOCRProcessor,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator
)
上面的代码中,有一些重要的导入语句:
VisionEncoderDecoderModel
: 我们需要这个类来定义不同的 TrOCR 模型。
TrOCRProcessor
: TrOCR 需要将数据集进行特定的归一化处理,这个类会对图像进行合适的归一化和预处理。
Seq2SeqTrainer
: 这个用来初始化训练 API。
Seq2SeqTrainingArguments
: 在训练时,训练 API 需要一些参数, Seq2SeqTrainingArguments
类会初始化所有需要的参数,然后传给 API。
transforms
: Torchvision transforms
模块用来对图像进行数据增强。
现在,设置好随机数种子,定义计算设备。
def seed_everything(seed_value):
np.random.seed(seed_value)
torch.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
下面的代码包含了下载和解压 CTW1500 的帮助函数。
def download_and_unzip(url, save_path):
print(f"Downloading and extracting assets....", end="")
# Downloading zip file using urllib package.
urlretrieve(url, save_path)
try:
# Extracting zip file using the zipfile package.
with ZipFile(save_path) as z:
# Extract ZIP file contents in the same directory.
z.extractall(os.path.split(save_path)[0])
print("Done")
except Exception as e:
print("\nInvalid file.", e)
URL = r"https://www.dropbox.com/scl/fi/vyvr7jbdvu8o174mbqgde/scut_data.zip?rlkey=fs8axkpxunwu6if9a2su71kxs&dl=1"
asset_zip_path = os.path.join(os.getcwd(), "scut_data.zip")
# Download if asset ZIP does not exist.
if not os.path.exists(asset_zip_path):
download_and_unzip(URL, asset_zip_path)
解压之后,数据集的目录结构是这样的。
scut_data/
├── scut_train
├── scut_test
├── scut_train.txt
└── scut_test.txt
数据在scut_data
文件夹中,包含了scut_train
和scut_test
两个子目录。
两个文本文件包含了标注信息,格式为:
006052.jpg ty Starts with Education
006053.jpg Cardi's
006054.jpg YOU THE BUSINESS SIDE OF GREEN
006055.jpg hat is
...
每行包括了图像的文件名以及文本信息,用空格分开。文本和图像用第一个空格分开,文件名不能包含空格。
在开始训练之前,需要定义训练,数据集和模型的一些配置。
@dataclass(frozen=True)
class TrainingConfig:
BATCH_SIZE: int = 48
EPOCHS: int = 35
LEARNING_RATE: float = 0.00005
@dataclass(frozen=True)
class DatasetConfig:
DATA_ROOT: str = 'scut_data'
@dataclass(frozen=True)
class ModelConfig:
MODEL_NAME: str = 'microsoft/trocr-small-printed'
模型会训练 35 个 epochs,训练时使用 batchsize 为 48,学习率为 0.00005,太高的学习率会导致训练不稳定,会从一开始就 loss 变得很大。
我们还定义了数据集的根目录路径,定义了需要微调的模型为 TrOCR 的小打印文本模型。
我们可视化几个数据集中的样本:
def visualize(dataset_path):
plt.figure(figsize=(15, 3))
for i in range(15):
plt.subplot(3, 5, i+1)
all_images = os.listdir(f"{dataset_path}/scut_train")
image = plt.imread(f"{dataset_path}/scut_train/{all_images[i]}")
plt.imshow(image)
plt.axis('off')
plt.title(all_images[i].split('.')[0])
plt.show()
visualize(DatasetConfig.DATA_ROOT)
标签在文本文件中,我们将训练和测试文本转化为Pandas的DataFrame的格式,这样更加方便加载。
train_df = pd.read_fwf(
os.path.join(DatasetConfig.DATA_ROOT, 'scut_train.txt'), header=None
)
train_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True)
test_df = pd.read_fwf(
os.path.join(DatasetConfig.DATA_ROOT, 'scut_test.txt'), header=None
)
test_df.rename(columns={0: 'file_name', 1: 'text'}, inplace=True)
现在,file_name列包含了文件名,text列包含了图像对应的文本。
下一步是定义数据增强。
# Augmentations.
train_transforms = transforms.Compose([
transforms.ColorJitter(brightness=.5, hue=.3),
transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
])
我们对图像应用了 ColorJitter
和GaussianBlur
,不需要进行旋转和翻转的操作,因为原始数据集里已经有了足够的多样性。
准备数据集的最好的方法是写一个自定义的数据集类,这样就可以更好的控制输入,下面的代码定义了CustomOCRDataset
类,用来准备数据集。
class CustomOCRDataset(Dataset):
def __init__(self, root_dir, df, processor, max_target_length=128):
self.root_dir = root_dir
self.df = df
self.processor = processor
self.max_target_length = max_target_length
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# The image file name.
file_name = self.df['file_name'][idx]
# The text (label).
text = self.df['text'][idx]
# Read the image, apply augmentations, and get the transformed pixels.
image = Image.open(self.root_dir + file_name).convert('RGB')
image = train_transforms(image)
pixel_values = self.processor(image, return_tensors='pt').pixel_values
# Pass the text through the tokenizer and get the labels,
# i.e. tokenized labels.
labels = self.processor.tokenizer(
text,
padding='max_length',
max_length=self.max_target_length
).input_ids
# We are using -100 as the padding token.
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
return encoding
__init__()
方法接收根文件夹的路径,DataFrame,TrOCR processor以及标签的最大长度作为参数。
__getitem__()
方法首先读取图像和标签,然后进行数据增强,然后经过TrOCRProcessor返回Pytorch的tensor格式的归一化的像素值,然后,文本标签经过tokenizer,如果标签比128个字符短,使用-100进行不全,如果比128长,需要被截断。最后,以字典的形式返回像素值和标签。
在生成验证集之前,需要初始化TrOCRProcessor,这样,TrOCRProcessor就可以传入dataset类中。
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
train_dataset = CustomOCRDataset(
root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'scut_train/'),
df=train_df,
processor=processor
)
valid_dataset = CustomOCRDataset(
root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'scut_test/'),
df=test_df,
processor=processor
)
上面的代码包含了数据集的准备操作。
VisionEncoderDecoderModel
类可以访问所有的TrOCR模型,from_pretrained()
方法接受仓库的名称作为参数,然后加载预训练模型。
model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
model.to(device)
print(model)
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")
这个模型包含61.5百万参数,在所有的参数上进行微调。
模型准备的一个最重要的部分是模型的参数配置,配置如下:
# Set special tokens used for creating the decoder_input_ids from the labels.
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# Set Correct vocab size.
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
预训练的TrOCR模型有它自己的配置信息,但是,要进行微调的话,需要修改一些参数,包括token IDs, vocabulary size还有End of Sequence token。
还有,early stopping设置为True
,这样确保当度量指标停止提升后的几个epochs,训练就可以停止。
我们使用AdamW优化器,权值衰减参数为0.0005。
optimizer = optim.AdamW(
model.parameters(), lr=TrainingConfig.LEARNING_RATE, weight_decay=0.0005
)
度量指标为CER (Character Error Rate)。
cer_metric = evaluate.load('cer')
def compute_cer(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
cer = cer_metric.compute(predictions=pred_str, references=label_str)
return {"cer": cer}
CER就是模型没有预测正确的字符的数量,CER越低,模型越好。
在训练之前,需要初始化训练参数。
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
evaluation_strategy='epoch',
per_device_train_batch_size=TrainingConfig.BATCH_SIZE,
per_device_eval_batch_size=TrainingConfig.BATCH_SIZE,
fp16=True,
output_dir='seq2seq_model_printed/',
logging_strategy='epoch',
save_strategy='epoch',
save_total_limit=5,
report_to='tensorboard',
num_train_epochs=TrainingConfig.EPOCHS
)
我们训练时使用了FP16,这样可以使用更少的GPU内存,我们可以使用更大的batchsize,使用tensorboard进行日志的报告。
训练参数和其他需要的参数一起送到训练API里面。
# Initialize trainer.
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor.feature_extractor,
args=training_args,
compute_metrics=compute_cer,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
data_collator=default_data_collator
)
使用train()方法来调用训练过程。
res = trainer.train()
Epoch Training Loss Validation Loss Cer
1 3.822000 2.677871 0.687739
2 2.497100 2.474666 0.690800
3 2.180700 2.336284 0.627641
.
.
.
33 0.146800 2.130022 0.504209
34 0.145800 2.167060 0.511095
35 0.138300 2.120335 0.494496
训练结束之后,模型的CER为49%,由于使用了small TrOCR,这时一个非常不错的结果。
下面是Tensorboard上训练过程中CER的图:
该曲线在训练过程中,总体趋势是下降的,尽管训练更长的时间会有更好的结果,我们还是先使用我们现有的模型试试。
有了训练好的模型,我们可以在验证数据上跑推理了。
第一步是加载训练好的最新保存的模型。
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step)).to(device)
下面是几个帮助函数,首先是读取一张图像。
def read_and_show(image_path):
"""
:param image_path: String, path to the input image.
Returns:
image: PIL Image.
"""
image = Image.open(image_path).convert('RGB')
return image
下面的函数对图像进行模型的前向传播。
def ocr(image, processor, model):
"""
:param image: PIL Image.
:param processor: Huggingface OCR processor.
:param model: Huggingface OCR model.
Returns:
generated_text: the OCR'd text string.
"""
# We can directly perform OCR on cropped images.
pixel_values = processor(image, return_tensors='pt').pixel_values.to(device)
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
最后一个函数对所有的图像进行循环推理。
def eval_new_data(
data_path=os.path.join(DatasetConfig.DATA_ROOT, 'scut_test', '*'),
num_samples=50
):
image_paths = glob.glob(data_path)
for i, image_path in tqdm(enumerate(image_paths), total=len(image_paths)):
if i == num_samples:
break
image = read_and_show(image_path)
text = ocr(image, processor, trained_model)
plt.figure(figsize=(7, 4))
plt.imshow(image)
plt.title(text)
plt.axis('off')
plt.show()
eval_new_data(
data_path=os.path.join(DatasetConfig.DATA_ROOT, 'scut_test', '*'),
num_samples=100
)
我们在100个样本上进行了推理。
下面2张图是训练之前识别错的,一张是弯曲文本,一张是竖向文本:
在模型微调之后,可以正确的识别出结果了。在这个例子中,尽管文本非常的扭曲,还是可以正确的识别。
在上面的3个例子中,模型对于非常模糊的图也能预测正确。
在本文中,我们介绍了在弯曲文本识别数据集上对 TrOCR 模型的微调。我们从数据集讨论开始。接下来是数据集准备和 TrOCR 模型的训练。训练后,我们进行了推理实验并分析了结果。我们的结果表明,微调 TrOCR 模型可以带来更好的性能,即使在模糊或弯曲的文本图像上也是如此。
OCR 不仅仅是识别场景中的文本,它还涉及使用 OCR 构建应用程序,例如验证码识别器或将 TrOCR 识别器与车牌检测pipeline相结合。
英文原文:https://learnopencv.com/fine-tuning-trocr-training-trocr-to-recognize-curved-text/