|
10 | 10 | import levanter
|
11 | 11 | from levanter import callbacks
|
12 | 12 | from levanter.compat.hf_checkpoints import HFCheckpointConverter
|
13 |
| -from levanter.data.text import CausalLmDataset, LMDatasetConfig |
| 13 | +from levanter.data.text import LMDatasetConfig |
14 | 14 | from levanter.lora import (
|
15 | 15 | LoraConfig,
|
16 | 16 | lora_trainable_params_filter,
|
@@ -80,23 +80,15 @@ def main(config: LoraLmConfig):
|
80 | 80 | parameter_axis_mapping = config.trainer.parameter_axis_mapping
|
81 | 81 |
|
82 | 82 | train_dataset = config.data.train_set(
|
83 |
| - Pos.size, |
| 83 | + Pos, |
84 | 84 | batch_schedule=config.trainer.batch_schedule,
|
85 |
| - QPos=Pos, |
86 |
| - KPos=KeyPos, |
| 85 | + key=data_key, |
87 | 86 | )
|
88 | 87 |
|
89 | 88 | if train_dataset is None:
|
90 | 89 | raise ValueError("No training set!")
|
91 | 90 |
|
92 |
| - eval_datasets = { |
93 |
| - name: config.data.validation_set( |
94 |
| - Pos.size, |
95 |
| - QPos=Pos, |
96 |
| - KPos=KeyPos, |
97 |
| - ) |
98 |
| - for name in config.data.validation_splits |
99 |
| - } |
| 91 | + eval_datasets = config.data.validation_sets(Pos) |
100 | 92 |
|
101 | 93 | if len(eval_datasets) == 0:
|
102 | 94 | logger.warning("No evaluation datasets provided.")
|
@@ -135,19 +127,25 @@ def loraize_hf_model(model):
|
135 | 127 | logger.info(f"Trainable parameter count: {just_lora_params}")
|
136 | 128 | logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count:.3e}")
|
137 | 129 |
|
138 |
| - for name, eval_dataset in eval_datasets.items(): |
139 |
| - eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id) |
140 |
| - trainer.add_eval_hook(eval_dataset, name=name) |
| 130 | + max_eval_examples_per_ds = config.trainer.max_eval_batches |
| 131 | + if max_eval_examples_per_ds is not None: |
| 132 | + max_eval_examples_per_ds *= config.trainer.eval_batch_size |
141 | 133 |
|
142 |
| - # boilerplate hooks and such |
143 |
| - if len(eval_datasets) == 0: |
144 |
| - logger.warning("No evaluation datasets provided.") |
| 134 | + tagged_eval_datasets = config.data.tagged_eval_sets(Pos) |
145 | 135 |
|
146 |
| - for name, eval_dataset in eval_datasets.items(): |
147 |
| - eval_dataset = CausalLmDataset( |
148 |
| - eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id, eos_id=tokenizer.eos_token_id |
| 136 | + if len(tagged_eval_datasets) == 0: |
| 137 | + logger.warning("No evaluation datasets provided.") |
| 138 | + else: |
| 139 | + cb = levanter.eval.cb_tagged_lm_evaluate( |
| 140 | + trainer.EvalBatch, |
| 141 | + tagged_eval_datasets, |
| 142 | + tokenizer, |
| 143 | + trainer.device_mesh, |
| 144 | + trainer.compute_axis_mapping, |
| 145 | + max_eval_examples_per_ds, |
| 146 | + mp=config.trainer.mp, |
149 | 147 | )
|
150 |
| - trainer.add_eval_hook(eval_dataset, name=name) |
| 148 | + trainer.add_hook(cb, every=config.trainer.steps_per_eval) |
151 | 149 |
|
152 | 150 | trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1)
|
153 | 151 | if config.peft_save_path is not None:
|
|
0 commit comments