@@ -234,7 +234,8 @@ def test_latest_checkpoint_path(self) -> None:
234234 path_4 = os .path .join (temp_dir , "epoch_700" )
235235 os .mkdir (path_4 )
236236 self .assertEqual (
237- get_latest_checkpoint_path (temp_dir , METADATA_FNAME ), path_2
237+ get_latest_checkpoint_path (temp_dir , METADATA_FNAME ),
238+ path_2 ,
238239 )
239240
240241 @skip_if_not_distributed
@@ -284,7 +285,8 @@ def _latest_checkpoint_path_distributed() -> None:
284285 expected_path = path_container [0 ]
285286 tc .assertIsNotNone (expected_path )
286287 tc .assertEqual (
287- get_latest_checkpoint_path (temp_dir , METADATA_FNAME ), expected_path
288+ get_latest_checkpoint_path (temp_dir , METADATA_FNAME ),
289+ expected_path ,
288290 )
289291
290292 if is_rank0 :
@@ -368,7 +370,12 @@ def test_retrieve_checkpoint_dirpaths(self) -> None:
368370 # compares set equality since order of returned dirpaths is not guaranteed
369371 # in _retrieve_checkpoint_dirpaths
370372 self .assertEqual (
371- set (_retrieve_checkpoint_dirpaths (temp_dir , metadata_fname = None )),
373+ {
374+ str (x )
375+ for x in _retrieve_checkpoint_dirpaths (
376+ temp_dir , metadata_fname = None
377+ )
378+ },
372379 {os .path .join (temp_dir , path ) for path in paths [:- 1 ]},
373380 )
374381 self .assertEqual (
@@ -382,9 +389,12 @@ def test_retrieve_checkpoint_dirpaths(self) -> None:
382389 pass
383390
384391 self .assertEqual (
385- set (
386- _retrieve_checkpoint_dirpaths (temp_dir , metadata_fname = ".metadata" )
387- ),
392+ {
393+ str (x )
394+ for x in _retrieve_checkpoint_dirpaths (
395+ temp_dir , metadata_fname = ".metadata"
396+ )
397+ },
388398 {os .path .join (temp_dir , paths [2 ])},
389399 )
390400
@@ -394,30 +404,36 @@ def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None:
394404 """
395405 with tempfile .TemporaryDirectory () as temp_dir :
396406 paths = [
397- "epoch_0_step_10_val_loss=10" ,
398- "epoch_1_step_10_val_loss=5" ,
407+ "epoch_0_step_10_val_loss=10.0 " ,
408+ "epoch_1_step_10_val_loss=5.0 " ,
399409 "epoch_2_step_10" ,
400410 "epoch_0_step_5" ,
401- "epoch_0_step_6_train_loss=13" ,
411+ "epoch_0_step_6_train_loss=13.0 " ,
402412 ]
403413 for path in paths :
404414 os .mkdir (os .path .join (temp_dir , path ))
405415 # make last path a file instead of a directory
406- with open (os .path .join (temp_dir , "epoch_0_step_3_val_loss=3" ), "w" ):
416+ with open (os .path .join (temp_dir , "epoch_0_step_3_val_loss=3.0 " ), "w" ):
407417 pass
408418
409419 # compares set equality since order of returned dirpaths is not guaranteed
410420 # in _retrieve_checkpoint_dirpaths
411421 self .assertEqual (
412- set (_retrieve_checkpoint_dirpaths (temp_dir , metadata_fname = None )),
422+ {
423+ str (x )
424+ for x in _retrieve_checkpoint_dirpaths (
425+ temp_dir , metadata_fname = None
426+ )
427+ },
413428 {os .path .join (temp_dir , path ) for path in paths },
414429 )
415430 self .assertEqual (
416- set (
417- _retrieve_checkpoint_dirpaths (
431+ {
432+ str (x )
433+ for x in _retrieve_checkpoint_dirpaths (
418434 temp_dir , metadata_fname = None , metric_name = "val_loss"
419435 )
420- ) ,
436+ } ,
421437 {
422438 os .path .join (temp_dir , path ) for path in paths [:2 ]
423439 }, # since last path is a file
@@ -433,11 +449,12 @@ def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None:
433449 pass
434450
435451 self .assertEqual (
436- set (
437- _retrieve_checkpoint_dirpaths (
452+ {
453+ str (x )
454+ for x in _retrieve_checkpoint_dirpaths (
438455 temp_dir , metadata_fname = ".metadata" , metric_name = "val_loss"
439456 )
440- ) ,
457+ } ,
441458 {os .path .join (temp_dir , paths [1 ])},
442459 )
443460
@@ -467,7 +484,7 @@ def create_tmp_dir() -> str:
467484 os .mkdir (path2 )
468485 torch .distributed .barrier ()
469486
470- ckpt_dirpaths = get_checkpoint_dirpaths (temp_dir )
487+ ckpt_dirpaths = [ str ( x ) for x in get_checkpoint_dirpaths (temp_dir )]
471488 tc = unittest .TestCase ()
472489 tc .assertEqual (set (ckpt_dirpaths ), {path1 , path2 })
473490
@@ -492,7 +509,7 @@ def test_get_checkpoint_dirpaths(self) -> None:
492509 os .mkdir (path3 )
493510
494511 self .assertEqual (
495- set ( get_checkpoint_dirpaths (temp_dir )) ,
512+ { str ( x ) for x in get_checkpoint_dirpaths (temp_dir )} ,
496513 {path1 , path2 , path3 },
497514 )
498515
@@ -505,7 +522,10 @@ def test_get_checkpoint_dirpaths(self) -> None:
505522 os .mkdir (path3 )
506523
507524 self .assertEqual (
508- set (get_checkpoint_dirpaths (temp_dir , metric_name = "val_loss" )),
525+ {
526+ str (x )
527+ for x in get_checkpoint_dirpaths (temp_dir , metric_name = "val_loss" )
528+ },
509529 {path1 , path2 , path3 },
510530 )
511531
@@ -519,20 +539,27 @@ def test_checkpoint_sorting_utils(self) -> None:
519539 """
520540 Tests the sort utilities
521541 """
522- paths = ["epoch_1_step_20" , "epoch_4_step_130" , "epoch_0_step_10_val_loss=10" ]
523- self .assertEqual (_sort_by_recency (paths ), [paths [2 ], paths [0 ], paths [1 ]])
542+ paths = [
543+ "foo/epoch_1_step_20" ,
544+ "foo/epoch_4_step_130" ,
545+ "foo/epoch_0_step_10_val_loss=10.0" ,
546+ ]
547+ ckpts = [CheckpointPath .from_str (x ) for x in paths ]
548+ sorted_paths = [str (x ) for x in _sort_by_recency (ckpts )]
549+ self .assertEqual (sorted_paths , [paths [2 ], paths [0 ], paths [1 ]])
524550
525551 paths = [
526- "epoch_1_step_20_val_loss=0.09" ,
527- "epoch_4_step_130_val_loss=29" ,
528- "epoch_0_step_10_val_loss=10" ,
552+ "foo/ epoch_1_step_20_val_loss=0.09" ,
553+ "foo/ epoch_4_step_130_val_loss=29.0 " ,
554+ "foo/ epoch_0_step_10_val_loss=10.0 " ,
529555 ]
530- self .assertEqual (
531- _sort_by_metric_value (paths , mode = "min" ), [paths [1 ], paths [2 ], paths [0 ]]
532- )
533- self .assertEqual (
534- _sort_by_metric_value (paths , mode = "max" ), [paths [0 ], paths [2 ], paths [1 ]]
535- )
556+ ckpts = [CheckpointPath .from_str (x ) for x in paths ]
557+
558+ sorted_paths = [str (x ) for x in _sort_by_metric_value (ckpts , mode = "min" )]
559+ self .assertEqual (sorted_paths , [paths [1 ], paths [2 ], paths [0 ]])
560+
561+ sorted_paths = [str (x ) for x in _sort_by_metric_value (ckpts , mode = "max" )]
562+ self .assertEqual (sorted_paths , [paths [0 ], paths [2 ], paths [1 ]])
536563
537564 def test_delete_checkpoint (self ) -> None :
538565 """
0 commit comments