Skip to content

Commit 7a54b72

Browse files
committedMar 15, 2022
final doc fixes
1 parent 7f0b337 commit 7a54b72

6 files changed

+21
-78
lines changed
 

‎code/FlowersDataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ def __getitem__(self, idx):
7979
return im, txt2im_labels, im2txt_masked_labels, img_idx, txt_idx
8080

8181
def get_captions_of_image(self, img_idx):
82-
"""Get all the captions (10) of a given image
82+
"""Get all the captions (10) of a given image by its index
8383
84-
:param img_idx: idx of a given image
84+
:param img_idx: index of a given image
8585
:return: a list of all the captions (10) of a given image
8686
"""
8787
with open(os.path.join(self.txts_path, f'image_{img_idx:05}.txt')) as f:

‎code/fid_score_override.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def get_activations(files, model, batch_size=50, dims=2048, device='cpu',
4444
'Setting batch size to data size'))
4545
batch_size = len(files)
4646

47-
dataset = fid_score.ImagePathDataset(files, transforms=TF.Compose([TF.Resize((224, 224)),
48-
TF.ToTensor()]))
47+
dataset = fid_score.ImagePathDataset(files, transforms=TF.Compose([TF.Resize((224, 224)),
48+
TF.ToTensor()]))
4949
dataloader = torch.utils.data.DataLoader(dataset,
5050
batch_size=batch_size,
5151
shuffle=False,

‎code/generating.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def generate_test_examples(device, gens_dir, im2txt_model, test_loader, txt2im_m
211211
parser.add_argument('--out_dir', required=True, type=str, help='A directory of a trained model to generate for')
212212
parser.add_argument('--text', type=str, default=None, help='Text prompt for which to generate an image')
213213
parser.add_argument('--img_path', type=str, default=None, help='Path to the image for which to generate a caption')
214-
parser.add_argument('--amount', type=int, default=1, help="The amount of images to generate from the cutsom text, "
214+
parser.add_argument('--amount', type=int, default=1, help="The amount of images to generate from the custom text, "
215215
"if given (via '--text')")
216216
parsed_args = parser.parse_args()
217217

‎code/main.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
args = yaml.load(f, Loader=yaml.FullLoader)
3535
args.update(vars(parsed_args))
3636

37+
# If the option for training continuation from a checkpoint was marked, check if it's possible to continue
3738
if args["continue_training"] is not None:
3839
# The first saved thing is the generator_k1.pth file, so if no such file exists - there is no mid-training
3940
# state to continue from

‎code/tests.py

-69
This file was deleted.

‎code/training.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -340,14 +340,19 @@ def load_checkpoint(txt2im_model, im2txt_model, txt2im_optimizer, im2txt_optimiz
340340
start_epoch = 1
341341
start_k = 1
342342

343+
# search for saved weights files
343344
pth_files = [x for x in os.listdir(args["output_dir"]) if x.endswith('.pth')]
345+
346+
# Search for end-of-epoch saved models (containing both parts - txt2im and im2txt)
344347
avail_model_epochs = [int(x[8:-4]) for x in pth_files if x.startswith('mod')]
345348
if len(avail_model_epochs):
346-
max_model_epoch = max(avail_model_epochs)
349+
max_model_epoch = max(avail_model_epochs) # get the max epoch saved
347350
model_pth_path = os.path.join(args["output_dir"], f'models_e{max_model_epoch}.pth')
348-
epoch_checkpoint = torch.load(model_pth_path, map_location=device)
349-
start_epoch = epoch_checkpoint["epochs"] + 1
351+
epoch_checkpoint = torch.load(model_pth_path, map_location=device) # load it
352+
start_epoch = epoch_checkpoint["epochs"] + 1 # the starting epoch should be the following epoch
350353
losses = epoch_checkpoint["losses"]
354+
355+
# load the last saved models
351356
im2txt_model.load_state_dict(epoch_checkpoint["im2txt"])
352357
im2txt_optimizer.load_state_dict(epoch_checkpoint["optimizer_im2txt"])
353358
txt2im_model.load_state_dict(epoch_checkpoint["txt2im"])
@@ -357,6 +362,8 @@ def load_checkpoint(txt2im_model, im2txt_model, txt2im_optimizer, im2txt_optimiz
357362
del epoch_checkpoint
358363
torch.cuda.empty_cache()
359364

365+
# the txt2im model can continue to be updated (and saved) in the middle of an epoch, so now we search for the
366+
# latest saved gstep after the max epoch found previously
360367
gen_files = [x for x in pth_files if x.startswith('gen')]
361368
max_gstep = 0
362369
best_gstep_checkpoint = None
@@ -368,20 +375,24 @@ def load_checkpoint(txt2im_model, im2txt_model, txt2im_optimizer, im2txt_optimiz
368375

369376
del gstep_checkpoint
370377
torch.cuda.empty_cache()
371-
378+
379+
# if we found a more recent gstep we now need to load the updated txt2im
372380
if max_gstep != 0:
373381
best_gstep_checkpoint = torch.load(os.path.join(args["output_dir"], f'generator_k{max_gstep}.pth'), map_location=device)
374382

375383
if best_gstep_checkpoint is None and len(avail_model_epochs):
376384
print(f"loaded TXT2IM from {model_pth_path}\n"
377385
f"Starting on epoch {start_epoch}, k {start_k}")
386+
# the TXT2IM model was loaded from the last saved epoch
378387
elif best_gstep_checkpoint is None:
388+
# If we got here someone deleted the saved pth files while the code was running
379389
raise RuntimeError("No pth files saved. the code shouldn't reach here")
380390
else:
381391
txt2im_model.load_state_dict(best_gstep_checkpoint["txt2im"])
382392
txt2im_optimizer.load_state_dict(best_gstep_checkpoint["optimizer_txt2im"])
383393
start_k = best_gstep_checkpoint["k"] + 1
384394
print(f"loaded TXT2IM from {os.path.join(args['output_dir'], f'generator_k{max_gstep}.pth')}\n"
385395
f"Starting on epoch {start_epoch}, k {start_k}")
396+
# the TXT2IM model was successfully loaded from the last saved gstep
386397

387398
return txt2im_model, im2txt_model, txt2im_optimizer, im2txt_optimizer, losses, start_epoch, start_k

0 commit comments

Comments
 (0)
Please sign in to comment.