Skip to content

Commit

Permalink
revert
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Jan 17, 2025
1 parent 3a534dd commit fa049d6
Showing 1 changed file with 22 additions and 23 deletions.
45 changes: 22 additions & 23 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,17 +783,32 @@ def _preprocess_supervised_example(


def _prepare_supervised_examples(ex: list[dict], tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis) -> list[LmExample]:
"""
Prepare examples for training. This function converts the (cached) encodings into an LmExample.
It goes through the following steps:
1. Pad the batch to the maximum length.
2. Mask out the input and prompt if requested.
3. Create an LmExample with the input_ids as the input and the next token as the target.
"""
lens = np.array([ex["sources_len"] for ex in ex])

# Pad to max length
ex_pad = tokenizer.pad(ex, padding="max_length", max_length=Pos.size)

# Create examples with appropriate loss masking

ex_pad = tokenizer.pad(
ex,
padding="max_length",
max_length=Pos.size,
)

input_ids = ex_pad["input_ids"]
truncated = [ids[-Pos.size :] for ids in input_ids]

out = []
for ids, len in zip(ex_pad["input_ids"], lens):
for ids, len in zip(truncated, lens):
causal = _mk_sup_example_jit(Pos, hax.named(ids, Pos), len, tokenizer.pad_token_id, tokenizer.eos_token_id)

out.append(causal)

return out

@functools.partial(jax.jit, static_argnums=(0, 3, 4))
Expand All @@ -807,22 +822,6 @@ def _mk_sup_example_jit(Pos, input_ids: hax.NamedArray, sources_len, pad_token_i
return LmExample.causal(input_ids, loss_mask=loss_mask, eos_id=eos_id)


















def mk_supervised_datasets(
sources: Mapping[str, SupervisedSourceConfigBase] | SupervisedSourceConfigBase,
split: str,
Expand Down

0 comments on commit fa049d6

Please sign in to comment.