Skip to content

Commit 213bef3

Browse files
committed
1.修复TrainBatchLoop和EvaluateBatchLoop中的问题; 2.修复OverfitDataLoader不break的问题
1 parent bdb24a3 commit 213bef3

File tree

3 files changed

+6
-0
lines changed

3 files changed

+6
-0
lines changed

fastNLP/core/controllers/loops/evaluate_batch_loop.py

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ class EvaluateBatchLoop(Loop):
1818
def __init__(self, batch_step_fn:Optional[Callable]=None):
1919
if batch_step_fn is not None:
2020
self.batch_step_fn = batch_step_fn
21+
else:
22+
self.batch_step_fn = EvaluateBatchLoop.batch_step_fn
2123

2224
def run(self, evaluator, dataloader) -> Dict:
2325
r"""

fastNLP/core/controllers/loops/train_batch_loop.py

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ class TrainBatchLoop(Loop):
2020
def __init__(self, batch_step_fn: Optional[Callable] = None):
2121
if batch_step_fn is not None:
2222
self.batch_step_fn = batch_step_fn
23+
else:
24+
self.batch_step_fn = TrainBatchLoop.batch_step_fn
2325

2426
def run(self, trainer, dataloader):
2527
r"""

fastNLP/core/dataloaders/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def __init__(self, dataloader, overfit_batches: int, batches=None):
131131
for idx, batch in enumerate(dataloader):
132132
if idx < self.overfit_batches or self.overfit_batches <= -1:
133133
self.batches.append(batch)
134+
else:
135+
break
134136
else:
135137
assert isinstance(batches, list)
136138
self.batches = batches

0 commit comments

Comments
 (0)