|
9 | 9 | cudnn.benchmark = True
|
10 | 10 | cudnn.fastest = True
|
11 | 11 |
|
| 12 | +# FLAG_PLATFORM = 'laptop' |
| 13 | +FLAG_PLATFORM = 'colab' |
| 14 | + |
12 | 15 | ## setup parse
|
13 | 16 | parser = argparse.ArgumentParser(description='Train the unet network',
|
14 | 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
15 | 18 |
|
16 |
| -parser.add_argument('--gpu_ids', default='-1', dest='gpu_ids') |
| 19 | +if FLAG_PLATFORM == 'colab': |
| 20 | + parser.add_argument('--gpu_ids', default='0', dest='gpu_ids') |
17 | 21 |
|
18 |
| -parser.add_argument('--mode', default='train', choices=['train', 'test'], dest='mode') |
19 |
| -parser.add_argument('--train_continue', default='off', choices=['on', 'off'], dest='train_continue') |
| 22 | + parser.add_argument('--dir_checkpoint', default='./drive/My Drive/GitHub/pytorch-noise2void/checkpoints', dest='dir_checkpoint') |
| 23 | + parser.add_argument('--dir_log', default='./drive/My Drive/GitHub/pytorch-noise2void/log', dest='dir_log') |
| 24 | + parser.add_argument('--dir_result', default='./drive/My Drive/GitHub/pytorch-noise2void/results', dest='dir_result') |
| 25 | + parser.add_argument('--dir_data', default='./drive/My Drive/datasets', dest='dir_data') |
| 26 | +elif FLAG_PLATFORM == 'laptop': |
| 27 | + parser.add_argument('--gpu_ids', default='-1', dest='gpu_ids') |
20 | 28 |
|
21 |
| -parser.add_argument('--scope', default='resnet', dest='scope') |
22 |
| -parser.add_argument('--norm', type=str, default='inorm', dest='norm') |
| 29 | + parser.add_argument('--dir_checkpoint', default='./checkpoints', dest='dir_checkpoint') |
| 30 | + parser.add_argument('--dir_log', default='./log', dest='dir_log') |
| 31 | + parser.add_argument('--dir_result', default='./results', dest='dir_result') |
| 32 | + parser.add_argument('--dir_data', default='./datasets', dest='dir_data') |
| 33 | + |
| 34 | +parser.add_argument('--mode', default='train', choices=['train', 'test'], dest='mode') |
| 35 | +parser.add_argument('--train_continue', default='on', choices=['on', 'off'], dest='train_continue') |
23 | 36 |
|
24 |
| -parser.add_argument('--dir_checkpoint', default='./checkpoints', dest='dir_checkpoint') |
25 |
| -parser.add_argument('--dir_log', default='./log', dest='dir_log') |
| 37 | +parser.add_argument('--scope', default='denoising_resnet', dest='scope') |
| 38 | +parser.add_argument('--norm', type=str, default='bnorm', dest='norm') |
26 | 39 |
|
27 | 40 | parser.add_argument('--name_data', type=str, default='bsd500', dest='name_data')
|
28 |
| -parser.add_argument('--dir_data', default='../datasets', dest='dir_data') |
29 |
| -parser.add_argument('--dir_result', default='./results', dest='dir_result') |
30 | 41 |
|
31 | 42 | parser.add_argument('--num_epoch', type=int, default=300, dest='num_epoch')
|
32 |
| -parser.add_argument('--batch_size', type=int, default=4, dest='batch_size') |
| 43 | +parser.add_argument('--batch_size', type=int, default=1, dest='batch_size') |
33 | 44 |
|
34 |
| -parser.add_argument('--lr_G', type=float, default=1e-4, dest='lr_G') |
| 45 | +parser.add_argument('--lr_G', type=float, default=1e-3, dest='lr_G') |
35 | 46 |
|
36 | 47 | parser.add_argument('--optim', default='adam', choices=['sgd', 'adam', 'rmsprop'], dest='optim')
|
37 | 48 | parser.add_argument('--beta1', default=0.5, dest='beta1')
|
|
52 | 63 |
|
53 | 64 | parser.add_argument('--data_type', default='float32', dest='data_type')
|
54 | 65 |
|
55 |
| -parser.add_argument('--num_freq_disp', type=int, default=1, dest='num_freq_disp') |
56 |
| -parser.add_argument('--num_freq_save', type=int, default=1, dest='num_freq_save') |
| 66 | +parser.add_argument('--num_freq_disp', type=int, default=10, dest='num_freq_disp') |
| 67 | +parser.add_argument('--num_freq_save', type=int, default=50, dest='num_freq_save') |
57 | 68 |
|
58 | 69 | PARSER = Parser(parser)
|
59 | 70 |
|
|
0 commit comments