ocr技术系列之三大批量生成文字训练集

冠军的试炼 冠军的试炼     2022-10-14     562

关键词:

放假了,终于可以继续可以静下心写一写OCR方面的东西。上次谈到文字的切割,今天打算总结一下我们怎么得到用于训练的文字数据集。如果是想训练一个手写体识别的模型,用一些前人收集好的手写文字集就好了,比如中科院的这些数据集。但是如果我们只是想要训练一个专门用于识别印刷汉字的模型,那么我们就需要各种印刷字体的训练集,那怎么获取呢?借助强大的图像库,自己生成就行了!

先捋一捋思路,生成文字集需要什么步骤:

  1. 确定你要生成多少字体,生成一个记录着汉字与label的对应表。
  2. 确定和收集需要用到的字体文件。
  3. 生成字体图像,存储在规定的目录下。
  4. 适当的数据增强。

第三步的生成字体图像最为重要,如果仅仅是生成很正规的文字,那么用这个正规文字集去训练模型,第一图像数目有点少,第二模型泛化能力比较差,所以我们需要对字体图像做大量的图像处理工作,以增大我们的印刷体文字数据集。

我总结了一下,我们可以做的一些图像增强工作有这些:

  1. 文字扭曲
  2. 背景噪声(椒盐)
  3. 文字位置(设置文字的中心点)
  4. 笔画粘连(膨胀来模拟)
  5. 笔画断裂(腐蚀来模拟)
  6. 文字倾斜(文字旋转)
  7. 多种字体

做完以上增强后,我们得到的数据集已经非常庞大了。

现在开始一步一步生成我们的3755个汉字的印刷体文字数据集。

一、生成汉字与label的对应表

这里的汉字、label映射表的生成我使用了pickel模块,借助它生成一个id:汉字的映射文件存储下来。
这里举个小例子说明怎么生成这个“汉字:id”映射表。

首先在一个txt文件里写入你想要的汉字,如果对汉字对应的ID没有要求的话,我们不妨使用该汉字的排位作为其ID,比如“一二三四五”中,五的ID就是00005。如此类推,把汉字读入内存,建立一个字典,把这个关系记录下来,再使用pickle.dump存入文件保存。

二、收集字体文件

字体文件上网收集就好了,但是值得注意的是,不是每一种字体都支持汉字,所以我们需要筛选出真正适合汉字生成的字体文件才可以。我一共使用了十三种汉字字体作为我们接下来汉字数据集用到的字体,具体如下图:

当然,如果需要进一步扩大数据集来增强训练得到的模型的泛化能力,可以花更多的时间去收集各类汉字字体,那么模型在面对各种字体时也能从容应对,给出准确的预测。

三、文字图像生成

首先是定义好输入参数,其中包括输出目录、字体目录、测试集大小、图像尺寸、图像旋转幅度等等。

def args_parse():
    #解析输入参数
    parser = argparse.ArgumentParser(
        description=description, formatter_class=RawTextHelpFormatter)
    parser.add_argument('--out_dir', dest='out_dir',
                        default=None, required=True,
                        help='write a caffe dir')
    parser.add_argument('--font_dir', dest='font_dir',
                        default=None, required=True,
                        help='font dir to to produce images')
    parser.add_argument('--test_ratio', dest='test_ratio',
                        default=0.2, required=False,
                        help='test dataset size')
    parser.add_argument('--width', dest='width',
                        default=None, required=True,
                        help='width')
    parser.add_argument('--height', dest='height',
                        default=None, required=True,
                        help='height')
    parser.add_argument('--no_crop', dest='no_crop',
                        default=True, required=False,
                        help='', action='store_true')
    parser.add_argument('--margin', dest='margin',
                        default=0, required=False,
                        help='', )
    parser.add_argument('--rotate', dest='rotate',
                        default=0, required=False,
                        help='max rotate degree 0-45')
    parser.add_argument('--rotate_step', dest='rotate_step',
                        default=0, required=False,
                        help='rotate step for the rotate angle')
    parser.add_argument('--need_aug', dest='need_aug',
                        default=False, required=False,
                        help='need data augmentation', action='store_true')   
    args = vars(parser.parse_args()) 
    return args

接下来需要将我们第一步得到的对应表读入内存,因为这个表示ID到汉字的映射,我们在做一下转换,改成汉字到ID的映射,用于后面的字体生成。

#将汉字的label读入,得到(ID:汉字)的映射表label_dict
label_dict = get_label_dict()

