Skip to content

Commit 9f46d77

Browse files
committed
cleaning up after cursor's bumbling
1 parent dd594a8 commit 9f46d77

File tree

5 files changed

+37
-33
lines changed

5 files changed

+37
-33
lines changed

src/levanter/data/dataset.py

+9
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,17 @@ def map_batches(self, fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs)
123123
return BatchMappedAsyncDataset(self, fn, *extra_args, **extra_kwargs)
124124

125125
def slice_dataset(self, start_index: Optional[int] = None, end_index: Optional[int] = None):
126+
"""
127+
Slices the dataset from `start_index` to `end_index`.
128+
"""
126129
return SlicedAsyncDataset(self, start_index, end_index)
127130

131+
def take(self, n: int):
132+
"""
133+
Alias for `slice_dataset(end_index=n)`.
134+
"""
135+
return self.slice_dataset(end_index=n)
136+
128137
def shuffle(self, key: PRNGKey):
129138
import levanter.data.permutation as permutation
130139

src/levanter/main/eval_lm.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -47,26 +47,25 @@ def main(config: EvalLmConfig):
4747

4848
Batch = Axis("batch", config.trainer.eval_batch_size)
4949
Pos = config.model.Pos
50-
KeyPos = config.model.KeyPos
5150

5251
if config.eval_on_train:
53-
raw_dataset = config.data.train_set(
52+
ds = config.data.train_set(
5453
Pos,
5554
config.trainer.batch_schedule,
5655
key=jax.random.PRNGKey(0),
57-
KPos=KeyPos,
5856
)
5957
else:
60-
raw_dataset = config.data.validation_set(Pos, KPos=KeyPos)
58+
ds = config.data.validation_set(Pos) # type: ignore
59+
assert ds is not None, "No validation set found"
6160

62-
if raw_dataset is None:
61+
if ds is None:
6362
raise ValueError("no dataset found!")
6463

6564
if config.max_batches is not None:
66-
raw_dataset = raw_dataset.take(config.max_batches * config.trainer.eval_batch_size)
65+
ds = ds.take(config.max_batches * config.trainer.eval_batch_size)
6766

6867
eval_loader = DataLoader(
69-
raw_dataset,
68+
ds,
7069
Batch,
7170
max_buffered_batches=None,
7271
mesh=config.trainer.device_mesh,

src/levanter/main/lora_lm.py

+20-22
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import levanter
1111
from levanter import callbacks
1212
from levanter.compat.hf_checkpoints import HFCheckpointConverter
13-
from levanter.data.text import CausalLmDataset, LMDatasetConfig
13+
from levanter.data.text import LMDatasetConfig
1414
from levanter.lora import (
1515
LoraConfig,
1616
lora_trainable_params_filter,
@@ -80,23 +80,15 @@ def main(config: LoraLmConfig):
8080
parameter_axis_mapping = config.trainer.parameter_axis_mapping
8181

8282
train_dataset = config.data.train_set(
83-
Pos.size,
83+
Pos,
8484
batch_schedule=config.trainer.batch_schedule,
85-
QPos=Pos,
86-
KPos=KeyPos,
85+
key=data_key,
8786
)
8887

8988
if train_dataset is None:
9089
raise ValueError("No training set!")
9190

92-
eval_datasets = {
93-
name: config.data.validation_set(
94-
Pos.size,
95-
QPos=Pos,
96-
KPos=KeyPos,
97-
)
98-
for name in config.data.validation_splits
99-
}
91+
eval_datasets = config.data.validation_sets(Pos)
10092

10193
if len(eval_datasets) == 0:
10294
logger.warning("No evaluation datasets provided.")
@@ -135,19 +127,25 @@ def loraize_hf_model(model):
135127
logger.info(f"Trainable parameter count: {just_lora_params}")
136128
logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count:.3e}")
137129

138-
for name, eval_dataset in eval_datasets.items():
139-
eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id)
140-
trainer.add_eval_hook(eval_dataset, name=name)
130+
max_eval_examples_per_ds = config.trainer.max_eval_batches
131+
if max_eval_examples_per_ds is not None:
132+
max_eval_examples_per_ds *= config.trainer.eval_batch_size
141133

142-
# boilerplate hooks and such
143-
if len(eval_datasets) == 0:
144-
logger.warning("No evaluation datasets provided.")
134+
tagged_eval_datasets = config.data.tagged_eval_sets(Pos)
145135

146-
for name, eval_dataset in eval_datasets.items():
147-
eval_dataset = CausalLmDataset(
148-
eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id, eos_id=tokenizer.eos_token_id
136+
if len(tagged_eval_datasets) == 0:
137+
logger.warning("No evaluation datasets provided.")
138+
else:
139+
cb = levanter.eval.cb_tagged_lm_evaluate(
140+
trainer.EvalBatch,
141+
tagged_eval_datasets,
142+
tokenizer,
143+
trainer.device_mesh,
144+
trainer.compute_axis_mapping,
145+
max_eval_examples_per_ds,
146+
mp=config.trainer.mp,
149147
)
150-
trainer.add_eval_hook(eval_dataset, name=name)
148+
trainer.add_hook(cb, every=config.trainer.steps_per_eval)
151149

152150
trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1)
153151
if config.peft_save_path is not None:

src/levanter/main/viz_logprobs.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,8 @@ def main(config: VizLmConfig):
5353
# some axes we use outside the model proper
5454
EvalBatch = config.trainer.EvalBatch
5555
Pos = config.model.Pos
56-
KeyPos = config.model.KeyPos
5756

58-
validation_sets = config.data.validation_sets(Pos, KPos=KeyPos)
57+
validation_sets = config.data.validation_sets(Pos)
5958

6059
# some axes we use outside the model proper
6160
Pos = config.model.Pos

tests/test_text.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ def test_dont_blow_up_without_validation_set():
2020
)
2121

2222
Pos = hax.Axis("Pos", 10)
23-
KPos = hax.Axis("KPos", 10)
2423
# mostly just making sure this doesn't blow up
25-
assert config.validation_set(Pos, KPos=KPos) is None
24+
assert config.validation_set(Pos) is None
2625

2726

2827
def test_lm_example_handles_ignore_id():

0 commit comments

Comments
 (0)