微调TrOCR,训练TrOCR能识别弯曲和模糊文本


作者:Sovit Rath

编译:ronghuaiyang

导读

本文对TrOCR模型在弯曲和模糊文本数据集上进行了微调,并分析了每一步的代码和训练结果。


TrOCR (Transformer based Optical Character Recognition)模型是最好的 OCR 模型之一,在之前的文章中,我们分析了这个模型在单行打印文本和手写文本上的效果有多好。但是,和其他深度学习模型一样,它也有局限性。它在弯曲文本上的效果不好。本文会对 TrOCR 在弯曲文本数据集上进行微调。

我们知道,之前的 TrOCR 无法识别弯曲文本和竖向文本图像。这些图像在 SCUT-CTW1500 中有,我们会在这个数据集上训练 TrOCR 并分析结果,我们会知道 TrOCR 模型在不同的用例上的边界是哪里。

我们使用 Hugging Face 的训练 API 来训练模型,为了完成整个操作,我们需要安装下面的步骤进行:

  • 准备和分析弯曲文本图像数据集

  • 加载 Hugging Face 的 TrOCR 小打印文本模型

  • 初始化 Hugging Face 的 Sequence to Sequence 训练 API

  • 定义评估度量

  • 训练模型并跑推理

弯曲文本数据集

SCUT-CTW1500 数据集(后面称为 CTW1500)包含了几千张现实场景中的弯曲的文本图像。

原始数据集在官方仓库中有:https://github.com/Yuliang-Liu/Curve-Text-Detector,包含训练集和测试集,我们把训练集划分为训练集和验证集。

最终的数据集包含 6052 个训练样本和 1651 个验证样本。图像的标签存在一个文本文件中,数据集中的图像和标签如下:

从上图中可以明显看出一些事情。除了弯曲的文本图像外,数据集还包含模糊和朦胧的图像。这种真实世界的图像变化给深度学习模型带来了挑战。了解如此多样化的数据集中图像和文本的特征对于任何 OCR 模型的最新性能至关重要。这给 TrOCR 模型带来了一个有趣的挑战,当然,经过训练,它会在此类图像上表现得更好。

在弯曲文本上进行 TrOCR 的微调

让我们跳到本文的技术方面。从这里开始,我们将详细讨论 TrOCR 训练过程的代码。

安装和导入需要的库

第一步是安装所有需要的库。

!pip install -q transformers
!pip install -q sentencepiece
!pip install -q jiwer
!pip install -q datasets
!pip install -q evaluate
!pip install -q -U accelerate


!pip install -q matplotlib
!pip install -q protobuf==3.20.1
!pip install -q tensorboar

在这些里面,有一些十分非常重要的:

  • transformers: 这是 Hugging Face 的 transformers 库,我们可以通过它获取几百个基于 transformer 的模型,包括 TrOCR 模型。

  • sentencepiece: 这是 sentencepiece tokenizer 库,可以将词转化为 tokens 和数字,这也是 Hugging Face 的一部分。

  • jiwer: jiwer 库中有一些语音和语言识别的度量,包括 WER (Word Error Rate) 和**CER (Character Error Rate)**。我们会使用 CER 度量来评估模型训练结果。

然后,我们导入需要的库和包。

import os
import os
import torch
import evaluate
import numpy as np
import pandas as pd
import glob as glob
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms


from PIL import Image
from zipfile import ZipFile
from tqdm.notebook import tqdm
from dataclasses import dataclass
from torch.utils.data import Dataset
from urllib.request import urlretrieve
from transformers import (
    VisionEncoderDecoderModel,
    TrOCRProcessor,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    default_data_collator
)

上面的代码中,有一些重要的导入语句:

  • VisionEncoderDecoderModel: 我们需要这个类来定义不同的 TrOCR 模型。

  • TrOCRProcessor: TrOCR 需要将数据集进行特定的归一化处理,这个类会对图像进行合适的归一化和预处理。

  • Seq2SeqTrainer: 这个用来初始化训练 API。

  • Seq2SeqTrainingArguments: 在训练时,训练 API 需要一些参数, Seq2SeqTrainingArguments 类会初始化所有需要的参数,然后传给 API。

  • transforms: Torchvision transforms 模块用来对图像进行数据增强。

