Skip to content

Commit 66ebeb2

Browse files
author
hanyoseob
committed
Upload init
1 parent 0f2439c commit 66ebeb2

File tree

3 files changed

+27
-18
lines changed

3 files changed

+27
-18
lines changed

dataset.py

-2
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,6 @@ def __getitem__(self, index):
7373
if data.shape[0] > data.shape[1]:
7474
data = data.transpose((1, 0, 2))
7575

76-
sz = data.shape
77-
7876
label = data + self.noise[index]
7977
input, mask = self.generate_mask(copy.deepcopy(label))
8078

main.py

+24-13
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,40 @@
99
cudnn.benchmark = True
1010
cudnn.fastest = True
1111

12+
# FLAG_PLATFORM = 'laptop'
13+
FLAG_PLATFORM = 'colab'
14+
1215
## setup parse
1316
parser = argparse.ArgumentParser(description='Train the unet network',
1417
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
1518

16-
parser.add_argument('--gpu_ids', default='-1', dest='gpu_ids')
19+
if FLAG_PLATFORM == 'colab':
20+
parser.add_argument('--gpu_ids', default='0', dest='gpu_ids')
1721

18-
parser.add_argument('--mode', default='train', choices=['train', 'test'], dest='mode')
19-
parser.add_argument('--train_continue', default='off', choices=['on', 'off'], dest='train_continue')
22+
parser.add_argument('--dir_checkpoint', default='./drive/My Drive/GitHub/pytorch-noise2void/checkpoints', dest='dir_checkpoint')
23+
parser.add_argument('--dir_log', default='./drive/My Drive/GitHub/pytorch-noise2void/log', dest='dir_log')
24+
parser.add_argument('--dir_result', default='./drive/My Drive/GitHub/pytorch-noise2void/results', dest='dir_result')
25+
parser.add_argument('--dir_data', default='./drive/My Drive/datasets', dest='dir_data')
26+
elif FLAG_PLATFORM == 'laptop':
27+
parser.add_argument('--gpu_ids', default='-1', dest='gpu_ids')
2028

21-
parser.add_argument('--scope', default='resnet', dest='scope')
22-
parser.add_argument('--norm', type=str, default='inorm', dest='norm')
29+
parser.add_argument('--dir_checkpoint', default='./checkpoints', dest='dir_checkpoint')
30+
parser.add_argument('--dir_log', default='./log', dest='dir_log')
31+
parser.add_argument('--dir_result', default='./results', dest='dir_result')
32+
parser.add_argument('--dir_data', default='./datasets', dest='dir_data')
33+
34+
parser.add_argument('--mode', default='train', choices=['train', 'test'], dest='mode')
35+
parser.add_argument('--train_continue', default='on', choices=['on', 'off'], dest='train_continue')
2336

24-
parser.add_argument('--dir_checkpoint', default='./checkpoints', dest='dir_checkpoint')
25-
parser.add_argument('--dir_log', default='./log', dest='dir_log')
37+
parser.add_argument('--scope', default='denoising_resnet', dest='scope')
38+
parser.add_argument('--norm', type=str, default='bnorm', dest='norm')
2639

2740
parser.add_argument('--name_data', type=str, default='bsd500', dest='name_data')
28-
parser.add_argument('--dir_data', default='../datasets', dest='dir_data')
29-
parser.add_argument('--dir_result', default='./results', dest='dir_result')
3041

3142
parser.add_argument('--num_epoch', type=int, default=300, dest='num_epoch')
32-
parser.add_argument('--batch_size', type=int, default=4, dest='batch_size')
43+
parser.add_argument('--batch_size', type=int, default=1, dest='batch_size')
3344

34-
parser.add_argument('--lr_G', type=float, default=1e-4, dest='lr_G')
45+
parser.add_argument('--lr_G', type=float, default=1e-3, dest='lr_G')
3546

3647
parser.add_argument('--optim', default='adam', choices=['sgd', 'adam', 'rmsprop'], dest='optim')
3748
parser.add_argument('--beta1', default=0.5, dest='beta1')
@@ -52,8 +63,8 @@
5263

5364
parser.add_argument('--data_type', default='float32', dest='data_type')
5465

55-
parser.add_argument('--num_freq_disp', type=int, default=1, dest='num_freq_disp')
56-
parser.add_argument('--num_freq_save', type=int, default=1, dest='num_freq_save')
66+
parser.add_argument('--num_freq_disp', type=int, default=10, dest='num_freq_disp')
67+
parser.add_argument('--num_freq_save', type=int, default=50, dest='num_freq_save')
5768

5869
PARSER = Parser(parser)
5970

train.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ def train(self):
146146
dataset_train = Dataset(dir_data_train, data_type=self.data_type, transform=transform_train, sgm=25, ratio=0.9, size_data=size_data, size_window=size_window)
147147
dataset_val = Dataset(dir_data_val, data_type=self.data_type, transform=transform_val, sgm=25, ratio=0.9, size_data=size_data, size_window=size_window)
148148

149-
loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=0)
150-
loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle=True, num_workers=0)
149+
loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8)
150+
loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle=True, num_workers=8)
151151

152152
num_train = len(dataset_train)
153153
num_val = len(dataset_val)
@@ -356,7 +356,7 @@ def test(self):
356356

357357
# dataset_test = Dataset(dir_data_test, data_type=self.data_type, transform=transform_test, sgm=(0, 25))
358358
dataset_test = Dataset(dir_data_test, data_type=self.data_type, transform=transform_test, sgm=25, ratio=1, size_data=size_data, size_window=size_window)
359-
loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=0)
359+
loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=8)
360360

361361
num_test = len(dataset_test)
362362

0 commit comments

Comments
 (0)