Skip to content

Commit

Permalink
david's suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Feb 5, 2025
1 parent 1fdeec5 commit 361b680
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 23 deletions.
3 changes: 3 additions & 0 deletions src/levanter/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def pack_prompt_completions(
"""
Packs a list of prompt completions into LmExamples using the SequencePacker
"""

packers = [SequencePacker(Pos, max_segments_per_example, pad_token)]

for sequence in sequences:
loss_mask = np.arange(len(sequence.ids)) >= sequence.prompt_length - 1
loss_mask[-1] = 0
Expand All @@ -120,6 +122,7 @@ def pack_prompt_completions(
for packer in packers:
if packer.can_pack(sequence.ids):
packer.add_example(sequence.ids, loss_mask, sequence.segment_id)

if packer.num_segments == max_segments_per_example:
yield packer.pack()
packers.remove(packer)
Expand Down
4 changes: 0 additions & 4 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,10 +1000,6 @@ def mk_cached_sft_dataset(

# Cache the processed data
cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(config.cache_dir, await_finished=True)

# Ensure padding token is set (needed by _prepare_supervised_example)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return cached_dataset


Expand Down
35 changes: 16 additions & 19 deletions src/levanter/main/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from levanter.optim import AdamConfig, OptimizerConfig
from levanter.trainer import Trainer, TrainerConfig
from levanter.utils.background_iterable import BackgroundIterator
from levanter.utils.hf_utils import HfTokenizer
from levanter.utils.jax_utils import use_cpu_device


Expand Down Expand Up @@ -233,7 +232,13 @@ def train(config: SFTConfig):
logger.info("Starting SFT from scratch")

logger.info("Packing prompt completions")
packed_iterator = _pack_requests(prompt_completion_iterator, tokenizer, Pos, max_pack_size=4)
packed_iterator = pack_prompt_completions(
Pos,
prompt_completion_iterator,
max_segments_per_example=4,
pad_token=tokenizer.pad_token_id,
max_buffered_examples=16,
)
logger.info("Stacking batches to train batch")
packed_iterator = stack_batches(example_iterator=packed_iterator, Pos=Pos, TrainBatch=trainer.TrainBatch)
# TODO what's a good number for max_capacity?
Expand Down Expand Up @@ -272,8 +277,8 @@ def create_prompt_completion_iterator(cached_dataset: AsyncDataset, Pos: hax.Axi
if length is None:
raise ValueError("Dataset length cannot be None")

for batch_indicies in batched(range(length), 4096):
examples = asyncio.run(cached_dataset.get_batch(batch_indicies))
for indicies in batched(range(length), 4096):
examples = asyncio.run(cached_dataset.get_batch(indicies))

for i in range(len(examples)):
example = examples[i]
Expand All @@ -289,24 +294,16 @@ def create_prompt_completion_iterator(cached_dataset: AsyncDataset, Pos: hax.Axi
continue

try:
yield PromptCompletion(ids=ids, prompt_length=sources_len, segment_id=batch_indicies[i])
except ValueError:
yield PromptCompletion(ids=ids, prompt_length=sources_len, segment_id=indicies[i])
except ValueError as e:
# Likely error: PromptCompletion may raise a ValueError if the token list is empty or if its length is not greater than the prompt_length.
logger.error(
f"Error creating PromptCompletion (ids length: {len(ids)}, sources_len: {sources_len}, segment id:"
f" {indicies[i]}): {e}"
)
continue


def _pack_requests(
prompt_completion_iterator: Iterator[PromptCompletion], tokenizer: HfTokenizer, Pos: hax.Axis, max_pack_size: int
) -> Iterator[LmExample]:
# TODO: use a better packing algorithm?
yield from pack_prompt_completions(
Pos,
prompt_completion_iterator,
max_segments_per_example=max_pack_size,
pad_token=tokenizer.pad_token_id,
max_buffered_examples=16,
)


"""
Helper function to create a dummy instance with the same shape as the batch.
When we reach the end of the dataset but we want a full batch,
Expand Down

0 comments on commit 361b680

Please sign in to comment.