char_list=[]  # 汉字列表
value_list=[] # label列表
for (value,chars) in label_dict.items():
    print (value,chars)
    char_list.append(chars)
    value_list.append(value)

# 合并成新的映射关系表:(汉字:ID)
lang_chars = dict(zip(char_list,value_list)) 
font_check = FontCheck(lang_chars) 

我们对旋转的角度存储到列表中,旋转角度的范围是[-rotate,rotate].

if rotate < 0:
    roate = - rotate

if rotate > 0 and rotate <= 45:
    all_rotate_angles = []
    for i in range(0, rotate+1, rotate_step):  
        all_rotate_angles.append(i)
    for i in range(-rotate, 0, rotate_step):
        all_rotate_angles.append(i)
    #print(all_rotate_angles)

现在说一下字体图像是怎么生成的,首先我们使用的工具是PIL。PIL里面有很好用的汉字生成函数,我们用这个函数再结合我们提供的字体文件,就可以生成我们想要的数字化的汉字了。我们先设定好我们生成的字体颜色为黑底白色,字体尺寸由输入参数来动态设定。

# 生成字体图像
class Font2Image(object):

    def __init__(self,
                 width, height,
                 need_crop, margin):
        self.width = width
        self.height = height
        self.need_crop = need_crop
        self.margin = margin

    def do(self, font_path, char, rotate=0):
        find_image_bbox = FindImageBBox()
        # 黑色背景
        img = Image.new("RGB", (self.width, self.height), "black")
        draw = ImageDraw.Draw(img)
        font = ImageFont.truetype(font_path, int(self.width * 0.7),)
        # 白色字体
        draw.text((0, 0), char, (255, 255, 255),
                  font=font)
        if rotate != 0:
            img = img.rotate(rotate)
        data = list(img.getdata())
        sum_val = 0
        for i_data in data:
            sum_val += sum(i_data)
        if sum_val > 2:
            np_img = np.asarray(data, dtype='uint8')
            np_img = np_img[:, 0]
            np_img = np_img.reshape((self.height, self.width))
            cropped_box = find_image_bbox.do(np_img)
            left, upper, right, lower = cropped_box
            np_img = np_img[upper: lower + 1, left: right + 1]
            if not self.need_crop:
                preprocess_resize_keep_ratio_fill_bg = \
                    PreprocessResizeKeepRatioFillBG(self.width, self.height,
                                                    fill_bg=False,
                                                    margin=self.margin)
                np_img = preprocess_resize_keep_ratio_fill_bg.do(
                    np_img)
            # cv2.imwrite(path_img, np_img)
            return np_img
        else:
            print("img doesn't exist.")

我们写两个循环,外层循环是汉字列表,内层循环是字体列表,对于每个汉字会得到一个image_list列表,里面存储着这个汉字的所有图像。

for (char, value) in lang_chars.items():  # 外层循环是字
    image_list = []
    print (char,value)
    #char_dir = os.path.join(images_dir, "%0.5d" % value)
    for j, verified_font_path in enumerate(verified_font_paths):    # 内层循环是字体   
        if rotate == 0:
            image = font2image.do(verified_font_path, char)
            image_list.append(image)
        else:
            for k in all_rotate_angles:	
                image = font2image.do(verified_font_path, char, rotate=k)
                image_list.append(image)

我们将image_list中图像按照比例分为训练集和测试集存储。

        test_num = len(image_list) * test_ratio
        random.shuffle(image_list)  # 图像列表打乱
        count = 0
        for i in range(len(image_list)):
            img = image_list[i]
            #print(img.shape)
            if count < test_num :
                char_dir = os.path.join(test_images_dir, "%0.5d" % value)
            else:
                char_dir = os.path.join(train_images_dir, "%0.5d" % value)

            if not os.path.isdir(char_dir):
                os.makedirs(char_dir)

            path_image = os.path.join(char_dir,"%d.png" % count)
            cv2.imwrite(path_image,img)
            count += 1

写好代码后,我们执行如下指令,开始生成印刷体文字汉字集。

 python gen_printed_char.py --out_dir ./dataset --font_dir ./chinese_fonts --width 30 --height 30 --margin 4 --rotate 30 --rotate_step 1

解析一下上述指令的附属参数:

  1. --out_dir 表示生成的汉字图像的存储目录
  2. --font_dir 表示放置汉字字体文件的路径
  3. --width --height 表示生成图像的高度和宽度
  4. --margin 表示字体与边缘的间隔
  5. --rotate 表示字体旋转的范围,[-rotate,rotate]
  6. --rotate_step 表示每次旋转的间隔

