Skip to content

Commit c807777

Browse files
committed
增加数据增强代码
1 parent 5a51d09 commit c807777

File tree

1 file changed

+50
-2
lines changed

1 file changed

+50
-2
lines changed

ocr/gen_printed_char.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,48 @@
1717
import numpy as np
1818
import shutil
1919
import traceback
20+
import copy
2021

2122

23+
class dataAugmentation(object):
24+
def __init__(self,noise=True,dilate=True,erode=True):
25+
self.noise = noise
26+
self.dilate = dilate
27+
self.erode = erode
28+
29+
@classmethod
30+
def add_noise(cls,img):
31+
for i in range(20): #添加点噪声
32+
temp_x = np.random.randint(0,img.shape[0])
33+
temp_y = np.random.randint(0,img.shape[1])
34+
img[temp_x][temp_y] = 255
35+
return img
36+
37+
@classmethod
38+
def add_erode(cls,img):
39+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(3, 3))
40+
img = cv2.erode(img,kernel)
41+
return img
42+
43+
@classmethod
44+
def add_dilate(cls,img):
45+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(3, 3))
46+
img = cv2.dilate(img,kernel)
47+
return img
48+
49+
def do(self,img_list=[]):
50+
aug_list= copy.deepcopy(img_list)
51+
for i in range(len(img_list)):
52+
im = img_list[i]
53+
if self.noise and random.random()<0.5:
54+
im = self.add_noise(im)
55+
if self.dilate and random.random()<0.5:
56+
im = self.add_dilate(im)
57+
elif self.erode:
58+
im = self.add_erode(im)
59+
aug_list.append(im)
60+
return aug_list
61+
2262
# 对字体图像做等比例缩放
2363
class PreprocessResizeKeepRatio(object):
2464

@@ -296,6 +336,9 @@ def args_parse():
296336
parser.add_argument('--rotate_step', dest='rotate_step',
297337
default=0, required=False,
298338
help='rotate step for the rotate angle')
339+
parser.add_argument('--need_aug', dest='need_aug',
340+
default=False, required=False,
341+
help='need data augmentation', action='store_true')
299342
args = vars(parser.parse_args())
300343
return args
301344

@@ -316,6 +359,7 @@ def args_parse():
316359
need_crop = not options['no_crop']
317360
margin = int(options['margin'])
318361
rotate = int(options['rotate'])
362+
need_aug = options['need_aug']
319363
rotate_step = int(options['rotate_step'])
320364
train_image_dir_name = "train"
321365
test_image_dir_name = "test"
@@ -379,10 +423,14 @@ def args_parse():
379423
for k in all_rotate_angles:
380424
image = font2image.do(verified_font_path, char, rotate=k)
381425
image_list.append(image)
382-
426+
427+
428+
if need_aug:
429+
data_aug = dataAugmentation()
430+
image_list = data_aug.do(image_list)
431+
383432
test_num = len(image_list) * test_ratio
384433
random.shuffle(image_list) # 图像列表打乱
385-
386434
count = 0
387435
for i in range(len(image_list)):
388436
img = image_list[i]

0 commit comments

Comments
 (0)