Skip to content

Commit a504a9d

Browse files
author
ZouKexin-522
committed
[Trainer] Fix data skip logic for IterableDataset when resuming from checkpoint
1 parent 469e808 commit a504a9d

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

paddleformers/trainer/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,10 +1865,10 @@ def _inner_training_loop(
18651865
assert (
18661866
paddle.sum(paddle.stack(global_step_list) - global_step_list[0]) == 0
18671867
), f"Error, get different global step, please check! step list: {[x.item() for x in global_step_list]}"
1868-
1869-
epochs_trained = self.state.global_step // num_update_steps_per_epoch
1868+
_num_update_steps_per_epoch_for_skip = (self.state.max_steps if len_dataloader is None else num_update_steps_per_epoch)
1869+
epochs_trained = self.state.global_step // _num_update_steps_per_epoch_for_skip
18701870
if not args.ignore_data_skip:
1871-
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
1871+
steps_trained_in_current_epoch = self.state.global_step % _num_update_steps_per_epoch_for_skip
18721872
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
18731873
else:
18741874
steps_trained_in_current_epoch = 0

0 commit comments

Comments
 (0)