From a68c4fcb7d1619df62d1cf925aa3ccb2ddbb5881 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Fri, 17 Jan 2025 13:58:11 -0800 Subject: [PATCH] remap --- src/levanter/data/text.py | 1697 +++++++++++++++++-------------------- 1 file changed, 795 insertions(+), 902 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 755ed80fb..c512a9e12 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -265,7 +265,6 @@ def _create_lm_example(tokens, key): async def async_len(self) -> int: return await self.dataset.async_len() - def _maybe_force_tokenizer_parallelism(tokenizer: PreTrainedTokenizerBase): if tokenizer.is_fast and os.getenv("TOKENIZERS_PARALLELISM") is None: # if we're using a fast tokenizer, we want to force parallelism @@ -277,303 +276,368 @@ def _maybe_force_tokenizer_parallelism(tokenizer: PreTrainedTokenizerBase): ws = regex.compile(r"\s") -class PassthroughTokenizer(PreTrainedTokenizer): - def __init__(self, vocab_size, **kwargs): - self._vocab = {i: i for i in range(vocab_size)} - self._vocab_size = vocab_size - super().__init__(**kwargs) - @property - def vocab_size(self) -> int: - return self._vocab_size +class BatchTokenizer(BatchProcessor[str, dict]): + """ + A batch processor that tokenizes a batch of strings using a tokenizer. + By default, this will append eos to the end of the string, even if the tokenizer doesn't. + """ - def get_vocab(self): - return self._vocab + def __init__( + self, + tokenizer: PreTrainedTokenizerBase, + enforce_bos=True, + enforce_eos=True, + *, + override_resources=None, + _workaround_len=LONG_STRING_WORKAROUND, + return_attention_mask=False, + padding=False, + max_length=None, + ): + _maybe_force_tokenizer_parallelism(tokenizer) + self.tokenizer = tokenizer + self.override_resources = override_resources + self.return_attention_mask = return_attention_mask + self.padding = padding + if max_length is not None: + self.max_length = max_length + else: + self.max_length = self.tokenizer.model_max_length - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, ...]: - return () + # see if the tokenizer appends bos/eos + # if we don't have an eos/bos token in the tokenizer, skip + if tokenizer.bos_token_id is None: + enforce_bos = False + if tokenizer.eos_token_id is None: + enforce_eos = False - def _tokenize(self, text, **kwargs): - tokens = np.fromstring(text, dtype=int, sep=" ") - return tokens + # HF's BPE-based tokenizers do not, but the bert and roberta ones do + # TODO: this doesn't necessarily ensure it, I guess, but eh + if enforce_eos or enforce_bos: + input_ids = tokenizer("hi there")["input_ids"] + should_append_eos = input_ids[-1] != tokenizer.eos_token_id and enforce_eos + should_append_bos = input_ids[0] != tokenizer.bos_token_id and enforce_bos + else: + should_append_eos = False + should_append_bos = False - def _convert_token_to_id(self, token: str) -> int: - return int(token) + self._need_to_add_eos = should_append_eos + self._need_to_add_bos = should_append_bos + self._workaround_len = _workaround_len - def _convert_id_to_token(self, index: int) -> str: - return str(index) + def __call__(self, batch: Sequence[str]) -> list[dict]: + if self._need_to_add_bos: + batch = [self.tokenizer.bos_token + " " + d for d in batch] -@dataclass -class LMDatasetSourceConfig: - """This class represents a dataset source with URLs or hf name/id.""" + if self._need_to_add_eos: + batch = [d + " " + self.tokenizer.eos_token for d in batch] - tags: Optional[List[str]] = None - """tags for the dataset. Typically the name of the dataset in the config will be added as a tag as well""" + if self._needs_long_sequence_workaround: + batch, needs_merge = self._break_for_long_sequences(batch) + else: + needs_merge = [] - id: Optional[str] = None # id (or path) for hf dataset - name: Optional[str] = None # name for hf dataset + if self.padding is not False: + encoding = self.tokenizer( + batch, + return_attention_mask=self.return_attention_mask, + verbose=False, + padding=self.padding, + max_length=self.max_length, + truncation=True, + ) # type: ignore + else: + encoding = self.tokenizer( + batch, return_attention_mask=self.return_attention_mask, verbose=False + ) # type: ignore - plaintext: bool = False - stream: bool = True # whether to use streaming when doing hf - text_key: str = "text" # key for the text field in the jsonl file or hf dataset + if needs_merge: + new_encoding = self._merge_split_encodings(batch, encoding, needs_merge) + encoding = BatchEncoding(new_encoding) - train_urls: List[str] = () # type: ignore - validation_urls: List[str] = () # type:ignore - cache_dir: Optional[str] = None # Optionally override the cache dir for this component + # debatch the encoding + unbatched = [dict(zip(encoding, t)) for t in zip(*[encoding[k] for k in encoding])] - def get_shard_source(self, split) -> Optional[ShardedDataSource[str]]: - if self.id is not None: - try: - ds = WrappedHFDataSource(self.id, split=split, name=self.name, streaming=self.stream) - except ValueError as e: - # if the message starts with Bad split, then just return None - if str(e).startswith("Bad split"): - logger.warning(f"Splits {split} not found for {self.id} {self.name}") - return None + return unbatched + + def _break_for_long_sequences(self, batch): + orig_lengths = [len(d) for d in batch] + # break any strings that are longer than LONG_STRING_WORKAROUND characters into smaller chunks + orig_batch = batch + batch = [] + needs_merge = [] + for i, d in enumerate(orig_batch): + needs_merge.append(False) + orig_len = orig_lengths[i] + while len(d) > self._workaround_len: + # we'd rather break strings at whitespace, so find the first whitespace + match = ws.search(d, self._workaround_len) + # this is vanishingly unlikely, but if we can't find a whitespace, just break it at the limit + if match is None: + split = len(d) else: - raise + split = match.start() - if len(ds.shard_names) == 0: - return None + batch.append(d[:split]) + needs_merge.append(True) - return ds.map(lambda x: x[self.text_key]) - else: - split_urls = self.urls_for_split(split) - if len(split_urls) == 0: - return None - return TextUrlDataSource(split_urls, self.text_key) + d = d[split:] + orig_len -= split - def doc_iterator(self, split: str): - if self.id is not None: - dataset = datasets.load_dataset(self.id, name=self.name, streaming=self.stream) - data = dataset[split] - for doc in data: - yield doc[self.text_key] - else: - urls = self.urls_for_split(split) + batch.append(d) + return batch, needs_merge - yield from TextUrlDataSource(urls, self.text_key) + @property + def metadata(self) -> Dict[str, Any]: + return { + "tokenizer": self.tokenizer.name_or_path, + "vocab_size": len(self.tokenizer), + "return_attention_mask": self.return_attention_mask, + "padding": self.padding, + "max_length": self.max_length, + "append_bos": self._need_to_add_bos, + "append_eos": self._need_to_add_eos, + } - def urls_for_split(self, split): - if split == "train": - urls = self.train_urls - elif split == "validation": - urls = self.validation_urls - else: - raise ValueError(f"Unknown split {split}") + @property + def output_exemplar(self) -> dict: + return dict(**self.tokenizer("hi there", return_attention_mask=self.return_attention_mask, verbose=False)) - urls = [globbed for url in urls for globbed in expand_glob(url)] - return urls + @property + def name_or_path(self): + return self.tokenizer.name_or_path + @property + def vocab_size(self): + return self.tokenizer.vocab_size -@dataclass -class LMTaskConfig(abc.ABC): - tokenizer: str = "gpt2" - vocab_size: Optional[int] = None # if using the passthrough tokenizer, this is required + @staticmethod + def _merge_split_encodings(batch, encoding, needs_merge): + # merge the encodings back together + # we might need to merge multiple encodings together + # needs merge marks the first n-1 encodings that need to be merged for each document + new_encoding = {} + for k, v in encoding.items(): + if len(v) == 0: + continue + if isinstance(v[0], np.ndarray): + assert len(v) == len(batch) + v_out = [] + vs_to_merge = [] + for i in range(len(batch)): + if not needs_merge[i]: + v_out.append(np.concatenate(vs_to_merge)) + vs_to_merge = [] + vs_to_merge.append(v[i]) - # config related to caching - cache_dir: Optional[str] = "cache/" - cache_options: CacheOptions = field(default_factory=CacheOptions) - enforce_eos: bool = True # whether to append eos even if the tokenizer doesn't + if len(vs_to_merge) > 0: + v_out.append(np.concatenate(vs_to_merge)) - ignore_token_id: Optional[int] = None - shuffle: bool | int = False - """whether to shuffle the dataset. True means shuffle the whole dataset, False means don't shuffle. - If you want to shuffle in eras, set this to the era length""" + new_encoding[k] = v_out + elif isinstance(v[0], list): + v_out = [] + vs_to_merge = [] + for i in range(len(batch)): + if not needs_merge[i]: + if len(vs_to_merge) > 0: + v_out.append(list(chain(*vs_to_merge))) + vs_to_merge = [] + vs_to_merge.append(v[i]) + + if len(vs_to_merge) > 0: + v_out.append(list(chain(*vs_to_merge))) + new_encoding[k] = v_out + else: + raise ValueError(f"Unknown type {type(v[0])}") + return new_encoding + # TODO remove this when it's resolved https://github.com/huggingface/tokenizers/issues/1495 @cached_property - def the_tokenizer(self) -> HfTokenizer: - if self.tokenizer == "passthrough": - return PassthroughTokenizer(self.vocab_size) + def _needs_long_sequence_workaround(self): + if isinstance(self.tokenizer, PreTrainedTokenizerFast): + normalizer = self.tokenizer.backend_tokenizer.normalizer + if normalizer is None: + return False + # if there's a "Replace" normalizer, then we need to do the workaround + # inexplicably there's no way to see inside a Sequence so we also have to assume it needs it + return isinstance(normalizer, (normalizers.Replace, normalizers.Sequence)) else: - return load_tokenizer(self.tokenizer) + return False - @abc.abstractmethod - def train_set( - self, - seq_len: int, - monitors: Union[bool, List[MetricsMonitor]] = True, - *, - key: Optional[PRNGKeyArray], - epochs: Optional[int] = None, - ) -> AsyncDataset[np.ndarray]: - pass - - @abc.abstractmethod - def validation_sets( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Mapping[str, AsyncDataset[np.ndarray]]: - pass + @property + def num_cpus(self) -> int: + if self.override_resources is not None: + cpus = self.override_resources.get("num_cpus", None) + if cpus is not None: + return cpus + return num_cpus_used_by_tokenizer(self.tokenizer) @property - @abc.abstractmethod - def sources(self) -> Mapping[str, LMDatasetSourceConfig]: - pass + def num_gpus(self) -> int: + if self.override_resources is not None: + return self.override_resources.get("num_gpus", 0) + return 0 - def tagged_eval_sets( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> list[Tuple[AsyncDataset[np.ndarray], List[str]]]: - tags = {name: (config.tags or []) + [name] for name, config in self.sources.items()} - eval_sets = self.validation_sets(seq_len, monitors) - return [(eval_sets[name], tags[name]) for name in eval_sets] +def concatenate_and_group_texts( + encoding: BatchEncoding, + seq_len: int, + stride: Optional[int] = None, + drop_remainder: bool = True, + mask_stride_overlap=True, +) -> Iterator[BatchEncoding]: + """Groups texts in a batch together. Typically, you'll want to use this with a fairly large + set of texts, e.g. 1000 docs. + You should set mask_stride_overlap to True and drop_remainder to False if you want to use this for test data -CANONICAL_INPUT_FIELD = "prompt" -CANONICAL_OUTPUT_FIELD = "response" + Args: + encoding: The batch of texts to concatenate and group. + seq_len: The max length of sequences to emit + stride: The stride to use when grouping texts. If None, then the stride is set to seq_len. + mask_stride_overlap: Whether to mask out overlapping tokens if we're using a stride. + drop_remainder: Whether to drop the last batch if it's not a multiple of the seq_len. + Returns: + An iterator of tokenized texts, one at a time. + """ + concatenated = BatchEncoding(data={k: np.array(list(chain(*v))) for k, v in encoding.items()}) + total_length = len(concatenated.input_ids) + stride = stride or seq_len -@dataclass -class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): - """This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls""" + # Drop the "very last" bit of the dataset that doesn't fit into block size... + if drop_remainder and total_length % stride != 0: + total_length = ((total_length - seq_len + stride) // stride) * stride - cache_dir: Optional[str] = "cache/" + # Split by Chunks of Maximum Length + # we want to take chunks up until we've covered all "total_length" tokens with a sliding window of size "stride" + for begin in range(0, total_length - seq_len + stride, stride): + data = {k: v[begin : begin + seq_len] for k, v in concatenated.items()} - def train_set( - self, - seq_len: int, - monitors: Union[bool, List[MetricsMonitor]] = True, - *, - key: Optional[PRNGKeyArray] = None, - epochs: Optional[int] = None, - ) -> AsyncDataset[np.ndarray]: + if mask_stride_overlap and stride != seq_len: + labels = data.get("labels", data["input_ids"]) + if begin != 0: + labels = _mask_overlap(labels, seq_len, stride) + data["labels"] = labels - ds: AsyncDataset[np.ndarray] | None = self.token_seq_dataset("train", seq_len, monitors) + yield BatchEncoding(data=data) - # add epoch flag here. - if ds is None: - raise ValueError("No training set!") - if epochs: - logger.info("Wrapping dataset in epoch dataset") - ds = EpochDataset(ds, max_epochs=epochs) +# -100 is pytorch's label mask +def _mask_overlap(labels, target_len, stride, sentinel=-100): + """Masks out overlapping tokens in a sequence when we're using a stride.""" + labels = copy.deepcopy(labels) + if isinstance(labels, list): + for i in range(target_len - stride): + if i < len(labels): + labels[i] = sentinel + else: + labels[0 : target_len - stride] = sentinel - if self.shuffle is True: - ds = ds.shuffle(key) - elif isinstance(self.shuffle, int) and self.shuffle > 0: - ds = ds.era_shuffle(self.shuffle, key=key) + return labels - return ds # type: ignore - def validation_set( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Optional[TokenSeqDataset]: - return self.token_seq_dataset("validation", seq_len, monitors) +def _stack_batch_encodings(a: BatchEncoding, b: BatchEncoding) -> BatchEncoding: + """Stacks two batch encodings together, assuming that the keys are the same.""" - def validation_sets( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Mapping[str, AsyncDataset[np.ndarray]]: - validation_set = self.validation_set(seq_len, monitors) - if validation_set is not None: - return {"": validation_set} + def _ensure_batched(x): + if len(x) == 0: + return list(x) + elif isinstance(x[0], Sequence) or isinstance(x[0], np.ndarray): + return list(x) else: - return {} - - @property - def sources(self) -> Mapping[str, LMDatasetSourceConfig]: - return {"": self} - - @cached_property - def _has_validation_set(self): - if len(self.validation_urls) > 0: - return True + return [x] - if self.id is not None: - dataset = datasets.load_dataset(self.id, name=self.name, streaming=self.stream, split="validation") - try: - next(iter(dataset)) - return True - except StopIteration: - return False + return BatchEncoding({k: _ensure_batched(a[k]) + _ensure_batched(b[k]) for k in a.keys()}) - return False - def token_seq_dataset( - self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Optional[TokenSeqDataset]: - cache = self.build_or_load_cache(split, monitors=monitors) - if cache is None: - return None - return TokenSeqDataset(cache, seq_len) +@dataclass +class LMDatasetSourceConfig: + """This class represents a dataset source with URLs or hf name/id.""" - def build_or_load_cache( - self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True, logger_name: Optional[str] = None - ) -> Optional[TreeCache[BatchEncoding]]: - if self.cache_dir is None: - raise ValueError("cache_dir cannot be None") + tags: Optional[List[str]] = None + """tags for the dataset. Typically the name of the dataset in the config will be added as a tag as well""" - split_cache_dir = os.path.join(self.cache_dir, split) - name = logger_name or os.path.basename(self.cache_dir) + id: Optional[str] = None # id (or path) for hf dataset + name: Optional[str] = None # name for hf dataset - try: - # TODO: pass in options - return TreeCache.load(split_cache_dir, exemplar={"input_ids": np.zeros(0, dtype=np.int32)}) - except FileNotFoundError: - pass + plaintext: bool = False + stream: bool = True # whether to use streaming when doing hf + text_key: str = "text" # key for the text field in the jsonl file or hf dataset - source = self.get_shard_source(split) - if source is None: - logger.info(f"No data for {split}") - return None + train_urls: List[str] = () # type: ignore + validation_urls: List[str] = () # type:ignore + cache_dir: Optional[str] = None # Optionally override the cache dir for this component - logger.info(f"Building cache for {split}...") + def get_shard_source(self, split) -> Optional[ShardedDataSource[str]]: + if self.id is not None: + try: + ds = WrappedHFDataSource(self.id, split=split, name=self.name, streaming=self.stream) + except ValueError as e: + # if the message starts with Bad split, then just return None + if str(e).startswith("Bad split"): + logger.warning(f"Splits {split} not found for {self.id} {self.name}") + return None + else: + raise - if monitors is True: - monitors = [ - LoggingMetricsMonitor(prefix=f"preprocessing/{name}/{split}", commit=False), - LoggerMetricsMonitor(f"preprocessing.{name}.{split}"), - ] - elif monitors is False: - monitors = [] + if len(ds.shard_names) == 0: + return None - bt = BatchTokenizer(self.the_tokenizer, enforce_bos=True, enforce_eos=self.enforce_eos) + return ds.map(lambda x: x[self.text_key]) + else: + split_urls = self.urls_for_split(split) + if len(split_urls) == 0: + return None + return TextUrlDataSource(split_urls, self.text_key) - return build_or_load_cache( - split_cache_dir, - source, - bt, - monitors=monitors, - await_finished=False, - options=self.cache_options, - split=split, - ) + def doc_iterator(self, split: str): + if self.id is not None: + dataset = datasets.load_dataset(self.id, name=self.name, streaming=self.stream) + data = dataset[split] + for doc in data: + yield doc[self.text_key] + else: + urls = self.urls_for_split(split) + yield from TextUrlDataSource(urls, self.text_key) -class SupervisedSourceConfigBase(Protocol): - def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: - raise NotImplementedError + def urls_for_split(self, split): + if split == "train": + urls = self.train_urls + elif split == "validation": + urls = self.validation_urls + else: + raise ValueError(f"Unknown split {split}") - input_field: str - output_field: str - tags: Optional[List[str]] - cache_dir: str + urls = [globbed for url in urls for globbed in expand_glob(url)] + return urls @dataclass -class LMMixtureDatasetConfig(LMTaskConfig): - """This class represents a mixture of datasets with their associated weights.""" +class LMTaskConfig(abc.ABC): + tokenizer: str = "gpt2" + vocab_size: Optional[int] = None # if using the passthrough tokenizer, this is required + # config related to caching cache_dir: Optional[str] = "cache/" + cache_options: CacheOptions = field(default_factory=CacheOptions) + enforce_eos: bool = True # whether to append eos even if the tokenizer doesn't - # data source configs and weights - configs: Dict[str, LMDatasetSourceConfig] = field(default_factory=dict) - """ configuration of each dataset source (urls, hf dataset id, etc.) """ - train_weights: Dict[str, float] = field(default_factory=dict) - """ weights for each dataset source. They will be normalized to sum to 1. """ - stop_strategy: str = field(default=StopStrategy.RESTART_STRATEGY) - mixture_block_size: int = 2048 - """ block size for the mixture dataset.""" + ignore_token_id: Optional[int] = None + shuffle: bool | int = False + """whether to shuffle the dataset. True means shuffle the whole dataset, False means don't shuffle. + If you want to shuffle in eras, set this to the era length""" - def __post_init__(self): - if len(self.configs) == 0: - raise ValueError("At least one dataset must be provided") - - if set(self.configs.keys()) != set(self.train_weights.keys()): - raise ValueError( - f"The keys in configs and weights must be the same;got {self.configs.keys()} and" - f" {self.train_weights.keys()}" - ) + @cached_property + def the_tokenizer(self) -> HfTokenizer: + if self.tokenizer == "passthrough": + return PassthroughTokenizer(self.vocab_size) + else: + return load_tokenizer(self.tokenizer) + @abc.abstractmethod def train_set( self, seq_len: int, @@ -582,845 +646,674 @@ def train_set( key: Optional[PRNGKeyArray], epochs: Optional[int] = None, ) -> AsyncDataset[np.ndarray]: - doc_caches = self.build_caches("train", monitors=monitors) - token_datasets = {name: TokenSeqDataset(cache, seq_len) for name, cache in doc_caches.items()} + pass - if epochs: - raise ValueError("Epochs are not supported for mixture datasets") + @abc.abstractmethod + def validation_sets( + self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + ) -> Mapping[str, AsyncDataset[np.ndarray]]: + pass - if key is None: - key = jax.random.PRNGKey(0) + @property + @abc.abstractmethod + def sources(self) -> Mapping[str, LMDatasetSourceConfig]: + pass - mix_key, shuffle_key = jax.random.split(key) + def tagged_eval_sets( + self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + ) -> list[Tuple[AsyncDataset[np.ndarray], List[str]]]: + tags = {name: (config.tags or []) + [name] for name, config in self.sources.items()} + eval_sets = self.validation_sets(seq_len, monitors) - # We shuffle the components and not the overall mixture because this lets us preserve - # the "stable batch" property of the mixture dataset. - def shuffle_ds(ds, key): - if self.shuffle is True: - ds = ds.shuffle(key) - elif isinstance(self.shuffle, int): - ds = ds.era_shuffle(self.shuffle, key=key) + return [(eval_sets[name], tags[name]) for name in eval_sets] - return ds - if self.shuffle: - out_token_datasets = {} - key_iter = key_iterator(shuffle_key) - for name, ds in token_datasets.items(): - out_token_datasets[name] = shuffle_ds(ds, next(key_iter)) - token_datasets = out_token_datasets +CANONICAL_INPUT_FIELD = "prompt" +CANONICAL_OUTPUT_FIELD = "response" - mixture = MixtureDataset( - datasets=token_datasets, - weights=self.train_weights, - stop_strategy=self.stop_strategy, - key=mix_key, - block_size=2048, - ) - return mixture +class SupervisedSourceConfigBase(Protocol): + def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: + raise NotImplementedError - def training_sets( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Mapping[str, TokenSeqDataset]: - doc_caches = self.build_caches("train", monitors=monitors) - token_datasets = {name: TokenSeqDataset(cache, seq_len) for name, cache in doc_caches.items()} - return token_datasets + input_field: str + output_field: str + tags: Optional[List[str]] + cache_dir: str - def validation_sets( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Mapping[str, AsyncDataset[np.ndarray]]: - doc_caches = self.build_caches("validation", monitors=monitors) - token_datasets = {name: TokenSeqDataset(cache, seq_len) for name, cache in doc_caches.items()} - return token_datasets - def build_caches( - self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Dict[str, TreeCache[dict]]: - # this is a bit gross, but we want to forward all "Task" config fields to the LMDatasetConfig for building. - # We do this by just grabbing all the fields from the LMDatasetConfig and forwarding them to the - # LMDatasetConfig.build_or_load_cache method. We exclude the cache_dir field. - task_config_fields = set(x.name for x in dataclasses.fields(LMTaskConfig)) - task_config_dict = {k: v for k, v in self.__dict__.items() if k in task_config_fields and k != "cache_dir"} +@dataclass +class LMSupervisedDatasetConfig(SupervisedSourceConfigBase): + """Config for supervised fine-tuning datasets""" - caches = {} - for name, source_config in self.configs.items(): - weight = self.train_weights.get(name, 0) + cache_dir: str = "cache/" - if weight == 0 and split == "train": - continue + # HF dataset config + hf_dataset_name: Optional[str] = None # e.g. "tatsu-lab/alpaca" or "OpenAssistant/oasst1" + hf_dataset_split: str = "train" # which split to use - source_config_dict = dict(**source_config.__dict__) + # Local files config + validation_urls: List[str] = field(default_factory=list) # paths to jsonl/json files - if source_config.cache_dir is None: - # replace with the main cache dir/{name} - if self.cache_dir is None: - raise ValueError( - "If the 'main' cache_dir is None, then all component cache_dirs must be non-None, but" - f"{name}'s cache_dir is None." - ) - cache_dir = os.path.join(self.cache_dir, name) - source_config_dict["cache_dir"] = cache_dir + # Field names in the data + input_field: str = CANONICAL_INPUT_FIELD # name of the input field + output_field: str = CANONICAL_OUTPUT_FIELD # name of output field - dataset = LMDatasetConfig( - **source_config_dict, - **task_config_dict, - ) - cache = dataset.build_or_load_cache(split, monitors) - # drop the data source and corresponding weight if the cache is not built - if cache is None: - logger.warning(f"Skipping {name} for split {split} because no source was provided") - else: - caches[name] = cache + # Optional metadata + tags: Optional[List[str]] = None - # in practice it works best if we block on validation caches - if split == "validation": - for cache in caches.values(): - cache.await_finished() + def __post_init__(self): + warnings.warn( + "LMSupervisedDatasetConfig is deprecated. Use SupervisedHfSourceConfig or " + "SupervisedUrlSourceConfig instead.", + DeprecationWarning, + ) + def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: + if self.hf_dataset_name is not None: + return WrappedHFDataSource(self.hf_dataset_name, split=self.hf_dataset_split) + elif split != "validation": + raise ValueError("Only validation split is supported for local files") else: - logger.info(f"Not waiting for {split} caches to finish building") + urls = [globbed for url in self.validation_urls for globbed in expand_glob(url)] + return JsonlDataSource(urls) - return caches - @property - def sources(self) -> Mapping[str, LMDatasetSourceConfig]: - return self.configs +@dataclass(frozen=True) +class SupervisedHfSourceConfig(SupervisedSourceConfigBase): + cache_dir: str + id: str + name: str | None = None + streaming: bool = True -def datasource_from_chat_jsonl( - urls: Sequence[str], messages_field: str = "messages", input_role: str = "user", output_role: str = "assistant" -) -> "ShardedDataSource[dict]": - """Creates a ShardedDataSource from JSONL files containing chat messages. + input_field: str = CANONICAL_INPUT_FIELD + output_field: str = CANONICAL_OUTPUT_FIELD + tags: Optional[List[str]] = None - Args: - urls: Sequence of URLs or glob patterns pointing to JSONL files - messages_field: Field name containing the messages in each JSON object - input_role: Role identifier for input messages - output_role: Role identifier for output messages + def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: + return WrappedHFDataSource(self.id, split=split, name=self.name, streaming=self.streaming).map( + lambda x: {CANONICAL_INPUT_FIELD: x[self.input_field], CANONICAL_OUTPUT_FIELD: x[self.output_field]} + ) - Returns: - ShardedDataSource configured for chat data - """ - # Expand any glob patterns in the URLs - expanded_urls = [] - for url in urls: - if any(c in url for c in "*?[]"): - expanded_urls.extend(gcs_glob(url)) - else: - expanded_urls.append(url) - return ChatJsonlDataSource(expanded_urls, messages_field, input_role, output_role) +@dataclass(frozen=True) +class SupervisedUrlSourceConfig(SupervisedSourceConfigBase): + cache_dir: str + train_urls: list[str] = dataclasses.field(default_factory=list) + validation_urls: list[str] = dataclasses.field(default_factory=list) + input_field: str = CANONICAL_INPUT_FIELD + output_field: str = CANONICAL_OUTPUT_FIELD + tags: Optional[List[str]] = None -class BatchTokenizer(BatchProcessor[str, dict]): - """ - A batch processor that tokenizes a batch of strings using a tokenizer. - By default, this will append eos to the end of the string, even if the tokenizer doesn't. - """ + def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: + urls = self.train_urls if split == "train" else self.validation_urls + if not urls: + return None - def __init__( - self, - tokenizer: PreTrainedTokenizerBase, - enforce_bos=True, - enforce_eos=True, - *, - override_resources=None, - _workaround_len=LONG_STRING_WORKAROUND, - return_attention_mask=False, - padding=False, - max_length=None, - ): - _maybe_force_tokenizer_parallelism(tokenizer) - self.tokenizer = tokenizer - self.override_resources = override_resources - self.return_attention_mask = return_attention_mask - self.padding = padding - if max_length is not None: - self.max_length = max_length - else: - self.max_length = self.tokenizer.model_max_length + urls = [globbed for url in urls for globbed in expand_glob(url)] - # see if the tokenizer appends bos/eos - # if we don't have an eos/bos token in the tokenizer, skip - if tokenizer.bos_token_id is None: - enforce_bos = False - if tokenizer.eos_token_id is None: - enforce_eos = False + source = UrlDataSource(urls, columns=[self.input_field, self.output_field]) + return source.map( + lambda x: {CANONICAL_INPUT_FIELD: x[self.input_field], CANONICAL_OUTPUT_FIELD: x[self.output_field]} + ) - # HF's BPE-based tokenizers do not, but the bert and roberta ones do - # TODO: this doesn't necessarily ensure it, I guess, but eh - if enforce_eos or enforce_bos: - input_ids = tokenizer("hi there")["input_ids"] - should_append_eos = input_ids[-1] != tokenizer.eos_token_id and enforce_eos - should_append_bos = input_ids[0] != tokenizer.bos_token_id and enforce_bos - else: - should_append_eos = False - should_append_bos = False - self._need_to_add_eos = should_append_eos - self._need_to_add_bos = should_append_bos - self._workaround_len = _workaround_len +SupervisedSourceConfig: TypeAlias = Union[SupervisedHfSourceConfig, SupervisedUrlSourceConfig] - def __call__(self, batch: Sequence[str]) -> list[dict]: - if self._need_to_add_bos: - batch = [self.tokenizer.bos_token + " " + d for d in batch] - if self._need_to_add_eos: - batch = [d + " " + self.tokenizer.eos_token for d in batch] +def _preprocess_supervised_example( + batch, tokenizer: PreTrainedTokenizerBase, input_field: str, output_field: str +) -> dict: + sources = [example[input_field] for example in batch] - if self._needs_long_sequence_workaround: - batch, needs_merge = self._break_for_long_sequences(batch) - else: - needs_merge = [] + targets = [example[output_field] for example in batch] + # TODO: this seems pretty wasteful since you end up tokenizing twice, but it's how alpaca does it. + examples = [s + t for s, t in zip(sources, targets)] + sources_tokenized = tokenizer(sources, padding=False, truncation=True) + examples_tokenized = tokenizer(examples, padding=False, truncation=True) - if self.padding is not False: - encoding = self.tokenizer( - batch, - return_attention_mask=self.return_attention_mask, - verbose=False, - padding=self.padding, - max_length=self.max_length, - truncation=True, - ) # type: ignore - else: - encoding = self.tokenizer( - batch, return_attention_mask=self.return_attention_mask, verbose=False - ) # type: ignore + source_lens = [len(s) for s in sources_tokenized["input_ids"]] - if needs_merge: - new_encoding = self._merge_split_encodings(batch, encoding, needs_merge) - encoding = BatchEncoding(new_encoding) + return { + "input_ids": [np.array(example, dtype=np.int32) for example in examples_tokenized["input_ids"]], + "sources_len": np.array(source_lens, dtype=np.int32), + } - # debatch the encoding - unbatched = [dict(zip(encoding, t)) for t in zip(*[encoding[k] for k in encoding])] - return unbatched +def _prepare_supervised_examples(ex: list[dict], tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis) -> list[LmExample]: + lens = np.array([ex["sources_len"] for ex in ex]) + + # Pad to max length + ex_pad = tokenizer.pad(ex, padding="max_length", max_length=Pos.size) + + # Create examples with appropriate loss masking + out = [] + for ids, len in zip(ex_pad["input_ids"], lens): + causal = _mk_sup_example_jit(Pos, hax.named(ids, Pos), len, tokenizer.pad_token_id, tokenizer.eos_token_id) + out.append(causal) + + return out - def _break_for_long_sequences(self, batch): - orig_lengths = [len(d) for d in batch] - # break any strings that are longer than LONG_STRING_WORKAROUND characters into smaller chunks - orig_batch = batch - batch = [] - needs_merge = [] - for i, d in enumerate(orig_batch): - needs_merge.append(False) - orig_len = orig_lengths[i] - while len(d) > self._workaround_len: - # we'd rather break strings at whitespace, so find the first whitespace - match = ws.search(d, self._workaround_len) - # this is vanishingly unlikely, but if we can't find a whitespace, just break it at the limit - if match is None: - split = len(d) - else: - split = match.start() +@functools.partial(jax.jit, static_argnums=(0, 3, 4)) +def _mk_sup_example_jit(Pos, input_ids: hax.NamedArray, sources_len, pad_token_id, eos_id): + # mask out padding and anything before the start of the target + loss_mask = hax.arange(Pos) >= sources_len - 1 + # don't predict the padding + targets = hax.roll(input_ids, -1, Pos) + loss_mask = loss_mask & (targets != pad_token_id) + loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_)) + return LmExample.causal(input_ids, loss_mask=loss_mask, eos_id=eos_id) - batch.append(d[:split]) - needs_merge.append(True) - d = d[split:] - orig_len -= split - batch.append(d) - return batch, needs_merge - @property - def metadata(self) -> Dict[str, Any]: - return { - "tokenizer": self.tokenizer.name_or_path, - "vocab_size": len(self.tokenizer), - "return_attention_mask": self.return_attention_mask, - "padding": self.padding, - "max_length": self.max_length, - "append_bos": self._need_to_add_bos, - "append_eos": self._need_to_add_eos, - } - @property - def output_exemplar(self) -> dict: - return dict(**self.tokenizer("hi there", return_attention_mask=self.return_attention_mask, verbose=False)) - @property - def name_or_path(self): - return self.tokenizer.name_or_path - @property - def vocab_size(self): - return self.tokenizer.vocab_size - @staticmethod - def _merge_split_encodings(batch, encoding, needs_merge): - # merge the encodings back together - # we might need to merge multiple encodings together - # needs merge marks the first n-1 encodings that need to be merged for each document - new_encoding = {} - for k, v in encoding.items(): - if len(v) == 0: - continue - if isinstance(v[0], np.ndarray): - assert len(v) == len(batch) - v_out = [] - vs_to_merge = [] - for i in range(len(batch)): - if not needs_merge[i]: - v_out.append(np.concatenate(vs_to_merge)) - vs_to_merge = [] - vs_to_merge.append(v[i]) - if len(vs_to_merge) > 0: - v_out.append(np.concatenate(vs_to_merge)) - new_encoding[k] = v_out - elif isinstance(v[0], list): - v_out = [] - vs_to_merge = [] - for i in range(len(batch)): - if not needs_merge[i]: - if len(vs_to_merge) > 0: - v_out.append(list(chain(*vs_to_merge))) - vs_to_merge = [] - vs_to_merge.append(v[i]) - if len(vs_to_merge) > 0: - v_out.append(list(chain(*vs_to_merge))) - new_encoding[k] = v_out - else: - raise ValueError(f"Unknown type {type(v[0])}") - return new_encoding - # TODO remove this when it's resolved https://github.com/huggingface/tokenizers/issues/1495 - @cached_property - def _needs_long_sequence_workaround(self): - if isinstance(self.tokenizer, PreTrainedTokenizerFast): - normalizer = self.tokenizer.backend_tokenizer.normalizer - if normalizer is None: - return False - # if there's a "Replace" normalizer, then we need to do the workaround - # inexplicably there's no way to see inside a Sequence so we also have to assume it needs it - return isinstance(normalizer, (normalizers.Replace, normalizers.Sequence)) - else: - return False - @property - def num_cpus(self) -> int: - if self.override_resources is not None: - cpus = self.override_resources.get("num_cpus", None) - if cpus is not None: - return cpus - return num_cpus_used_by_tokenizer(self.tokenizer) - @property - def num_gpus(self) -> int: - if self.override_resources is not None: - return self.override_resources.get("num_gpus", 0) - return 0 -def concatenate_and_group_texts( - encoding: BatchEncoding, - seq_len: int, - stride: Optional[int] = None, - drop_remainder: bool = True, - mask_stride_overlap=True, -) -> Iterator[BatchEncoding]: - """Groups texts in a batch together. Typically, you'll want to use this with a fairly large - set of texts, e.g. 1000 docs. - You should set mask_stride_overlap to True and drop_remainder to False if you want to use this for test data - Args: - encoding: The batch of texts to concatenate and group. - seq_len: The max length of sequences to emit - stride: The stride to use when grouping texts. If None, then the stride is set to seq_len. - mask_stride_overlap: Whether to mask out overlapping tokens if we're using a stride. - drop_remainder: Whether to drop the last batch if it's not a multiple of the seq_len. +def mk_supervised_datasets( + sources: Mapping[str, SupervisedSourceConfigBase] | SupervisedSourceConfigBase, + split: str, + tokenizer: PreTrainedTokenizerBase, + Pos: hax.Axis, +) -> dict[str, tuple[AsyncDataset[LmExample], Sequence[str]]]: + """ + Create supervised datasets from a set of sources. Returns: - An iterator of tokenized texts, one at a time. + A dictionary of dataset names to tuples of the dataset and the tags associated with the dataset. """ - concatenated = BatchEncoding(data={k: np.array(list(chain(*v))) for k, v in encoding.items()}) - total_length = len(concatenated.input_ids) - stride = stride or seq_len - - # Drop the "very last" bit of the dataset that doesn't fit into block size... - if drop_remainder and total_length % stride != 0: - total_length = ((total_length - seq_len + stride) // stride) * stride + out: dict[str, tuple[AsyncDataset[LmExample], Sequence[str]]] = {} - # Split by Chunks of Maximum Length - # we want to take chunks up until we've covered all "total_length" tokens with a sliding window of size "stride" - for begin in range(0, total_length - seq_len + stride, stride): - data = {k: v[begin : begin + seq_len] for k, v in concatenated.items()} + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token - if mask_stride_overlap and stride != seq_len: - labels = data.get("labels", data["input_ids"]) - if begin != 0: - labels = _mask_overlap(labels, seq_len, stride) - data["labels"] = labels + if isinstance(sources, Mapping): + for name, config in sources.items(): + source = config.get_shard_source(split) + if source is None: + continue - yield BatchEncoding(data=data) + ds = _cache_supervised_set( + source, config.cache_dir, tokenizer, Pos, config.input_field, config.output_field + ) + if config.tags is None: + tags = [name] + else: + tags = config.tags + [name] -# -100 is pytorch's label mask -def _mask_overlap(labels, target_len, stride, sentinel=-100): - """Masks out overlapping tokens in a sequence when we're using a stride.""" - labels = copy.deepcopy(labels) - if isinstance(labels, list): - for i in range(target_len - stride): - if i < len(labels): - labels[i] = sentinel + out[name] = (ds, tags) else: - labels[0 : target_len - stride] = sentinel + source = sources.get_shard_source(split) # type: ignore + if source is not None: + ds = _cache_supervised_set( + source, sources.cache_dir, tokenizer, Pos, sources.input_field, sources.output_field + ) + tags = sources.tags or [] + if isinstance(sources, SupervisedHfSourceConfig): + name = sources.id + if sources.name is not None: + name = f"{name}/{sources.name}" - return labels + tags = [name] + tags + else: + name = "default" + out[name] = (ds, tags) + return out -def _stack_batch_encodings(a: BatchEncoding, b: BatchEncoding) -> BatchEncoding: - """Stacks two batch encodings together, assuming that the keys are the same.""" - def _ensure_batched(x): - if len(x) == 0: - return list(x) - elif isinstance(x[0], Sequence) or isinstance(x[0], np.ndarray): - return list(x) - else: - return [x] +def mk_supervised_dataset( + config: SupervisedSourceConfigBase, split: str, tokenizer: HfTokenizer, Pos: hax.Axis +) -> AsyncDataset[LmExample]: - return BatchEncoding({k: _ensure_batched(a[k]) + _ensure_batched(b[k]) for k in a.keys()}) + source = config.get_shard_source(split) + output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} -@dataclass -class LMDatasetSourceConfig: - """This class represents a dataset source with URLs or hf name/id.""" + dataset = source.map_batches( # type: ignore + lambda ex: _preprocess_supervised_example(ex, tokenizer, config.input_field, config.output_field), + batch_size=128, + num_cpus=num_cpus_used_by_tokenizer(tokenizer), + output_exemplar=output_exemplar, + ) - tags: Optional[List[str]] = None - """tags for the dataset. Typically the name of the dataset in the config will be added as a tag as well""" + cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(config.cache_dir, await_finished=True) - id: Optional[str] = None # id (or path) for hf dataset - name: Optional[str] = None # name for hf dataset + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token - plaintext: bool = False - stream: bool = True # whether to use streaming when doing hf - text_key: str = "text" # key for the text field in the jsonl file or hf dataset + return cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos)) - train_urls: List[str] = () # type: ignore - validation_urls: List[str] = () # type:ignore - cache_dir: Optional[str] = None # Optionally override the cache dir for this component - def get_shard_source(self, split) -> Optional[ShardedDataSource[str]]: - if self.id is not None: - try: - ds = WrappedHFDataSource(self.id, split=split, name=self.name, streaming=self.stream) - except ValueError as e: - # if the message starts with Bad split, then just return None - if str(e).startswith("Bad split"): - logger.warning(f"Splits {split} not found for {self.id} {self.name}") - return None - else: - raise +def _cache_supervised_set(source, cache_dir, tokenizer, Pos, input_field, output_field): + """ + Cache a supervised dataset into input_ids and sources_len. + """ + output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} + dataset = source.map_batches( + lambda ex: _preprocess_supervised_example(ex, tokenizer, input_field, output_field), + batch_size=128, + num_cpus=num_cpus_used_by_tokenizer(tokenizer), + output_exemplar=output_exemplar, + ) + cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(cache_dir, await_finished=True) + ds = cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos)) + return ds - if len(ds.shard_names) == 0: - return None - return ds.map(lambda x: x[self.text_key]) - else: - split_urls = self.urls_for_split(split) - if len(split_urls) == 0: - return None - return TextUrlDataSource(split_urls, self.text_key) +@dataclass(frozen=True) +class ChatUrlDataSourceConfig: + """Config for loading JSONL files in OpenAI chat format for supervised fine-tuning.""" - def doc_iterator(self, split: str): - if self.id is not None: - dataset = datasets.load_dataset(self.id, name=self.name, streaming=self.stream) - data = dataset[split] - for doc in data: - yield doc[self.text_key] - else: - urls = self.urls_for_split(split) + cache_dir: str + train_urls: List[str] = field(default_factory=list) + validation_urls: List[str] = field(default_factory=list) - yield from TextUrlDataSource(urls, self.text_key) + # Chat format specific fields + messages_field: str = "messages" + input_role: str = "user" + output_role: str = "assistant" - def urls_for_split(self, split): - if split == "train": - urls = self.train_urls - elif split == "validation": - urls = self.validation_urls - else: - raise ValueError(f"Unknown split {split}") + def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: + """Gets ShardedDataSource for either training or validation data.""" + urls = self.validation_urls if split == "validation" else self.train_urls - urls = [globbed for url in urls for globbed in expand_glob(url)] - return urls + if not urls: + return None + # Use the datasource_from_chat_jsonl function from sharded_datasource + return datasource_from_chat_jsonl( + urls, messages_field=self.messages_field, input_role=self.input_role, output_role=self.output_role + ) -@dataclass -class LMTaskConfig(abc.ABC): - tokenizer: str = "gpt2" - vocab_size: Optional[int] = None # if using the passthrough tokenizer, this is required - # config related to caching - cache_dir: Optional[str] = "cache/" - cache_options: CacheOptions = field(default_factory=CacheOptions) - enforce_eos: bool = True # whether to append eos even if the tokenizer doesn't +def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase, should_append_eos: bool) -> dict: + """ + Preprocess chat examples to match the format of preprocess_supervised_example. + Returns a dict with input_ids and sources_len like the supervised case. - ignore_token_id: Optional[int] = None - shuffle: bool | int = False - """whether to shuffle the dataset. True means shuffle the whole dataset, False means don't shuffle. - If you want to shuffle in eras, set this to the era length""" + Args: + batch: List of dicts with input/output pairs + tokenizer: HuggingFace tokenizer + should_append_eos: Whether we need to manually add EOS (True if tokenizer doesn't do it automatically) + """ + sources = [example["input"] for example in batch] + targets = [example["output"] for example in batch] + + # Tokenize sources to get lengths + sources_tokenized = tokenizer(sources, padding=False, truncation=True) + + # Combine for full examples + full_examples = [f"{s}{t}" for s, t in zip(sources, targets)] + examples_tokenized = tokenizer(full_examples, padding=False, truncation=True) + + return { + "input_ids": [np.array(example, dtype=np.int32) for example in examples_tokenized["input_ids"]], + "sources_len": np.array([len(s) for s in sources_tokenized["input_ids"]], dtype=np.int32), + } - @cached_property - def the_tokenizer(self) -> HfTokenizer: - if self.tokenizer == "passthrough": - return PassthroughTokenizer(self.vocab_size) - else: - return load_tokenizer(self.tokenizer) - @abc.abstractmethod - def train_set( - self, - seq_len: int, - monitors: Union[bool, List[MetricsMonitor]] = True, - *, - key: Optional[PRNGKeyArray], - epochs: Optional[int] = None, - ) -> AsyncDataset[np.ndarray]: - pass - @abc.abstractmethod - def validation_sets( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Mapping[str, AsyncDataset[np.ndarray]]: - pass - @property - @abc.abstractmethod - def sources(self) -> Mapping[str, LMDatasetSourceConfig]: - pass - def tagged_eval_sets( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> list[Tuple[AsyncDataset[np.ndarray], List[str]]]: - tags = {name: (config.tags or []) + [name] for name, config in self.sources.items()} - eval_sets = self.validation_sets(seq_len, monitors) - return [(eval_sets[name], tags[name]) for name in eval_sets] -CANONICAL_INPUT_FIELD = "prompt" -CANONICAL_OUTPUT_FIELD = "response" -class SupervisedSourceConfigBase(Protocol): - def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: - raise NotImplementedError +def mk_chat_sft_dataset( + config: ChatUrlDataSourceConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis +) -> AsyncDataset[LmExample]: + """Creates a dataset from JSONL files containing chat format data for SFT.""" + source = config.get_shard_source("train") + if source is None: + raise ValueError("No training data source found") - input_field: str - output_field: str - tags: Optional[List[str]] - cache_dir: str + # Set up example structure matching supervised case + output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} + input_ids = tokenizer("hi there")["input_ids"] + should_append_eos = input_ids[-1] != tokenizer.eos_token_id + logger.info(f"Manual EOS Needed: {should_append_eos}") -@dataclass -class LMSupervisedDatasetConfig(SupervisedSourceConfigBase): - """Config for supervised fine-tuning datasets""" + # Process the dataset + dataset = source.map_batches( + lambda ex: preprocess_chat_example(ex, tokenizer, should_append_eos), + batch_size=128, + num_cpus=num_cpus_used_by_tokenizer(tokenizer), + output_exemplar=output_exemplar, + ) - cache_dir: str = "cache/" + # Cache the processed data + cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(config.cache_dir, await_finished=True) - # HF dataset config - hf_dataset_name: Optional[str] = None # e.g. "tatsu-lab/alpaca" or "OpenAssistant/oasst1" - hf_dataset_split: str = "train" # which split to use + # Ensure padding token is set (needed by _prepare_supervised_example) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token - # Local files config - validation_urls: List[str] = field(default_factory=list) # paths to jsonl/json files + # Reuse the supervised prepare function directly + return cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos)) - # Field names in the data - input_field: str = CANONICAL_INPUT_FIELD # name of the input field - output_field: str = CANONICAL_OUTPUT_FIELD # name of output field - # Optional metadata - tags: Optional[List[str]] = None +@dataclass +class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): + """This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls""" - def __post_init__(self): - warnings.warn( - "LMSupervisedDatasetConfig is deprecated. Use SupervisedHfSourceConfig or " - "SupervisedUrlSourceConfig instead.", - DeprecationWarning, - ) + cache_dir: Optional[str] = "cache/" - def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: - if self.hf_dataset_name is not None: - return WrappedHFDataSource(self.hf_dataset_name, split=self.hf_dataset_split) - elif split != "validation": - raise ValueError("Only validation split is supported for local files") - else: - urls = [globbed for url in self.validation_urls for globbed in expand_glob(url)] - return JsonlDataSource(urls) + def train_set( + self, + seq_len: int, + monitors: Union[bool, List[MetricsMonitor]] = True, + *, + key: Optional[PRNGKeyArray] = None, + epochs: Optional[int] = None, + ) -> AsyncDataset[np.ndarray]: + ds: AsyncDataset[np.ndarray] | None = self.token_seq_dataset("train", seq_len, monitors) -@dataclass(frozen=True) -class SupervisedHfSourceConfig(SupervisedSourceConfigBase): - cache_dir: str - id: str - name: str | None = None + # add epoch flag here. + if ds is None: + raise ValueError("No training set!") - streaming: bool = True + if epochs: + logger.info("Wrapping dataset in epoch dataset") + ds = EpochDataset(ds, max_epochs=epochs) - input_field: str = CANONICAL_INPUT_FIELD - output_field: str = CANONICAL_OUTPUT_FIELD - tags: Optional[List[str]] = None + if self.shuffle is True: + ds = ds.shuffle(key) + elif isinstance(self.shuffle, int) and self.shuffle > 0: + ds = ds.era_shuffle(self.shuffle, key=key) - def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: - return WrappedHFDataSource(self.id, split=split, name=self.name, streaming=self.streaming).map( - lambda x: {CANONICAL_INPUT_FIELD: x[self.input_field], CANONICAL_OUTPUT_FIELD: x[self.output_field]} - ) - - -@dataclass(frozen=True) -class SupervisedUrlSourceConfig(SupervisedSourceConfigBase): - cache_dir: str - train_urls: list[str] = dataclasses.field(default_factory=list) - validation_urls: list[str] = dataclasses.field(default_factory=list) + return ds # type: ignore - input_field: str = CANONICAL_INPUT_FIELD - output_field: str = CANONICAL_OUTPUT_FIELD - tags: Optional[List[str]] = None + def validation_set( + self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + ) -> Optional[TokenSeqDataset]: + return self.token_seq_dataset("validation", seq_len, monitors) - def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: - urls = self.train_urls if split == "train" else self.validation_urls - if not urls: - return None + def validation_sets( + self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + ) -> Mapping[str, AsyncDataset[np.ndarray]]: + validation_set = self.validation_set(seq_len, monitors) + if validation_set is not None: + return {"": validation_set} + else: + return {} - urls = [globbed for url in urls for globbed in expand_glob(url)] + @property + def sources(self) -> Mapping[str, LMDatasetSourceConfig]: + return {"": self} - source = UrlDataSource(urls, columns=[self.input_field, self.output_field]) - return source.map( - lambda x: {CANONICAL_INPUT_FIELD: x[self.input_field], CANONICAL_OUTPUT_FIELD: x[self.output_field]} - ) + @cached_property + def _has_validation_set(self): + if len(self.validation_urls) > 0: + return True + if self.id is not None: + dataset = datasets.load_dataset(self.id, name=self.name, streaming=self.stream, split="validation") + try: + next(iter(dataset)) + return True + except StopIteration: + return False -SupervisedSourceConfig: TypeAlias = Union[SupervisedHfSourceConfig, SupervisedUrlSourceConfig] + return False + def token_seq_dataset( + self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + ) -> Optional[TokenSeqDataset]: + cache = self.build_or_load_cache(split, monitors=monitors) + if cache is None: + return None + return TokenSeqDataset(cache, seq_len) -def _preprocess_supervised_example( - batch, tokenizer: PreTrainedTokenizerBase, input_field: str, output_field: str -) -> dict: - sources = [example[input_field] for example in batch] + def build_or_load_cache( + self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True, logger_name: Optional[str] = None + ) -> Optional[TreeCache[BatchEncoding]]: + if self.cache_dir is None: + raise ValueError("cache_dir cannot be None") - targets = [example[output_field] for example in batch] - # TODO: this seems pretty wasteful since you end up tokenizing twice, but it's how alpaca does it. - examples = [s + t for s, t in zip(sources, targets)] - sources_tokenized = tokenizer(sources, padding=False, truncation=True) - examples_tokenized = tokenizer(examples, padding=False, truncation=True) + split_cache_dir = os.path.join(self.cache_dir, split) + name = logger_name or os.path.basename(self.cache_dir) - source_lens = [len(s) for s in sources_tokenized["input_ids"]] + try: + # TODO: pass in options + return TreeCache.load(split_cache_dir, exemplar={"input_ids": np.zeros(0, dtype=np.int32)}) + except FileNotFoundError: + pass - return { - "input_ids": [np.array(example, dtype=np.int32) for example in examples_tokenized["input_ids"]], - "sources_len": np.array(source_lens, dtype=np.int32), - } + source = self.get_shard_source(split) + if source is None: + logger.info(f"No data for {split}") + return None + logger.info(f"Building cache for {split}...") -def _prepare_supervised_examples(ex: list[dict], tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis) -> list[LmExample]: - lens = np.array([ex["sources_len"] for ex in ex]) - - # Pad to max length - ex_pad = tokenizer.pad(ex, padding="max_length", max_length=Pos.size) - - # Create examples with appropriate loss masking - out = [] - for ids, len in zip(ex_pad["input_ids"], lens): - causal = _mk_sup_example_jit(Pos, hax.named(ids, Pos), len, tokenizer.pad_token_id, tokenizer.eos_token_id) - out.append(causal) - - return out + if monitors is True: + monitors = [ + LoggingMetricsMonitor(prefix=f"preprocessing/{name}/{split}", commit=False), + LoggerMetricsMonitor(f"preprocessing.{name}.{split}"), + ] + elif monitors is False: + monitors = [] + bt = BatchTokenizer(self.the_tokenizer, enforce_bos=True, enforce_eos=self.enforce_eos) -@functools.partial(jax.jit, static_argnums=(0, 3, 4)) -def _mk_sup_example_jit(Pos, input_ids: hax.NamedArray, sources_len, pad_token_id, eos_id): - # mask out padding and anything before the start of the target - loss_mask = hax.arange(Pos) >= sources_len - 1 - # don't predict the padding - targets = hax.roll(input_ids, -1, Pos) - loss_mask = loss_mask & (targets != pad_token_id) - loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_)) - return LmExample.causal(input_ids, loss_mask=loss_mask, eos_id=eos_id) + return build_or_load_cache( + split_cache_dir, + source, + bt, + monitors=monitors, + await_finished=False, + options=self.cache_options, + split=split, + ) -def mk_supervised_datasets( - sources: Mapping[str, SupervisedSourceConfigBase] | SupervisedSourceConfigBase, - split: str, - tokenizer: PreTrainedTokenizerBase, - Pos: hax.Axis, -) -> dict[str, tuple[AsyncDataset[LmExample], Sequence[str]]]: - """ - Create supervised datasets from a set of sources. +class PassthroughTokenizer(PreTrainedTokenizer): + def __init__(self, vocab_size, **kwargs): + self._vocab = {i: i for i in range(vocab_size)} + self._vocab_size = vocab_size + super().__init__(**kwargs) - Returns: - A dictionary of dataset names to tuples of the dataset and the tags associated with the dataset. - """ - out: dict[str, tuple[AsyncDataset[LmExample], Sequence[str]]] = {} + @property + def vocab_size(self) -> int: + return self._vocab_size - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + def get_vocab(self): + return self._vocab - if isinstance(sources, Mapping): - for name, config in sources.items(): - source = config.get_shard_source(split) - if source is None: - continue + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, ...]: + return () - ds = _cache_supervised_set( - source, config.cache_dir, tokenizer, Pos, config.input_field, config.output_field - ) + def _tokenize(self, text, **kwargs): + tokens = np.fromstring(text, dtype=int, sep=" ") + return tokens - if config.tags is None: - tags = [name] - else: - tags = config.tags + [name] + def _convert_token_to_id(self, token: str) -> int: + return int(token) - out[name] = (ds, tags) - else: - source = sources.get_shard_source(split) # type: ignore - if source is not None: - ds = _cache_supervised_set( - source, sources.cache_dir, tokenizer, Pos, sources.input_field, sources.output_field - ) - tags = sources.tags or [] - if isinstance(sources, SupervisedHfSourceConfig): - name = sources.id - if sources.name is not None: - name = f"{name}/{sources.name}" + def _convert_id_to_token(self, index: int) -> str: + return str(index) + +@dataclass +class LMMixtureDatasetConfig(LMTaskConfig): + """This class represents a mixture of datasets with their associated weights.""" - tags = [name] + tags - else: - name = "default" - out[name] = (ds, tags) + cache_dir: Optional[str] = "cache/" - return out + # data source configs and weights + configs: Dict[str, LMDatasetSourceConfig] = field(default_factory=dict) + """ configuration of each dataset source (urls, hf dataset id, etc.) """ + train_weights: Dict[str, float] = field(default_factory=dict) + """ weights for each dataset source. They will be normalized to sum to 1. """ + stop_strategy: str = field(default=StopStrategy.RESTART_STRATEGY) + mixture_block_size: int = 2048 + """ block size for the mixture dataset.""" + def __post_init__(self): + if len(self.configs) == 0: + raise ValueError("At least one dataset must be provided") -def mk_supervised_dataset( - config: SupervisedSourceConfigBase, split: str, tokenizer: HfTokenizer, Pos: hax.Axis -) -> AsyncDataset[LmExample]: + if set(self.configs.keys()) != set(self.train_weights.keys()): + raise ValueError( + f"The keys in configs and weights must be the same;got {self.configs.keys()} and" + f" {self.train_weights.keys()}" + ) - source = config.get_shard_source(split) + def train_set( + self, + seq_len: int, + monitors: Union[bool, List[MetricsMonitor]] = True, + *, + key: Optional[PRNGKeyArray], + epochs: Optional[int] = None, + ) -> AsyncDataset[np.ndarray]: + doc_caches = self.build_caches("train", monitors=monitors) + token_datasets = {name: TokenSeqDataset(cache, seq_len) for name, cache in doc_caches.items()} - output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} + if epochs: + raise ValueError("Epochs are not supported for mixture datasets") - dataset = source.map_batches( # type: ignore - lambda ex: _preprocess_supervised_example(ex, tokenizer, config.input_field, config.output_field), - batch_size=128, - num_cpus=num_cpus_used_by_tokenizer(tokenizer), - output_exemplar=output_exemplar, - ) + if key is None: + key = jax.random.PRNGKey(0) - cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(config.cache_dir, await_finished=True) + mix_key, shuffle_key = jax.random.split(key) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + # We shuffle the components and not the overall mixture because this lets us preserve + # the "stable batch" property of the mixture dataset. + def shuffle_ds(ds, key): + if self.shuffle is True: + ds = ds.shuffle(key) + elif isinstance(self.shuffle, int): + ds = ds.era_shuffle(self.shuffle, key=key) - return cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos)) + return ds + if self.shuffle: + out_token_datasets = {} + key_iter = key_iterator(shuffle_key) + for name, ds in token_datasets.items(): + out_token_datasets[name] = shuffle_ds(ds, next(key_iter)) + token_datasets = out_token_datasets -def _cache_supervised_set(source, cache_dir, tokenizer, Pos, input_field, output_field): - """ - Cache a supervised dataset into input_ids and sources_len. - """ - output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} - dataset = source.map_batches( - lambda ex: _preprocess_supervised_example(ex, tokenizer, input_field, output_field), - batch_size=128, - num_cpus=num_cpus_used_by_tokenizer(tokenizer), - output_exemplar=output_exemplar, - ) - cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(cache_dir, await_finished=True) - ds = cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos)) - return ds + mixture = MixtureDataset( + datasets=token_datasets, + weights=self.train_weights, + stop_strategy=self.stop_strategy, + key=mix_key, + block_size=2048, + ) + return mixture -@dataclass(frozen=True) -class ChatUrlDataSourceConfig: - """Config for loading JSONL files in OpenAI chat format for supervised fine-tuning.""" + def training_sets( + self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + ) -> Mapping[str, TokenSeqDataset]: + doc_caches = self.build_caches("train", monitors=monitors) + token_datasets = {name: TokenSeqDataset(cache, seq_len) for name, cache in doc_caches.items()} + return token_datasets - cache_dir: str - train_urls: List[str] = field(default_factory=list) - validation_urls: List[str] = field(default_factory=list) + def validation_sets( + self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + ) -> Mapping[str, AsyncDataset[np.ndarray]]: + doc_caches = self.build_caches("validation", monitors=monitors) + token_datasets = {name: TokenSeqDataset(cache, seq_len) for name, cache in doc_caches.items()} + return token_datasets - # Chat format specific fields - messages_field: str = "messages" - input_role: str = "user" - output_role: str = "assistant" + def build_caches( + self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True + ) -> Dict[str, TreeCache[dict]]: + # this is a bit gross, but we want to forward all "Task" config fields to the LMDatasetConfig for building. + # We do this by just grabbing all the fields from the LMDatasetConfig and forwarding them to the + # LMDatasetConfig.build_or_load_cache method. We exclude the cache_dir field. + task_config_fields = set(x.name for x in dataclasses.fields(LMTaskConfig)) + task_config_dict = {k: v for k, v in self.__dict__.items() if k in task_config_fields and k != "cache_dir"} - def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: - """Gets ShardedDataSource for either training or validation data.""" - urls = self.validation_urls if split == "validation" else self.train_urls + caches = {} + for name, source_config in self.configs.items(): + weight = self.train_weights.get(name, 0) - if not urls: - return None + if weight == 0 and split == "train": + continue - # Use the datasource_from_chat_jsonl function from sharded_datasource - return datasource_from_chat_jsonl( - urls, messages_field=self.messages_field, input_role=self.input_role, output_role=self.output_role - ) + source_config_dict = dict(**source_config.__dict__) + if source_config.cache_dir is None: + # replace with the main cache dir/{name} + if self.cache_dir is None: + raise ValueError( + "If the 'main' cache_dir is None, then all component cache_dirs must be non-None, but" + f"{name}'s cache_dir is None." + ) + cache_dir = os.path.join(self.cache_dir, name) + source_config_dict["cache_dir"] = cache_dir -def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase, should_append_eos: bool) -> dict: - """ - Preprocess chat examples to match the format of preprocess_supervised_example. - Returns a dict with input_ids and sources_len like the supervised case. + dataset = LMDatasetConfig( + **source_config_dict, + **task_config_dict, + ) + cache = dataset.build_or_load_cache(split, monitors) + # drop the data source and corresponding weight if the cache is not built + if cache is None: + logger.warning(f"Skipping {name} for split {split} because no source was provided") + else: + caches[name] = cache - Args: - batch: List of dicts with input/output pairs - tokenizer: HuggingFace tokenizer - should_append_eos: Whether we need to manually add EOS (True if tokenizer doesn't do it automatically) - """ - sources = [example["input"] for example in batch] - targets = [example["output"] for example in batch] - - # Tokenize sources to get lengths - sources_tokenized = tokenizer(sources, padding=False, truncation=True) - - # Combine for full examples - full_examples = [f"{s}{t}" for s, t in zip(sources, targets)] - examples_tokenized = tokenizer(full_examples, padding=False, truncation=True) - - return { - "input_ids": [np.array(example, dtype=np.int32) for example in examples_tokenized["input_ids"]], - "sources_len": np.array([len(s) for s in sources_tokenized["input_ids"]], dtype=np.int32), - } + # in practice it works best if we block on validation caches + if split == "validation": + for cache in caches.values(): + cache.await_finished() + else: + logger.info(f"Not waiting for {split} caches to finish building") -def mk_chat_sft_dataset( - config: ChatUrlDataSourceConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis -) -> AsyncDataset[LmExample]: - """Creates a dataset from JSONL files containing chat format data for SFT.""" - source = config.get_shard_source("train") - if source is None: - raise ValueError("No training data source found") + return caches - # Set up example structure matching supervised case - output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} + @property + def sources(self) -> Mapping[str, LMDatasetSourceConfig]: + return self.configs - input_ids = tokenizer("hi there")["input_ids"] - should_append_eos = input_ids[-1] != tokenizer.eos_token_id - logger.info(f"Manual EOS Needed: {should_append_eos}") - # Process the dataset - dataset = source.map_batches( - lambda ex: preprocess_chat_example(ex, tokenizer, should_append_eos), - batch_size=128, - num_cpus=num_cpus_used_by_tokenizer(tokenizer), - output_exemplar=output_exemplar, - ) - # Cache the processed data - cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(config.cache_dir, await_finished=True) +def datasource_from_chat_jsonl( + urls: Sequence[str], messages_field: str = "messages", input_role: str = "user", output_role: str = "assistant" +) -> "ShardedDataSource[dict]": + """Creates a ShardedDataSource from JSONL files containing chat messages. - # Ensure padding token is set (needed by _prepare_supervised_example) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + Args: + urls: Sequence of URLs or glob patterns pointing to JSONL files + messages_field: Field name containing the messages in each JSON object + input_role: Role identifier for input messages + output_role: Role identifier for output messages - # Reuse the supervised prepare function directly - return cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos)) + Returns: + ShardedDataSource configured for chat data + """ + # Expand any glob patterns in the URLs + expanded_urls = [] + for url in urls: + if any(c in url for c in "*?[]"): + expanded_urls.extend(gcs_glob(url)) + else: + expanded_urls.append(url) + return ChatJsonlDataSource(expanded_urls, messages_field, input_role, output_role) def preprocess_chat_example_for_packing( batch,