现在,设置好随机数种子,定义计算设备。

def seed_everything(seed_value):
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

下载和解压数据集

下面的代码包含了下载和解压 CTW1500 的帮助函数。

def download_and_unzip(url, save_path):
    print(f"Downloading and extracting assets....", end="")


    # Downloading zip file using urllib package.
    urlretrieve(url, save_path)


    try:
        # Extracting zip file using the zipfile package.
        with ZipFile(save_path) as z:
            # Extract ZIP file contents in the same directory.
            z.extractall(os.path.split(save_path)[0])


        print("Done")


    except Exception as e:
        print("\nInvalid file.", e)


URL = r"https://www.dropbox.com/scl/fi/vyvr7jbdvu8o174mbqgde/scut_data.zip?rlkey=fs8axkpxunwu6if9a2su71kxs&dl=1"
asset_zip_path = os.path.join(os.getcwd(), "scut_data.zip")

# Download if asset ZIP does not exist.
if not os.path.exists(asset_zip_path):
    download_and_unzip(URL, asset_zip_path)

解压之后,数据集的目录结构是这样的。

scut_data/
├── scut_train
├── scut_test
├── scut_train.txt
└── scut_test.txt

数据在scut_data文件夹中,包含了scut_trainscut_test两个子目录。

两个文本文件包含了标注信息,格式为:

006052.jpg  ty Starts with Education
006053.jpg Cardi's
006054.jpg YOU THE BUSINESS SIDE OF GREEN
006055.jpg hat is
...

每行包括了图像的文件名以及文本信息,用空格分开。文本和图像用第一个空格分开,文件名不能包含空格。

定义配置

在开始训练之前,需要定义训练,数据集和模型的一些配置。

@dataclass(frozen=True)
class TrainingConfig:
    BATCH_SIZE:    int = 48
    EPOCHS:        int = 35
    LEARNING_RATE: float = 0.00005

@dataclass(frozen=True)
class DatasetConfig:
    DATA_ROOT:     str = 'scut_data'

@dataclass(frozen=True)
class ModelConfig:
    MODEL_NAME: str = 'microsoft/trocr-small-printed'

模型会训练 35 个 epochs,训练时使用 batchsize 为 48,学习率为 0.00005,太高的学习率会导致训练不稳定,会从一开始就 loss 变得很大。

我们还定义了数据集的根目录路径,定义了需要微调的模型为 TrOCR 的小打印文本模型。

可视化几个样本

我们可视化几个数据集中的样本:

def visualize(dataset_path):
    plt.figure(figsize=(153))
    for i in range(15):
        plt.subplot(35, i+1)
        all_images = os.listdir(f"{dataset_path}/scut_train")
        image = plt.imread(f"{dataset_path}/scut_train/{all_images[i]}")
        plt.imshow(image)
        plt.axis('off')
        plt.title(all_images[i].split('.')[0])
    plt.show()


visualize(DatasetConfig.DATA_ROOT)

准备数据集

标签在文本文件中,我们将训练和测试文本转化为Pandas的DataFrame的格式,这样更加方便加载。

train_df = pd.read_fwf(
    os.path.join(DatasetConfig.DATA_ROOT, 'scut_train.txt'), header=None
)
train_df.rename(columns={0'file_name'1'text'}, inplace=True)
test_df = pd.read_fwf(
    os.path.join(DatasetConfig.DATA_ROOT, 'scut_test.txt'), header=None
)
test_df.rename(columns={0'file_name'1'text'}, inplace=True)

现在,file_name列包含了文件名,text列包含了图像对应的文本。

下一步是定义数据增强。

# Augmentations.
train_transforms = transforms.Compose([
    transforms.ColorJitter(brightness=.5, hue=.3),
    transforms.GaussianBlur(kernel_size=(59), sigma=(0.15)),
])

我们对图像应用了 ColorJitterGaussianBlur,不需要进行旋转和翻转的操作,因为原始数据集里已经有了足够的多样性。

准备数据集的最好的方法是写一个自定义的数据集类,这样就可以更好的控制输入,下面的代码定义了CustomOCRDataset类,用来准备数据集。

class CustomOCRDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length


    def __len__(self):
        return len(self.df)


    def __getitem__(self, idx):
        # The image file name.
        file_name = self.df['file_name'][idx]
        # The text (label).
        text = self.df['text'][idx]
        # Read the image, apply augmentations, and get the transformed pixels.
        image = Image.open(self.root_dir + file_name).convert('RGB')
        image = train_transforms(image)
        pixel_values = self.processor(image, return_tensors='pt').pixel_values
        # Pass the text through the tokenizer and get the labels,
        # i.e. tokenized labels.
        labels = self.processor.tokenizer(
            text,
            padding='max_length',
            max_length=self.max_target_length
        ).input_ids
        # We are using -100 as the padding token.
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

__init__() 方法接收根文件夹的路径,DataFrame,TrOCR processor以及标签的最大长度作为参数。

__getitem__() 方法首先读取图像和标签,然后进行数据增强,然后经过TrOCRProcessor返回Pytorch的tensor格式的归一化的像素值,然后,文本标签经过tokenizer,如果标签比128个字符短,使用-100进行不全,如果比128长,需要被截断。最后,以字典的形式返回像素值和标签。

在生成验证集之前,需要初始化TrOCRProcessor,这样,TrOCRProcessor就可以传入dataset类中。

processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
train_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'scut_train/'),
    df=train_df,
    processor=processor
)
valid_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT, 'scut_test/'),
    df=test_df,
    processor=processor
)

上面的代码包含了数据集的准备操作。

准备TrOCR Small打印文本模型

VisionEncoderDecoderModel 类可以访问所有的TrOCR模型,from_pretrained()方法接受仓库的名称作为参数,然后加载预训练模型。

model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
model.to(device)
print(model)
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

这个模型包含61.5百万参数,在所有的参数上进行微调。

模型准备的一个最重要的部分是模型的参数配置,配置如下:

# Set special tokens used for creating the decoder_input_ids from the labels.
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# Set Correct vocab size.
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id


model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

预训练的TrOCR模型有它自己的配置信息,但是,要进行微调的话,需要修改一些参数,包括token IDsvocabulary size还有End of Sequence token

还有,early stopping设置为True,这样确保当度量指标停止提升后的几个epochs,训练就可以停止。

优化度量指标

我们使用AdamW优化器,权值衰减参数为0.0005。

optimizer = optim.AdamW(
    model.parameters(), lr=TrainingConfig.LEARNING_RATE, weight_decay=0.0005
)

度量指标为CER (Character Error Rate)。

cer_metric = evaluate.load('cer')


def compute_cer(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions


    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)


    cer = cer_metric.compute(predictions=pred_str, references=label_str)


    return {"cer": cer}

CER就是模型没有预测正确的字符的数量,CER越低,模型越好。

TrOCR的训练和验证

在训练之前,需要初始化训练参数。

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy='epoch',
    per_device_train_batch_size=TrainingConfig.BATCH_SIZE,
    per_device_eval_batch_size=TrainingConfig.BATCH_SIZE,
    fp16=True,
    output_dir='seq2seq_model_printed/',
    logging_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=5,
    report_to='tensorboard',
    num_train_epochs=TrainingConfig.EPOCHS
)

我们训练时使用了FP16,这样可以使用更少的GPU内存,我们可以使用更大的batchsize,使用tensorboard进行日志的报告。

训练参数和其他需要的参数一起送到训练API里面。

# Initialize trainer.
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_cer,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=default_data_collator
)

使用train()方法来调用训练过程。

res = trainer.train()
Epoch Training Loss Validation Loss Cer
1 3.822000 2.677871 0.687739
2 2.497100 2.474666 0.690800
3 2.180700 2.336284 0.627641
.
.
.
33 0.146800 2.130022 0.504209
34 0.145800 2.167060 0.511095
35 0.138300 2.120335 0.494496

训练结束之后,模型的CER为49%,由于使用了small TrOCR,这时一个非常不错的结果。

下面是Tensorboard上训练过程中CER的图:

该曲线在训练过程中,总体趋势是下降的,尽管训练更长的时间会有更好的结果,我们还是先使用我们现有的模型试试。

使用微调过的TrOCR模型进行推理

有了训练好的模型,我们可以在验证数据上跑推理了。