生成这么一个3755个汉字的数据集的所需的时间还是很久的,估计接近一个小时。其实这个生成过程可以用多线程、多进程并行加速,但是考虑到这种文字数据集只需生成一次就好,所以就没做这方面的优化了。数据集生成完我们可以发现,在dataset文件夹下得到train和test两个文件夹,train和test文件夹下都有3755个子文件夹,分别存储着生成的3755个汉字对应的图像,每个子文件的名字就是该汉字对应的id。随便选择一个train文件夹下的一个子文件夹打开,可以看到所获得的汉字图像,一共634个。

dataset下自动生成测试集和训练集

测试集和训练集下都有3755个子文件夹,用于存储每个汉字的图像。

生成出来的汉字图像

额外的图像增强

第三步生成的汉字图像是最基本的数据集,它所做的图像处理仅有旋转这么一项,如果我们想在数据增强上再做多点东西,想必我们最终训练出来的OCR模型的性能会更加优秀。我们使用opencv来完成我们定制的汉字图像增强任务。

因为生成的图像比较小,仅仅是30*30,如果对这么小的图像加噪声或者形态学处理,得到的字体图像会很糟糕,所以我们在做数据增强时,把图片尺寸适当增加,比如设置为100×100,再进行相应的数据增强,效果会更好。

噪点增加

def add_noise(cls,img):
    for i in range(20): #添加点噪声
        temp_x = np.random.randint(0,img.shape[0])
        temp_y = np.random.randint(0,img.shape[1])
        img[temp_x][temp_y] = 255
    return img

适当腐蚀

def add_erode(cls,img):
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(3, 3))    
    img = cv2.erode(img,kernel) 
    return img

适当膨胀

def add_dilate(cls,img):
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(3, 3))    
    img = cv2.dilate(img,kernel) 
    return img

然后做随机扰动

def do(self,img_list=[]):
    aug_list= copy.deepcopy(img_list)
    for i in range(len(img_list)):
        im = img_list[i]
        if self.noise and random.random()<0.5:
            im = self.add_noise(im)
        if self.dilate and random.random()<0.25:
            im = self.add_dilate(im)
        if self.erode and random.random()<0.25:
            im = self.add_erode(im)    
        aug_list.append(im)
    return aug_list

输入指令

python gen_printed_char.py --out_dir ./dataset2 --font_dir ./chinese_fonts --width 100 --height 100 --margin 10 --rotate 30 --rotate_step 1 --need_aug

使用这种生成的图像如下图所示,第一数据集扩大了两倍,第二图像的丰富性进一步提高,效果还是明显的。当然,如果要获得最好的效果,还需要调一下里面的参数,这里就不再详细说明了。

至此,我们所需的印刷体汉字数据集已经成功生成完毕,下一步要做的就是利用这些数据集设计一个卷积神经网络做文字识别了!完整的代码可以在我的github获取。

ocr技术系列之六文本检测ctpn的代码实现(代码片段)

这几天一直在用Pytorch来复现文本检测领域的CTPN论文,本文章将从数据处理、训练标签生成、神经网络搭建、损失函数设计、训练主过程编写等这几个方面来一步一步复现CTPN。CTPN算法理论可以参考这里。训练数据处理我们的训... 查看详情

ocr文字识别方法综述

 📝OCR文字识别技术介绍合集:1️⃣OCR文字识别技术系列第一章:OCR文字识别技术总结(一)2️⃣OCR文字识别技术系列第二章:OCR文字识别技术总结(二)3️⃣OCR文字识别技术系列第三章:OCR... 查看详情

ocr文字识别经典论文详解

📝OCR文字识别技术介绍合集:1️⃣OCR文字识别技术系列第一章:OCR文字识别技术总结(一)2️⃣OCR文字识别技术系列第二章:OCR文字识别技术总结(二)3️⃣OCR文字识别技术系列第三章:OCR... 查看详情

ocr文字识别技术总结(代码片段)

...。本系列目录:1️⃣OCR系列第一章:OCR文字识别技术总结(一)2️⃣OCR系列第二章:OCR文字识别技术总结(二)3️⃣OCR系列第三章:OCR文字识别技术总结(三)4️⃣OCR系列第四章:OCR... 查看详情

ocr技术系列之一字符识别技术总览

...世今生也有了一个比较清晰的了解。所以想写一篇关于OCR技术的综述,对OCR相关的知识点都好好总结一遍,以加深个人理解。什么是OCR?OCR英文全称是OpticalCharacterRecognition,中文叫做光学字符识别。它是利用光学技术和计算机技... 查看详情

