Skip to content

Commit dac96b2

Browse files
authored
fix(hf): pass fake eval dataset since it is required (#853)
1 parent a15bb31 commit dac96b2

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

src/dvclive/huggingface.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ def on_train_end(
7373
):
7474
if self._log_model is True and state.is_world_process_zero:
7575
fake_trainer = Trainer(
76-
args=args, model=kwargs.get("model"), tokenizer=kwargs.get("tokenizer")
76+
args=args,
77+
model=kwargs.get("model"),
78+
tokenizer=kwargs.get("tokenizer"),
79+
eval_dataset=["fake"],
7780
)
7881
name = "best" if args.load_best_model_at_end else "last"
7982
output_dir = os.path.join(args.output_dir, name)

tests/frameworks/test_huggingface.py

+1
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def test_huggingface_log_model(
162162
live_callback = callback(live=live, log_model=log_model)
163163

164164
args.load_best_model_at_end = best
165+
args.metric_for_best_model = "loss"
165166

166167
trainer = Trainer(
167168
model,

0 commit comments

Comments
 (0)