@@ -340,14 +340,19 @@ def load_checkpoint(txt2im_model, im2txt_model, txt2im_optimizer, im2txt_optimiz
340
340
start_epoch = 1
341
341
start_k = 1
342
342
343
+ # search for saved weights files
343
344
pth_files = [x for x in os .listdir (args ["output_dir" ]) if x .endswith ('.pth' )]
345
+
346
+ # Search for end-of-epoch saved models (containing both parts - txt2im and im2txt)
344
347
avail_model_epochs = [int (x [8 :- 4 ]) for x in pth_files if x .startswith ('mod' )]
345
348
if len (avail_model_epochs ):
346
- max_model_epoch = max (avail_model_epochs )
349
+ max_model_epoch = max (avail_model_epochs ) # get the max epoch saved
347
350
model_pth_path = os .path .join (args ["output_dir" ], f'models_e{ max_model_epoch } .pth' )
348
- epoch_checkpoint = torch .load (model_pth_path , map_location = device )
349
- start_epoch = epoch_checkpoint ["epochs" ] + 1
351
+ epoch_checkpoint = torch .load (model_pth_path , map_location = device ) # load it
352
+ start_epoch = epoch_checkpoint ["epochs" ] + 1 # the starting epoch should be the following epoch
350
353
losses = epoch_checkpoint ["losses" ]
354
+
355
+ # load the last saved models
351
356
im2txt_model .load_state_dict (epoch_checkpoint ["im2txt" ])
352
357
im2txt_optimizer .load_state_dict (epoch_checkpoint ["optimizer_im2txt" ])
353
358
txt2im_model .load_state_dict (epoch_checkpoint ["txt2im" ])
@@ -357,6 +362,8 @@ def load_checkpoint(txt2im_model, im2txt_model, txt2im_optimizer, im2txt_optimiz
357
362
del epoch_checkpoint
358
363
torch .cuda .empty_cache ()
359
364
365
+ # the txt2im model can continue to be updated (and saved) in the middle of an epoch, so now we search for the
366
+ # latest saved gstep after the max epoch found previously
360
367
gen_files = [x for x in pth_files if x .startswith ('gen' )]
361
368
max_gstep = 0
362
369
best_gstep_checkpoint = None
@@ -368,20 +375,24 @@ def load_checkpoint(txt2im_model, im2txt_model, txt2im_optimizer, im2txt_optimiz
368
375
369
376
del gstep_checkpoint
370
377
torch .cuda .empty_cache ()
371
-
378
+
379
+ # if we found a more recent gstep we now need to load the updated txt2im
372
380
if max_gstep != 0 :
373
381
best_gstep_checkpoint = torch .load (os .path .join (args ["output_dir" ], f'generator_k{ max_gstep } .pth' ), map_location = device )
374
382
375
383
if best_gstep_checkpoint is None and len (avail_model_epochs ):
376
384
print (f"loaded TXT2IM from { model_pth_path } \n "
377
385
f"Starting on epoch { start_epoch } , k { start_k } " )
386
+ # the TXT2IM model was loaded from the last saved epoch
378
387
elif best_gstep_checkpoint is None :
388
+ # If we got here someone deleted the saved pth files while the code was running
379
389
raise RuntimeError ("No pth files saved. the code shouldn't reach here" )
380
390
else :
381
391
txt2im_model .load_state_dict (best_gstep_checkpoint ["txt2im" ])
382
392
txt2im_optimizer .load_state_dict (best_gstep_checkpoint ["optimizer_txt2im" ])
383
393
start_k = best_gstep_checkpoint ["k" ] + 1
384
394
print (f"loaded TXT2IM from { os .path .join (args ['output_dir' ], f'generator_k{ max_gstep } .pth' )} \n "
385
395
f"Starting on epoch { start_epoch } , k { start_k } " )
396
+ # the TXT2IM model was successfully loaded from the last saved gstep
386
397
387
398
return txt2im_model , im2txt_model , txt2im_optimizer , im2txt_optimizer , losses , start_epoch , start_k
0 commit comments