17
17
import numpy as np
18
18
import shutil
19
19
import traceback
20
+ import copy
20
21
21
22
23
+ class dataAugmentation (object ):
24
+ def __init__ (self ,noise = True ,dilate = True ,erode = True ):
25
+ self .noise = noise
26
+ self .dilate = dilate
27
+ self .erode = erode
28
+
29
+ @classmethod
30
+ def add_noise (cls ,img ):
31
+ for i in range (20 ): #添加点噪声
32
+ temp_x = np .random .randint (0 ,img .shape [0 ])
33
+ temp_y = np .random .randint (0 ,img .shape [1 ])
34
+ img [temp_x ][temp_y ] = 255
35
+ return img
36
+
37
+ @classmethod
38
+ def add_erode (cls ,img ):
39
+ kernel = cv2 .getStructuringElement (cv2 .MORPH_RECT ,(3 , 3 ))
40
+ img = cv2 .erode (img ,kernel )
41
+ return img
42
+
43
+ @classmethod
44
+ def add_dilate (cls ,img ):
45
+ kernel = cv2 .getStructuringElement (cv2 .MORPH_RECT ,(3 , 3 ))
46
+ img = cv2 .dilate (img ,kernel )
47
+ return img
48
+
49
+ def do (self ,img_list = []):
50
+ aug_list = copy .deepcopy (img_list )
51
+ for i in range (len (img_list )):
52
+ im = img_list [i ]
53
+ if self .noise and random .random ()< 0.5 :
54
+ im = self .add_noise (im )
55
+ if self .dilate and random .random ()< 0.5 :
56
+ im = self .add_dilate (im )
57
+ elif self .erode :
58
+ im = self .add_erode (im )
59
+ aug_list .append (im )
60
+ return aug_list
61
+
22
62
# 对字体图像做等比例缩放
23
63
class PreprocessResizeKeepRatio (object ):
24
64
@@ -296,6 +336,9 @@ def args_parse():
296
336
parser .add_argument ('--rotate_step' , dest = 'rotate_step' ,
297
337
default = 0 , required = False ,
298
338
help = 'rotate step for the rotate angle' )
339
+ parser .add_argument ('--need_aug' , dest = 'need_aug' ,
340
+ default = False , required = False ,
341
+ help = 'need data augmentation' , action = 'store_true' )
299
342
args = vars (parser .parse_args ())
300
343
return args
301
344
@@ -316,6 +359,7 @@ def args_parse():
316
359
need_crop = not options ['no_crop' ]
317
360
margin = int (options ['margin' ])
318
361
rotate = int (options ['rotate' ])
362
+ need_aug = options ['need_aug' ]
319
363
rotate_step = int (options ['rotate_step' ])
320
364
train_image_dir_name = "train"
321
365
test_image_dir_name = "test"
@@ -379,10 +423,14 @@ def args_parse():
379
423
for k in all_rotate_angles :
380
424
image = font2image .do (verified_font_path , char , rotate = k )
381
425
image_list .append (image )
382
-
426
+
427
+
428
+ if need_aug :
429
+ data_aug = dataAugmentation ()
430
+ image_list = data_aug .do (image_list )
431
+
383
432
test_num = len (image_list ) * test_ratio
384
433
random .shuffle (image_list ) # 图像列表打乱
385
-
386
434
count = 0
387
435
for i in range (len (image_list )):
388
436
img = image_list [i ]
0 commit comments