第一步是加载训练好的最新保存的模型。

processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step)).to(device)

下面是几个帮助函数,首先是读取一张图像。

def read_and_show(image_path):
    """
    :param image_path: String, path to the input image.


    Returns:
        image: PIL Image.
    """

    image = Image.open(image_path).convert('RGB')
    return image

下面的函数对图像进行模型的前向传播。

def ocr(image, processor, model):
    """
    :param image: PIL Image.
    :param processor: Huggingface OCR processor.
    :param model: Huggingface OCR model.


    Returns:
        generated_text: the OCR'd text string.
    """

    # We can directly perform OCR on cropped images.
    pixel_values = processor(image, return_tensors='pt').pixel_values.to(device)
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

最后一个函数对所有的图像进行循环推理。

def eval_new_data(
    data_path=os.path.join(DatasetConfig.DATA_ROOT, 'scut_test''*'),
    num_samples=50
)
:

    image_paths = glob.glob(data_path)
    for i, image_path in tqdm(enumerate(image_paths), total=len(image_paths)):
        if i == num_samples:
            break
        image = read_and_show(image_path)
        text = ocr(image, processor, trained_model)
        plt.figure(figsize=(74))
        plt.imshow(image)
        plt.title(text)
        plt.axis('off')
        plt.show()

eval_new_data(
    data_path=os.path.join(DatasetConfig.DATA_ROOT, 'scut_test''*'),
    num_samples=100
)

我们在100个样本上进行了推理。

下面2张图是训练之前识别错的,一张是弯曲文本,一张是竖向文本:

在模型微调之后,可以正确的识别出结果了。在这个例子中,尽管文本非常的扭曲,还是可以正确的识别。

在上面的3个例子中,模型对于非常模糊的图也能预测正确。

总结

在本文中,我们介绍了在弯曲文本识别数据集上对 TrOCR 模型的微调。我们从数据集讨论开始。接下来是数据集准备和 TrOCR 模型的训练。训练后,我们进行了推理实验并分析了结果。我们的结果表明,微调 TrOCR 模型可以带来更好的性能,即使在模糊或弯曲的文本图像上也是如此。

OCR 不仅仅是识别场景中的文本,它还涉及使用 OCR 构建应用程序,例如验证码识别器或将 TrOCR 识别器与车牌检测pipeline相结合。


END

英文原文:https://learnopencv.com/fine-tuning-trocr-training-trocr-to-recognize-curved-text/

相关推荐

  • C++库文件和头文件编写教程
  • 10节课+200篇论文!实战深度学习热门领域
  • 直播来袭 | 微盟技术沙龙-数字化时代下的SaaS SCRM系统实战
  • 国美APP抽奖弹窗辱骂创始人;小米14系列或搭载MIOS;知名开发者遭微软MVP项目组除名;DHH锐评:前端根本不需要构建
  • 我用过很多代码生成器,还是选了他
  • 2023 年 Serverless 状态报告发布:采用率大幅增长
  • 代码生成:基于 AI 大模型的挑战与前景
  • 创新风潮迭起,2023深圳国际金融科技大赛——西丽湖金融科技大学生挑战赛正式启动
  • 谷歌如何释放和衡量开发人员的生产力
  • 大模型时代下的技术变革:训练、负载、部署、效率、安全……都遇到了新挑战?
  • 20 个最频繁使用的 Python 代码片段
  • GPT-4肆虐「谁是卧底」桌游!交谈逼真,类人属性仍有发展空间
  • 碾压GPT-4,微软最强AutoGen爆火!多个智能体协作,编码速度飙升4倍,GitHub狂揽10k星
  • 百度谷歌成为AI黄埔军校,Transformer八子融资超8.7亿刀!「AI行业全景报告」总结GenAI大爆发
  • 再阉割H800?美商务部新政加强限制GPU出口,预计本周公布
  • 英伟达爆火智能体研究:AI逼真还原人类情感!会饿会孤独,会跑步会发火
  • YOLO再升级!华为诺亚提出Gold-YOLO,聚集-分发机制打造新SOTA
  • 靠发AIGC论文拿了100万年薪!不是靠努力和勤奋,而是......
  • 数据科学的业务价值转化秘籍
  • AB实验的关键变量