ocr文字识别经典论文详解

...识别综述合集:1️⃣OCR系列第一章:OCR文字识别技术总结(一)2️⃣OCR系列第二章:OCR文字识别技术总结(二)3️⃣OCR系列第三章:OCR文字识别技术总结(三)4️⃣OCR系列第四章:OCR... 查看详情

中文识别数据集生成脚本

概述该脚本能够将用户指定的字符输出为不同字体的图像文件,用于训练文字识别的机器学习模型或用于其他文字识别OCR项目详细代码下载:http://www.demodashi.com/demo/13952.html 一、开发背景随着近几年来计算机算力的不断提升... 查看详情

基于深度学习和语言模型的印刷文字ocr系统

作者:苏剑林系列博文:科学空间OCR技术浅探:1.全文简述OCR技术浅探:2.背景与假设OCR技术浅探:3.特征提取(1)OCR技术浅探:3.特征提取(2)OCR技术浅探:4.文字定位OCR技术浅探:5.文本切割OCR技术浅探:6.光学识别OCR技术浅探:7.... 查看详情

ocr文字识别—基于ctc/attention/ace的三大解码算法

本文全面梳理一下OCR文字识别三种解码算法,先介绍一下什么是OCR文字识别,然后介绍一下常用的特征提取方法CRNN,最后介绍3种常用的解码算法CTC/Attention/ACE。什么是OCR文字识别?一般来说,文字识别之前需... 查看详情

paddle入门实战系列:中文场景文字识别(代码片段)

...构化数据中取得不错成果及应用,中文场景文字识别技术在人们的日常生活中受到广泛关注,在文档识别、身份证识别、银行卡识别、病例识别、名片识别等领域广泛应用.但由于中文场景中的文字容易受光照变化、低分... 查看详情

数平精准推荐|ocr技术之数据篇

...精准推荐团队利用图像增强,语义理解,生成对抗网络等技术生成高质足量的数据,为算法模型提供燃料,帮助OCR技术服务在多种业务场景中快速迭代,提升效果。一.背景介绍如果把深度学习看做引擎,大量带标注的数据则是... 查看详情

在自定义数据集上训练 tensorflow attention-ocr 的管道是啥?

】在自定义数据集上训练tensorflowattention-ocr的管道是啥?【英文标题】:what\'sthepipelinetotraintensorflowattention-ocroncustomizeddataset?在自定义数据集上训练tensorflowattention-ocr的管道是什么?【发布时间】:2018-10-0108:56:44【问题描述】:我... 查看详情

用transformer实现ocr字符识别!(代码片段)

...务数据集,讲解如何使用transformer实现一个简单的OCR文字识别任务,并从中体会transformer是如何应用到除分类以外更复杂的CV任务中的。全文分为四部分:一、数据集简介与获取二、数据分析与关系构建三、如何将transf... 查看详情

全球名校课程作业分享系列(11)--斯坦福cs231n之生成对抗网络(代码片段)

课程作业原地址:CS231nAssignment3作业及整理:@邓姸蕾&&@Molly&&@寒小阳时间:2018年2月。出处:http://blog.csdn.net/han_xiaoyang/article/details/79316554引言CS231N到目前位置,所有对神经网络的应用都是... 查看详情

给ai换个“大动力小心脏”之ocr异构加速

...能为GPUP4130%,处理延时仅为P4的1/10,CPU的1/30。文字识别技术-OCROCR技术,通俗来讲就是从图像中检测并识别字符的一种方法,在证通用文字识别、书籍电子化、自动信息 查看详情

ocr文字识别方法综述

...研究者学习。➡️点击跳转到网站。 📝OCR文字识别技术介绍合集:1️⃣OCR文字识别技术系列第一章:OCR文字识 查看详情

aigcchatgptgpt系列?我的认识

...content),新型内容生产方式。AIGC是利用人工智能技术来生成内容,也就是,它可以用输入数据生成相同或不同类型的内容,比如输入文字、生成文字,输入文字、生成图像等。GPT-3是生成型的预训练变换... 查看详情

python下实现tesseractocr训练字符库(opencv-python边缘检测代替jtessboxeditor手动矫正)(代码片段)

...opencv-python和pytesseract(可直接在cmd中下载)三、生成数据集和训练所需文件1.导入需要的包,以及自己写的一个简单的data_agumentation.py2.看一下生成数据集函数的形参和函数内部我预设的参数3.生成font_properties.txt文件4.... 查看详情