diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 520281793..737f6fa3c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -32,7 +32,7 @@ jobs: pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,DEV,DOCS]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]" - name: Run tests run: pytest -v -ra . diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 75ba3bb31..bb3547496 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -34,7 +34,7 @@ jobs: pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,DEV,DOCS]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]" - name: Build the documentation run: mkdocs build diff --git a/Dockerfile b/Dockerfile index 6c013c14d..0f3c2d8cb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,7 +37,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,DEV]" triton==3.1.0 +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0 # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM diff --git a/Megatron-LM b/Megatron-LM index 511e8f5cb..75b0d9787 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 511e8f5cbe3ab8291953ac64e5beceb727a1b814 +Subproject commit 75b0d97876006c4b6b23fce302100d18dbf7db37 diff --git a/docs/user_guide/evaluators.md b/docs/user_guide/evaluators.md new file mode 100644 index 000000000..0e075cebe --- /dev/null +++ b/docs/user_guide/evaluators.md @@ -0,0 +1,134 @@ +# Evaluations + +Fast-LLM allows you to perform various evaluations during training or as a separate evaluation step. In both cases, you need to use your training config with `training.evaluators` specified. + +For evaluators used during training, both `interval` and `offset` must be specified. Then, start training as usual with: + +`fast-llm train gpt --config path/to/training/config.yaml` + +To perform evaluation as a separate step, use the same training config. Depending on the training progress, either the start model or the latest checkpoint will be loaded, and `interval` and `offset` will be ignored. To start evaluation: + +`fast-llm evaluate gpt --config path/to/training/config.yaml` + +## Currently Supported Evaluators + +- `loss` +- `lm_eval` + +## Loss Evaluator + +To set up loss evaluation, specify a dataset to be used in the `data.datasets` section of the config. You must also define the loss evaluator in the `training.evaluators` config section. See example below. + +```yaml +training: + evaluations: + stack_3b: + interval: 10 + evaluator: + type: loss + iterations: 10 + dataset_name: stack_3b + fineweb: + evaluator: + type: loss + iterations: 10 + dataset_name: stack_3b + interval: 10 +data: + datasets: + stack_3b: + type: memmap + path: path/to/memmap/dataset + fineweb: + type: memmap + path: path/to/memmap/dataset1 +``` + +## Evaluation Harness (`lm_eval`) Evaluator + +**Note:** Only data parallelism is currently supported for the `lm_eval` evaluator. + +To run `lm_eval` evaluations, version `0.4.9` of `lm_eval` must be installed along with all dependencies required for your evaluation tasks. + +The following environment variables may need to be set: + +- `HF_HOME`: Path for Hugging Face data caching +- `WANDB_API_KEY_PATH`: Path to a file containing your Weights & Biases API key (if logging to W&B) +- `HUGGINGFACE_API_KEY_PATH`: Path to a file containing your Hugging Face hub token +- `NLTK_DATA`: Path to a directory that will contain downloaded NLTK packages (needed for some tasks) +- `HF_ALLOW_CODE_EVAL=1`: Required for some evaluation tasks + +You may need to specify additional environment variables depending on the `lm_eval` tasks you want to run. + +To specify an `lm_eval` task, the evaluator config includes the following fields: + +### Model Config + +The model instantiated for training is reused for evaluation, so you don't need to specify it separately. However, there are some parameters specific to `lm_eval`. See `fast_llm/engine/evaluation/config.EvaluatorLmEvalConfig` for details. + +### CLI Parameters for `lm_eval` + +All other parameters are specified as if you were calling the `lm_eval` CLI, using a list of strings. Some CLI parameters are ignored or restricted—specifically those related to model loading, W&B, batch sizes, and device setup, as these are managed by the rest of the Fast-LLM configuration. + +Also, the tokenizer must be specified in `data.tokenizer`. If the tokenizer does not have a `bos_token`, it must be specified explicitly in `data.tokenizer.bos_token`. Although `lm_eval` does not use the `bos_token` directly, it is still required because the same tokenizer is used by other Fast-LLM components. + +Below is an example of the config: + +```yaml +training: + evaluations: + lm_eval_tasks1: + interval: 10 + evaluator: + type: lm_eval + cli_args: + - --tasks + - gsm8k,xnli_en,wikitext,ifeval + - --output_path + - /path/to/lm_eval/output +data: + tokenizer: + path: path/to/the/tokenizer +``` + +It is also possible to run different tasks with different intervals and offsets—for example, to run slower or more comprehensive tasks less frequently.: + +```yaml +training: + evaluations: + gsm8k: + interval: 20 + evaluator: + type: lm_eval + cli_args: + - --tasks + - gsm8k + - --output_path + - /path/to/lm_eval/output + - --limit + - "64" + ifeval: + offset: 10 + interval: 40 + evaluator: + type: lm_eval + cli_args: + - --tasks + - ifeval + - --output_path + - /path/to/lm_eval/output + - --limit + - "32" + faster_tasks: + interval: 10 + evaluator: + type: lm_eval + cli_args: + - --tasks + - xnli_en,wikitext + - --output_path + - /path/to/lm_eval/output +data: + tokenizer: + path: path/to/the/tokenizer +``` diff --git a/fast_llm/cli.py b/fast_llm/cli.py index 66ce096d5..c4a13c5dc 100644 --- a/fast_llm/cli.py +++ b/fast_llm/cli.py @@ -7,6 +7,7 @@ from fast_llm.engine.config_utils.logging import configure_logging from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.utils import set_global_variables # Import these submodules to ensure classes are added to the dynamic class registry. import fast_llm.data.auto # isort: skip @@ -20,6 +21,8 @@ def fast_llm_main_wrapper(): # (Pre-)configure logging configure_logging() + # Set global and environment variables before third-party imports. + set_global_variables() try: yield except Exception as e: diff --git a/fast_llm/config.py b/fast_llm/config.py index 0004501bd..c534b11f3 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -735,7 +735,7 @@ def _get_class_name(cls) -> str: @classmethod def from_dict( cls, - default: "Config| dict[str, typing.Any]]", + default: "Config| dict[str, typing.Any]", *updates: "Config| dict[str | tuple[str, ...], typing.Any]", strict: bool = True, update_type: UpdateType = UpdateType.override, diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index e82e08013..86f8e7297 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -8,10 +8,13 @@ import contextlib import datetime +import io import logging +import pickle import typing import torch +import torch.monitor from torch._C._distributed_c10d import Work from torch.distributed import ( # noqa ProcessGroup, @@ -46,6 +49,7 @@ def broadcast( return work else: work.wait() + return None def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name: str) -> None: @@ -110,6 +114,7 @@ def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, ta return work else: work.wait() + return None def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None: @@ -119,6 +124,7 @@ def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, ta return work else: work.wait() + return None @contextlib.contextmanager @@ -133,3 +139,118 @@ def set_generator(generator: torch.Generator) -> typing.Generator[None, None, No finally: generator.set_state(default_generator.get_state()) default_generator.set_state(old_state) + + +def gather( + tensor: torch.Tensor, + gather_list: list[torch.Tensor] | None = None, + group: ProcessGroup | None = None, + async_op: bool = False, + dst: int = 0, +): + assert group is not None + opts = torch.distributed.GatherOptions() + opts.rootRank = dst + work = group.gather([gather_list] if dst == group.rank() else [], [tensor], opts) + + if async_op: + return work + elif work is not None: + work.wait() + return None + + +def scatter( + tensor: torch.Tensor, + scatter_list: list[torch.Tensor] | None = None, + group: ProcessGroup | None = None, + async_op: bool = False, + src: int = 0, +): + assert group is not None + opts = torch.distributed.ScatterOptions() + opts.rootRank = src + opts.asyncOp = async_op + work = group.scatter( + [tensor if not tensor.is_complex() else torch.view_as_real(tensor)], + [[t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list]] if src == group.rank() else [], + opts, + ) + if async_op: + return work + elif work is not None: + work.wait() + return None + + +def _object_to_tensor(obj: typing.Any) -> torch.Tensor: + f = io.BytesIO() + pickle.Pickler(f).dump(obj) + return torch.tensor(torch.UntypedStorage.from_buffer(f.getvalue(), dtype=torch.uint8), dtype=torch.uint8) + + +def _tensor_to_object(tensor: torch.Tensor) -> typing.Any: + return pickle.Unpickler(io.BytesIO(tensor.numpy(force=True).tobytes())).load() + + +def gather_object( + obj: typing.Any, + group: ProcessGroup | None = None, + dst: int = 0, +) -> list[typing.Any] | None: + assert group is not None + group_rank = group.rank() + group_size = group.size() + device = torch.cuda.current_device() + + obj_tensor = _object_to_tensor(None if group_rank == dst else obj) + sizes = torch.full([group.size()], len(obj_tensor), dtype=torch.int64, device=device) + all_gather_into_tensor(sizes, sizes[group.rank()], group=group) + sizes = sizes.tolist() + max_size = max(sizes) + + input_tensor = torch.empty(max_size, dtype=torch.uint8, device=device) + + if group_rank == dst: + output_tensors = list(torch.empty(max_size * group_size, dtype=torch.uint8, device=device).chunk(group_size)) + gather(input_tensor, output_tensors, dst=dst, group=group) + return [ + obj if rank_ == dst else _tensor_to_object(tensor[:size]) + for rank_, (tensor, size) in enumerate(zip(output_tensors, sizes, strict=True)) + ] + else: + input_tensor[: obj_tensor.numel()].copy_(obj_tensor) + gather(input_tensor, None, dst=dst, group=group) + return None + + +def scatter_object( + scatter_object_input_list: typing.Optional[list[typing.Any]] = None, + group: ProcessGroup | None = None, + src: int = 0, +) -> typing.Any: + assert group is not None + group_rank = group.rank() + group_size = group.size() + device = torch.cuda.current_device() + + if group_rank == src: + tensor_list = [ + _object_to_tensor(None if rank_ == src else obj) for rank_, obj in enumerate(scatter_object_input_list) + ] + sizes = [tensor.numel() for tensor in tensor_list] + max_size = max(sizes) + size_tensor = torch.tensor([[size, max_size] for size in sizes], dtype=torch.int64, device=device) + scatter(size_tensor[group_rank], list(size_tensor.unbind()), src=src, group=group) + scatter_list = list(torch.empty(max_size * group_size, dtype=torch.uint8, device=device).chunk(group_size)) + for scatter_tensor, tensor, size in zip(scatter_list, tensor_list, sizes, strict=True): + scatter_tensor[:size].copy_(tensor) + scatter(scatter_list[src], scatter_list, src=src, group=group) + return scatter_object_input_list[src] + else: + size_tensor = torch.empty(2, dtype=torch.int64, device=device) + scatter(size_tensor, None, src=src, group=group) + size, max_size = size_tensor.tolist() + output_tensor = torch.empty(max_size, dtype=torch.uint8, device=device) + scatter(output_tensor, None, src=src, group=group) + return _tensor_to_object(output_tensor[:size]) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 2a1986b63..493361f32 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -177,7 +177,6 @@ def _init( assert self._num_pixels == num_pixels if num_tokens is not None: assert self._num_tokens == num_tokens - self._image_sizes = np.array(self._image_sizes, dtype=np.int32) def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 42062a58c..29a784b77 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -143,7 +143,7 @@ def _sample(self) -> None: # Get the document sizes, the main information needed for sampling. document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) - if image_sizes.any(): + if image_sizes: image_token_sizes = [] for i, sizes in enumerate(image_sizes): image_token_sizes.append( diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index fce0f022c..d6d473838 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -458,7 +458,7 @@ def _split_and_blend_dataset_configs( text_sizes, image_sizes = dataset.get_document_sizes() tokens_cumsum = text_sizes.cumsum() Assert.eq(tokens_cumsum[-1], dataset_config.num_tokens) - if image_sizes.any(): + if image_sizes: num_pixels_cumsum = np.cumsum([x.prod(axis=1).sum() for x in image_sizes]) # We use the patch sizes only for the purposes of even splitting and blending weights. # We can always use a different patch size for training without any significant impact diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index f8cfa8c5b..b23037e84 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -2,7 +2,6 @@ import os import pathlib import typing -import warnings import yaml @@ -10,7 +9,7 @@ from fast_llm.engine.config_utils.logging import TensorLogs, TensorLogsConfig, configure_logging from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.utils import log +from fast_llm.utils import log, set_global_variables if typing.TYPE_CHECKING: from fast_llm.engine.distributed.distributed import Distributed @@ -99,20 +98,9 @@ def get_run(self, distributed: "Distributed") -> "Run": TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels TritonConfig.TRITON_LINEAR = self.run.triton_linear_kernels run = Run(config=self, distributed=distributed) - self._set_external_variables() + set_global_variables(not self.run.torch_dynamo_enable) return run - def _set_external_variables(self) -> None: - import torch._dynamo - - # TODO: Find an alternative to get reliable tensor-parallel overlap. - if os.environ.get("CUDA_DEVICE_MAX_CONNECTIONS", ""): - warnings.warn("Setting CUDA_DEVICE_MAX_CONNECTIONS breaks things.") - if "PYTHONHASHSEED" not in os.environ: - warnings.warn("PYTHONHASHSEED should be set and to the same value for all workers.") - - torch._dynamo.config.disable = not self.run.torch_dynamo_enable # noqa - _MAIN_RANK = 0 @@ -142,7 +130,7 @@ def __init__( self._distributed.config.data_rank == 0 and self._distributed.config.tensor_rank == 0 ) config_dict = config.to_dict() - config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.performance) + config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.debug) if self._config.experiment_dir is not None: self._experiment_directory = self._config.experiment_dir.resolve() diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 99c1bcf70..0d971a88a 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -1,3 +1,4 @@ +import logging import math import typing @@ -5,9 +6,13 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: + import torch + from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.distributed.distributed import Distributed +logger = logging.getLogger(__name__) + class TensorDim: def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None): @@ -19,11 +24,11 @@ def __init__(self, name: str, global_size: int | None, parallel_dim: Distributed def __repr__(self) -> str: return ( - f"TensorDim(" + f"{type(self).__name__}(" f"name={self._name}," f" size={self._size}," f" global_size={self._global_size}," - f" parallel_dim={None if self.parallel_dim is None else self._parallel_dim}" + f" parallel_dim={self._parallel_dim}" f")" ) @@ -38,83 +43,140 @@ def name(self) -> str: def size(self) -> int: return self._size - @property - def expanded_shape(self) -> tuple[int, ...]: - return (self._size,) - - @property - def ndim(self) -> int: - return 1 - @property def global_size(self) -> int: return self._global_size @property - def global_expanded_shape(self) -> tuple[int, ...]: - return (self._size if self._parallel_dim is None else self._size * self._parallel_dim.size,) + def is_parallel(self) -> bool: + return self._parallel_dim is not None and self._parallel_dim.size > 1 @property def parallel_dim(self) -> DistributedDim | None: + # TODO: Make more flexible for derived classes? return self._parallel_dim - @property - def parallel_dim_index(self) -> int | None: - return None if self._parallel_dim is None else 0 - @property def parallel_group(self) -> "ProcessGroup|None": + # TODO: Make more flexible for derived classes? return None if self._parallel_dim is None else self._parallel_dim.group def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim is not None + assert self.is_parallel return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + if self.parallel_group is not None: + from fast_llm.core.ops import gather_op + + return gather_op(tensor, self.parallel_group, dim) + else: + return tensor + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + return ( + tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] + if self.parallel_dim is not None and self.parallel_dim.size > 1 + else tensor + ) + class CompositeTensorDim(TensorDim): - def __init__(self, name: str, dims: tuple[TensorDim, ...]): - # TODO: Recursive composition?? - parallel_dims = [(i, dim.parallel_dim) for i, dim in enumerate(dims) if dim.parallel_dim] - Assert.leq(len(parallel_dims), 1) + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = None + for dim, tensor_dim in enumerate(tensor_dims): + if tensor_dim.is_parallel: + # TODO: Allow more than one parallel subdim? + assert parallel_dim is None + parallel_dim = tensor_dim.parallel_dim + self._parallel_dim_index = dim super().__init__( name=name, - global_size=math.prod(dim.global_size for dim in dims), - parallel_dim=parallel_dims[0][1] if parallel_dims else None, - ) - self._dims = dims - self._parallel_dim_index = ( - sum(dim.ndim for dim in self._dims[: parallel_dims[0][0]]) - + self._dims[parallel_dims[0][0]].parallel_dim_index - if parallel_dims - else None + global_size=math.prod(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, ) + self._tensor_dims = tensor_dims - @property - def dims(self) -> tuple[TensorDim, ...]: - return self._dims + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + assert self._parallel_dim_index is not None + dims = list(self._tensor_dims) + dims[self._parallel_dim_index] = dims[self._parallel_dim_index].replace_parallel_dim(distributed_dim) + return CompositeTensorDim(self.name, tuple(dims)) - @property - def ndim(self) -> int: - return sum(dim.ndim for dim in self._dims) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global(tensor, dim + i) - @property - def expanded_shape(self) -> tuple[int, ...]: - return sum((dim.expanded_shape for dim in self._dims), ()) + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def global_expanded_shape(self) -> tuple[int, ...]: - return sum((dim.global_expanded_shape for dim in self._dims), ()) + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): + tensor = tensor_dim.global_to_local(tensor, dim + i) + return tensor if expand else tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def parallel_dim_index(self) -> int | None: - return self._parallel_dim_index + +class ConcatenatedTensorDim(TensorDim): + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = tensor_dims[0].parallel_dim + for dim, tensor_dim in enumerate(tensor_dims[1:]): + # TODO: Allow more flexibility? + Assert.is_(tensor_dim.parallel_dim, parallel_dim) + + super().__init__( + name=name, + global_size=sum(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, + ) + self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim_index is not None - dims = list(self.dims) - dims[self.parallel_dim_index] = dims[self.parallel_dim_index].replace_parallel_dim(distributed_dim) - return CompositeTensorDim(self.name, tuple(dims)) + assert self.is_parallel + return ConcatenatedTensorDim( + self.name, tuple(tensor_dim.replace_parallel_dim(distributed_dim) for tensor_dim in self._tensor_dims) + ) + + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + import torch + + return ( + torch.concatenate( + [ + tensor_dim.local_to_global(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + if self.is_parallel and expand: + raise NotImplementedError() + import torch + + return ( + torch.concatenate( + [ + tensor_dim.global_to_local(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.global_size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) class DefaultDimNames: @@ -147,21 +209,22 @@ def distributed(self) -> "Distributed": assert self._is_setup return self._distributed - def add_tensor_dim(self, dim: TensorDim) -> None: - if isinstance(dim, CompositeTensorDim): - for dim_ in dim.dims: - Assert.incl(dim_.name, self._tensor_dims) - Assert.eq(dim_, self._tensor_dims[dim_.name]) - if dim.name in self._tensor_dims: - Assert.eq(dim, self._tensor_dims[dim.name]) + def add_tensor_dim(self, tensor_dim: TensorDim) -> None: + if tensor_dim.name in self._tensor_dims: + Assert.eq(tensor_dim, self._tensor_dims[tensor_dim.name]) else: - if dim.parallel_dim is not None: - assert dim.parallel_dim.name in self._distributed_config.distributed_dims, dim.parallel_dim.name + if tensor_dim.parallel_dim is not None: + assert ( + tensor_dim.parallel_dim.name in self._distributed_config.distributed_dims + ), tensor_dim.parallel_dim.name Assert.eq( - dim.parallel_dim.__dict__, - self._distributed_config.distributed_dims[dim.parallel_dim.name].__dict__, + tensor_dim.parallel_dim.__dict__, + self._distributed_config.distributed_dims[tensor_dim.parallel_dim.name].__dict__, ) - self._tensor_dims[dim.name] = dim + self._tensor_dims[tensor_dim.name] = tensor_dim def get_tensor_dim(self, name: str) -> TensorDim: return self._tensor_dims[name] + + # TODO: Replace uses + __getitem__ = get_tensor_dim diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index 7223631f8..04e4227f1 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -6,7 +6,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.evaluation.evaluator import Evaluator, EvaluatorLoss + from fast_llm.engine.evaluation.evaluator import Evaluator, EvaluatorLmEval, LossEvaluator @config_class() @@ -40,7 +40,7 @@ def _from_dict( @config_class(dynamic_type={EvaluatorConfig: "loss"}) -class EvaluatorLossConfig(EvaluatorConfig): +class LossEvaluatorConfig(EvaluatorConfig): _abstract: typing.ClassVar[bool] = False iterations: int | None = Field( @@ -58,7 +58,58 @@ def get_evaluator( batch_config: BatchConfig, data_load_num_proc: int, train_iters: int | None = None, - ) -> "EvaluatorLoss": - from fast_llm.engine.evaluation.evaluator import EvaluatorLoss + ) -> "LossEvaluator": + from fast_llm.engine.evaluation.evaluator import LossEvaluator - return EvaluatorLoss(name, self, batch_config, data_load_num_proc, train_iters) + return LossEvaluator(name, self, batch_config, data_load_num_proc, train_iters) + + +@config_class(dynamic_type={EvaluatorConfig: "lm_eval"}) +class LmEvalEvaluatorConfig(EvaluatorConfig): + _abstract: typing.ClassVar[bool] = False + + cli_args: list[str] = Field( + default_factory=lambda: [], + desc="lm_eval CLI arguments, excluding those related to model, wandb, batch sizes, and device.", + ) + + truncation: bool = Field( + default=False, + desc="Whether to use truncation during tokenization (useful when inputs exceed model's max length);" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + logits_cache: bool = Field( + default=True, + desc="Whether to enable logits caching for speedup and avoiding recomputation during repeated evaluations;" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + add_bos_token: bool = Field( + default=False, + desc="Whether to prepend a beginning-of-sequence (BOS) token, required for some models like LLaMA;" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + prefix_token_id: int | None = Field( + default=None, + desc="Token ID to use as a prefix to the input (e.g., for control codes or prompts);" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + max_length: int | None = Field( + default=None, + desc="Maximum sequence length including both prompt and newly generated tokens." + " If not set, it is inferred from the Fast-LLM model config or tokenizer.", + ) + + def get_evaluator( + self, + name: str, + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + ) -> "EvaluatorLmEval": + from fast_llm.engine.evaluation.lm_eval.evaluator import LmEvalEvaluator + + return LmEvalEvaluator(name, self, batch_config, data_load_num_proc, train_iters) diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index 3fee32baf..3bdc2407f 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -10,7 +10,7 @@ from fast_llm.engine.config_utils.run import Run, log_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.evaluation.config import EvaluatorConfig, EvaluatorConfigBase, EvaluatorLossConfig +from fast_llm.engine.evaluation.config import EvaluatorConfig, EvaluatorConfigBase, LossEvaluatorConfig from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.schedule.runner import ScheduleRunner @@ -20,8 +20,6 @@ from fast_llm.logging import format_metrics from fast_llm.utils import get_and_reset_memory_usage_mib -# from fast_llm.engine.training.lm_eval.evaluator import simple_evaluate as lm_eval_simple_evaluate - logger = logging.getLogger(__name__) @@ -53,7 +51,7 @@ class Evaluator[ConfigType: EvaluatorConfig](Configurable[ConfigType], abc.ABC): def __init__( self, name: str, - eval_config: EvaluatorLossConfig, + eval_config: LossEvaluatorConfig, batch_config: BatchConfig, data_load_num_proc: int, train_iters: int | None = None, @@ -97,8 +95,8 @@ def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: """ -class EvaluatorLoss[ConfigType: EvaluatorLossConfig](Evaluator[ConfigType]): - config_class: typing.ClassVar[type[EvaluatorLossConfig]] = EvaluatorLossConfig +class LossEvaluator[ConfigType: LossEvaluatorConfig](Evaluator[ConfigType]): + config_class: typing.ClassVar[type[LossEvaluatorConfig]] = LossEvaluatorConfig def setup( self, diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py new file mode 100644 index 000000000..162ceaf60 --- /dev/null +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -0,0 +1,90 @@ +import logging +import os +import pathlib +import typing + +from fast_llm.data.data.abstract import Data +from fast_llm.engine.config_utils.run import Run +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.evaluation.config import LmEvalEvaluatorConfig +from fast_llm.engine.evaluation.evaluator import ( + EvaluationMetrics, + Evaluator, + EvaluatorSamplingParameters, + TrainingProgress, +) +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.schedule.runner import ScheduleRunner + +if typing.TYPE_CHECKING: + from fast_llm.engine.evaluation.lm_eval.fast_llm_wrapper import FastLLMLmEvalWrapper + from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM + +logger = logging.getLogger(__name__) + + +class LmEvalEvaluator[ConfigType: LmEvalEvaluatorConfig](Evaluator[ConfigType]): + config_class: typing.ClassVar[type[LmEvalEvaluatorConfig]] = LmEvalEvaluatorConfig + + _hf_model: "HuggingfaceBaseModelForCausalLM" = None + _flm_wrapper: "FastLLMLmEvalWrapper" = None + + def setup( + self, + distributed: Distributed, + run: Run, + multi_stage: FastLLMModel, + runner: ScheduleRunner, + data: Data, + phase: PhaseType, + ) -> None: + if "HUGGINGFACE_API_KEY_PATH" in os.environ: + os.environ["HF_TOKEN"] = pathlib.Path(os.environ["HUGGINGFACE_API_KEY_PATH"]).open("r").read().strip() + else: + if not "HF_TOKEN" in os.environ: + logger.warning( + "No `HF_TOKEN` or `HUGGINGFACE_API_KEY_PATH` environment variable provided. " + "Assuming the user has already logged in to the Hugging Face Hub." + ) + + from fast_llm.engine.evaluation.lm_eval.fast_llm_wrapper import FastLLMLmEvalWrapper + + super().setup(distributed, run, multi_stage, runner, data, phase) + + self._hf_model = self._multi_stage.config_class.get_huggingface_model_for_causal_lm_class()( + self._multi_stage, runner=self._runner + ) + + # For reporting purposes, just to indicate it is from Fast-LLM + # as lm_eval.simple_evaluate will take it for results['config']['model'] + self._hf_model.config.name_or_path = type(self._hf_model).__name__ + + self._flm_wrapper = FastLLMLmEvalWrapper( + model=self._hf_model, + tokenizer=self._data.tokenizer.tokenizer, + truncation=self._config.truncation, + logits_cache=self._config.logits_cache, + add_bos_token=self._config.add_bos_token, + prefix_token_id=self._config.prefix_token_id, + max_length=self._config.max_length, + ) + self._is_setup = True + + def run( + self, + training_progress: TrainingProgress | None = None, + run_index: int | None = None, + ) -> EvaluationMetrics: + assert self._is_setup + + # completed_steps is added to output_path like output_path/runs/run_index/completed_steps/ + completed_steps = 0 if training_progress is None else training_progress.completed_steps + + self._flm_wrapper.run(self._config.cli_args, completed_steps, self._run.index) + + # lm_eval logs to disc, wandb and prints to screen itself + return EvaluationMetrics() + + def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: + return None diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py new file mode 100644 index 000000000..8f4dffedf --- /dev/null +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -0,0 +1,909 @@ +import copy +import logging + +import jinja2 +import lm_eval.api.instance +import lm_eval.api.model +import lm_eval.evaluator +import lm_eval.models.utils +import lm_eval.utils +import torch +import torch.nn.functional as F +import tqdm.auto +import transformers + +from fast_llm.core.distributed import gather_object, safe_barrier, scatter_object +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results +from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM +from fast_llm.layers.transformer.rotary.config import NoRotaryConfig + +logger = logging.getLogger(__name__) + + +class FastLLMLmEvalWrapper(lm_eval.api.model.TemplateLM): + _DEFAULT_MAX_LENGTH = 2048 + _DEFAULT_MAX_GEN_TOKENS = 256 + + def __init__( + self, + model: HuggingfaceBaseModelForCausalLM, + tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast, + truncation: bool | None = False, + logits_cache: bool = True, + add_bos_token: bool | None = False, + prefix_token_id: int | None = None, + max_length: int | None = None, + ): + super().__init__() + + # === Distributed setup === + self._rank = 0 # For lm_eval: always run on main rank + self._world_size = 1 + self._distributed: Distributed = model._inference_runner._fast_llm_model.distributed + + if ( + self._distributed.config.sequence_data_rank == 0 + and self._distributed.config.pipeline_rank == 0 + and self._distributed.config.tensor_rank == 0 + ): + self._group = self._distributed.batch_data_group + else: + self._group = torch.distributed.GroupMember.NON_GROUP_MEMBER + + # === Model & tokenizer setup === + self._model = model + self._device = model.device + self._config = model.config + + assert isinstance(tokenizer, (transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast)) + self._tokenizer = tokenizer + self._tokenizer = lm_eval.models.utils.configure_pad_token(self._tokenizer, model_config=self._config) + + # === Generation/configuration parameters === + self._truncation = truncation + self._logits_cache = logits_cache + self._add_bos_token = add_bos_token + self._max_length = max_length + self._custom_prefix_token_id = prefix_token_id + if prefix_token_id is not None: + logger.info(f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}") + + # === Internal constants === + self._backend = "causal" + self._vocab_size = self._tokenizer.vocab_size + + # === Batch configuration === + self._batch_schedule = 1 + self._batch_sizes = {} # Not used dynamically by lm_eval + self._batch_size_per_gpu = model._inference_runner._batch_config.micro_batch_size + self._batch_size = self._batch_size_per_gpu * self._distributed.config.batch_data_parallel + self._max_batch_size = self._batch_size + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self._tokenizer.eos_token_id + + # overrides from TemplateLM, but not used externally + @property + def prefix_token_id(self): + # it is used as prefix for loglikelihood + if self._custom_prefix_token_id is not None: + return self._custom_prefix_token_id + if self._tokenizer.bos_token_id is not None: + return self._tokenizer.bos_token_id + return self._tokenizer.eos_token_id + + @property + def max_length(self): + # if max length manually set, return it + if self._max_length: + return self._max_length + + # check if it is absolute positional encoding and return max_position_embeddings + if hasattr(self._config.fast_llm_config.base_model, "transformer"): + # NOTE: will need to extend if more relative encoding types will be added + if isinstance(self._config.fast_llm_config.base_model.transformer.rotary, NoRotaryConfig): + return self._config.fast_llm_config.base_model.max_position_embeddings + + # check if tokenizer holds model sequence leigh info + if hasattr(self._tokenizer, "model_max_length"): + if self._tokenizer.model_max_length == 1000000000000000019884624838656: + return self._DEFAULT_MAX_LENGTH + return self._tokenizer.model_max_length + + # finally try to get sequence length from batch config + if hasattr(self._model._inference_runner._batch_config, "sequence_length"): + return self._model._inference_runner._batch_config.sequence_length + + return self._DEFAULT_MAX_LENGTH + + # @property + # def device(self): + # # only used for world_size when lm_eval world size > 1 and + # # should not be called with current lm_eval support implementation + # return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + @property + def tokenizer(self): + return self._tokenizer + + @property + def tokenizer_name(self) -> str: + return self._tokenizer.name_or_path.replace("/", "__") + + def run(self, cli_args: list[str], completed_steps: int, run_index: int): + if self._distributed.config.rank == 0: + args, simple_eval_kwargs = prepare_lm_eval_simple_eval_params(cli_args, completed_steps, run_index) + simple_eval_kwargs["model"] = self + + # Needed for reporting as batch_size is set from args not lm for reporting in evaluate + simple_eval_kwargs["batch_size"] = self._batch_size + simple_eval_kwargs["max_batch_size"] = self._max_batch_size + + # As of lm_eval commit 758c5ed891b1ca48acd8d3a0d309a827215796b7 + # Expected to be a string even if empty and not None in simple_evaluate + simple_eval_kwargs["model_args"] = "" + + results = lm_eval.evaluator.simple_evaluate(**simple_eval_kwargs) + self.stop_workers() + + # Evaluation_tracker save expects model to be either string, but if model is passed + # LM wrapper needs to be deep copyable and json serializable + simple_eval_kwargs["evaluation_tracker"].general_config_tracker.model_source = ( + self._model.config.name_or_path + ) + + if results is not None: + process_lm_eval_results( + args, + results, + simple_eval_kwargs["evaluation_tracker"], + completed_steps, + ) + else: + self.worker_model_invoke() + + # TODO: do we need it here as self.stop_workers() and self.worker_model_invoke() + # already have barrier + safe_barrier(self._distributed.world_group, f"lm_eval Run end") + + def _model_invoke( + self, + input_ids, + attention_mask, + labels, + max_length, + stop, + generate: bool, + continue_generate: bool, + **generation_kwargs, + ): + # TODO: Consider passing true messages and payloads around instead of combining all data into a large tuple. + # Messages could include types like logits, generate, finished. + + # Group is always None if world size is 1 + if self._group is None: + # Must not be called with continue_generate false on one process + assert continue_generate + return self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + ) + + world_size = self._group.size() + + assert self._group.rank() == 0 + + if continue_generate: + assert input_ids is not None + if generate: + assert max_length is not None and stop is not None + + # always divide by world_size, if not full batch, some ranks will get less work or not at all + assert self._batch_size % world_size == 0 + step = self._batch_size // world_size + + input_ids = [input_ids[i * step : (i + 1) * step] for i in range(world_size)] + attention_mask = [ + attention_mask[i * step : (i + 1) * step] if attention_mask is not None else None + for i in range(world_size) + ] + labels = [labels[i * step : (i + 1) * step] if labels is not None else None for i in range(world_size)] + + scatter_list = [ + [ + input_ids[i], + attention_mask[i], + labels[i], + max_length, + stop, + generate, + continue_generate, + generation_kwargs, + ] + for i in range(world_size) + ] + else: + scatter_list = [[None, None, None, None, None, None, False, None] for _ in range(world_size)] + + input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = ( + scatter_object( + scatter_list, + group=self._group, + ) + ) + + if not continue_generate: + return None + + assert len(input_ids) > 0 + + result = self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + ) + + gather_list = gather_object(result, group=self._group) + # Clean gather list from empty shards + gather_list = [el for el in gather_list if len(el) > 0] + + # If it was model generate tensors could be of different length + # so we aggregate results to list instead of a tensor + if generate: + result = sum((el.tolist() for el in gather_list), []) + else: + assert all(el.device.type == "cpu" for el in gather_list) + result = torch.cat(gather_list, dim=0) + + return result + + def worker_model_invoke(self): + assert self._group is not None + # if isinstance(self.group, dist.ProcessGroup): + if not isinstance(self._group, int): + # groups is None for world_size 1 + assert self._group.rank() != 0 + # on worker ranks the function need to wait to be called multiple times + while True: + input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = ( + scatter_object( + None, + group=self._group, + ) + ) + + # Stop signal was send, end waiting/processing loop + if not continue_generate: + break + + # if some data was received, work, otherwise return empty tensor + if len(input_ids) > 0: + result = self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + ) + else: + result = input_ids + + gather_object(result, group=self._group) + else: + # TODO: implement distributed model support + assert self._group == torch.distributed.GroupMember.NON_GROUP_MEMBER + safe_barrier(self._distributed.world_group, "lm_eval_end") + + def stop_workers(self): + # Group is always None if world size is 1 + if self._group is None: + return + self._model_invoke(None, None, None, None, None, None, continue_generate=False) + safe_barrier(self._distributed.world_group, "lm_eval_end") + + def _model_invoke_inner( + self, input_ids, attention_mask, labels, max_length, stop, generate: bool, **generation_kwargs + ): + if generate: + return self._model_generate_inner(input_ids, attention_mask, max_length, stop, **generation_kwargs) + else: + return self._model_call_inner(input_ids, attention_mask, labels) + + def _model_call(self, input_ids, attention_mask=None, labels=None): + return self._model_invoke( + input_ids, attention_mask, labels, None, None, generate=False, continue_generate=True + ) + + def _model_generate(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): + return self._model_invoke( + input_ids, + attention_mask, + None, + max_length, + stop, + generate=True, + continue_generate=True, + **generation_kwargs, + ) + + def _model_call_inner(self, input_ids, attention_mask=None, labels=None): + """ + :param input_ids: torch.Tensor + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape + [batch, sequence_ctx]. the size of sequence may vary from call to call + :param attention_mask: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :param labels: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :return + A torch tensor of shape [batch, sequence, vocab] with the + logits returned from the model's decoder + """ + if attention_mask is not None or labels is not None: + assert attention_mask is not None and labels is not None + + # TODO: do we need no_grad for fast_llm model? + with torch.no_grad(): + # We move logits to the CPU because they will be copied across processes and nodes + # in a multi-GPU, multi-node setup and eventually collected on the main rank. + # We cannot afford to accumulate them on rank 0 GPU, as GPU memory may already be tight. + # CPU tensors are slower, but we typically have much more CPU RAM available. + + # TODO: Check if it's possible to move some of the _loglikelihood_tokens work here + # and pass only the results around instead of the full logits. + # Computing errors here is also preferable, as copying logits across nodes and GPUs + # is inefficient and can involve gigabytes of data. + return self._model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ).logits.cpu() + + def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): + # temperature = 0.0 if not set + # if do_sample is false and temp==0.0: + # remove temperature, as do_sample=False takes care of this + # and we don't want a warning from HF + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + # build stopping criteria + stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( + self._tokenizer, stop, input_ids.shape[1], input_ids.shape[0] + ) + + return self._model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + stopping_criteria=stopping_criteria, + pad_token_id=self._tokenizer.pad_token_id, + use_cache=False, + **generation_kwargs, + ) + + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> list[int]: + """ """ + # default for None - empty dict, use predefined tokenizer param + # used for all models except for CausalLM or predefined value + special_tokens_kwargs = {} + + # by default for CausalLM - false or self.add_bos_token is set + if add_special_tokens is None: + if self._backend == "causal": + special_tokens_kwargs = {"add_special_tokens": False or self._add_bos_token} + # otherwise the method explicitly defines the value + else: + special_tokens_kwargs = {"add_special_tokens": add_special_tokens} + + encoding = self._tokenizer.encode(string, **special_tokens_kwargs) + + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + + return encoding + + def tok_batch_encode( + self, + strings: list[str], + padding_side: str = "left", + left_truncate_len: int = None, + truncation: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. + old_padding_side = self._tokenizer.padding_side + self._tokenizer.padding_side = padding_side + + add_special_tokens = {} + if self._backend == "causal": + add_special_tokens = {"add_special_tokens": False or self._add_bos_token} + + encoding = self._tokenizer( + strings, + truncation=truncation, + padding="longest", + return_tensors="pt", + **add_special_tokens, + ) + if left_truncate_len: + original_lengths = encoding["input_ids"].size(1) + if original_lengths > left_truncate_len: + logger.warn( + f"Left truncation applied. Original sequence length was {original_lengths}, " + f"truncating to last {left_truncate_len} tokens. Some content will be lost.", + ) + encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] + encoding["attention_mask"] = encoding["attention_mask"][:, -left_truncate_len:] + self._tokenizer.padding_side = old_padding_side + + return encoding["input_ids"], encoding["attention_mask"] + + def tok_decode(self, tokens, skip_special_tokens=True): + return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + + def _select_cont_toks(self, logits: torch.Tensor, contlen: int = None, inplen: int = None) -> torch.Tensor: + if self._backend == "causal": + assert contlen and inplen, "Must pass input len and cont. len to select scored logits for causal LM" + # discard right-padding. + # also discard the input/context tokens. we'll only score continuations. + logits = logits[inplen - contlen : inplen] + elif self._backend == "seq2seq": + assert contlen and not inplen, "Selecting scored logits for Seq2SeqLM requires only cont. len" + # only discard right-padding. + # the logits input to this fn only contain decoder-side tokens. + logits = logits[:contlen] + + return logits + + def loglikelihood_rolling( + self, requests: list[lm_eval.api.instance.Instance], disable_tqdm: bool = False + ) -> list[float]: + adaptive_batch_size = None + if self._batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + + # First, collect all windows from all requests + all_windows = [] # List of (request_idx, window) tuples + request_window_counts = [] # Track number of windows per request + + for req_idx, (string,) in enumerate( + tqdm.auto.tqdm( + [req.args for req in requests], + disable=(disable_tqdm or (self.rank != 0)), + ) + ): + # The tokenizer may raise: "Token indices sequence length is longer than the specified maximum sequence length for this model" + # This is expected and fine, as the sequence will be split into chunks of max_length later. + rolling_token_windows: list[tuple[list[int], list[int]]] = list( + map( + lm_eval.utils.make_disjoint_window, + lm_eval.utils.get_rolling_token_windows( + token_list=self.tok_encode(string), + prefix_token=self.prefix_token_id, + max_seq_len=self.max_length, + context_len=1, + ), + ) + ) + + # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case + windows = [(None,) + x for x in rolling_token_windows] + + # Store windows with their request index + all_windows.extend((req_idx, window) for window in windows) + request_window_counts.append(len(windows)) + + # Handle distributed case padding + pad_amnt = 0 + if self.world_size > 1: + mytensor = torch.tensor(len(all_windows), device=self._device) + gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist() + pad_amnt = max(gathered) - gathered[self.rank] + if pad_amnt > 0: + all_windows += pad_amnt * [all_windows[0]] + + all_nlls = [] + batch_size = adaptive_batch_size or self._batch_size + for i in range(0, len(all_windows), batch_size): + batch = all_windows[i : i + batch_size] + # Extract just the windows for processing, keeping track of request indices + batch_indices, batch_windows = zip(*batch) + + batch_nlls = self._loglikelihood_tokens( + requests=batch_windows, + disable_tqdm=False, + override_bs=len(batch_windows), + ) + # Store results with their request indices + all_nlls.extend(zip(batch_indices, batch_nlls)) + + # Remove padding if necessary + if (self.world_size > 1) and (pad_amnt > 0): + all_nlls = all_nlls[:-pad_amnt] + + # Reconstruct per-request loglikelihoods + loglikelihoods = [] + current_idx = 0 + for window_count in request_window_counts: + # Get all nlls for this request + request_nlls = all_nlls[current_idx : current_idx + window_count] + # Sum up the nlls for this request (discarding is_greedy) + request_total = sum(nll[0] for _, nll in request_nlls) + loglikelihoods.append(request_total) + current_idx += window_count + + string = requests[len(loglikelihoods) - 1].args[0] + self.cache_hook.add_partial("loglikelihood_rolling", (string,), request_total) + + return loglikelihoods + + def _batch_scheduler(self, pos, n_reordered_requests): + sched = pos // int(len(n_reordered_requests) / self._batch_schedule) + if sched in self._batch_sizes: + return self._batch_sizes[sched] + if (len(self._batch_sizes) > 1) and (self._batch_sizes[sched - 1] == self._max_batch_size): + # if previous batch size is already maximal, skip recomputation + self._batch_sizes[sched] = self._max_batch_size + return self._batch_sizes[sched] + print(f"Passed argument batch_size = auto:{self._batch_schedule}. Detecting largest batch size") + self._batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos) + print(f"Determined largest batch size: {self._batch_sizes[sched]}") + return self._batch_sizes[sched] + + def _loglikelihood_tokens( + self, + requests: list[tuple[tuple[str, str], list[int], list[int]]], + disable_tqdm: bool = False, + override_bs: int = None, + ) -> list[tuple[float, bool]]: + # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context + res = [] + + # NOTE: for the sort_fn, the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + # NOTE: the group_fn Defines the key to group and lookup one-token continuations + # Use with group_by="contexts" (optional)" + # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. + # speeds up some multiple-choice tasks proportionally to the number of choices. + # groups requests by context+continuation[:-1] and infer on one request/group. + re_ord = lm_eval.models.utils.Collator( + requests, + sort_fn=lambda req: (-(len(req[1]) + len(req[2])), tuple(req[1]) + tuple(req[2])), + group_by="contexts" if self._backend == "causal" and self._logits_cache else None, + group_fn=lambda req: req[-2] + req[-1][:-1], + ) + + # automatic (variable) batch size detection for vectorization + # pull longest context sample from request + n_reordered_requests = len(re_ord) + batch_size = self._batch_size if self._batch_size != "auto" else override_bs if override_bs is not None else 0 + batch_fn = ( + self._batch_scheduler + if self._batch_size == "auto" and n_reordered_requests > 0 and not override_bs + else None + ) + + chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) + pbar = tqdm.auto.tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running loglikelihood requests", + ) + for chunk in chunks: + inps = [] + cont_toks_list = [] + inplens = [] + + conts = [] + encoder_attns = [] + + padding_len_inp = None + padding_len_cont = None + # because vectorizing is annoying, we first convert each (context, continuation) pair to padded + # tensors, then we pack them together into a batch, call the model, and then pick it all apart + # again because vectorizing is annoying + + for _, context_enc, continuation_enc in chunk: + # sanity check + assert len(context_enc) > 0 + assert len(continuation_enc) > 0 + assert len(continuation_enc) <= self.max_length + + # how this all works (illustrated on a causal decoder-only setup): + # CTX CONT + # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] + # model \ \ + # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the + # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice + + # when too long to fit in context, truncate from the left + if self._backend == "causal": + total_length = len(context_enc) + len(continuation_enc) + if total_length > self.max_length + 1: + logger.warning( + f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) " + f"exceeds model's maximum length ({self.max_length}). " + f"Truncating {total_length - self.max_length + 1} tokens from the left." + ) + inp = torch.tensor( + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], + dtype=torch.long, + device=self._device, + ) + (inplen,) = inp.shape + elif self._backend == "seq2seq": + inp = torch.tensor( + (context_enc)[-self.max_length :], + dtype=torch.long, + device=self._device, + ) + (inplen,) = inp.shape + + # build encoder attn masks + encoder_attns.append(torch.ones_like(inp)) + + cont = torch.tensor( + (continuation_enc)[-self.max_length :], + # TODO: left-shift these? + # TODO: our code assumes we never end up truncating conts for either model type + dtype=torch.long, + device=self._device, + ) + (contlen,) = cont.shape + + conts.append(cont) + + padding_len_cont = max(padding_len_cont, contlen) if padding_len_cont is not None else contlen + + padding_len_inp = max(padding_len_inp, inplen) if padding_len_inp is not None else inplen + + inps.append(inp) # [1, inp_length] + cont_toks_list.append(continuation_enc) + inplens.append(inplen) + + # create encoder attn mask and batched conts, if seq2seq + call_kwargs = {} + if self._backend == "causal": + batched_inps = lm_eval.models.utils.pad_and_concat( + padding_len_inp, inps, padding_side="right" + ) # [batch, padding_len_inp] + elif self._backend == "seq2seq": + # TODO: left-pad encoder inps and mask? + batched_inps = lm_eval.models.utils.pad_and_concat(padding_len_inp, inps) # [batch, padding_len_inp] + batched_conts = lm_eval.models.utils.pad_and_concat( + padding_len_cont, conts + ) # [batch, padding_len_cont] + batched_encoder_mask = lm_eval.models.utils.pad_and_concat( + padding_len_inp, encoder_attns + ) # [batch, padding_len_inp] + call_kwargs = { + "attention_mask": batched_encoder_mask, + "labels": batched_conts, + } + + multi_logits = F.log_softmax( + self._model_call(batched_inps, **call_kwargs), dim=-1 + ) # [batch, padding_length (inp or cont), vocab] + + # TODO: Consider moving this part to per-shard execution in a multi-GPU and multi-node setup + # to avoid copying logits between GPUs and nodes, and to enable performing logits computations on the GPU. + for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip( + chunk, multi_logits, inplens, cont_toks_list + ): + # Slice to original seq length + contlen = len(cont_toks) + # take only logits in the continuation + # (discard context toks if decoder-only ; discard right-padding) + # also discards + checks for "virtual tokens" in the causal LM's input window + # from prompt/prefix tuning tokens, if applicable + ctx_len = inplen + (logits.shape[0] - padding_len_inp) if self._backend == "causal" else None + logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) + logits = logits.unsqueeze(0) # [1, seq, vocab] + + # Check if per-token argmax is exactly equal to continuation + greedy_tokens = logits.argmax(dim=-1) + + # check for one-token continuation cache hits. + # noop in case group_by != "contexts" or no cache hit and returns the + # original args. Otherwise, expands the logits batch dimension and yields each + # batch along with matching continuation tokens and prompt strings. + # logits -> [1, seq, vocab] + for request_str, cont_toks, logits in re_ord.get_cache( + req_str=request_str, + cxt_toks=ctx_tokens, + cont_toks=cont_toks, + logits=logits, + ): + # NOTE: Currently, computations are performed on the CPU due to limited GPU memory. + cont_toks = torch.tensor(cont_toks, dtype=torch.long, device="cpu").unsqueeze(0) # [1, seq] + + max_equal = (greedy_tokens == cont_toks).all() + + # Obtain log-probs at the corresponding continuation token indices + # last_token_slice = logits[:, -1, :].squeeze(0).tolist() + logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] + + # Answer: (log prob, is-exact-match) + answer = (float(logits.sum()), bool(max_equal)) + + res.append(answer) + + if request_str is not None: + # special case: loglikelihood_rolling produces a number of loglikelihood requests + # all with cache key None. instead do add_partial on the per-example level + # in the loglikelihood_rolling() function for those. + self.cache_hook.add_partial("loglikelihood", request_str, answer) + pbar.update(1) + + pbar.close() + + return re_ord.get_original(res) + + def generate_until(self, requests: list[lm_eval.api.instance.Instance], disable_tqdm: bool = False) -> list[str]: + res = [] + + pbar = tqdm.auto.tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running generate_until requests", + ) + adaptive_batch_size = None + if self._batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + # for each different set of kwargs, we execute all requests, by batch. + batch_size = ( + self._batch_size + if self._batch_size != "auto" + else adaptive_batch_size if adaptive_batch_size is not None else 0 + ) + batch_fn = self._batch_scheduler if self._batch_size == "auto" and not adaptive_batch_size else None + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + # group_fn=lambda x: x[1] -> x=(context, gen_kwargs) + # NOTE: for sort_fn, the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + re_ords = lm_eval.models.utils.Collator( + [reg.args for reg in requests], + sort_fn=lambda req: (-len(self.tok_encode(req[0])), req[0]), + group_by="gen_kwargs", + group_fn=lambda x: x[1], + ) + chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) + eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) + + for chunk in chunks: + contexts, all_gen_kwargs = zip(*chunk) + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + # unpack our keyword arguments. + if isinstance(gen_kwargs, dict): + kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 + # add EOS token to stop sequences + until = lm_eval.models.utils.handle_stop_sequences(kwargs.pop("until", None), eos=eos) + else: + raise ValueError(f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}") + if "max_gen_toks" in kwargs.keys(): + max_gen_toks = kwargs.pop("max_gen_toks") + else: + max_gen_toks = self._DEFAULT_MAX_GEN_TOKENS + + # set the max length in tokens of inputs ("context_enc") + if self._backend == "causal": + # max len for inputs = max length, minus room to generate the max new tokens + max_ctx_len = self.max_length - max_gen_toks + assert ( + max_ctx_len > 0 + ), f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})." + elif self._backend == "seq2seq": + # max len for inputs = encoder's whole max_length + max_ctx_len = self.max_length + + # encode, pad, and truncate contexts for this batch + input_ids, attention_mask = self.tok_batch_encode( + contexts, + left_truncate_len=max_ctx_len, + truncation=self._truncation, + ) + input_ids = input_ids.to(self._device) + attention_mask = attention_mask.to(self._device) + + if "max_length" not in kwargs: + kwargs["max_length"] = input_ids.shape[1] + max_gen_toks + + # perform batched generation + cont = self._model_generate( + input_ids=input_ids, + attention_mask=attention_mask, + stop=until, + **kwargs, + ) + + # cont_toks_list = cont.tolist() + cont_toks_list = cont + + for cont_toks, context in zip(cont_toks_list, contexts): + # discard context + left-padding toks if using causal decoder-only LM + if self._backend == "causal": + cont_toks = cont_toks[input_ids.shape[1] :] + + s = self.tok_decode(cont_toks) + + # use secondary stop seqs to cut off should-have-been-stopped content post-hoc + for term in until: + if len(term) > 0: + # ignore '' separator, + # for seq2seq case where self.tok_decode(self.eot_token_id) = '' + s = s.split(term)[0] + + res.append(s) + + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + + return res + + def apply_chat_template(self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True) -> str: + """ + Method to apply a chat template to a list of chat history between user and model. + """ + try: + chat_templated = self._tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + except jinja2.exceptions.TemplateError: + logger.warning("Failed to apply chat template. removing the system role in chat history.") + chat_history = [msg for msg in chat_history if msg["role"] != "system"] + chat_templated = self._tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + + return chat_templated diff --git a/fast_llm/engine/evaluation/lm_eval/utils.py b/fast_llm/engine/evaluation/lm_eval/utils.py new file mode 100644 index 000000000..afcfc1a9c --- /dev/null +++ b/fast_llm/engine/evaluation/lm_eval/utils.py @@ -0,0 +1,244 @@ +import argparse +import json +import logging +import os +import pathlib +import sys +from pathlib import Path + +import lm_eval.__main__ +import lm_eval.evaluator +import lm_eval.loggers +import lm_eval.tasks +import lm_eval.utils + +from fast_llm.utils import Assert + +logger = logging.getLogger(__name__) + + +def parse_eval_args(parser: argparse.ArgumentParser, args: list[str]) -> argparse.Namespace: + lm_eval.__main__.check_argument_types(parser) + return parser.parse_args(args) + + +def prepare_lm_eval_simple_eval_params( + cli_args: list[str], + completed_steps: int, + run_index: int, +) -> tuple[argparse.Namespace, dict[str, any]]: + """ + Parses CLI arguments for an LM evaluation run and prepares keyword arguments + for the `evaluate` function. + + This function wraps argument parsing, environment configuration, task resolution, + and metadata setup needed for evaluation with Fast-LLM's `lm_eval` wrapper. It also + handles special cases like hub token injection, dynamic sample loading, and task + listing commands. + + Args: + cli_args (list[str]): Command-line arguments, excluding the program name. + completed_steps (int): Current number of completed training steps, used to + uniquely tag evaluation output paths. + run_index (int): index of the current run of Fast-LLM experiment + + Returns: + tuple: + - argparse.Namespace: Parsed CLI arguments. + - dict: Keyword arguments to pass into `simple_evaluate`, including task list, + tracker, cache settings, random seeds, and generation parameters. + + Raises: + ValueError: If required fields like `--tasks` or `--output_path` are missing + when needed, or if misconfigured combinations are detected. + SystemExit: If special task listing flags are used. + """ + parser = lm_eval.__main__.setup_parser() + parser.add_argument( + "--no_defaults", + action="store_true", + ) + args = parse_eval_args(parser, cli_args) + + # NOTE: all this args are set by fast_llm on the model directly or not used here + Assert.eq(args.wandb_args, "") + Assert.eq(args.wandb_config_args, "") + Assert.eq(args.model, "hf") + Assert.eq(args.model_args, "") + Assert.eq(int(args.batch_size), 1) + Assert.none(args.max_batch_size) + Assert.none(args.device) + + # update the evaluation tracker args with the output path and the HF token + evaluation_tracker_args = "" + if args.output_path: + args.output_path = str(pathlib.Path(args.output_path) / f"runs/{run_index}/{completed_steps}") + evaluation_tracker_args += f",output_path={args.output_path}" + + evaluation_tracker_args = lm_eval.utils.simple_parse_args_string(evaluation_tracker_args) + evaluation_tracker = lm_eval.loggers.EvaluationTracker(**evaluation_tracker_args) + + if args.predict_only: + args.log_samples = True + if (args.log_samples or args.predict_only) and not args.output_path: + raise ValueError("Specify --output_path if providing --log_samples or --predict_only") + + if args.fewshot_as_multiturn and args.apply_chat_template is False: + raise ValueError( + "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)." + ) + + if args.include_path is not None: + args.include_path = args.include_path.split(",") + logger.info(f"Including paths: {args.include_path}") + metadata = ( + lm_eval.utils.simple_parse_args_string(args.model_args) + if isinstance(args.model_args, str) + else args.model_args if isinstance(args.model_args, dict) else {} + ) | (args.metadata if isinstance(args.metadata, dict) else lm_eval.utils.simple_parse_args_string(args.metadata)) + + task_manager = lm_eval.tasks.TaskManager( + verbosity=args.verbosity, + include_path=args.include_path, + include_defaults=not args.no_defaults, + metadata=metadata, + ) + + if args.limit: + logger.warning(" --limit SHOULD ONLY BE USED FOR TESTING." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") + if args.samples: + assert args.limit is None, "If --samples is not None, then --limit must be None." + if (samples := Path(args.samples)).is_file(): + args.samples = json.loads(samples.read_text()) + else: + args.samples = json.loads(args.samples) + + if args.tasks is None: + logger.error("Need to specify task to evaluate.") + sys.exit() + elif args.tasks == "list": + print(task_manager.list_all_tasks()) + sys.exit() + elif args.tasks == "list_groups": + print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False)) + sys.exit() + elif args.tasks == "list_tags": + print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False)) + sys.exit() + elif args.tasks == "list_subtasks": + print(task_manager.list_all_tasks(list_groups=False, list_tags=False)) + sys.exit() + else: + if os.path.isdir(args.tasks): + import glob + + task_names = [] + yaml_path = os.path.join(args.tasks, "*.yaml") + for yaml_file in glob.glob(yaml_path): + config = lm_eval.utils.load_yaml_config(yaml_file) + task_names.append(config) + else: + task_list = args.tasks.split(",") + task_names = task_manager.match_tasks(task_list) + for task in [task for task in task_list if task not in task_names]: + if os.path.isfile(task): + config = lm_eval.utils.load_yaml_config(task) + task_names.append(config) + task_missing = [ + task for task in task_list if task not in task_names and "*" not in task + ] # we don't want errors if a wildcard ("*") task name was used + + if task_missing: + missing = ", ".join(task_missing) + logger.error( + f"Tasks were not found: {missing}\n" + f"{lm_eval.utils.SPACING}Try `lm-eval --tasks list` for list of available tasks", + ) + raise ValueError( + f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all" + " available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG'" + " to troubleshoot task registration issues." + ) + + logger.info(f"Selected Tasks: {task_names}") + + request_caching_args = lm_eval.evaluator.request_caching_arg_to_dict(cache_requests=args.cache_requests) + + eval_kwargs = dict( + tasks=task_names, + num_fewshot=args.num_fewshot, + # batch_size=args.batch_size, + # max_batch_size=args.max_batch_size, + # device=args.device, + use_cache=args.use_cache, + limit=args.limit, + samples=args.samples, + check_integrity=args.check_integrity, + write_out=args.write_out, + log_samples=args.log_samples, + evaluation_tracker=evaluation_tracker, + system_instruction=args.system_instruction, + apply_chat_template=args.apply_chat_template, + fewshot_as_multiturn=args.fewshot_as_multiturn, + gen_kwargs=args.gen_kwargs, + task_manager=task_manager, + predict_only=args.predict_only, + random_seed=args.seed[0], + numpy_random_seed=args.seed[1], + torch_random_seed=args.seed[2], + fewshot_random_seed=args.seed[3], + confirm_run_unsafe_code=args.confirm_run_unsafe_code, + metadata=metadata, + **request_caching_args, + ) + + return args, eval_kwargs + + +def process_lm_eval_results( + args: argparse.Namespace, + results: dict[str, any], + evaluation_tracker: lm_eval.loggers.EvaluationTracker, + completed_steps: int | None, +) -> None: + if results is not None: + completed_steps = 0 if completed_steps is None else completed_steps + import wandb + + if args.log_samples: + samples = results.pop("samples") + dumped = json.dumps(results, indent=2, default=lm_eval.utils.handle_non_serializable, ensure_ascii=False) + if args.show_config: + print(dumped) + + batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) + + # Add W&B logging if we have the run to log to + # we expect the rest of the fast_llm code will finish the run. + if wandb.run is not None: + try: + wandb_logger = lm_eval.loggers.WandbLogger(init_args={"step": completed_steps}) + wandb_logger.post_init(results) + wandb_logger.log_eval_result() + if args.log_samples: + wandb_logger.log_eval_samples(samples) + except Exception as e: + logger.info(f"Logging to Weights and Biases failed due to {e}") + + evaluation_tracker.save_results_aggregated(results=results, samples=samples if args.log_samples else None) + + if args.log_samples: + for task_name, config in results["configs"].items(): + evaluation_tracker.save_results_samples(task_name=task_name, samples=samples[task_name]) + + if evaluation_tracker.push_results_to_hub or evaluation_tracker.push_samples_to_hub: + evaluation_tracker.recreate_metadata_card() + + # TODO: convert to logging entries instead? + print( + f"{results["config"]["model"]}, gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " + f"batch_size: {results["config"]["batch_size"]}{f' ({batch_sizes})' if batch_sizes else ''}" + ) + print(lm_eval.utils.make_table(results)) + if "groups" in results: + print(lm_eval.utils.make_table(results, "groups")) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 3c2db428d..54a82492b 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -1,3 +1,4 @@ +import logging import os import pathlib import typing @@ -14,6 +15,8 @@ from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + class HuggingfacePreTrainedModel(transformers.PreTrainedModel): config_class: typing.ClassVar[type[HuggingfaceModelConfig]] = HuggingfaceModelConfig @@ -41,6 +44,8 @@ def __init__( # The HF constructor performs a deep copy of the config, # but config.fast_llm_config may contain non-picklable items like process groups. # Temporarily remove it before the call and restore it afterward. + # TODO: Find a clean solution — overriding __deepcopy__ doesn't work here + # because internally they use copy.deepcopy(self.__dict__). fast_llm_config = config.fast_llm_config config.fast_llm_config = None super().__init__(config, **kwargs) @@ -64,6 +69,11 @@ def __init__( with transformers.modeling_utils.no_init_weights(): self.post_init() + if fast_llm_model.config.multi_stage.zero_stage == 3: + logger.warning( + "zero_stage=3 is used for the model; forward and generate will be extremely slow during inference." + ) + @classmethod def from_pretrained( cls, diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 0b8bb94f2..df9259abd 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -113,7 +113,8 @@ def forward( losses, metrics, ) - self._log_layer_forward(output, kwargs, i) + if output is not None: + self._log_layer_forward(output, kwargs, i) # TODO: very slow and memory consuming, only use for debugging for now # TODO: decide if and how we want to return diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 2f18f1360..3218a1963 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -185,12 +185,15 @@ def initialize_weights(self) -> None: # Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device) global_shape = meta.global_shape - if self._distributed_config.reproducible_init and ( - global_shape.numel() != parameter.numel() or not self._mode.on_device + if meta.requires_global_initialization or ( + self._distributed_config.reproducible_init + and (global_shape.numel() != parameter.numel() or not self._mode.on_device) ): # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) meta.init_parameter(global_param, distributed=self._distributed) + # It happens. + Assert.eq(global_param.shape, global_shape) if self._mode.on_device: parameter.copy_(fsdp.parameter_global_to_shard(global_param, meta.tensor_name)) elif self._mode.on_device: diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index efe8f714b..3dbec5348 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -54,7 +54,7 @@ class IntervalConfig(Config): def _validate(self) -> None: if self.interval: - with self._set_implicit_default(): + with self._set_implicit_default(None): self.offset %= self.interval super()._validate() @@ -388,7 +388,7 @@ def _validate(self) -> None: # TODO: Add support. Assert.eq(self.model.distributed.pipeline_parallel, 1) # TODO: Check if these work. - Assert.eq(self.model.distributed.tensor_parallel, 1) + # Assert.eq(self.model.distributed.tensor_parallel, 1) Assert.eq(self.model.distributed.sequence_data_parallel, 1) if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 64408bb06..5f5511a15 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -303,7 +303,16 @@ def _run_training(self) -> None: else: metrics = {} done = True - self._evaluator_runner.run(metrics=metrics) + self._evaluator_runner.run( + metrics=metrics, + # This is set to ensure that evaluators like lm_eval log results at the correct step if a checkpoint was loaded. + training_progress=TrainingProgress( + done=done, + completed_steps=self._completed_steps, + consumed_samples=self._consumed_samples, + consumed_tokens=self._consumed_tokens, + ), + ) if done and PhaseType.test in self._samples_per_split: log_main_rank(lambda: f"Running test phase ...") @@ -318,7 +327,7 @@ def _run_training(self) -> None: log_main_rank(formatted_metrics) self._wandb.alert("Testing results", formatted_metrics, "WARN") # TODO: This may erase some metrics. - self._wandb.log_metrics(self._completed_steps, metrics) + self._wandb.log_metrics(self._completed_steps, metrics, commit=True) def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # Tracking loss. @@ -339,6 +348,8 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: self._config.training.prefetch_factor, ) + has_test_phase = PhaseType.test in self._samples_per_split + log_main_rank("Training ...") # TODO: Synchronization is probably unnecessary. @@ -456,7 +467,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: ) if is_main_rank() and metrics: - self._wandb.log_metrics(self._completed_steps, metrics) + self._wandb.log_metrics(self._completed_steps, metrics, commit=not (done and has_test_phase)) stop = done or self._config.training.shutdown.enabled(self._completed_steps) diff --git a/fast_llm/engine/training/wandb.py b/fast_llm/engine/training/wandb.py index 185b89c28..724b5b718 100644 --- a/fast_llm/engine/training/wandb.py +++ b/fast_llm/engine/training/wandb.py @@ -44,12 +44,12 @@ def __init__(self, config: WandbConfig, run: Run, experiment_config: Config): else: self._wandb = None - def log_metrics(self, completed_steps: int, metrics: dict[str, dict[str, float | int]]) -> None: + def log_metrics(self, completed_steps: int, metrics: dict[str, dict[str, float | int]], commit: bool) -> None: # Note: metrics modified in-place if self._wandb is not None: import wandb - wandb.log(metrics, step=completed_steps) # noqa + wandb.log(metrics, step=completed_steps, commit=commit) # noqa def alert(self, title, text, level="INFO", wait=0.001) -> None: if self._wandb is not None and self._config.alert.post_alerts: diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 7a289b579..eaeaa0d18 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -151,7 +151,8 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= group.size() return loss, grad @@ -225,7 +226,8 @@ def _torch_reverse_kl_forward_backward( ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. - Much simpler and more reliable than custom implementation! + This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case. + In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss. """ Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") Assert.eq(target.shape, logits.shape) @@ -253,14 +255,10 @@ def _torch_reverse_kl_forward_backward( with torch.enable_grad(): logits_ = logits.detach().requires_grad_(grad_output is not None) - # Use log_softmax for consistency instead of _fused_softmax scaled_logits = logits_ * logits_scale_factor scaled_logits = torch.clamp(scaled_logits, min=-50, max=50) student_log_probs = torch.log_softmax(scaled_logits, dim=-1) - # Convert to probabilities for kl_div - # student_probs_ = torch.exp(student_log_probs) - # Reverse KL: input=teacher_log_probs, target=student_probs if loss_mask is None: loss = torch.nn.functional.kl_div( @@ -277,7 +275,8 @@ def _torch_reverse_kl_forward_backward( loss = (loss_per_sample * loss_mask).mean() if group is not None and target_format != TargetFormat.labels: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= group.size() if grad_output is not None: loss.backward(torch.full_like(loss, grad_output)) @@ -297,6 +296,7 @@ def reverse_kl_forward_backward( logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, + vocab_parallel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -340,6 +340,18 @@ def reverse_kl_forward_backward( if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) # TODO: implement fused? - return _torch_reverse_kl_forward_backward( - logits, target, loss_mask, grad_output, logits_scale_factor, target_format, group, teacher_softmax_temperature - ) + if vocab_parallel: + Assert.eq(teacher_softmax_temperature, 1) + Assert.eq(logits_scale_factor, 1) + raise NotImplementedError("Vocab parallel reverse KL is not implemented yet.") + else: + return _torch_reverse_kl_forward_backward( + logits, + target, + loss_mask, + grad_output, + logits_scale_factor, + target_format, + group, + teacher_softmax_temperature, + ) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 9f32ac689..07dadbc22 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -99,7 +99,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig): ) def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.tensor import init_uniform_ + from fast_llm.tensor import init_uniform_centered_ kwargs = { "hidden_dim": hidden_dim, @@ -110,9 +110,7 @@ def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> " } if self.initialization_range: mean = 0 if self.zero_centered else 1 - kwargs["weight_init_method"] = init_uniform_( - mean - self.initialization_range, mean + self.initialization_range - ) + kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) return self.module_class(**kwargs) @property diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index cd19a47a5..7249ef569 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -94,8 +94,8 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None - assert out_dim.parallel_dim is None + assert not in_dim.is_parallel + assert not out_dim.is_parallel super().__init__( in_dim, out_dim, @@ -132,7 +132,7 @@ def __init__( sequence_parallel: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None + assert not in_dim.is_parallel self._group_size = 1 if out_dim.parallel_dim is None else out_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( @@ -176,7 +176,7 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert out_dim.parallel_dim is None + assert not out_dim.is_parallel self._group_size = 1 if in_dim.parallel_dim is None else in_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 5f30beaef..bccc1d627 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -158,7 +158,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: @@ -242,7 +242,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft.py index 3a1966e51..08f3e535b 100644 --- a/fast_llm/layers/common/peft.py +++ b/fast_llm/layers/common/peft.py @@ -19,12 +19,12 @@ def lora_linear( ): layer.weight.requires_grad = False in_dim = layer._in_dim + assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." if in_dim.parallel_dim is not None: - assert in_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." in_dim = TensorDim(in_dim.name, in_dim.global_size) out_dim = layer._out_dim + assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." if out_dim.parallel_dim is not None: - assert out_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." out_dim = TensorDim(out_dim.name, out_dim.global_size) if out_channel_begin is not None or out_channel_end is not None: if out_channel_begin is None: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 25fc2b28d..eed2d134f 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -125,12 +125,16 @@ def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: if isinstance(input_, TensorMeta): - return TensorMeta.from_tensor_space( - (DefaultDimNames.scalar,), - self._tensor_space, - tensor_name="Loss", - reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa - ) + if self._is_last_head: + return TensorMeta.from_tensor_space( + (DefaultDimNames.scalar,), + self._tensor_space, + tensor_name="Loss", + reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa + ) + else: + return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") + if not self._is_last_head: # MTP: split the stacked input shared_hidden, input_ = torch.unbind(input_, dim=0) @@ -234,12 +238,24 @@ def _get_targets( else: lm_target = None - targets = (dpo_target, lm_target, distillation_target, loss_mask) - if self._sequence_parallel_logits: + targets = (dpo_target, lm_target, distillation_target) + # If we do distillation, no need to split it here as it has already been split in the embedding layer! + # if we do CPT/language modeling, we need to split the targets here! + if ( + self._config.distillation_model is not None + and self._sequence_parallel_logits + and not self._parallel_embeddings + and not self._sequence_parallel + ) or (self._config.distillation_model is None and self._sequence_parallel_logits): + # We dont split targets if they already have been split in the embedding layer! targets = [ None if target is None else split_op(target, self._tensor_space.distributed.tensor_group, 0) for target in targets ] + # Loss mask may need to be split. It was not split in the embedding layer as it is not used there. + if loss_mask is not None and self._sequence_parallel_logits: + loss_mask = split_op(loss_mask, self._tensor_space.distributed.tensor_group, 0) + targets = (*targets, loss_mask) if not any(target is not None for target in targets): # Simplify so we don't have to check every time. targets = None @@ -408,6 +424,7 @@ def _logits_cross_entropy_forward_backward( target_format=( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), + vocab_parallel=logits.shape[-1] != self._config.vocab_size, ) elif self._distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 46d629aa8..c06d85148 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,28 +1,34 @@ import enum from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace +from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig -from fast_llm.utils import Assert +from fast_llm.tensor import Initializer +from fast_llm.utils import Assert, div class SSMDimNames: - model_dim = "model_dim" # Model dimension (D) - state_dim = "state_dim" # State dimension (N) - conv_dim = "conv_dim" # Dimension of the conv1d input in mamba layers - inner_dim = "inner_dim" # Inner dimension after expansion - dt_rank = "dt_rank" # Rank of Δ - inner_proj_mamba = "inner_proj_mamba" # Inner projection dimension for mamba - inner_proj_discrete_mamba2 = "inner_proj_discrete_mamba2" # Inner projection dimension for discrete mamba2 - inner_proj_mamba2 = "inner_proj_mamba2" # Inner projection dimension for mamba2 - x_proj_dim = "x_proj_dim" # X projection dimension - head_dim = "head_dim" # Dimension of the mamba2 head (P) - conv_kernel_size = "conv_kernel_size" # Kernel size of the conv1d in mamba layers - qk_heads = "qk_heads" # Number of QK heads - v_heads = "v_heads" # Number of V heads - - # Mamba 2 - x_proj_dim_2 = "x_proj_dim_2" # d_xb + # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. + state = "ssm_state" # State dimension (N), aka head size / num channels + head_dim = "ssm_head_dim" + head_groups = "ssm_head_groups" + group_heads = "ssm_group_heads" + + convolution_kernel = "ssm_convolution_kernel" # Kernel dimension of the conv1d in mamba layers + + dt_rank = "ssm_dt_rank" + + # Composite dimensions + composite_heads = "ssm_composite_heads" + composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim" + composite_head_groups_and_state = "ssm_composite_head_groups_and_state" + + # Concatenated dimensions + concatenated_convolution = "ssm_concatenated_convolution" + concatenated_x_projection = "ssm_x_concatenated_x_projection" + concatenated_inner_projection = "ssm_concatenated_inner_projection" class SSMBlockType(enum.StrEnum): @@ -35,6 +41,32 @@ class SSMBlockType(enum.StrEnum): mamba2 = "m2" transformer = "t" + def get_mixer_class(self): + if self == SSMBlockType.mamba: + from fast_llm.layers.ssm.mamba_layer import MambaLayer + + return MambaLayer + elif self == SSMBlockType.mamba2: + from fast_llm.layers.ssm.mamba2 import Mamba2 + + return Mamba2 + elif self == SSMBlockType.mamba2_discrete: + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + + return DiscreteMamba2 + else: + raise NotImplementedError(self) + + +class DTInitType(enum.StrEnum): + constant = "constant" + random = "random" + + def get_init_method(self, scale: float) -> Initializer: + from fast_llm.tensor import init_fill_, init_uniform_centered_ + + return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) + @config_class() class SSMConfig(LLMBlockConfig): @@ -45,79 +77,87 @@ class SSMConfig(LLMBlockConfig): desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) + + # Model dimensions + # TODO: Remove (redundant default) expansion_factor: int = Field( default=2, desc="Expansion factor for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # head_size [MambaLayer, Mamba2, DiscreteMamba2] state_size: int = Field( default=16, desc="State size for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # [MambaLayer, Mamba2, DiscreteMamba2] conv_kernel_dimension: int = Field( default=4, desc="Conv kernel dimension for Mamba blocks.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - # Layer parameters - add_bias_linear: bool = Field( - default=False, - desc="Whether to use bias in SSM layers", - hint=FieldHint.architecture, - ) - + # [MambaLayer, Mamba2] dt_rank: None | int = Field( default=None, desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)", hint=FieldHint.architecture, ) - chunk_size: int = Field( - default=256, - desc="Chunk size for Mamba2 blocks.", - hint=FieldHint.architecture, - ) + # head_groups [DiscreteMamba2] n_qk_heads: int = Field( default=32, desc="Number of QK heads for Mamba2 blocks.", hint=FieldHint.architecture, ) + # heads [DiscreteMamba2]# TODO: Remove? (redundant) n_v_heads: int = Field( default=32, desc="Number of V heads for Mamba2 blocks.", hint=FieldHint.architecture, ) - activation_type: ActivationType = Field( + # c_size [MambaLayer, Mamba2, DiscreteMamba2]? + d_inner: None | int = Field( + default=None, + desc="Inner dimension for Mamba2 blocks.", + hint=FieldHint.core, + ) + # xb_size [Mamba2] + d_xb: int = Field( default=None, - desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", + desc="Dimension of the xB in Mamba2 blocks.", hint=FieldHint.architecture, ) - debug_ssm: bool = Field( + + # Model options + # add_bias_linear [Mamba2, DiscreteMamba2] [hard-coded to False in MambaLayer] + add_bias_linear: bool = Field( default=False, - desc="debug_ssm", - hint=FieldHint.optional, + desc="Whether to use bias in SSM layers", + hint=FieldHint.architecture, ) - dt_min: float = Field( - default=0.001, - desc="Minimum step size for discretization", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), + # activation_type [DiscreteMamba2] [hard-coded to silu in MambaLayer, Mamba2] + activation_type: ActivationType = Field( + default=None, + hint=FieldHint.architecture, ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), + # repeat_xb_before_conv [Mamba2] + repeat_kv_before_conv: bool = Field( + default=True, + desc="Whether to repeat x and B before (True) or after (False) the conv1d in Mamba2 blocks.", + hint=FieldHint.architecture, ) - - d_inner: None | int = Field( - default=None, - desc="Inner dimension for Mamba2 blocks.", - hint=FieldHint.core, + # chunk_size [DiscreteMamba2] + chunk_size: int = Field( + default=256, + desc="Chunk size for Mamba2 blocks.", + hint=FieldHint.architecture, ) + + # Learning rate + # lr_scale [MambaLayer, Mamba2, DiscreteMamba2] mamba_lr_scale: float | None = Field( default=None, desc="Learning rate scale for Mamba blocks.", @@ -125,43 +165,38 @@ class SSMConfig(LLMBlockConfig): valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - # Mamba 2 - repeat_kv_before_conv: bool = Field( - default=True, - desc="Whether to repeat the KV before the conv1d in Mamba2 blocks.", - hint=FieldHint.architecture, - ) - d_xb: int = Field( - default=None, - desc="Dimension of the xB in Mamba2 blocks.", - hint=FieldHint.architecture, - ) - dt_init: str = Field( - default="random", + # Initialization + # dt_weight_initialization_method [Mamba2] + dt_init: DTInitType = Field( + default=DTInitType.random, desc="Initialization method for dt", hint=FieldHint.core, ) - dt_max: float = Field( - default=0.1, - desc="Maximum step size for discretization", + # dt_weight_initialization_scale [Mamba2] + dt_scale: float = Field( + default=1.0, + desc="Scale for dt", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) + # dt_bias_initialization_min [MambaLayer, Mamba2] dt_min: float = Field( default=0.001, desc="Minimum step size for discretization", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", + # dt_bias_initialization_max [MambaLayer, Mamba2] + dt_max: float = Field( + default=0.1, + desc="Maximum step size for discretization", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_scale: float = Field( - default=1.0, - desc="Scale for dt", + # dt_bias_initialization_floor [MambaLayer, Mamba2] + dt_init_floor: float = Field( + default=1e-4, + desc="Minimum value for initializing dt", hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) @@ -172,3 +207,79 @@ def _validate(self) -> None: self.activation_type = ActivationType.silu super()._validate() Assert.geq(self.dt_max, self.dt_min) + + def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: + tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + + # Head groups are configured differently depending on the block type. + if block_type == SSMBlockType.mamba: + num_heads = div(self.d_inner, self.state_size) + num_head_groups = num_heads + elif block_type == SSMBlockType.mamba2: + num_heads = div(self.d_inner, self.state_size) + num_head_groups = div(self.d_xb, self.state_size) + elif block_type == SSMBlockType.mamba2_discrete: + # TODO: Use different variables? + num_heads = self.n_v_heads + num_head_groups = self.n_qk_heads + else: + raise NotImplementedError(block_type) + + tensor_space.add_tensor_dim(state := TensorDim(SSMDimNames.state, self.state_size)) + if block_type == SSMBlockType.mamba2_discrete: + tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, div(self.d_inner, num_heads))) + else: + head_dim = state + + tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) + tensor_space.add_tensor_dim(group_heads := TensorDim(SSMDimNames.group_heads, div(num_heads, num_head_groups))) + tensor_space.add_tensor_dim( + heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) + ) + tensor_space.add_tensor_dim( + heads_and_head_dim := CompositeTensorDim( + SSMDimNames.composite_heads_and_head_dim, (head_groups, group_heads, head_dim) + ) + ) + tensor_space.add_tensor_dim( + head_groups_and_state := CompositeTensorDim( + SSMDimNames.composite_head_groups_and_state, (head_groups, state) + ) + ) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.convolution_kernel, self.conv_kernel_dimension)) + + # DT projection + if block_type in (SSMBlockType.mamba, SSMBlockType.mamba2): + tensor_space.add_tensor_dim(dt_rank := TensorDim(SSMDimNames.dt_rank, self.dt_rank)) + + if block_type == SSMBlockType.mamba: + tensor_space.add_tensor_dim( + ConcatenatedTensorDim(SSMDimNames.concatenated_x_projection, (dt_rank, state, state)) + ) + # TODO: Use composition instead + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, (heads_and_head_dim, heads_and_head_dim) + ) + ) + elif block_type == SSMBlockType.mamba2: + # TODO: Factor out state? + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, + (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim), + ) + ) + elif block_type == SSMBlockType.mamba2_discrete: + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, + (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim, heads), + ) + ) + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_convolution, + (heads_and_head_dim, head_groups_and_state, head_groups_and_state), + ) + ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 934cd2b5d..64377b93c 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,14 +1,16 @@ import logging -import math +import typing import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.common.linear import Linear +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) @@ -30,229 +32,195 @@ _causal_conv1d_available = False -def bias_init_method(conv_weight): - fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_(-bound, bound) - - -class DiscreteMamba2(torch.nn.Module): +class DiscreteMamba2(Mixer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" + _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): - """ - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. - Other options are all experimental and should not need to be configured. - """ - # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} - super().__init__() - self.config: SSMConfig = config - bias = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None - mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") - - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.qk_heads) - td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.v_heads) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.inner_proj_discrete_mamba2) - - self.d_model = td_model.size - self.d_inner = td_inner.size - self.d_state = td_state.size - self.chunk_size = config.chunk_size - self.n_qk_heads = td_n_qk_heads.size - self.n_v_heads = td_n_v_heads.size - self.conv_kernel_size = td_conv_kernel.size - - self.act = config.activation_type.activation_fn - self.activation_name = config.activation_type.name + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) + self._config: SSMConfig = config + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) + hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) + conv1d_dim = tensor_space.get_tensor_dim(SSMDimNames.concatenated_convolution) + heads_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) + + # local_head_groups = head_groups / TP + self._local_head_groups = tensor_space.get_tensor_dim(SSMDimNames.head_groups).size + # local_heads = local_head_groups * group_heads + self._local_heads = heads_dim.size + # local_inner_size = local_heads * head_size + self._local_inner_size = inner_dim.size + # local_bc_size = local_head_groups * state + self._local_bc_size = tensor_space.get_tensor_dim(SSMDimNames.composite_head_groups_and_state).size # TODO: double check initializations # Projections - self.in_proj = Linear( - td_model, - td_inner_proj, - bias=bias, - weight_init_method=kaiming_init_(td_model.size), - lr_scale=mamba_layer_lr_scale, + self.in_proj = OutputParallelLinear( + hidden_dim, + tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(transformer_config.hidden_size), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) - self.z_bias = ( - ParameterMeta.from_dims( - (td_inner,), + if not config.add_bias_linear: + self.z_bias = ParameterMeta.from_dims( + (inner_dim,), weight_decay=False, init_method=init_zeros_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - if not bias - else 0.0 - ) - self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) - ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 - lr_scale=mamba_layer_lr_scale, + ( + conv1d_dim, + tensor_space.get_tensor_dim(DefaultDimNames.scalar), + tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), + ), + init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), + lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (td_conv,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale + (conv1d_dim,), + init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + lr_scale=lr_scale, ) - # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_n_qk_heads,), + (heads_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - - # out_proj - self.out_proj = Linear( - td_inner, - td_model, - bias=bias, - weight_init_method=kaiming_init_(td_inner.size), - lr_scale=mamba_layer_lr_scale, + self.out_proj = InputParallelLinear( + inner_dim, + hidden_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(self._config.d_inner), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) - def forward(self, hidden_states, kwargs): - """ - ON variable names and pep8: keeping some variable names as in the original code for clarity. - - Args: - u: (B, L, D), - - Returns: - outputs: dict. - outputs["hidden_states"]: (B, L, D). - outputs["state"]: inference cache. - """ - if kwargs[TransformerKwargs.sequence_first]: - raise NotImplementedError(f"Sequence-first not supported for SSMs.") - + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - input_ = hidden_states - outputs = {} - # assert state is None - batch, seqlen, dim = input_.shape - - state = None - # Hacky way to initialize state during inference - chunk_size = self.chunk_size if state is None else seqlen + sequence_length = kwargs[TransformerKwargs.sequence_q_dim].global_size # Pad input to nearest multiple of chunklen - padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size - u = torch.nn.functional.pad(input_, (0, 0, 0, padded_len - seqlen)) - - # Project input - xBCzA_log = self.in_proj(u) + padded_length = (1 + (sequence_length - 1) // self._config.chunk_size) * self._config.chunk_size + if padded_length != sequence_length: + assert not kwargs[TransformerKwargs.sequence_first] and input_.size(1) == sequence_length + input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) + + # inner_projection : (batch/local_or_padded_sequence, local_sequence/batch, hidden) + # -> (batch/local_or_padded_sequence, local_sequence/batch, inner_projection) + # inner_projection: (batch, local_or_padded_sequence, hidden) -> (batch, padded_sequence, local_inner_size) + inner_projection = self.in_proj(input_) + # Standardize to (batch, padded_sequence, inner_projection) + if kwargs[TransformerKwargs.sequence_first]: + inner_projection = inner_projection.transpose(0, 1) - ( - xBC, - z, - A_log, - ) = torch.split( - xBCzA_log, + print("QAIKOFNMJOWENM inner_projection", inner_projection.shape) + xBC, z, A_log = torch.split( + inner_projection, [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, + self._local_inner_size + 2 * self._local_bc_size, + self._local_inner_size, + self._local_heads, ], dim=-1, ) - - if state is not None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead torch.nn.functional.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = einops.rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - state["conv"].copy_( - torch.nn.functional.pad(xBC_t, (self.conv_kernel_size - xBC_t.shape[-1], 0)) - ) # Update state (B D W) + print("QAIKOFNMJOWENM xBC", xBC.shape, self._local_inner_size, self._local_bc_size) + print("QAIKOFNMJOWENM z", z.shape) + print("QAIKOFNMJOWENM A_log", A_log.shape) # Convolutional layer - xBC = self.convolutional_forward(xBC, padded_len) + # xbc: (batch, padded_sequence, local_heads * head_size + 2 * local_head_groups * state) + xBC = self.convolutional_forward(xBC, padded_length) x, B, C = torch.split( xBC, [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, + self._local_inner_size, + self._local_bc_size, + self._local_bc_size, ], dim=-1, ) - x = einops.rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) - B = einops.rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) - C = einops.rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + # x: (batch, padded_sequence, local_heads * head_size) -> (batch, padded_sequence, local_heads, head_size) + x = einops.rearrange(x, "b l (h n) -> b l h n", h=self._local_heads) + + # b,c: (batch, padded_sequence, local_head_groups * state) -> (batch, padded_sequence, local_head_groups, state) + B = einops.rearrange(B, "b l (h n) -> b l h n", h=self._local_head_groups) + C = einops.rearrange(C, "b l (h n) -> b l h n", h=self._local_head_groups) # SSM forward - result = _mamba_chunk_scan_combined( - x=x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1), + y = _mamba_chunk_scan_combined( + x=self._apply_a_log(x, A_log), dt=A_log, dt_softplus=True, - A=-torch.ones(self.n_v_heads, device=A_log.device), + A=-torch.ones(self._local_heads, device=A_log.device), B=B, C=C, - chunk_size=chunk_size, - # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(state is not None), + chunk_size=self._config.chunk_size, + return_final_states=False, ) - - if state is not None: - y, ssm_state = result - state["ssm"].copy_(ssm_state) - else: - y = result - Du = torch.einsum("h,blhp->blhp", self.D, x) - y = einops.rearrange(y + Du, "b l h p -> b l (h p)") # Norm and gate - out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) - outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - - if self._return_input: - return torch.stack([input_, outputs["hidden_states"]], dim=0) + if not self._config.add_bias_linear: + z = z + self.z_bias - # TODO: since we do not support inference for now, we only return the hidden states for now. - return outputs["hidden_states"], None + # y: (batch, padded_sequence, local_heads, head_size) -> (batch, sequence, local_heads * head_size) + y = ((y + Du).flatten(2, 3) * torch.nn.functional.silu(z))[:, :sequence_length] + if kwargs[TransformerKwargs.sequence_first]: + # TODO: Is contiguous needed? + y = y.transpose(0, 1).contiguous() + # out_proj: (batch/sequence, sequence/batch, local_heads * head_size) + # -> (batch/local_sequence, local_sequence/batch, hidden) + a, b = self.out_proj(y) + logger.info(f"EKFBN y {y.shape}") + logger.info(f"EKFBN a {a.shape}") + return self.out_proj(y) + + @torch.compile + def _apply_a_log(self, x: torch.Tensor, A_log: torch.Tensor) -> torch.Tensor: + return x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1) def convolutional_forward(self, xBC, padded_len): """Convolutional layer forward pass for the full sequence.""" - if _causal_conv1d_available and self.activation_name in ( - "silu", - "swish", - "identity", + if _causal_conv1d_available and self._config.activation_type in ( + ActivationType.silu, + ActivationType.identity, ): xBC = _causal_conv1d_fn( xBC.transpose(1, 2), - einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), + self.conv1d_weight.squeeze(1), self.conv1d_bias, - activation=None if self.activation_name == "identity" else self.activation_name, + activation=( + None + if self._config.activation_type == ActivationType.identity + else self._config.activation_type.value + ), ).transpose(1, 2) else: - xBC = self.act( + xBC = self._config.activation_type.activation_fn( torch.nn.functional.conv1d( xBC.transpose(1, 2), self.conv1d_weight, bias=self.conv1d_bias, groups=self.conv1d_weight.shape[0], - padding=self.conv_kernel_size - 1, + padding=self._config.conv_kernel_dimension - 1, )[..., :padded_len].transpose(1, 2) ) return xBC diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index ee222d6d2..986606634 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -1,6 +1,6 @@ import typing -from fast_llm.layers.transformer.transformer import BaseBlock +from fast_llm.layers.transformer.transformer import BaseBlock, Mixer if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.tensor_space import TensorSpace @@ -8,27 +8,30 @@ from fast_llm.layers.transformer.config import TransformerConfig -class LlambaBlock(BaseBlock): +class SSMBlock(BaseBlock): """ A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ _name = "Llamba block" - _mixer_module_name = "mixer" def __init__( self, - config_transformer: "TransformerConfig", - config_ssm: "SSMConfig", + transformer_config: "TransformerConfig", + ssm_config: "SSMConfig", tensor_space: "TensorSpace", - mixer_cls, - layer_index: int, + mixer_cls: type[Mixer], + block_index: int, return_input: bool = False, ): - self.mixer_cls = mixer_cls - self._config_ssm = config_ssm - self._debug_mode = self._config_ssm.debug_ssm - super().__init__(config_transformer, tensor_space, layer_index, return_input) + self._ssm_config = ssm_config + self._mixer_cls = mixer_cls + super().__init__(transformer_config, tensor_space, block_index, return_input) - def _create_mixer(self): - self.mixer = self.mixer_cls(self._config_ssm, layer_idx=self._layer_index, tensor_space=self._tensor_space) + def _create_mixer(self) -> Mixer: + return self._mixer_cls( + self._ssm_config, + tensor_space=self._tensor_space, + block_index=self._block_index, + transformer_config=self._config, + ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 8a61a8969..1ae25e44c 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,17 +1,17 @@ import logging -import math import typing -import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.common.linear import Linear +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ -from fast_llm.utils import get_lr_scale - -logger = logging.getLogger(__name__) +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ +from fast_llm.utils import Assert, div, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa @@ -27,246 +27,202 @@ except (ImportError, RuntimeError): _causal_conv1d_available = False - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def bias_init_method(conv_weight): - fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - return init_uniform_(-bound, bound) +logger = logging.getLogger(__name__) -class Mamba2(torch.nn.Module): +class Mamba2(Mixer): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ + _mixer_name: typing.ClassVar[str] = "mamba_2" + + _XZ_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.composite_heads_and_head_dim, + TransformerDimNames.sequence_q, + ) + _BC_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.composite_heads, + SSMDimNames.state, + TransformerDimNames.sequence_q, + ) + def __init__( self, config: SSMConfig, - layer_idx: int, tensor_space: TensorSpace, - return_input: bool = False, + block_index: int, + transformer_config: TransformerConfig, ): - super().__init__() - self.config: SSMConfig = config - bias: bool = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale: float | None = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None - mamba_layer_lr_scale: float | tuple[float | None, ...] | None = get_lr_scale( - self.config.mamba_lr_scale, layer_lr_scale + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) + self._config: SSMConfig = config + Assert.eq(self._config.activation_type, ActivationType.silu) + layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + + inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_head_dim) + xb_dim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) + hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) + dt_rank_dim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) + + self._local_heads = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads).size + self._local_head_groups = tensor_space.get_tensor_dim(name=SSMDimNames.head_groups).size + self._group_heads = div(self._local_heads, self._local_head_groups) + self._local_inner_size = inner_dim.size + self._local_xb_size = xb_dim.size + + conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim + self.conv1d_weight = ParameterMeta.from_dims( + ( + conv1d_dim, + tensor_space.get_tensor_dim(DefaultDimNames.scalar), + tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), + ), + init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), + lr_scale=lr_scale, ) - - td_inner: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_dim) - td_state: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.state_dim) - td_model: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.model_dim) - tdt_rank: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) - td_xb: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.x_proj_dim_2) - td_inner_proj: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_proj_mamba2) - td_conv_kernel: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel_size) - - self.repeat_kv_before_conv = config.repeat_kv_before_conv - - self.d_state = td_state.size - self.d_model = td_model.size - self.d_xb = td_xb.size - self.d_inner = td_inner.size - self.dt_rank = tdt_rank.size - - if self.repeat_kv_before_conv: - self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), - ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 - lr_scale=mamba_layer_lr_scale, - ) - - self.conv1d_bias = ParameterMeta.from_dims( - (td_inner,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale - ) - else: - self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, TensorDim("1", 1), td_conv_kernel), - init_method=init_uniform_( - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), - ), - ) - self.conv1d_bias = ParameterMeta.from_dims( - (td_xb,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale - ) - - self.activation = "silu" - - self.num_xb_head = td_xb.size // td_state.size - self.num_C_head = td_inner.size // td_state.size - self.repeat_group = self.num_C_head // self.num_xb_head - - self.in_proj = Linear( - td_model, - td_inner_proj, - bias=bias, - weight_init_method=kaiming_init_(td_model.size), - lr_scale=mamba_layer_lr_scale, + self.conv1d_bias = ParameterMeta.from_dims( + (conv1d_dim,), + init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + lr_scale=lr_scale, ) - - # Initialize special dt projection to preserve variance at initialization - dt_scale = config.dt_scale # 1.0 - dt_init_std = self.dt_rank**-0.5 * dt_scale - if config.dt_init == "constant": - dt_init = init_fill_(dt_init_std) - elif config.dt_init == "random": - dt_init = init_uniform_(-dt_init_std, dt_init_std) - else: - raise NotImplementedError - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt_max = config.dt_max # or 0.1 - dt_min = config.dt_min # or 0.001 - dt_init_floor = config.dt_init_floor # or 1e-4 - dt = torch.exp(torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( - min=dt_init_floor + self.in_proj = OutputParallelLinear( + hidden_dim, + tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(transformer_config.hidden_size), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - - def init_from_tensor_( - value: torch.Tensor, - ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - logger.info( - f"Initializing {meta.tensor_name} with shape {meta.shape}, tensor shape {tensor.shape} from value shape {value.shape}" - ) - # TODO: fix and remove try-except - try: - return tensor.copy_(value) - except RuntimeError as e: - logger.error(f"Failed to copy value to tensor: {e}") - return tensor.fill_(0.0) - return init_ - - self.dt_proj = Linear( - tdt_rank, - td_inner, + self.dt_in_proj = Linear( + hidden_dim, + dt_rank_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(transformer_config.hidden_size), + lr_scale=lr_scale, + ) + self.dt_proj = OutputParallelLinear( + dt_rank_dim, + inner_dim, bias=False, - weight_init_method=dt_init, - lr_scale=mamba_layer_lr_scale, + # Initialize special dt projection to preserve variance at initialization + weight_init_method=self._config.dt_init.get_init_method( + self._config.dt_rank**-0.5 * self._config.dt_scale + ), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, ) - # define bias outside the linear layer since its also used in the selective_scan_fn - logger.info(f"td_inner: {td_inner}, inv_dt: {inv_dt.shape}") + # define bias outside the linear layer since it's also used in the selective_scan_fn self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), init_method=init_from_tensor_(inv_dt), lr_scale=mamba_layer_lr_scale + (inner_dim,), + init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + lr_scale=lr_scale, ) - - A = einops.repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A).flatten() # Keep A_log in fp32 - logger.info(f"A_log: {A_log.shape}, td_inner: {td_inner}, td_state: {td_state}") self.A_log = ParameterMeta.from_dims( - (td_inner, td_state), - init_method=init_from_tensor_(A_log), - lr_scale=mamba_layer_lr_scale, + (inner_dim, tensor_space.get_tensor_dim(name=SSMDimNames.state)), + init_method=init_A(self._config.state_size, self._config.d_inner), + lr_scale=lr_scale, weight_decay=False, ) - self.D = ParameterMeta.from_dims( - (td_inner,), + (inner_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) - - self.out_proj = Linear( - td_inner, - td_model, - bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + self.out_proj = InputParallelLinear( + inner_dim, + hidden_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(self._config.d_inner), + sequence_parallel=self._sequence_parallel, + # TODO: lr_scale? ) - def forward(self, hidden_states, kwargs): - """ - hidden_states: (B, L, D) - Returns: same shape as hidden_states - """ + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - batch, seqlen, dim = hidden_states.shape - outputs = {} - - conv_state, ssm_state = None, None - - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + assert _causal_conv1d_available + + # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) + # -> (batch/sequence, sequence/batch, inner_projection) + inner_projection = self.in_proj(input_) + dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias + # Standardize to (batch, sequence, inner_projection) + if kwargs[TransformerKwargs.sequence_first]: + inner_projection = inner_projection.transpose(0, 1) + dt = dt.transpose(0, 1) + + sequence_length = inner_projection.size(1) + + z, x, b, c = torch.split( + inner_projection, + [self._local_inner_size, self._local_xb_size, self._local_xb_size, self._local_inner_size], + dim=2, + ) - x = einops.rearrange(x, "b l d -> b d l") - z = einops.rearrange(z, "b l d -> b d l") + # z: (batch, sequence, local_heads * state) -> (batch, local_heads * state, sequence) + z = z.transpose(1, 2) - B = einops.rearrange(B, "b l (n_group dstate) -> b n_group l dstate", dstate=self.d_state) - B = repeat_kv(B, self.repeat_group) # B, n_group, L, H - B = einops.rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() - C = einops.rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() + # x: (batch, sequence, local_head_groups * state) -> (batch, local_heads * state, sequence) + x = x.transpose(1, 2) + if self._config.repeat_kv_before_conv: + x = ( + x.unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) + .flatten(1, 2) + ) + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") + else: + x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") + x = ( + x.unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) + .flatten(1, 2) + ) - dt = self.dt_proj(dt) + self.dt_proj_bias # B, L, d_inner - dt = einops.rearrange(dt, "b l d -> b d l") # B, d_inner, L + # b: (batch, sequence, local_head_groups * state) -> (batch, local_heads, state, sequence) + b = ( + b.transpose(1, 2) + .unflatten(1, (self._local_head_groups, self._config.state_size)) + .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) + ) - if self.repeat_kv_before_conv: - assert self.repeat_group > 0 - x = einops.rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + # c: (batch, sequence, heads * state) -> (batch, heads, state, sequence) + c = c.transpose(1, 2).unflatten(1, (self._local_heads, self._config.state_size)) - assert self.activation in ["silu", "swish"] - if _causal_conv1d_available: - x = _causal_conv1d_fn( - x=x, - weight=einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), - bias=self.conv1d_bias, - activation=self.activation, - ) # B, L, D - else: - raise RuntimeError("Causal conv1d is not available. Please install causal_conv1d.") + # dt: (batch, sequence, heads * state) -> (batch, heads * state, sequence) + dt = dt.transpose(1, 2) - if not self.repeat_kv_before_conv: - x = einops.rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + if self._debug_level: + self._debug_log(z, "z", self._XZ_DIMS, kwargs) + self._debug_log(x, "x", self._XZ_DIMS, kwargs) + self._debug_log(b, "b", self._BC_DIMS, kwargs) + self._debug_log(c, "c", self._BC_DIMS, kwargs) + self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) y = selective_scan_fn( x, dt, - A, - B, - C, + -torch.exp(self.A_log.float()), + b, + c, self.D.float(), - z=z, - delta_bias=self.dt_proj_bias.float(), # self.dt_proj.bias.float(), + z, + delta_bias=self.dt_proj_bias.float(), delta_softplus=True, - return_last_state=False, ) - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(einops.rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)) - - y = einops.rearrange(y, "b d l -> b l d") - out = self.out_proj(y) - outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - return outputs["hidden_states"], None + if self._debug_level: + self._debug_log(y, "y", self._XZ_DIMS, kwargs) + + # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) + y = y.transpose(1, 2)[:, :sequence_length] + if kwargs[TransformerKwargs.sequence_first]: + # TODO: Is contiguous needed? + y = y.transpose(0, 1).contiguous() + # (batch/sequence, sequence/batch, local_heads * state) + # -> (batch/local_sequence, local_sequence/batch, hidden) + return self.out_proj(y) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 7c824d235..64c8227fc 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -1,14 +1,17 @@ +import logging import math -from typing import Callable +import typing -import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ -from fast_llm.utils import get_lr_scale +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ +from fast_llm.utils import Assert, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -17,6 +20,8 @@ except (ImportError, RuntimeError): _mamba_available = False +logger = logging.getLogger(__name__) + """ Note: this is mostly adapted from https://github.com/Zyphra/Zamba2, similar code is also in https://github.com/state-spaces/mamba. For now it only supports training and not inference. @@ -24,171 +29,137 @@ """ -def init_A(d_state, d_inner) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - # S4D real initialization - # TODO: adopt this initialization to work for tensor parallel setting! - A = einops.repeat(torch.arange(1, d_state + 1, dtype=torch.float32), "n -> d n", d=d_inner).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - if tensor.shape != A_log.shape: - if tensor.numel() == A_log.numel(): - tensor_view = tensor.view(d_inner, d_state) - tensor_view.copy_(A_log) - else: - raise ValueError(f"Tensor size {tensor.numel()} doesn't match expected size {A_log.numel()}") - else: - tensor.copy_(A_log) - return tensor - - return init_ - - -def init_dtprojbias( - d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float, factory_kwargs: dict -) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: +def init_A(d_state, d_inner) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + if tensor.numel() != d_state * d_inner: + raise ValueError("_init_A requires not supported for tensor slices.") + torch.log( + torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device) + .unsqueeze(0) + .expand(d_inner, d_state), + out=tensor, + ) + + return LambdaInitializer(init_, requires_global_initialization=True) + + +def init_dtprojbias(dt_max: float, dt_min: float, dt_init_floor: float) -> LambdaInitializer: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - dt = torch.exp( - torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) - ).clamp(min=dt_init_floor) + tensor.uniform_(math.log(dt_min), math.log(dt_max), generator=generator).exp_().clamp_min_(dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - tensor.copy_(inv_dt) - return tensor + tensor.add_(torch.log(-torch.expm1(-tensor))) - return init_ + return LambdaInitializer(init_) -class MambaLayer(torch.nn.Module): +class MambaLayer(Mixer): + _mixer_name: typing.ClassVar[str] = "mamba" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): - factory_kwargs = {} - super().__init__() - self.config: SSMConfig = config - self.layer_idx = layer_idx - - self._debug_mode = config.debug_ssm + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) + assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" + self._config = config + # TODO: It's not silu? + Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_inner_proj = tensor_space.get_tensor_dim( - SSMDimNames.inner_proj_mamba - ) # TensorDim("D_inner_2", self.d_inner * 2) - tdt_rank = tensor_space.get_tensor_dim(SSMDimNames.dt_rank) - td_x_proj = tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - self.d_conv = td_conv_kernel.size - self.d_inner = td_inner.size - self.d_state = td_state.size - self.d_model = td_model.size - self.dt_rank = tdt_rank.size - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None - mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - - self.in_proj_weight = ParameterMeta.from_dims( - (td_inner_proj, td_model), - init_method=kaiming_init_(td_model.size), + inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) + hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + + # TODO: Backward compatibility? + # TODO: lr_scale? + self.in_proj = Linear( + hidden_dim, + tensor_space.get_tensor_dim(SSMDimNames.concatenated_inner_projection), + bias=False, + weight_init_method=init_kaiming_(hidden_dim.size), ) self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("D_inner_2", self.d_inner // self.d_inner), td_conv_kernel), - init_method=kaiming_init_(td_inner.size), - lr_scale=mamba_layer_lr_scale, + ( + inner_dim, + tensor_space.get_tensor_dim(DefaultDimNames.scalar), + tensor_space.get_tensor_dim(SSMDimNames.convolution_kernel), + ), + init_method=init_kaiming_(inner_dim.size), + lr_scale=lr_scale, ) - self.conv1d_bias = None - - self.activation = "silu" - self.act = torch.nn.SiLU() - self.x_proj = Linear( - td_inner, - td_x_proj, - weight_init_method=kaiming_init_(td_inner.size), + inner_dim, + tensor_space.get_tensor_dim(SSMDimNames.concatenated_x_projection), + weight_init_method=init_kaiming_(inner_dim.size), bias=False, - lr_scale=mamba_layer_lr_scale, - **factory_kwargs, + lr_scale=lr_scale, ) self.x_proj.weight.auto_grad_accumulation = True # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( - (td_inner, tdt_rank), - init_method=kaiming_init_(tdt_rank.size), - lr_scale=mamba_layer_lr_scale, + (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.dt_rank)), + init_method=init_kaiming_(self._config.dt_rank), + lr_scale=lr_scale, ) self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), - init_method=init_dtprojbias( - self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor, factory_kwargs - ), - lr_scale=mamba_layer_lr_scale, + (inner_dim,), + init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (td_inner, td_state), + (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.state)), weight_decay=False, - init_method=init_A(self.d_state, self.d_inner), - lr_scale=mamba_layer_lr_scale, + init_method=init_A(self._config.state_size, inner_dim.size), + lr_scale=lr_scale, ) # D "skip" parameter self.D = ParameterMeta.from_dims( - (td_inner,), + (inner_dim,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + lr_scale=lr_scale, ) self.out_proj = Linear( - td_inner, - td_model, + inner_dim, + hidden_dim, bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. - weight_init_method=kaiming_init_(td_model.size), - lr_scale=mamba_layer_lr_scale, - **factory_kwargs, + weight_init_method=init_kaiming_(hidden_dim.size), + lr_scale=lr_scale, ) self.out_proj.weight.auto_grad_accumulation = True - self._return_input = return_input - def forward(self, hidden_states, kwargs): + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - batch, seqlen, dim = hidden_states.shape - - # We do matmul and transpose BLH -> HBL at the same time - xz = einops.rearrange( - self.in_proj_weight @ einops.rearrange(hidden_states, "b l d -> d (b l)"), - "d (b l) -> b d l", - l=seqlen, - ) - if self._debug_mode: - print("XZ: ", xz.shape) + in_proj = self.in_proj(input_).permute((1, 2, 0) if kwargs[TransformerKwargs.sequence_first] else (0, 2, 1)) - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s out = _mamba_inner_fn( - xz, + in_proj, self.conv1d_weight, - self.conv1d_bias, + None, self.x_proj.weight, self.dt_proj_weight, self.out_proj.weight, self.out_proj.bias, # is None here - A, + -torch.exp(self.A_log.float()), None, # input-dependent B None, # input-dependent C self.D.float(), delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) - if self._return_input: - out = torch.stack((hidden_states, out), dim=0) + if kwargs[TransformerKwargs.sequence_first]: + out = out.transpose(0, 1) return out, None diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index fbd6dd0c4..7b8bc98c8 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -8,9 +8,9 @@ from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs, TransformerSubLayerName -from fast_llm.logging import log_distributed_grad, log_distributed_tensor -from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import init_normal_, init_zeros_ +from fast_llm.utils import get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -45,27 +45,19 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(torch.nn.Module): +class Attention(Mixer): """ A self-attention layer. """ - def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index, - ): - super().__init__() + _mixer_name: typing.ClassVar[str] = "attn" + + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): + super().__init__(tensor_space, block_index, config.debug_transformer) self._transformer_dim_names = config._transformer_dim_names self._transformer_kwargs = config._transformer_kwargs self._config = config - self._tensor_space = tensor_space - # TODO Soham: fix assert - # Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) - self._layer_index = layer_index - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._debug_transformer = self._config.debug_transformer + self._causal = self._config.causal self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) @@ -89,7 +81,7 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) @@ -159,10 +151,10 @@ def _attn_fused( query, key, beta=0, - alpha=self._softmax_scale / self._layer_index, + alpha=self._softmax_scale / self._block_index, ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) - attn_weights = attn_weights.to(torch.float32) * self._layer_index + attn_weights = attn_weights.to(torch.float32) * self._block_index attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) @@ -181,19 +173,6 @@ def _attn_fused( .flatten(2) ) - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = {dim.name: dim for dim in kwargs[self._transformer_kwargs.hidden_dims]} - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) - for dim_name in dim_names - ), - tensor_name=f"transformer layer {self._layer_index} attn {name}", - dtype=input_.dtype, - ) - @property def _query_dims(self): return ( @@ -220,27 +199,6 @@ def _context_dims(self): self._transformer_dim_names.composite_dense, ) - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_transformer, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - def _query_key_value_forward( self, input_: torch.Tensor, sequence_first: bool ) -> tuple[torch.Tensor, torch.Tensor, dict[str, typing.Any]]: @@ -307,7 +265,7 @@ def _decide_window_size(self) -> int | None: # https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71 # TODO: make universal per layer config window_size = self._config.window_size - if self._config.max_window_layers is not None and self._layer_index < self._config.max_window_layers: + if self._config.max_window_layers is not None and self._block_index < self._config.max_window_layers: window_size = None return window_size @@ -348,7 +306,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) self._debug_log( key, @@ -402,7 +360,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ kwargs[self._transformer_kwargs.attention_mask_value], ) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query", self._query_dims, kwargs) self._debug_log( key, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index a46af1387..73f83ccf5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -40,11 +40,11 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory @@ -59,7 +59,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index 83f1110c1..ecf2c3fea 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -14,10 +14,10 @@ class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): super().__init__() self._name = name - self._layer_index = layer_index + self._block_index = block_index self._transformer_dim_names = config._transformer_dim_names self._transformer_kwargs = config._transformer_kwargs @@ -42,7 +42,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale lr_scale = get_lr_scale(lr_scale, layer_lr_scale) @@ -72,9 +72,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) def forward( self, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 38a80beff..d2f3bfba8 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,25 +8,85 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) +class Mixer(torch.nn.Module, abc.ABC): + """ + Base class for mixer modules. + """ + + _mixer_name: typing.ClassVar[str] + + def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): + super().__init__() + self._tensor_space = tensor_space + self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel + self._block_index = block_index + self._debug_level = debug_level + + @abc.abstractmethod + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Mixer module forward. Returns the output hidden states and an optional bias, + in case its addition can be made more efficient in `_bias_dropout_add`. + """ + + def _get_meta( + self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim + for dim in kwargs[TransformerKwargs.hidden_dims] + (kwargs[TransformerKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + for dim_name in dim_names + ), + tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", + dtype=input_.dtype, + ) + + def _debug_log( + self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> None: + # TODO: Local vs global + Assert.gt(self._debug_level, 0) + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + + class BaseBlock(Layer, abc.ABC): """ A transformer-like decoder base block with abstract mixer. """ - _mixer_module_name = "self_attn" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "mixer" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): super().__init__() self._transformer_dim_names = config._transformer_dim_names @@ -37,19 +97,20 @@ def __init__( # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - self._layer_index = layer_index + self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) # Note, layer_lr_scale does not impact the norms - # TODO: add a seperate norm_lr_scale + # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) - self._create_mixer() + # The mixer needs to be created here for backward-compatible weight ordering. + setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index + self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index ) # PEFT. @@ -57,7 +118,7 @@ def __init__( self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def _create_mixer(self): + def _create_mixer(self) -> Mixer: pass @torch.compile @@ -70,7 +131,7 @@ def _bias_dropout_add( @property def name(self) -> str: - return f"{self._name} {self._layer_index}" + return f"{self._name} {self._block_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[self._transformer_kwargs.hidden_dims] @@ -140,18 +201,21 @@ def forward( return hidden_states -class TransformerLayer(BaseBlock): +class TransformerBlock(BaseBlock): _name = "Transformer layer" - _mixer_module_name = "self_attn" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): - super().__init__(config, tensor_space, layer_index, return_input) + super().__init__(config, tensor_space, block_index, return_input) + + def _create_mixer(self) -> Mixer: + from fast_llm.layers.transformer.attention import Attention - def _create_mixer(self): - self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + return Attention(self._config, self._tensor_space, self._block_index) -class VisionTransformerLayer(TransformerLayer): +class VisionTransformerBlock(TransformerBlock): _name: str = "Vision transformer layer" diff --git a/fast_llm/logging.py b/fast_llm/logging.py index e8334de6e..6d555a0bb 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -138,6 +138,8 @@ def log_tensor[ if level < 1: return tensor = tensor.detach() + if tensor.ndim == 0: + tensor = tensor[None] save_stats = TensorLogs.config.save shape = tuple(tensor.shape) _, dtype = str(tensor.dtype).split("torch.") diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index c206ef406..534d813ff 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -7,7 +7,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.custom.config import CustomBaseModelConfig, CustomModelConfig from fast_llm.models.custom.head import CustomHead from fast_llm.models.gpt.config import GPTBaseModelConfig @@ -31,10 +31,10 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, ) for i in range(self._config.transformer.num_layers) ], diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 1e439e72e..be172af96 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -24,7 +24,7 @@ VisionTransformerKwargs, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerLayer, VisionTransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock, VisionTransformerBlock from fast_llm.layers.vision_encoder.adapter import VisionAdapter from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.layers.vision_encoder.patch_conv import PatchConv @@ -79,11 +79,11 @@ def get_output_layers(self) -> list[Layer]: for i in range(self._config.prediction_heads): if i > 0: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, # TODO MTP: which index? - layer_index=max(self._config.transformer.num_layers + i, 1), + block_index=max(self._config.transformer.num_layers + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, @@ -100,7 +100,7 @@ def get_output_layers(self) -> list[Layer]: def get_vision_layers(self) -> list[Layer]: vit_layers = [ - VisionTransformerLayer(self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1) + VisionTransformerBlock(self._config.vision_encoder.transformer, self._tensor_space, block_index=idx + 1) for idx in range(self._config.vision_encoder.transformer.num_layers) ] return [ @@ -120,10 +120,10 @@ def get_layers(self) -> list[Layer]: return [ *(self.get_embedding_layers()), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, @@ -342,19 +342,20 @@ def preprocess( reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] + token_ids = batch.token_ids if sequence_first: # Move the sequence dimension first to make sequence parallel ops more efficient. - batch.token_ids = batch.token_ids.transpose(0, 1).contiguous() + token_ids = token_ids.transpose(0, 1).contiguous() preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size if sequence_first: - tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] + tokens = token_ids[sequence_k - sequence_q : sequence_k] else: # TODO: Avoid multiple contiguous calls? - tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() + tokens = token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths if batch.chosen_spans is not None: @@ -374,10 +375,10 @@ def preprocess( if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels if sequence_first: - labels = batch.token_ids[sequence_offset : sequence_k + prediction_heads] + labels = token_ids[sequence_offset : sequence_k + prediction_heads] else: # TODO: Avoid multiple contiguous calls? - labels = batch.token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() + labels = token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config labels_cloned = False @@ -406,16 +407,32 @@ def preprocess( kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) if self._config.vision_encoder.enabled: + loss_mask = kwargs.get(LanguageModelKwargs.loss_mask, torch.ones_like(labels, dtype=torch.bool)) if self._config.vision_encoder.image_break_token is not None: if not labels_cloned: labels = labels.clone() labels_cloned = True labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) + loss_mask = torch.where( + labels == self._config.vision_encoder.image_break_token, False, loss_mask + ) + if self._config.distillation_model is not None: + kwargs[LanguageModelKwargs.loss_mask] = loss_mask if self._config.vision_encoder.image_end_token is not None: if not labels_cloned: labels = labels.clone() labels_cloned = True labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) + loss_mask = torch.where( + labels == self._config.vision_encoder.image_end_token, False, loss_mask + ) + if self._config.distillation_model is not None: + kwargs[LanguageModelKwargs.loss_mask] = loss_mask + # TODO: Check that this works. Can we remove previous loss_masking? + if self._config.distillation_model is not None: + loss_mask = kwargs.get(LanguageModelKwargs.loss_mask, torch.ones_like(labels, dtype=torch.bool)) + loss_mask = torch.where(labels == -100, False, loss_mask) + kwargs[LanguageModelKwargs.loss_mask] = loss_mask kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) @@ -452,7 +469,7 @@ def embedding(self) -> LanguageModelEmbedding: return self.layers[self.embedding_layer_index] @property - def transformer_layers(self) -> list[TransformerLayer]: + def transformer_layers(self) -> list[TransformerBlock]: return self.layers[self.embedding_layer_index + 1 : -1] @property diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 70362a40e..886fa7a32 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -6,12 +6,11 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig -from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig, SSMDimNames -from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -24,14 +23,14 @@ @config_class() -class HybridSSMBaseModelConfig(LanguageModelBaseConfig): +class HybridSSMBaseModelConfig(GPTBaseModelConfig): _abstract = False ssm: SSMConfig = Field( desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - hybrid_block_layout: list[str] | None = Field( + hybrid_block_layout: list[SSMBlockType] | None = Field( default=None, desc=f"Pattern of blocks to use in the model. Available types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, @@ -41,48 +40,16 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): desc="Multi-token prediction mixer to use in the model. If None, will use the last block type in `hybrid_block_layout`.", hint=FieldHint.optional, ) - use_megatron_initialization: bool = Field( - default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing - ) # TODO: is this needed? + # TODO: Support combination of different SSM block types. + ssm_block_type: SSMBlockType | None = Field(init=False) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: """ Setup the tensor space for the model. - Some of these can be setup directly in the layer config, but keeping them here for clarity. """ super().setup_tensor_space(tensor_space) - d_inner: int = self.ssm.d_inner - - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.model_dim, self.transformer.hidden_size)) - # Mamba-specific dimensions - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_dim, d_inner)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.state_dim, self.ssm.state_size)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, self.ssm.dt_rank)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, self.ssm.dt_rank + self.ssm.state_size * 2)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel_size, self.ssm.conv_kernel_dimension)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba, d_inner * 2)) - - if SSMBlockType.mamba2_discrete.value in self.hybrid_block_layout: - # Mamba2 specific dimensions - # as per https://github.com/cartesia-ai/edge/blob/a0e121ebed3d2324c6d762b0e211a08d62583681/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py#L66C3-L66C4 - headdim = d_inner // self.ssm.n_v_heads - Assert.eq(self.ssm.n_v_heads, d_inner // headdim) - Assert.eq(d_inner % headdim, 0) - Assert.eq(self.ssm.n_v_heads % self.ssm.n_qk_heads, 0) - - conv_dim = d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size - inner_proj_dim = 2 * d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size + self.ssm.n_v_heads - - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.head_dim, headdim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.qk_heads, self.ssm.n_qk_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.v_heads, self.ssm.n_v_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_discrete_mamba2, inner_proj_dim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) - elif SSMBlockType.mamba2.value in self.hybrid_block_layout: - inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner + self.ssm.dt_rank - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) + if self.ssm_block_type is not None: + self.ssm.setup_tensor_space(tensor_space, self.ssm_block_type) def _validate(self): with self._set_implicit_default(None): @@ -96,30 +63,21 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers if len(self.hybrid_block_layout) != self.transformer.num_layers: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: - raise ValueError( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - ) - num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) - logger.warning( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" - ) + raise ValueError(message) + num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) + logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) - Assert.custom( - lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), - f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", - ) - Assert.custom( - lambda _: self.default_mtp_type in SSMBlockType.__members__.values() or self.default_mtp_type is None, - f"Invalid MTP type: {self.default_mtp_type}. Must be one of {SSMBlockType.__members__.values()} or None", - ) - super()._validate() + ssm_block_types = set(self.hybrid_block_layout) - {SSMBlockType.transformer} + # TODO: Support combination of different SSM block types. + Assert.leq(len(ssm_block_types), 1) + self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): @@ -217,12 +175,6 @@ def _validate(self): logger.warning( "HybridSSMModelConfig is being instantiated. This model is experimental and may not work as expected." ) - if ( - self.base_model.sequence_first - or self.distributed.sequence_data_parallel > 1 - or self.distributed.sequence_tensor_parallel - ): - raise NotImplementedError(f"Sequence-first not supported for SSMs.") super()._validate() diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 640615e0e..64afbea06 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -5,6 +5,7 @@ from transformers.configuration_utils import PretrainedConfig +from fast_llm.config import MISSING from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, @@ -22,7 +23,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig -from fast_llm.layers.ssm.config import SSMBlockType +from fast_llm.layers.ssm.config import DTInitType, SSMBlockType from fast_llm.models.gpt.conversion import ( CommonLlamaHuggingfaceCheckpointHandler, LlavaHuggingfaceCheckpointHandler, @@ -55,11 +56,11 @@ class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - block_converter = RenameParamConverter( + block_converter = MappedConfigParamConverter( fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),), - ignore_missing=True, - default_value=[cls._default_block_type], + fast_llm_value=lambda x: [cls._default_block_type] if x == MISSING else x, + export_value=lambda x: [x_.value for x_ in x], ) return super()._create_config_converters() + [block_converter] @@ -215,7 +216,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ignore_missing=True, default_value=4, ), - RenameParamConverter( + MappedConfigParamConverter( fast_llm_names=(("ssm", "dt_init"),), export_names=( ( @@ -223,8 +224,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: "dt_init", ), ), - ignore_missing=True, - default_value="random", + fast_llm_value=lambda x: DTInitType.random if x == MISSING else DTInitType(x), + export_value=lambda x: x.value, ), ] @@ -289,6 +290,9 @@ def _create_weight_converters( ) # ================================================ # Mamba2 specific parameters + converters += self._get_weight_and_bias_converters( + f"layers.{offset+i+1}.mixer.dt_in_proj", f"{hf_base_prefix}model.layers.{i}.mixer.dt_in_proj", ssm_bias + ) converters += self._get_weight_and_bias_converters( f"layers.{offset+i+1}.mixer.dt_proj", f"{hf_base_prefix}model.layers.{i}.mixer.dt_proj", False ) @@ -847,6 +851,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: "AutoConfig": "configuration_llava_hybrid.LlavaHybridConfig", "AutoModel": "modeling_llava_hybrid.LlavaHybridModel", "AutoModelForVision2Seq": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + "AutoModelForCausalLM": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", }, ), ] diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index da7984c70..35e9b6885 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -849,9 +849,8 @@ def __init__( self.num_C_head = self.d_inner // self.d_state self.repeat_group = self.num_C_head // self.num_xb_head - self.in_proj = nn.Linear( - self.d_model, 2 * self.d_xb + 2 * self.d_inner + self.dt_rank, bias=bias, **factory_kwargs - ) + self.in_proj = nn.Linear(self.d_model, 2 * self.d_xb + 2 * self.d_inner, bias=bias, **factory_kwargs) + self.dt_in_proj = nn.Linear(self.d_model, self.dt_rank, bias=bias, **factory_kwargs) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=dt_proj_bias, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization @@ -939,8 +938,17 @@ def forward( outputs = {} A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split( + zxbc, + [ + self.d_inner, + self.d_xb, + self.d_xb, + self.d_inner, + ], + dim=-1, + ) x = rearrange(x, "b l d -> b d l") z = rearrange(z, "b l d -> b d l") @@ -950,7 +958,7 @@ def forward( B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) # B, L, d_inner dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: diff --git a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py index 9896d91d1..b056d3a00 100644 --- a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py +++ b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py @@ -4,6 +4,16 @@ from .configuration_llava_hybrid import LlavaHybridConfig +try: + # In the fast-llm repo, import from the SSM modeling file + from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + AprielThinkerSSMHybridModel, + HybridMambaAttentionDynamicCache, + ) +except ImportError: + # In the exported checkpoint, import from local file + from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel, HybridMambaAttentionDynamicCache + class LlavaMultiModalProjector(nn.Module): def __init__(self, config: LlavaHybridConfig): @@ -42,7 +52,6 @@ def __init__(self, config: LlavaHybridConfig): assert ( config.text_config.model_type == "apriel_ssm_thinker_hybrid" ), "Only Apriel SSM Hybrid model is supported in LlavaHybridModel" - from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel self.language_model = AprielThinkerSSMHybridModel(config.text_config) self.post_init() @@ -69,8 +78,6 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - from .modeling_ssm_hybrid_apriel15b import HybridMambaAttentionDynamicCache - # Copy of the method from `AprielThinkerSSMHybridForCausalLM` # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index df15907d2..29f115bd9 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -5,11 +5,8 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.ssm.mamba2 import Mamba2 -from fast_llm.layers.ssm.mamba_layer import MambaLayer -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -32,7 +29,6 @@ def __init__( config: HybridSSMBaseModelConfig, distributed_config: DistributedConfig, ): - self.SSM_BLOCK_CLS = LlambaBlock # TODO: extend to other block types if needed super().__init__(config, distributed_config) def get_output_layers(self) -> list[Layer]: @@ -40,52 +36,31 @@ def get_output_layers(self) -> list[Layer]: Get the output layers of the model. This includes the language model head and any additional heads specified in the configuration. """ - layers = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] + layers: list[Layer] = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] if self._config.prediction_heads > 1: block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] for i in range(1, self._config.prediction_heads): if block_type == SSMBlockType.transformer: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=len(self._config.hybrid_block_layout), + block_index=len(self._config.hybrid_block_layout), return_input=i != self._config.prediction_heads - 1, ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=Mamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + block_index=len(self._config.hybrid_block_layout), + tensor_space=self._tensor_space, + return_input=i != self._config.prediction_heads - 1, + ) + ) layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) return layers @@ -95,63 +70,35 @@ def get_layers(self) -> list[Layer]: Create a list of layers for the model, interleaving Transformer and Mamba blocks according to the block pattern. """ - layers = self.get_embedding_layers() + layers: list[Layer] = self.get_embedding_layers() # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): if block_type == SSMBlockType.transformer: # Transformer block layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, return_input=( i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 ), ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba: - # Create Mamba block - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=Mamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + block_index=i + 1, + tensor_space=self._tensor_space, + return_input=( + i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 + ), + ) + ) # Add the output layers layers += self.get_output_layers() diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index d780e4d6d..b89ed4a04 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,3 +1,4 @@ +import abc import functools import math import typing @@ -5,7 +6,7 @@ import torch from fast_llm.core.distributed import ReduceOp -from fast_llm.core.ops import gather_op, reduce_op +from fast_llm.core.ops import reduce_op from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed @@ -164,16 +165,18 @@ def local_to_global( *, distributed: Distributed, ) -> tuple[torch.Tensor, ...]: + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication - is_first_rank = distributed.config.tensor_rank == 0 - modified = False - for i, dim in enumerate(self.dims): - if dim.parallel_group is not None: - tensor = gather_op( - tensor.unflatten(i, dim.expanded_shape), dim.parallel_group, i + dim.parallel_dim_index - ).flatten(i, i + len(dim.expanded_shape) - 1) - is_first_rank, modified = is_first_rank and dim.parallel_group.rank() == 0, True + is_first_rank, modified = distributed.config.tensor_rank == 0, False + + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global(tensor, dim) + is_first_rank &= tensor_dim.parallel_dim.rank == 0 + modified = True for distributed_dim, op in self._reductions: if distributed_dim.group is not None: @@ -182,28 +185,30 @@ def local_to_global( tensor = tensor.clone() tensor = reduce_op(tensor, distributed_dim.group, op=op) is_first_rank, modified = is_first_rank and distributed_dim.group.rank() == 0, True + Assert.eq(tensor.shape, self.global_shape) return tensor, is_first_rank def global_to_local( self, tensor: torch.Tensor | SafeTensorSlice, - # Return an expanded tensor, avoiding `flatten` which copies the data. + # Return an expanded tensor, avoiding `flatten` which copies the data. TODO: Rework. expand: bool = False, ) -> torch.Tensor: """ Recover the tensor-parallel slice of a tensor. Support lazy-loaded safetensor slices. """ # Take a trivial slice to convert safetensor slices. - tensor_ = tensor[:] + tensor = tensor[:] assert not self._reductions + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.global_shape) - for i, dim in reversed(list(enumerate(self.dims))): - if dim.parallel_dim is not None and dim.parallel_dim.size > 1: - tensor_ = tensor_.unflatten(i, dim.global_expanded_shape).chunk( - dim.parallel_dim.size, i + dim.parallel_dim_index - )[dim.parallel_dim.rank] - - return tensor_ if expand else tensor_.reshape(self.shape) + for dim, tensor_dim in reversed(list(enumerate(self.dims))): + tensor = tensor_dim.global_to_local(tensor, dim, expand) + if not expand: + Assert.eq(tensor.shape, self.shape) + return tensor @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -237,7 +242,7 @@ def __init__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable[["ParameterMeta", torch.Tensor, torch.Generator], torch.Tensor] | None = None, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None" = None, weight_decay: bool = True, # Pass a list to split the parameter in contiguous (dim=0) chunks of equal size for optimization. lr_scale: float | None | tuple[float | None, ...] = None, @@ -247,7 +252,11 @@ def __init__( allow_no_grad: bool = False, ): super().__init__(data, tensor_name=tensor_name, dims=dims) - self.param_init_method = init_method + if init_method is not None and not isinstance(init_method, Initializer): + # Support non-wrapped callables for convenience. + assert callable(init_method) + init_method = LambdaInitializer(init_method) + self.param_init_method: Initializer | None = init_method self.param_weight_decay = weight_decay self._is_param = True self.param_grad_is_zero = False @@ -272,7 +281,7 @@ def __new__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None", weight_decay: bool = True, lr_scale: float | None | tuple[float | None, ...] = None, allow_sequence_tensor_parallel: bool = True, @@ -299,6 +308,10 @@ def init_parameter(self, tensor: torch.Tensor, distributed: Distributed) -> None generator = distributed.tp_init_generator if self.is_tensor_parallel else distributed.pp_init_generator self.param_init_method(self, tensor, generator) + @property + def requires_global_initialization(self) -> bool: + return self.param_init_method.requires_global_initialization + def save(self) -> dict[str, typing.Any]: return { "name": self.tensor_name, @@ -330,11 +343,32 @@ def accumulate_gradient(param: torch.Tensor, grad: torch.Tensor) -> None: triton_add(grad, param.grad_buffer, out=param.grad_buffer) # noqa -def init_fill_(value) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.fill_(value) +class Initializer(abc.ABC): + @abc.abstractmethod + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + pass - return init_ + requires_global_initialization = False + + +class LambdaInitializer(Initializer): + def __init__( + self, + init_method: typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None], + requires_global_initialization: bool = False, + ) -> None: + self._init_method = init_method + self.requires_global_initialization = requires_global_initialization + + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + return self._init_method(meta, tensor, generator) + + +def init_fill_(value: float) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor.fill_(value) + + return LambdaInitializer(init_) init_zeros_ = init_fill_(0.0) @@ -342,30 +376,35 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_normal_( - mean=0.0, std=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.normal_(mean, std, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) - return init_ + return LambdaInitializer(init_) -def kaiming_init_(d_in): +def init_kaiming_(d_in: float) -> LambdaInitializer: return init_normal_(0.0, math.sqrt(2.0 / d_in)) def init_uniform_( - low=0.0, high=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + low: float = 0.0, high: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.uniform_(low, high, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) + + return LambdaInitializer(init_) + - return init_ +def init_uniform_centered_(high: float, max_val: float | None = None, mean: float = 0.0) -> LambdaInitializer: + return init_uniform_( + mean - high, + mean + high, + min_val=None if max_val is None else mean - max_val, + max_val=None if max_val is None else mean + max_val, + ) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 472f5e9b7..58285d408 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -2,8 +2,10 @@ import itertools import logging import math +import os import signal import typing +import warnings from typing import Callable if typing.TYPE_CHECKING: @@ -395,6 +397,26 @@ def interrupted(self): return self._interrupted +def set_global_variables(disable_torch_dynamo: bool = False) -> None: + # Set global and environment variables. This needs to be called before importing any third-party package. + # TODO: Find an alternative to get reliable tensor-parallel overlap. + if os.environ.get("CUDA_DEVICE_MAX_CONNECTIONS", ""): + warnings.warn("Setting CUDA_DEVICE_MAX_CONNECTIONS breaks things.") + # All distributed workers need the same hash seed for consistent hashing. + if "PYTHONHASHSEED" not in os.environ: + warnings.warn("PYTHONHASHSEED should be set and to the same value for all workers.") + # On systems with more than 64 cores, numexpr may log an error and ignore the thread setting. + if "NUMEXPR_MAX_THREADS" not in os.environ: + import multiprocessing + + os.environ["NUMEXPR_MAX_THREADS"] = str(multiprocessing.cpu_count()) + + if disable_torch_dynamo: + import torch._dynamo + + torch._dynamo.config.disable = True # noqa + + _global_max_allocated = 0 _global_max_reserved = 0 diff --git a/mkdocs.yaml b/mkdocs.yaml index ab71bc231..85fd4bff0 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -179,6 +179,7 @@ nav: - Configuration: user_guide/configuration.md - Multi-Stage: user_guide/multi-stage.md - Parallelism: user_guide/parallelism.md + - Evaluators: user_guide/evaluators.md - Developer Guide: - Configuration: developer_guide/configuration.md - Model: diff --git a/setup.cfg b/setup.cfg index 715f6b630..baa6e4adc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,9 +48,14 @@ HUGGINGFACE = # Required to run SSMs # To install on cpu environment (ex. for IDE support): -# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation +# MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 + cartesia_pytorch>=0.0.2 + +GENERATION = + lm_eval>=0.4.9 + # Required for supporting vision inputs VISION = diff --git a/tests/conftest.py b/tests/conftest.py index 298117e1d..19bdfe5d9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,9 +8,13 @@ import pytest import xdist.scheduler -from fast_llm.utils import get_and_reset_memory_usage_mib +from fast_llm.utils import get_and_reset_memory_usage_mib, set_global_variables from tests.utils.depends import DependencyManager +# TODO: Is this early enough? +set_global_variables() # isort: skip + + if worker_name := os.environ.get("PYTEST_XDIST_WORKER"): if gpus := os.environ.get("CUDA_VISIBLE_DEVICES"): # We set the device through "CUDA_VISIBLE_DEVICES", and this needs to happen before importing torch. @@ -225,7 +229,9 @@ def pytest_terminal_summary(terminalreporter): terminalreporter.write_sep("=", "Highest gpu memory usage", bold=True) sorted_nodeids = sorted( resource_reports.keys(), - key=lambda nodeid: resource_reports[nodeid]["max_reserved"], + key=lambda nodeid: ( + resource_reports[nodeid]["max_reserved"] if "max_reserved" in resource_reports[nodeid] else 0 + ), reverse=True, ) for nodeid in sorted_nodeids[: terminalreporter.config.getoption("--show-gpu-memory")]: diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 438782dfe..3e6c37632 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -193,6 +193,7 @@ def test_gpt_blended_mixed(): def test_gpt_blended_mixed_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index e951cc2b1..4f36cdf89 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -39,6 +39,7 @@ def test_gpt_concatenate(): def test_gpt_concatenate_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 7472f1958..004b96289 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -58,6 +58,7 @@ def test_gpt_fim(): def test_gpt_fim_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { @@ -81,6 +82,7 @@ def test_gpt_fim_data(): def test_gpt_fim_data_legacy(): + get_test_dataset() get_test_data_and_compare_samples( { "format": "list", diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index da719a42d..665faf7ed 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -284,10 +284,15 @@ def test_load_pretrained( @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_huggingface_model(model_testing_config, get_convert_path): # Test that Fast-LLM's Hugging Face wrapper produces the same results as the converted Hugging Face model. + # TODO: Stress the importance of this test as the main correctness test for most models. # TODO: Review test. Move to test_generate? fast_llm_path = get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) hf_path = get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat) - model_ref = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( + try: + hf_class = model_testing_config.huggingface_model_for_causal_lm_class + except NotImplementedError: + pytest.skip(f"Hugging Face wrapper not implemented for {model_testing_config.name}.") + model_ref = hf_class.from_pretrained( CheckpointLoadConfig( path=get_convert_path(), format=DistributedCheckpointFormat, @@ -298,8 +303,8 @@ def test_huggingface_model(model_testing_config, get_convert_path): 0, model_ref.config.fast_llm_config.base_model.vocab_size, size=(4, 100), dtype=torch.int64, device="cuda" ) output_ref = model_ref(test_input) - model_from_fast_llm = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained(fast_llm_path) - model_from_hf = model_testing_config.huggingface_model_for_causal_lm_class.from_pretrained( + model_from_fast_llm = hf_class.from_pretrained(fast_llm_path) + model_from_hf = hf_class.from_pretrained( CheckpointLoadConfig( path=hf_path, format=model_testing_config.checkpoint_format, diff --git a/tests/models/test_lm_eval.py b/tests/models/test_lm_eval.py new file mode 100644 index 000000000..b9e2aa8c3 --- /dev/null +++ b/tests/models/test_lm_eval.py @@ -0,0 +1,124 @@ +import pathlib +import shutil + +import pytest + +from tests.utils.dataset import TOKENIZER_PATH, download_santacoder_tokenizer +from tests.utils.distributed_configs import DistributedTestingConfig +from tests.utils.model_configs import ModelTestingGroup +from tests.utils.utils import requires_cuda + +# NOTE: These tests only verify that the functionality runs without crashing. +# NOTE: The tokenizer is from a LLaMA-style model, which may not be suitable for all models, +# but it should be sufficient since we are not concerned with actual accuracy in this tests. + + +@pytest.fixture(scope="module") +def tokenizer_path(): + download_santacoder_tokenizer() + return TOKENIZER_PATH + + +@pytest.fixture(scope="function") +def get_lm_eval_config(tokenizer_path, monkeypatch): + # TODO: Investigate why loading the tokenizer here gives a vocab_size + # smaller than 49157, which is the size when loaded by Fast-LLM. + import lm_eval.evaluator + + # lm_eval gathers lots of system info when reporting results, and this is extremely slow, so we skip here. + monkeypatch.setattr(lm_eval.evaluator, "add_env_info", lambda x: None, raising=True) + + def do_get_lm_eval_config(base_path): + import lm_eval.tasks + + task_dir = pathlib.Path(lm_eval.tasks.__file__).parent.resolve() + return [ + f"data.tokenizer.path={tokenizer_path}", + f"model.base_model.vocab_size=49157", + "training.evaluators.evaluation_test.interval=2", + "training.evaluators.evaluation_test.evaluator.type=lm_eval", + "training.evaluators.evaluation_test.evaluator.cli_args=" + f'["--tasks=wikitext",' + f'"--output_path={str(base_path / "lm_eval")}",' + # lm_eval loads all available tasks by default which is slow. + f'"--include_path={str(task_dir / "wikitext")}",' + f'"--no_defaults",' + f'"--limit=1",' + f'"--batch_size=1",' + f'"--verbosity=DEBUG"]', + ] + + return do_get_lm_eval_config + + +# "gsm8k,xnli_en,wikitext" + + +@requires_cuda +@pytest.mark.model_testing_group(ModelTestingGroup.generate) +def test_lm_eval_in_training(run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config): + run_test_script_for_all_models( + distributed_testing_config=DistributedTestingConfig( + name="lm_eval_in_training", + config_args=get_lm_eval_config(run_test_script_base_path / "lm_eval_in_training") + + ["training.checkpoint.interval=2"], + ) + ) + + +@pytest.fixture(scope="module") +def copy_training_output(run_test_script_base_path: pathlib.Path): + def do_copy_training_output(distributed_testing_config: DistributedTestingConfig): + self_path = run_test_script_base_path / distributed_testing_config.name + shutil.copytree(run_test_script_base_path / distributed_testing_config.compare, self_path) + + return do_copy_training_output + + +@requires_cuda +@pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.generate) +def test_lm_eval_evaluation_last_checkpoint( + run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config, copy_training_output +): + distributed_testing_config = DistributedTestingConfig( + name="lm_eval_evaluation_last_checkpoint", + config_args=get_lm_eval_config(run_test_script_base_path / "lm_eval_evaluation_last_checkpoint"), + compare="lm_eval_in_training", + ) + copy_training_output(distributed_testing_config) + run_test_script_for_all_models(distributed_testing_config=distributed_testing_config, runnable_type="evaluate") + + +@requires_cuda +@pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"]) +@pytest.mark.model_testing_group(ModelTestingGroup.generate) +def test_lm_eval_evaluation_from_pretrained( + run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config +): + run_test_script_for_all_models( + distributed_testing_config=DistributedTestingConfig( + name="lm_eval_evaluation_from_pretrained", + config_args=get_lm_eval_config(run_test_script_base_path / "lm_eval_evaluation_from_pretrained") + + [ + "pretrained.format=distributed", + f"pretrained.path={run_test_script_base_path/'lm_eval_in_training/checkpoint/2'}", + "pretrained.model_weights=True", + ], + ) + ) + + +# TODO: rewrite for a new distributed test function +# @requires_cuda +# @pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"]) +# @pytest.mark.model_testing_group(ModelTestingGroup.generate, ModelTestingGroup.distributed) +# def test_lm_eval_in_training_dp2(run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config): +# run_test_script_for_all_models( +# distributed_testing_config=DistributedTestingConfig( +# name="lm_eval_in_training_dp2", +# config_args=get_lm_eval_config(run_test_script_base_path / "lm_eval_in_training_dp2") +# + ["training.checkpoint.interval=1"], +# num_gpus=2, +# ) +# ) diff --git a/tests/test_attention.py b/tests/test_attention.py index 87b0d3e59..dd36b840a 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -17,12 +17,12 @@ def test_decide_window_size(): # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 2 + attention._block_index = 2 assert attention._decide_window_size() == 512 # Arrange - Case 2: window_size is None (layer_index < max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 1 + attention._block_index = 1 assert attention._decide_window_size() is None # Arrange - Case 3: max_window_layers is None (always return window_size) diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index c530a170c..2f125717e 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,9 +3,10 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.utils import Assert +from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -23,6 +24,7 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): + get_model_test_dataset() args = model_testing_config.config_args + ["run.tensor_logs.save=False"] model_ref = _get_trainer_from_args(args, model_testing_config.model_type)._multi_stage model_frozen = _get_trainer_from_args( @@ -39,7 +41,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerLayer, LlambaBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, SSMBlock)) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip( diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 52b51c8a5..694faa55b 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -14,21 +14,15 @@ from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat from fast_llm.models.ssm.model import HybridSSMModel -try: - from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel as LMHeadModel -except ImportError: - LMHeadModel = None - +@pytest.mark.skip("Disabled due to cartesia_pytorch installation issue") @pytest.mark.slow -@pytest.mark.skipif( - LMHeadModel is None, - reason=f"cartesia_pytorch.Llamba not installed", -) def test_load_from_llamba_checkpoint(): """ Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. """ + import cartesia_pytorch.Llamba.llamba + vocab_size = 128256 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json batch_size = 2 seq_length = 32 @@ -38,7 +32,7 @@ def test_load_from_llamba_checkpoint(): x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") - hf_model = LMHeadModel.from_pretrained(path, strict=True).to("cuda") + hf_model = cartesia_pytorch.Llamba.llamba.LMHeadModel.from_pretrained(path, strict=True).to("cuda") parameter_sum_hf = sum(p.detach().sum().cpu().item() for p in hf_model.parameters()) hf_logits = hf_model(x)["logits"].cpu() del hf_model diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index a4136c40e..b770675d4 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -24,6 +24,13 @@ MODEL_TEST_VOCAB_SIZE = 384 +def download_santacoder_tokenizer(): + if not TOKENIZER_FILE.is_file(): + import transformers + + transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) + + def get_test_dataset( prefix: pathlib.Path = DATASET_PREFIX, seed: int = 1234, @@ -32,10 +39,7 @@ def get_test_dataset( vocab_size: int = TEST_VOCAB_SIZE, max_spans: int = 0, ): - if not TOKENIZER_FILE.is_file(): - import transformers - - transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) + download_santacoder_tokenizer() if not ( prefix.with_suffix(".idx").is_file() diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 643ca6c27..96982e510 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -28,6 +28,10 @@ from tests.utils.dataset import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE from tests.utils.distributed_configs import DistributedTestingConfig +from fast_llm.engine.evaluation.evaluators import ( # isort:skip # needed for dynamic type registration + EvaluatorsConfig, +) + _LOG_LEVEL = int(os.environ.get("LOG_LEVEL", 13)) @@ -75,6 +79,17 @@ def trainer_config(self) -> TrainerConfig: # See `RunnableConfig._from_parsed_args` return self.trainer_config_class.from_dict(self.trainer_config_class._parse_updates(self.config_args)) + @functools.cached_property + def evaluators_config_class(self) -> type[EvaluatorsConfig]: + # EvaluatorsConfig is a base class that, during parse_and_run, replaces itself with the appropriate TrainingConfig subclass. + # Therefore, the arguments passed to EvaluatorsConfig.parse_and_run must include the model type as the first element. + return EvaluatorsConfig + + @functools.cached_property + def evaluators_config(self) -> EvaluatorsConfig: + # See `RunnableConfig._from_parsed_args` + return self.evaluators_config_class.from_dict(self.evaluators_config_class._parse_updates(self.config_args)) + @functools.cached_property def model_config_class(self) -> type[FastLLMModelConfig]: # TODO: Ok to assume the model and trainer have the same name? @@ -456,7 +471,7 @@ def _update_and_add_testing_config( ) _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. + # Tests hybrid Mamba, llamba converter. "llama", "llava", extra_args=[ @@ -497,10 +512,8 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m']", - "model.base_model.ssm.state_size=8", - "model.base_model.ssm.chunk_size=32", - "model.base_model.ssm.n_qk_heads=8", - "model.base_model.ssm.n_v_heads=8", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=16", ], megatron_args=None, checkpoint_format=LLambaHuggingfaceCheckpointFormat, @@ -508,46 +521,60 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.broken, # TODO: Fix and bring back to `testing_groups` + ModelTestingGroup.convert: ModelTestingGroupAction.broken, ModelTestingGroup.generate: ModelTestingGroupAction.broken, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - # TODO: Fix and bring back to `testing_groups` - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, + ModelTestingGroup.distributed: ModelTestingGroupAction.not_implemented, }, compare_factor=2.0, - # SSMs don't support sequence-first configurations. - skip_tests=("sf", "sdp", "stp", "ms"), + # Micro-sequence split not supported. + skip_tests=("sdp", "ms"), ) - _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. - "llamba", - "hybrid_discrete_mamba2", + # Tests hybrid Mamba 2. + "llama", + "hybrid_mamba2", model_type="hybrid_ssm", extra_args=[ - "model.base_model.hybrid_block_layout=['t','m2d']", + "model.base_model.hybrid_block_layout=['t','m2']", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=8", + "model.base_model.ssm.d_xb=256", + # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}" ], megatron_args=None, - checkpoint_format=None, + checkpoint_format=AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, + compare_factor=2.0, + # Micro-sequence split not supported. + skip_tests=( + "sdp", + "ms", + ), # "pp","dp", "ce","16", "bf", "df", "stp"), ) + _update_and_add_testing_config( - # Tests hybrid ssm, llamba converter. - "llamba", - "hybrid_mamba2", + # Tests hybrid discrete Mamba 2. + "llama", + "hybrid_discrete_mamba2", model_type="hybrid_ssm", extra_args=[ - "model.base_model.hybrid_block_layout=['t','m2']", + "model.base_model.hybrid_block_layout=['t','m2d']", + "model.base_model.ssm.d_inner=512", + "model.base_model.ssm.state_size=8", + "model.base_model.ssm.n_qk_heads=8", + "model.base_model.ssm.n_v_heads=16", + "model.base_model.ssm.chunk_size=32", ], megatron_args=None, checkpoint_format=AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, @@ -557,8 +584,12 @@ def _update_and_add_testing_config( ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + # TODO: Implement + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, + compare_factor=2.0, + # Micro-sequence split and sequence-first not supported. + skip_tests=("sdp", "ms"), ) _update_and_add_testing_config( diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py index b8f996a82..7d706ebdb 100644 --- a/tests/utils/run_test_script.py +++ b/tests/utils/run_test_script.py @@ -69,12 +69,13 @@ def do_run_test_script_for_all_models( distributed_testing_config: DistributedTestingConfig, model_testing_config: ModelTestingConfig, base_path: pathlib.Path, + runnable_type: str = "train", ): Assert.leq(distributed_testing_config.num_gpus, DistributedConfig.default_world_size) get_model_test_dataset() args = [ "fast-llm", - "train", + runnable_type, model_testing_config.model_type, *model_testing_config.config_args, *distributed_testing_config.config_args, @@ -83,7 +84,12 @@ def do_run_test_script_for_all_models( f"run.experiment_dir={base_path/distributed_testing_config.name}", ] print(" ".join(args)) - model_testing_config.trainer_config_class.parse_and_run(args[3:]) + if runnable_type == "train": + model_testing_config.trainer_config_class.parse_and_run(args[3:]) + elif runnable_type == "evaluate": + model_testing_config.evaluators_config_class.parse_and_run(args[2:]) + else: + raise ValueError(f"Unknown runnable_type {runnable_type}") @pytest.fixture(scope="function")