From 7709e65d05652906936ada2cdb53c31ab4e68663 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 8 Apr 2025 06:51:19 +0000 Subject: [PATCH 01/97] WIP: multimodal support --- fast_llm/data/config.py | 39 ++++++++++++++++++ fast_llm/data/image_processor.py | 40 +++++++++++++++++++ .../data/preparator/gpt_memmap/prepare.py | 3 ++ fast_llm/data/processor.py | 11 +++++ setup.cfg | 2 + 5 files changed, 95 insertions(+) create mode 100644 fast_llm/data/image_processor.py create mode 100644 fast_llm/data/processor.py diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 1586d370d..351dcaaef 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -34,3 +34,42 @@ class TokenizerConfig(Config): desc="Path to the tokenizer file.", hint=FieldHint.core, ) + + +@config_class() +class ImageProcessorConfig(Config): + """ + Configuration for the image processor + """ + + # Defaults taken from [pixtral](https://github.com/huggingface/transformers/blob/794fde7b1c3d041519fc28ea3e1461b0cfcad4e7/src/transformers/models/pixtral/image_processing_pixtral.py#L201) + patch_size: list[int] = Field( + default_factory=lambda: [16, 16], + desc="Size for each path extracted from the image. Each patch corresponds to a token for the transformer", + hint=FieldHint.optional, + ) + max_height: int = Field( + default=1024, + desc="Maximum height of the image. Image will be resized if larger", + hint=FieldHint.optional, + ) + max_width: int = Field( + default=1024, + desc="Maximum width of the image. Image will be resized if larger", + hint=FieldHint.optional, + ) + mean: list[float] = Field( + default_factory=lambda: [0.48145466, 0.4578275, 0.40821073], + desc="Mean RGB values for pixel normalization", + hint=FieldHint.optional, + ) + std: list[float] = Field( + default_factory=lambda: [0.26862954, 0.26130258, 0.27577711], + desc="Standard deviation RGB values for pixel normalization", + hint=FieldHint.optional, + ) + rescale_factor: float = Field( + default=255.0, + desc="Diminisher factor for pixel normalization", + hint=FieldHint.optional, + ) diff --git a/fast_llm/data/image_processor.py b/fast_llm/data/image_processor.py new file mode 100644 index 000000000..cf4c6e938 --- /dev/null +++ b/fast_llm/data/image_processor.py @@ -0,0 +1,40 @@ +import math + +import torch +from torchvision.transforms.v2 import functional as F + +from fast_llm.data.config import ImageProcessorConfig + + +class ImageProcessor: + def __init__(self, config: ImageProcessorConfig): + self.patch_size = config.patch_size + self.mean = config.mean / config.rescale_factor + self.std = config.std / config.rescale_factor + self.max_height = config.max_height + self.max_width = config.max_width + assert ( + self.max_height % self.patch_size[0] == 0 + ), "max_height must be divisible by patch_size[0]. Found {max_height} and {self.patch_size[0]}" + assert ( + self.max_width % self.patch_size[1] == 0 + ), "max_width must be divisible by patch_size[1]. Found {max_width} and {self.patch_size[1]}" + + def resize(self, image: torch.Tensor) -> torch.Tensor: + # Resize the image to the specified size + height = image.shape[0] + width = image.shape[1] + ratio = max(height / self.max_height, width / self.max_width) + if ratio > 1: + height = math.ceil(height / ratio) + width = math.ceil(width / ratio) + else: + height = self.patch_size[0] * math.ceil(height / self.self.patch_size[0]) + width = self.patch_size[1] * math.ceil(width / self.patch_size[1]) + + # TODO: options for interpolation mode + return F.resize(image, size=(height, width), interpolation=F.InterpolationMode.BICUBIC) + + def normalize(self, image: torch.Tensor) -> torch.Tensor: + # Normalize the image using the mean and std + return F.normalize(image, mean=self.mean, std=self.std) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index b3dae1df1..5cfad9ec5 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -38,6 +38,9 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D _tokenizer: Tokenizer _data_type: DataType + def _process_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: + pass + def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: input_ids = [ np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) diff --git a/fast_llm/data/processor.py b/fast_llm/data/processor.py new file mode 100644 index 000000000..43b1cda83 --- /dev/null +++ b/fast_llm/data/processor.py @@ -0,0 +1,11 @@ +from fast_llm.data.tokenizer import Tokenizer + + +class MultiModalProcessor: + """ + Combines multiple modalities (text and image) and converts to tokens/patches for text and images. + """ + + def __init__(self, tokenizer: Tokenizer, image_processor=None): + self._tokenizer = tokenizer + self._image_processor = image_processor diff --git a/setup.cfg b/setup.cfg index c21f02a7b..3c1dad9da 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,8 @@ OPTIONAL = # Miscellanous requests>=2.32.3 tqdm>=4.66.3 + # Vision Tools + torchvision>=0.20.0 DEV = # Pre-commit git hook From 0db2bd21218fa133d4a1e41223552ece8f3044a7 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 9 Apr 2025 06:19:10 +0000 Subject: [PATCH 02/97] rough idea for memmap --- fast_llm/data/config.py | 18 ++++ fast_llm/data/dataset/gpt/memmap.py | 59 ++++++++++-- fast_llm/data/dataset/gpt/sampled.py | 2 + fast_llm/data/image_processor.py | 3 + fast_llm/data/preparator/gpt_memmap/config.py | 8 +- .../data/preparator/gpt_memmap/prepare.py | 92 ++++++++++++++----- fast_llm/data/processor.py | 11 --- 7 files changed, 145 insertions(+), 48 deletions(-) delete mode 100644 fast_llm/data/processor.py diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 351dcaaef..8c2c3c28e 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -73,3 +73,21 @@ class ImageProcessorConfig(Config): desc="Diminisher factor for pixel normalization", hint=FieldHint.optional, ) + + +@config_class() +class MultiModalProcessorConfig(Config): + """ + Wrapper config that stores the `ImageProcessorConfig` and `TokenizerConfig` + """ + + tokenizer: TokenizerConfig = Field( + default_factory=TokenizerConfig, + desc="Configuration for the tokenizer.", + hint=FieldHint.core, + ) + image_processor: ImageProcessorConfig = Field( + default_factory=ImageProcessorConfig, + desc="Configuration for the image processor.", + hint=FieldHint.core, + ) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index ef060b008..c8b2592f1 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -38,10 +38,14 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: self._has_spans = struct.unpack("= 3: + self._has_images = struct.unpack("= 2: self._spans = [] self._num_spans = np.frombuffer( self._index_bin_buffer, @@ -82,6 +86,8 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None offset=span_offset + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, ).reshape(-1, 2) ) + if self._has_images and self._version >= 3: + self._image_sizes = np.frombuffer() self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) @@ -151,7 +157,11 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Initialize metadata dtype = None num_documents = 0 - lengths = [] + doc_lengths = [] + n_images = [] + im_lengths = [] + im_positions = [] + total_images = 0 pointers = [] offset = 0 # number of spans for each document @@ -160,8 +170,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP prefix = pathlib.Path(prefix) prefix.parent.mkdir(parents=True, exist_ok=True) + pathlib.Path(prefix + "_images") # Write the binary data file (.bin) lazily + # TODO Soham: append image tokens along with text tokens with prefix.with_suffix(".bin").open("wb") as bin_stream: for document in documents: # Infer dtype from the first document @@ -174,10 +186,18 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Write document to binary file bin_stream.write(document.token_ids.tobytes(order="C")) + if document.images: + n_images.append(len(document.images)) + total_images += len(document.images) + for image, image_position in zip(document.images, document.image_positions): + im_lengths.append(image.size) + im_positions.append(document.image_positions) + bin_stream.write(image.tobytes(order="C")) # Update metadata doc_length = len(document.token_ids) - lengths.append(doc_length) + doc_lengths.append(doc_length) + im_lengths.append() pointers.append(offset) if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) @@ -186,7 +206,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP num_documents += 1 # Finalize metadata arrays - lengths = np.array(lengths, dtype=np.int32) + doc_lengths = np.array(doc_lengths, dtype=np.int32) pointers = np.array(pointers, dtype=np.int64) num_spans = np.array(num_spans, dtype=np.int32) if len(spans) > 0: @@ -194,27 +214,46 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP else: spans = np.array(spans, dtype=np.int32) + # TODO Soham: else condition might not be necessary + if total_images: + n_images = np.array(n_images, dtype=np.int32) + im_lengths = np.array(im_lengths, dtype=np.int32) + im_positions = np.array(im_positions, dtype=np.int32) + else: + n_images = np.array([]) + im_lengths = np.array([]) + im_positions = np.array([]) + # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: idx_stream.write(MEMMAP_INDEX_HEADER) # Indicates the version - # Version 2 optionally adds loss-masking spans - idx_stream.write(struct.pack(" 0 else 0)) + # Flag to indicate whether images are present + idx_stream.write(struct.pack(" 0 else 0)) # Data type idx_stream.write(struct.pack(" torch.Tensor: def normalize(self, image: torch.Tensor) -> torch.Tensor: # Normalize the image using the mean and std return F.normalize(image, mean=self.mean, std=self.std) + + def get_num_patches(self, image: torch.Tensor) -> torch.Tensor: + return (image.size(0) // self.patch_size[0]) * (image.size(1) // self.patch_size[1]) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 2c4311c37..60262743e 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -3,7 +3,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.data.config import TokenizerConfig +from fast_llm.data.config import MultiModalProcessorConfig from fast_llm.data.preparator.config import DatasetPreparatorConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert @@ -153,9 +153,9 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Configuration for the dataset.", hint=FieldHint.feature, ) - tokenizer: TokenizerConfig = Field( - default_factory=TokenizerConfig, - desc="Configuration for the tokenizer.", + data_processor: MultiModalProcessorConfig = Field( + default_factory=MultiModalProcessorConfig, + desc="Configuration for data processing. Describes the tokenizer and image processor", hint=FieldHint.feature, ) splits: dict[str, float] | None = Field( diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 5cfad9ec5..d4180986e 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -23,9 +23,9 @@ ) from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.multi_modal_processor import MultiModalProcessor from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig -from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -35,45 +35,79 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): config_class: typing.ClassVar[type[GPTMemmapDatasetPreparatorConfig]] = GPTMemmapDatasetPreparatorConfig - _tokenizer: Tokenizer + # _tokenizer: Tokenizer + _data_processor: MultiModalProcessor _data_type: DataType def _process_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: pass def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids = [ - np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) - for text in batch[self._config.dataset.field] + # input_ids = [ + # np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) + # for text in batch[self._config.dataset.field] + # ] + input_ids, images, image_token_positions = map( + list, + zip( + *[ + ( + np.array(input_ids, dtype=self._data_type.numpy), + np.array(images, dtype=np.uint8), + np.array(image_token_positions, dtype=np.int32), + ) + for input_ids, images, image_token_positions in [ + self._data_processor.tokenize(text, ims, im_char_positions) + for text, ims, im_char_positions in zip( + batch[self._config.dataset.field], + batch[self._config.dataset.images], + batch[self._config.dataset.image_positions], + ) + ] + ] + ), + ) + num_tokens = [ + len(x) + self._data_processor._image_processor.get_num_patches(im) for x, im in zip(input_ids, images) ] - num_tokens = [len(x) for x in input_ids] return { "input_ids": input_ids, + "images": images, + "image_positions": image_token_positions, "num_tokens": num_tokens, } def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids, token_spans = map( + input_ids, token_spans, images, image_token_positions = map( list, zip( *[ ( np.array(input_ids, dtype=self._data_type.numpy), np.array(token_spans, dtype=np.int32).reshape(-1, 2), + np.array(images, dtype=np.uint8), + np.array(image_token_positions, dtype=np.int32), ) - for input_ids, token_spans in [ - self._tokenizer.tokenize_with_spans(text, char_spans) + for input_ids, token_spans, images, image_token_positions in [ + self._data_processor.tokenize_with_spans(text, char_spans) for text, char_spans in zip( - batch[self._config.dataset.field], batch[self._config.dataset.loss_masking_spans] + batch[self._config.dataset.field], + batch[self._config.dataset.loss_masking_spans], + batch[self._config.dataset.images], + batch[self._config.dataset.image_positions], ) ] ] ), ) - num_tokens = [len(x) for x in input_ids] + num_tokens = [ + len(x) + self._data_processor._image_processor.get_num_patches(im) for x, im in zip(input_ids, images) + ] return { "input_ids": input_ids, "token_spans": token_spans, + "images": images, + "image_positions": image_token_positions, "num_tokens": num_tokens, } @@ -83,15 +117,27 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon shard_output_path = self._config.output_path / prefix def _document_generator(): - if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - np.array(item["input_ids"], dtype=self._data_type.numpy), - np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), - ) - else: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) + # TODO Soham: simplify this + for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + yield GPTSample( + np.array(item["input_ids"], dtype=self._data_type.numpy), + ( + np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2) + if self._config.dataset.loss_masking_spans + else None + ), + images if self._config.dataset.images else None, + image_positions if self._config.dataset.image_positions else None, + ) + # if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: + # for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + # yield GPTSample( + # np.array(item["input_ids"], dtype=self._data_type.numpy), + # np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), + # ) + # else: + # for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + # yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) @@ -169,12 +215,12 @@ def run(self) -> None: if self._config.dataset.disable_disk_space_check: datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True - # Load tokenizer - self._tokenizer = Tokenizer(config=self._config.tokenizer) + # Load Processor + self._processor = MultiModalProcessor(config=self._config.data_processor) # Decide the datatype based on the tokenizer vocabulary size self._data_type = ( - get_unsigned_integer_type(self._tokenizer.vocab_size) + get_unsigned_integer_type(self.processor._tokenizer.vocab_size) if self._config.dataset.data_type is None else self._config.dataset.data_type ) diff --git a/fast_llm/data/processor.py b/fast_llm/data/processor.py deleted file mode 100644 index 43b1cda83..000000000 --- a/fast_llm/data/processor.py +++ /dev/null @@ -1,11 +0,0 @@ -from fast_llm.data.tokenizer import Tokenizer - - -class MultiModalProcessor: - """ - Combines multiple modalities (text and image) and converts to tokens/patches for text and images. - """ - - def __init__(self, tokenizer: Tokenizer, image_processor=None): - self._tokenizer = tokenizer - self._image_processor = image_processor From 0d89f68d7c4d5a40f5fa7e2651ac61b75da31aa5 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 15 Apr 2025 06:10:33 +0000 Subject: [PATCH 03/97] faster image size reading --- fast_llm/data/dataset/gpt/memmap.py | 54 ++++++++++++------ fast_llm/data/image_processor.py | 17 +++--- fast_llm/data/preparator/gpt_memmap/config.py | 18 +++++- .../data/preparator/gpt_memmap/prepare.py | 55 ++++++++++++------- setup.cfg | 3 + 5 files changed, 101 insertions(+), 46 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index c8b2592f1..069240540 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -34,12 +34,12 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None self._name = name self._prefix = pathlib.Path(prefix) self._has_spans = 0 + self._has_images = 0 with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: self._has_spans = struct.unpack("= 2: self._spans = [] @@ -73,9 +74,8 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None self._index_bin_buffer, dtype=np.int32, count=self._num_documents, - offset=offset + self._document_sizes.nbytes + self._pointers.nbytes, + offset=offset, ) - span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + self._num_spans.nbytes self._num_spans_cumsum = np.r_[0, np.cumsum(self._num_spans[:-1], dtype=np.int64)] for idx in range(self._num_documents): self._spans.append( @@ -83,18 +83,40 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None self._index_bin_buffer, dtype=np.int32, count=self._num_spans[idx] * 2, - offset=span_offset + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, + offset=offset + + self._num_spans.nbytes + + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, ).reshape(-1, 2) ) + offset += ( + self._num_spans.nbytes + + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize + + sum([x.nbytes for x in self._spans]) + ) if self._has_images and self._version >= 3: - self._image_sizes = np.frombuffer() + self._n_images = np.frombuffer( + self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset + ) + self._im_lengths = np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=self._n_images.sum() * 3, + offset=offset + self._n_images.nbytes, + ) + self._im_positions = np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=self._n_images.sum(), + offset=offset + self._n_images.nbytes + self._im_lengths.nbytes, + ) self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) + # TODO Soham: fix num_tokens to include images self._num_tokens = div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize) - if num_tokens is not None: - assert self._num_tokens == num_tokens + # if num_tokens is not None: + # assert self._num_tokens == num_tokens def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: return (self._name, self._prefix, self._num_documents, self._num_tokens) @@ -110,6 +132,7 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap + # TODO Soham: get images def get( self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False ) -> GPTSample: @@ -170,10 +193,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP prefix = pathlib.Path(prefix) prefix.parent.mkdir(parents=True, exist_ok=True) - pathlib.Path(prefix + "_images") # Write the binary data file (.bin) lazily - # TODO Soham: append image tokens along with text tokens with prefix.with_suffix(".bin").open("wb") as bin_stream: for document in documents: # Infer dtype from the first document @@ -186,23 +207,25 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Write document to binary file bin_stream.write(document.token_ids.tobytes(order="C")) + total_im_size = 0 if document.images: n_images.append(len(document.images)) total_images += len(document.images) for image, image_position in zip(document.images, document.image_positions): - im_lengths.append(image.size) + # assume 3 channels (RGB) for all images + im_lengths.append(np.array(image.shape[1:])) im_positions.append(document.image_positions) bin_stream.write(image.tobytes(order="C")) + total_im_size += image.size # Update metadata doc_length = len(document.token_ids) doc_lengths.append(doc_length) - im_lengths.append() pointers.append(offset) if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) spans.append(document.loss_masking_spans) - offset += doc_length * np.dtype(dtype).itemsize + offset += doc_length * np.dtype(dtype).itemsize + total_im_size * np.dtype(np.uint8).itemsize num_documents += 1 # Finalize metadata arrays @@ -214,15 +237,12 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP else: spans = np.array(spans, dtype=np.int32) - # TODO Soham: else condition might not be necessary if total_images: n_images = np.array(n_images, dtype=np.int32) - im_lengths = np.array(im_lengths, dtype=np.int32) - im_positions = np.array(im_positions, dtype=np.int32) else: n_images = np.array([]) - im_lengths = np.array([]) - im_positions = np.array([]) + im_lengths = np.stack(im_lengths, dtype=np.int32) + im_positions = np.array(im_positions, dtype=np.int32) # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: diff --git a/fast_llm/data/image_processor.py b/fast_llm/data/image_processor.py index 473db11a2..c5cbe9095 100644 --- a/fast_llm/data/image_processor.py +++ b/fast_llm/data/image_processor.py @@ -9,8 +9,8 @@ class ImageProcessor: def __init__(self, config: ImageProcessorConfig): self.patch_size = config.patch_size - self.mean = config.mean / config.rescale_factor - self.std = config.std / config.rescale_factor + self.mean = [x / config.rescale_factor for x in config.mean] + self.std = [x / config.rescale_factor for x in config.std] self.max_height = config.max_height self.max_width = config.max_width assert ( @@ -20,16 +20,19 @@ def __init__(self, config: ImageProcessorConfig): self.max_width % self.patch_size[1] == 0 ), "max_width must be divisible by patch_size[1]. Found {max_width} and {self.patch_size[1]}" - def resize(self, image: torch.Tensor) -> torch.Tensor: + def resize(self, image): # Resize the image to the specified size - height = image.shape[0] - width = image.shape[1] + # TODO Soham: resize for patches only during train? + # TODO Soham: convert all images to tensor? + # height = image.shape[0] + # width = image.shape[1] + height, width = image.size ratio = max(height / self.max_height, width / self.max_width) if ratio > 1: height = math.ceil(height / ratio) width = math.ceil(width / ratio) else: - height = self.patch_size[0] * math.ceil(height / self.self.patch_size[0]) + height = self.patch_size[0] * math.ceil(height / self.patch_size[0]) width = self.patch_size[1] * math.ceil(width / self.patch_size[1]) # TODO: options for interpolation mode @@ -40,4 +43,4 @@ def normalize(self, image: torch.Tensor) -> torch.Tensor: return F.normalize(image, mean=self.mean, std=self.std) def get_num_patches(self, image: torch.Tensor) -> torch.Tensor: - return (image.size(0) // self.patch_size[0]) * (image.size(1) // self.patch_size[1]) + return (image.size[0] // self.patch_size[0]) * (image.size[1] // self.patch_size[1]) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 60262743e..8a15d96c8 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -59,6 +59,15 @@ class GPTHuggingfaceDatasetConfig(Config): loss_masking_spans: None | str = Field( default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional ) + image_paths: None | str = Field( + default=None, desc="Field containing images within the document", hint=FieldHint.optional + ) + image_positions: None | str = Field( + default=None, desc="Field containing image positions within a document", hint=FieldHint.optional + ) + images: None | str = Field( + default=None, desc="Field containing images relevant to a document", hint=FieldHint.optional + ) data_type: DataType | None = Field( default=None, desc="Data type of the dataset field." @@ -142,6 +151,12 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.optional, valid=check_field(Assert.geq, 1), ) + tokenize_batch_size: int = Field( + default=1000, + desc="Batch size for tokenization.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 1), + ) saving_workers: int = Field( default=1, desc="Number of processes for saving the data.", @@ -165,8 +180,9 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.optional, ) + # TODO Soham: move tokenizer validation to MultiModalDataProcessor def _validate(self) -> None: - assert self.tokenizer.path is not None + assert self.data_processor.tokenizer.path is not None if self.dataset.data_type is not None: Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV) super()._validate() diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index d4180986e..0199cb400 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,3 +1,5 @@ +import io +import itertools import json import logging import multiprocessing @@ -13,6 +15,7 @@ import tqdm import transformers import yaml +from PIL import Image from fast_llm.data.dataset.gpt.config import ( GPTBlendedDatasetConfig, @@ -42,37 +45,43 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D def _process_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: pass + # TODO Soham: can we merged tokenize_batch and tokenize_batch_with_spans? def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: # input_ids = [ # np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) # for text in batch[self._config.dataset.field] # ] - input_ids, images, image_token_positions = map( + input_ids, image_token_positions = map( list, zip( *[ ( np.array(input_ids, dtype=self._data_type.numpy), - np.array(images, dtype=np.uint8), np.array(image_token_positions, dtype=np.int32), ) - for input_ids, images, image_token_positions in [ - self._data_processor.tokenize(text, ims, im_char_positions) - for text, ims, im_char_positions in zip( + for input_ids, image_token_positions in [ + self._data_processor.tokenize( + text, + im_char_positions, + ) + for text, im_char_positions in zip( batch[self._config.dataset.field], - batch[self._config.dataset.images], - batch[self._config.dataset.image_positions], + batch.get(self._config.dataset.image_positions, itertools.repeat(None)), ) ] ] ), ) - num_tokens = [ - len(x) + self._data_processor._image_processor.get_num_patches(im) for x, im in zip(input_ids, images) - ] + num_tokens = [len(x) for x in input_ids] + # TODO Soham: is this ok? Should we get num_image_tokens separately? + for idx, images in enumerate(batch.get("images", [])): + for bytes_im in images: + with Image.open(io.BytesIO(bytes_im["bytes"])) as im: + width, height = im.size + num_tokens[idx] += (width * height * 3) // np.dtype(self._dtype).itemsize + return { "input_ids": input_ids, - "images": images, "image_positions": image_token_positions, "num_tokens": num_tokens, } @@ -92,16 +101,17 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict self._data_processor.tokenize_with_spans(text, char_spans) for text, char_spans in zip( batch[self._config.dataset.field], - batch[self._config.dataset.loss_masking_spans], - batch[self._config.dataset.images], - batch[self._config.dataset.image_positions], + batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), + batch.get(self._config.dataset.images, itertools.repeat(None)), + batch.get(self._config.dataset.image_positions, itertools.repeat(None)), ) ] ] ), ) num_tokens = [ - len(x) + self._data_processor._image_processor.get_num_patches(im) for x, im in zip(input_ids, images) + len(x) + sum([self._data_processor._image_processor.get_num_patches(im) for im in doc_images]) + for x, doc_images in zip(input_ids, images) ] return { "input_ids": input_ids, @@ -117,7 +127,6 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon shard_output_path = self._config.output_path / prefix def _document_generator(): - # TODO Soham: simplify this for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( np.array(item["input_ids"], dtype=self._data_type.numpy), @@ -126,8 +135,9 @@ def _document_generator(): if self._config.dataset.loss_masking_spans else None ), - images if self._config.dataset.images else None, - image_positions if self._config.dataset.image_positions else None, + # [np.array(Image.open(pathlib.Path(self._config.dataset.path) / path)) for path in item["image_paths"]] if self._config.dataset.image_paths else None, + [np.array(im) for im in item["images"]] if self._config.dataset.images else None, + item["image_positions"] if self._config.dataset.image_positions else None, ) # if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: # for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): @@ -215,12 +225,12 @@ def run(self) -> None: if self._config.dataset.disable_disk_space_check: datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True - # Load Processor - self._processor = MultiModalProcessor(config=self._config.data_processor) + # Load the data processor + self._data_processor = MultiModalProcessor(config=self._config.data_processor) # Decide the datatype based on the tokenizer vocabulary size self._data_type = ( - get_unsigned_integer_type(self.processor._tokenizer.vocab_size) + get_unsigned_integer_type(self._data_processor._tokenizer.vocab_size) if self._config.dataset.data_type is None else self._config.dataset.data_type ) @@ -269,6 +279,9 @@ def run(self) -> None: tokenize_fn = self._tokenize_batch_with_spans else: tokenize_fn = self._tokenize_batch + # Avoid decoding bytes to images unless asked + if self._config.dataset.images is not None: + dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) # Tokenize the dataset in parallel tokenized_dataset = dataset.map( diff --git a/setup.cfg b/setup.cfg index 3c1dad9da..57913f83d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,9 @@ OPTIONAL = requests>=2.32.3 tqdm>=4.66.3 # Vision Tools + # TODO Soham: use pillow-simd instead of pillow? + webp>=0.4.0 + pillow-simd>=9.5.0 torchvision>=0.20.0 DEV = From 3866a5330fcf299ba8347b8e3aed057b598b5185 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 21 Apr 2025 07:04:41 +0000 Subject: [PATCH 04/97] solidify prepare --- fast_llm/data/config.py | 60 ++++----- fast_llm/data/data/gpt/config.py | 1 + fast_llm/data/data/gpt/data.py | 9 +- fast_llm/data/dataset/gpt/config.py | 15 ++- fast_llm/data/dataset/gpt/memmap.py | 127 ++++++++++++++---- fast_llm/data/dataset/gpt/sampled.py | 52 +++++-- fast_llm/data/image_processor.py | 25 ++-- fast_llm/data/preparator/gpt_memmap/config.py | 10 +- .../data/preparator/gpt_memmap/prepare.py | 49 ++++--- fast_llm/data/tokenizer.py | 30 ++++- fast_llm/layers/language_model/config.py | 14 ++ 11 files changed, 291 insertions(+), 101 deletions(-) diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 8c2c3c28e..f1a0fd58a 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -43,36 +43,36 @@ class ImageProcessorConfig(Config): """ # Defaults taken from [pixtral](https://github.com/huggingface/transformers/blob/794fde7b1c3d041519fc28ea3e1461b0cfcad4e7/src/transformers/models/pixtral/image_processing_pixtral.py#L201) - patch_size: list[int] = Field( - default_factory=lambda: [16, 16], - desc="Size for each path extracted from the image. Each patch corresponds to a token for the transformer", - hint=FieldHint.optional, - ) - max_height: int = Field( - default=1024, - desc="Maximum height of the image. Image will be resized if larger", - hint=FieldHint.optional, - ) - max_width: int = Field( - default=1024, - desc="Maximum width of the image. Image will be resized if larger", - hint=FieldHint.optional, - ) - mean: list[float] = Field( - default_factory=lambda: [0.48145466, 0.4578275, 0.40821073], - desc="Mean RGB values for pixel normalization", - hint=FieldHint.optional, - ) - std: list[float] = Field( - default_factory=lambda: [0.26862954, 0.26130258, 0.27577711], - desc="Standard deviation RGB values for pixel normalization", - hint=FieldHint.optional, - ) - rescale_factor: float = Field( - default=255.0, - desc="Diminisher factor for pixel normalization", - hint=FieldHint.optional, - ) + # patch_size: list[int] = Field( + # default_factory=lambda: [16, 16], + # desc="Size for each path extracted from the image. Each patch corresponds to a token for the transformer", + # hint=FieldHint.optional, + # ) + # max_height: int = Field( + # default=1024, + # desc="Maximum height of the image. Image will be resized if larger", + # hint=FieldHint.optional, + # ) + # max_width: int = Field( + # default=1024, + # desc="Maximum width of the image. Image will be resized if larger", + # hint=FieldHint.optional, + # ) + # mean: list[float] = Field( + # default_factory=lambda: [0.48145466, 0.4578275, 0.40821073], + # desc="Mean RGB values for pixel normalization", + # hint=FieldHint.optional, + # ) + # std: list[float] = Field( + # default_factory=lambda: [0.26862954, 0.26130258, 0.27577711], + # desc="Standard deviation RGB values for pixel normalization", + # hint=FieldHint.optional, + # ) + # rescale_factor: float = Field( + # default=255.0, + # desc="Diminisher factor for pixel normalization", + # hint=FieldHint.optional, + # ) @config_class() diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index c98a781e6..652342b58 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -21,6 +21,7 @@ class GPTSamplingDefaultConfig(SamplingDefaultConfig, GPTSamplingConfig): gpu: bool = FieldUpdate(default=True) use_loss_masking_spans: bool = FieldUpdate(default=False) + use_images: bool = FieldUpdate(default=False) shuffle: ShufflingType = FieldUpdate(default=ShufflingType.epoch) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index a0940e7c6..5bd9d09e2 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -32,10 +32,16 @@ class GPTBatch: token_ids: torch.Tensor loss_masking_spans: list[torch.Tensor] | None = None sequence_lengths: list[torch.Tensor] | None = None + images: list[torch.Tensor] | None = None + image_positions: list[torch.Tensor] | None = None +# TODO: do we need a separate use_images? def gpt_data_collate_fn( - batch: list[GPTSample], use_loss_masking_spans: bool, cross_document_attention: bool + batch: list[GPTSample], + use_loss_masking_spans: bool, + cross_document_attention: bool, + use_images: bool, ) -> GPTBatch: stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None @@ -170,6 +176,7 @@ def get_iterator( gpt_data_collate_fn, use_loss_masking_spans=self._config.sampling.use_loss_masking_spans, cross_document_attention=self._cross_document_attention, + use_images=self._config.sampling.use_images, ), multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 0f04884b6..45d27e7d0 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -57,6 +57,11 @@ class GPTSamplingConfig(SamplingConfig): desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) + use_images: bool | None = Field( + default=None, + desc="Use images in the dataset.", + hint=FieldHint.feature, + ) shuffle: ShufflingType | None = Field( default=None, desc="Shuffling strategy.", @@ -73,6 +78,7 @@ class GPTSamplingData(SamplingData): tokenizer: "Tokenizer" truncate_documents: bool = True cross_document_attention: bool = True + patch_size: list[int] | None = None @config_class() @@ -178,11 +184,18 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): desc="Expected number of tokens in the dataset.", hint=FieldHint.optional, ) + num_pixels: int | None = Field( + default=None, + desc="Expected number of pixels in the dataset.", + hint=FieldHint.optional, + ) def build(self) -> "GPTMemmapDataset": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) + return GPTMemmapDataset( + str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens, self.num_pixels + ) @config_class() diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 069240540..87bd3a8eb 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -1,8 +1,10 @@ +import io import pathlib import struct import typing import numpy as np +import PIL.Image from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample @@ -26,10 +28,18 @@ def __init__( prefix: pathlib.Path | str, num_documents: int | None = None, num_tokens: int | None = None, + num_pixels: int | None = None, ): - self._init(name, prefix, num_documents, num_tokens) + self._init(name, prefix, num_documents, num_tokens, num_pixels) - def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None, num_tokens: int | None) -> None: + def _init( + self, + name: str, + prefix: pathlib.Path | str, + num_documents: int | None, + num_tokens: int | None, + num_pixels: int | None, + ) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) @@ -93,30 +103,48 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize + sum([x.nbytes for x in self._spans]) ) + self._n_pixels = 0 if self._has_images and self._version >= 3: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) - self._im_lengths = np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._n_images.sum() * 3, - offset=offset + self._n_images.nbytes, - ) - self._im_positions = np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._n_images.sum(), - offset=offset + self._n_images.nbytes + self._im_lengths.nbytes, - ) + self._im_lengths = [] + self._im_positions = [] + images_seen = 0 + # TODO Soham: verify correctness, reshaping into width, height? + for n_images in self._n_images: + self._im_lengths.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_images * 2, + offset=offset + self._n_images.nbytes + 2 * images_seen * np.dtype(np.int32).itemsize, + ).reshape(-1, 2) + ) + self._n_pixels += self._im_lengths[-1].prod(axis=1, initial=3).sum() + self._im_positions.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_images, + offset=offset + + self._n_images.nbytes + + 2 * self._n_images.sum() * np.dtype(np.int32).itemsize + + images_seen * np.dtype(np.int32).itemsize, + ) + ) + images_seen += n_images self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) - # TODO Soham: fix num_tokens to include images - self._num_tokens = div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize) - # if num_tokens is not None: - # assert self._num_tokens == num_tokens + # TODO Soham: fix num_tokens to include images. Get total number of image pixels from index file and assign + # self._num_tokens = div(self._bin_buffer_mmap.size - n_pixels, np.dtype(self._dtype).itemsize) + self._num_tokens = div(self._bin_buffer_mmap.size - self._n_pixels, np.dtype(self._dtype).itemsize) + if num_pixels is not None: + assert self._n_pixels == num_pixels + if num_tokens is not None: + assert self._num_tokens == num_tokens def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: return (self._name, self._prefix, self._num_documents, self._num_tokens) @@ -133,6 +161,42 @@ def __del__(self): del self._index_bin_buffer_mmap # TODO Soham: get images + def get( + self, + idx: int, + offset: int = 0, + length: int | None = None, + use_loss_masking_spans: bool = False, + # , patch_size: tuple(int), max_height: int, max_width: int + ): + # TODO Soham: Handle truncations? + # if self._has_images: + # doc_size = self._document_sizes[idx] + # n_images = self._n_images[idx] + # image_positions = self._im_positions[idx] + # image_lengths = self._im_lengths[idx] + # image_tokens_seen = 0 + # for idx in range(n_images): + # height, width = ImageProcessor.get_resize_dims(image_lengths[0], image_lengths[1], max_height, max_width) + # n_image_tokens = (height // patch_size[0]) * (width // patch_size[1]) + # if (image_positions[idx] > offset + length) or (image_positions[idx] + n_tokens < offset): + # continue + token_ids = np.frombuffer( + self._bin_buffer, + dtype=self._dtype, + count=self._document_sizes[idx] - offset if length is None else length, + offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, + ) + if self._has_images: + image_positions = self._im_positions[idx] + images = np.frombuffer( + self._bin_buffer, + dtype=np.dtype(np.uint8).itemsize, + count=self._image_lengths[idx][0] * self._image_lengths[idx][1] * 3, + offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, + ) + return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) + def get( self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False ) -> GPTSample: @@ -164,16 +228,25 @@ def __len__(self) -> int: def num_tokens(self) -> int: return self._num_tokens + @property + def has_images(self) -> bool: + return self._has_images + + # TODO: image sizes def get_document_sizes(self) -> np.ndarray: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes + return self._document_sizes, self._im_lengths - def get_document_size(self, index: int) -> int: - return self._document_sizes[index].item() + def get_document_size(self, index: int, patch_size: list[int]) -> int: + return self._document_sizes[index].item() + ( + sum((h // patch_size[0]) * (w // patch_size[1]) for h, w in self._image_lengths[index]) + if self._has_images + else 0 + ) @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): @@ -211,12 +284,14 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if document.images: n_images.append(len(document.images)) total_images += len(document.images) - for image, image_position in zip(document.images, document.image_positions): + for image in document.images: # assume 3 channels (RGB) for all images - im_lengths.append(np.array(image.shape[1:])) - im_positions.append(document.image_positions) - bin_stream.write(image.tobytes(order="C")) - total_im_size += image.size + with PIL.Image.open(io.BytesIO(image["bytes"])) as img: + pixels = np.array(img) + im_lengths.append(np.array(pixels.shape[:2])) + bin_stream.write(pixels.tobytes(order="C")) + total_im_size += pixels.size + im_positions.append(document.image_positions) # Update metadata doc_length = len(document.token_ids) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 22e3396b4..288018b12 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -12,6 +12,7 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset +from fast_llm.data.image_processor import ImageProcessor from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert @@ -89,11 +90,17 @@ def __init__( self._indexed_dataset = indexed_dataset self._num_samples = sampling.num_samples self._sequence_length = sampling.sequence_length + self._patch_size = sampling.patch_size self._cross_document_attention = sampling.cross_document_attention self._config = sampling.config self._truncate_documents = sampling.truncate_documents self._device = torch.device("cuda" if self._config.gpu else "cpu") + if self._indexed_dataset.has_images and self._truncate_documents: + raise RuntimeError( + "Truncating documents with images is not supported. Please turn off truncation to use images." + ) + if sampling.cache_directory is None: self._document_shuffling = MemmapArray() self._token_cumsum_shuffled = MemmapArray() @@ -126,9 +133,15 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. - document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) + # TODO Soham: verify numpy correctness + document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() + document_sizes = torch.from_numpy(document_sizes).to(self._device) + image_token_sizes = torch.zeros_like(document_sizes).to(self._device) + for i, sizes in enumerate(image_sizes): + image_token_sizes[i] = sum(sizes[0, :] // self._patch_size[0]) * (sizes[:, 1] // self._patch_size[1]) + documents_per_epoch = document_sizes.numel() - tokens_per_epoch = document_sizes.sum().item() + tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() # Calculate basic stats. if not self._truncate_documents: @@ -136,14 +149,14 @@ def _sample(self) -> None: "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - long_docs_filter = document_sizes > self._sequence_length + 1 + long_docs_filter = document_sizes + image_token_sizes > self._sequence_length + 1 ignored_documents = sum(long_docs_filter) if ignored_documents: log_main_rank( f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._sequence_length+1} tokens and will be ignored.", log_fn=logger.warning, ) - tokens_per_epoch = document_sizes[~long_docs_filter].sum().item() + tokens_per_epoch = (document_sizes[~long_docs_filter] + image_token_sizes[~long_docs_filter]).sum().item() if tokens_per_epoch == 0: raise RuntimeError( f" > No documents shorter than {self._sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." @@ -177,6 +190,7 @@ def _sample(self) -> None: "num_samples": self._num_samples, "unshuffled_epochs": unshuffled_epochs, "sequence_length": self._sequence_length, + "patch_size": self._patch_size, "truncate_documents": self._truncate_documents, "config": self._config.to_serialized(), } @@ -258,7 +272,7 @@ def _sample(self) -> None: # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` if unshuffled_epochs > 0: token_cumsum_unshuffled, num_tokens_unshuffled = self._get_token_cumsum( - document_sizes, + document_sizes + image_token_sizes, offset=0, # TODO: Allowing for max 100% extra tokens for padding, is that enough? dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), @@ -282,6 +296,9 @@ def _sample(self) -> None: document_shuffling.to( dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 ) + ] + + image_token_sizes[ + document_shuffling.to(torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32) ], offset=num_tokens_unshuffled, # TODO: Allowing for max 100% extra tokens for padding, is that enough? @@ -360,6 +377,9 @@ def __getitem__(self, index: int) -> typing.Any: token_ids = [] loss_masking_spans = [] + images = [] + image_positions = [] + image_tokens_added = 0 while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -367,7 +387,7 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size = self._indexed_dataset.get_document_size(document_index) + document_size = self._indexed_dataset.get_document_size(document_index, self._patch_size) if not self._truncate_documents: if document_size > self._sequence_length + 1: @@ -398,6 +418,12 @@ def __getitem__(self, index: int) -> typing.Any: length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._config.use_loss_masking_spans, ) + # TODO Soham: handle images with loss masking spans + for idx, im_position in enumerate(sample.image_positions): + # image_positions.append(im_positions + len(token_ids) + image_tokens_added) + image_positions.append(im_position + len(token_ids) + image_tokens_added) + image_tokens_added += ImageProcessor.get_num_patches(sample.images[idx]) + images.append(sample.images) token_ids.append(sample.token_ids) if self._config.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: @@ -411,6 +437,7 @@ def __getitem__(self, index: int) -> typing.Any: sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) + + np.array([ImageProcessor.get_num_patches(image) for image in images[idx] for idx in range(len(images))]) if not self._cross_document_attention else None ) @@ -420,9 +447,16 @@ def __getitem__(self, index: int) -> typing.Any: if self._config.use_loss_masking_spans else None ) - Assert.eq(len(token_ids), self._sequence_length + 1) - - return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) + images = [im for img_list in images for im in img_list] + Assert.eq(len(token_ids) + image_tokens_added, self._sequence_length + 1) + + return GPTSample( + token_ids=token_ids, + loss_masking_spans=loss_masking_spans, + sequence_lengths=sequence_lengths, + images=images, + image_positions=image_positions, + ) @property def name(self) -> str: diff --git a/fast_llm/data/image_processor.py b/fast_llm/data/image_processor.py index c5cbe9095..567c81469 100644 --- a/fast_llm/data/image_processor.py +++ b/fast_llm/data/image_processor.py @@ -26,21 +26,26 @@ def resize(self, image): # TODO Soham: convert all images to tensor? # height = image.shape[0] # width = image.shape[1] - height, width = image.size - ratio = max(height / self.max_height, width / self.max_width) - if ratio > 1: - height = math.ceil(height / ratio) - width = math.ceil(width / ratio) - else: - height = self.patch_size[0] * math.ceil(height / self.patch_size[0]) - width = self.patch_size[1] * math.ceil(width / self.patch_size[1]) + height, width = self.get_resize_dims(image.shape[0], image.shape[1], self.max_height, self.max_width) # TODO: options for interpolation mode return F.resize(image, size=(height, width), interpolation=F.InterpolationMode.BICUBIC) + # TODO Soham: move to utils + @classmethod + def get_resize_dims(height, width, max_height, max_width, patch_size: list[int]): + ratio = max(height / max_height, width / max_width) + return ( + (math.ceil(height / ratio), math.ceil(width / ratio)) + if ratio > 1 + else (patch_size[0] * math.ceil(height / patch_size[0]), patch_size[1] * math.ceil(width / patch_size[1])) + ) + def normalize(self, image: torch.Tensor) -> torch.Tensor: # Normalize the image using the mean and std return F.normalize(image, mean=self.mean, std=self.std) - def get_num_patches(self, image: torch.Tensor) -> torch.Tensor: - return (image.size[0] // self.patch_size[0]) * (image.size[1] // self.patch_size[1]) + @classmethod + # TODO Soham: move to utils + def get_num_patches(image: torch.Tensor, patch_size: list[int]) -> torch.Tensor: + return (image.size[0] // patch_size[0]) * (image.size[1] // patch_size[1]) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 8a15d96c8..89fe904cd 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -3,7 +3,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.data.config import MultiModalProcessorConfig +from fast_llm.data.config import TokenizerConfig from fast_llm.data.preparator.config import DatasetPreparatorConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert @@ -168,9 +168,9 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Configuration for the dataset.", hint=FieldHint.feature, ) - data_processor: MultiModalProcessorConfig = Field( - default_factory=MultiModalProcessorConfig, - desc="Configuration for data processing. Describes the tokenizer and image processor", + tokenizer: TokenizerConfig = Field( + default_factory=TokenizerConfig, + desc="Tokenizer configuration.", hint=FieldHint.feature, ) splits: dict[str, float] | None = Field( @@ -182,7 +182,7 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): # TODO Soham: move tokenizer validation to MultiModalDataProcessor def _validate(self) -> None: - assert self.data_processor.tokenizer.path is not None + assert self.tokenizer.path is not None if self.dataset.data_type is not None: Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV) super()._validate() diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 0199cb400..4965dfdfc 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -10,12 +10,12 @@ import datasets import huggingface_hub import numpy as np +import PIL.Image import requests import torch.distributed import tqdm import transformers import yaml -from PIL import Image from fast_llm.data.dataset.gpt.config import ( GPTBlendedDatasetConfig, @@ -26,9 +26,9 @@ ) from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample -from fast_llm.data.multi_modal_processor import MultiModalProcessor from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig +from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -38,8 +38,7 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): config_class: typing.ClassVar[type[GPTMemmapDatasetPreparatorConfig]] = GPTMemmapDatasetPreparatorConfig - # _tokenizer: Tokenizer - _data_processor: MultiModalProcessor + _tokenizer: Tokenizer _data_type: DataType def _process_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: @@ -60,7 +59,7 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ np.array(image_token_positions, dtype=np.int32), ) for input_ids, image_token_positions in [ - self._data_processor.tokenize( + self._tokenizer.tokenize( text, im_char_positions, ) @@ -73,17 +72,18 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ ), ) num_tokens = [len(x) for x in input_ids] - # TODO Soham: is this ok? Should we get num_image_tokens separately? + num_pixels = [0] * len(input_ids) for idx, images in enumerate(batch.get("images", [])): for bytes_im in images: - with Image.open(io.BytesIO(bytes_im["bytes"])) as im: + with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: width, height = im.size - num_tokens[idx] += (width * height * 3) // np.dtype(self._dtype).itemsize + num_pixels[idx] += width * height * 3 return { "input_ids": input_ids, "image_positions": image_token_positions, "num_tokens": num_tokens, + "num_pixels": num_pixels, } def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: @@ -98,7 +98,7 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict np.array(image_token_positions, dtype=np.int32), ) for input_ids, token_spans, images, image_token_positions in [ - self._data_processor.tokenize_with_spans(text, char_spans) + self._tokenizer.tokenize_with_spans(text, char_spans) for text, char_spans in zip( batch[self._config.dataset.field], batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), @@ -109,16 +109,20 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict ] ), ) - num_tokens = [ - len(x) + sum([self._data_processor._image_processor.get_num_patches(im) for im in doc_images]) - for x, doc_images in zip(input_ids, images) - ] + num_tokens = [len(x) for x in input_ids] + num_pixels = [0] * len(input_ids) + for idx, images in enumerate(images): + for bytes_im in images: + with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: + width, height = im.size + num_pixels[idx] += width * height * 3 return { "input_ids": input_ids, "token_spans": token_spans, "images": images, "image_positions": image_token_positions, "num_tokens": num_tokens, + "num_pixels": num_pixels, } def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetConfig: @@ -136,7 +140,8 @@ def _document_generator(): else None ), # [np.array(Image.open(pathlib.Path(self._config.dataset.path) / path)) for path in item["image_paths"]] if self._config.dataset.image_paths else None, - [np.array(im) for im in item["images"]] if self._config.dataset.images else None, + # [np.array(im) for im in item["images"]] if self._config.dataset.images else None, + item["images"] if self._config.dataset.images else None, item["image_positions"] if self._config.dataset.image_positions else None, ) # if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: @@ -157,6 +162,7 @@ def _document_generator(): "path": prefix, "num_documents": len(shard_dataset), # Use the length of the shard dataset directly "num_tokens": sum(len(doc["input_ids"]) for doc in shard_dataset), + "num_pixels": sum(doc["num_pixels"] for doc in shard_dataset), } ) @@ -225,12 +231,12 @@ def run(self) -> None: if self._config.dataset.disable_disk_space_check: datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True - # Load the data processor - self._data_processor = MultiModalProcessor(config=self._config.data_processor) + # Load tokenizer + self._tokenizer = Tokenizer(config=self._config.tokenizer) # Decide the datatype based on the tokenizer vocabulary size self._data_type = ( - get_unsigned_integer_type(self._data_processor._tokenizer.vocab_size) + get_unsigned_integer_type(self._tokenizer.vocab_size) if self._config.dataset.data_type is None else self._config.dataset.data_type ) @@ -293,6 +299,12 @@ def run(self) -> None: # Calculate total number of tokens total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens")) + total_pixels = ( + sum(tqdm.tqdm(tokenized_dataset["num_pixels"], desc="Counting pixels", unit="pixels")) + if self._config.dataset.images + else 0 + ) + total_tokens += total_pixels // np.dtype(self._data_type.numpy).itemsize # Split dataset into shards based on number of tokens num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) @@ -391,7 +403,8 @@ def _split_and_blend_dataset_configs( # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() - sizes_cumsum = dataset.get_document_sizes().cumsum() + # TODO Soham: handle pixels (could still work with number of tokens?) + sizes_cumsum = dataset.get_document_sizes()[0].cumsum() Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 28e105ee8..0e7d54709 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -35,13 +35,41 @@ def vocab(self) -> dict[str, int]: def inv_vocab(self) -> dict[int, str]: return self._inv_vocab - def tokenize(self, text: str, begin=True, end=True) -> list[int]: + def _tokenize(self, text: str, begin=True, end=True) -> list[int]: return ( ([self.bod_id] if begin else []) + self.tokenizer.encode(text, add_special_tokens=False) + ([self.eod_id] if end else []) ) + def tokenize(self, text, image_positions=None): + if not image_positions: + return self._tokenize(text), [], [] + image_idx = 0 + char_pos = 0 + token_ids = [] + image_token_positions = [] + beginning_of_text = True + while image_idx < len(image_positions): + if image_positions[image_idx] > len(text): + raise ValueError( + f"Image position {image_positions[image_idx]} is greater than text length {len(text)}" + ) + curr_text = text[char_pos : image_positions[image_idx]] + tokenized_text = self._tokenize( + curr_text, begin=beginning_of_text, end=image_positions[image_idx] >= len(text) + ) + beginning_of_text = False + token_ids.extend(tokenized_text) + image_token_positions = len(token_ids) + char_pos = image_positions[image_idx] + image_idx += 1 + if char_pos < len(text): + curr_text = text[char_pos:] + tokenized_text = self._tokenize(curr_text, begin=beginning_of_text, end=True) + token_ids.extend(tokenized_text) + return token_ids, image_token_positions + def tokenize_with_spans( self, text: str, char_spans: list[tuple[int, int]] ) -> tuple[list[int], list[tuple[int, int]]]: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 3bd796033..75c5418bb 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -6,6 +6,7 @@ from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig +from fast_llm.layers.vision_encoder.config import VisionArchitectureConfig from fast_llm.utils import Assert @@ -198,3 +199,16 @@ def _validate(self) -> None: if self.init_method_max_embed is not None and self.init_method_min_embed is not None: Assert.leq(self.init_method_min_embed, self.init_method_max_embed) super()._validate() + + +class MultiModalBaseConfig: + language_model: LanguageModelBaseConfig = Field( + default_factory=LanguageModelBaseConfig, + desc="Configuration for the language model.", + hint=FieldHint.core, + ) + vision_model: VisionArchitectureConfig = Field( + default_factory=VisionArchitectureConfig, + desc="Configuration for the vision inputs.", + hint=FieldHint.core, + ) From 841398396714e5c3b346d6d2c46dcb37f532c167 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 24 Apr 2025 07:55:31 +0000 Subject: [PATCH 05/97] wip --- fast_llm/data/data/gpt/config.py | 1 - fast_llm/data/data/gpt/data.py | 31 ++- fast_llm/data/dataset/gpt/config.py | 7 +- fast_llm/data/dataset/gpt/indexed.py | 12 +- fast_llm/data/dataset/gpt/memmap.py | 97 +++++---- fast_llm/data/dataset/gpt/sampled.py | 32 ++- fast_llm/data/image_processor.py | 10 +- fast_llm/engine/schedule/config.py | 15 ++ fast_llm/layers/language_model/config.py | 13 +- fast_llm/models/gpt/config.py | 4 + fast_llm/models/gpt/conversion.py | 258 +++++++++++++++++++++-- fast_llm/models/gpt/model.py | 12 +- fast_llm/models/gpt/trainer.py | 3 + 13 files changed, 400 insertions(+), 95 deletions(-) diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 652342b58..c98a781e6 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -21,7 +21,6 @@ class GPTSamplingDefaultConfig(SamplingDefaultConfig, GPTSamplingConfig): gpu: bool = FieldUpdate(default=True) use_loss_masking_spans: bool = FieldUpdate(default=False) - use_images: bool = FieldUpdate(default=False) shuffle: ShufflingType = FieldUpdate(default=ShufflingType.epoch) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 5bd9d09e2..22e4730c9 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -36,12 +36,11 @@ class GPTBatch: image_positions: list[torch.Tensor] | None = None -# TODO: do we need a separate use_images? +# TODO: collate images def gpt_data_collate_fn( batch: list[GPTSample], use_loss_masking_spans: bool, cross_document_attention: bool, - use_images: bool, ) -> GPTBatch: stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None @@ -50,8 +49,24 @@ def gpt_data_collate_fn( stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] if not cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] + batch_images = [] + for sample in batch: + if sample.images is not None: + batch_images.append([torch.from_numpy(image) for image in sample.images]) + else: + batch_images.append(None) + batch_image_positions = [] + for sample in batch: + if sample.image_positions is not None: + batch_image_positions.append(torch.from_numpy(sample.image_positions)) + else: + batch_image_positions.append(None) return GPTBatch( - token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths + token_ids=torch.from_numpy(stacked_ids), + loss_masking_spans=stacked_spans, + sequence_lengths=sequence_lengths, + images=batch_images if any(batch_images) else None, + image_positions=batch_image_positions if any(batch_image_positions) else None, ) @@ -73,6 +88,9 @@ def __init__( vocab_size: int, max_sequence_length: int, cross_document_attention: bool = True, + patch_size: list[int] | None = None, + max_image_height: int | None = None, + max_image_width: int | None = None, ): """ Create the data and gather some basic information on the dataset(s). @@ -82,6 +100,9 @@ def __init__( self._vocab_size = vocab_size self._max_sequence_length = max_sequence_length self._cross_document_attention = cross_document_attention + self._patch_size = patch_size + self._max_image_height = max_image_height + self._max_image_width = max_image_width def setup( self, @@ -129,6 +150,9 @@ def setup( tokenizer=self._tokenizer, truncate_documents=self._config.truncate_documents, cross_document_attention=self._cross_document_attention, + patch_size=self._patch_size, + image_height=self._max_image_height, + image_width=self._max_image_width, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) @@ -176,7 +200,6 @@ def get_iterator( gpt_data_collate_fn, use_loss_masking_spans=self._config.sampling.use_loss_masking_spans, cross_document_attention=self._cross_document_attention, - use_images=self._config.sampling.use_images, ), multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 45d27e7d0..8022a05f7 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -57,11 +57,6 @@ class GPTSamplingConfig(SamplingConfig): desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) - use_images: bool | None = Field( - default=None, - desc="Use images in the dataset.", - hint=FieldHint.feature, - ) shuffle: ShufflingType | None = Field( default=None, desc="Shuffling strategy.", @@ -79,6 +74,8 @@ class GPTSamplingData(SamplingData): truncate_documents: bool = True cross_document_attention: bool = True patch_size: list[int] | None = None + image_height: int | None = None + image_width: int | None = None @config_class() diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 688ea6a70..209c6e317 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -11,6 +11,7 @@ class GPTIndexedDataset(IndexedDataset): + # TODO Soham: should we change this to include images? @abc.abstractmethod def get_document_sizes(self) -> np.ndarray: """ @@ -44,10 +45,15 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. - return self._dataset.get_document_sizes()[self._begin : self._end] + doc_sizes, im_sizes = self._dataset.get_document_sizes() + return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] - def get_document_size(self, index: int) -> int: - return self._dataset.get_document_size(self._begin + index) + def get_document_size(self, index: int, patch_size: list[int]) -> int: + return self._dataset.get_document_size(self._begin + index, patch_size) + + @property + def has_images(self) -> bool: + return self._dataset.has_images class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 87bd3a8eb..43fba843c 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -103,17 +103,17 @@ def _init( + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize + sum([x.nbytes for x in self._spans]) ) - self._n_pixels = 0 + self._num_pixels = 0 if self._has_images and self._version >= 3: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) - self._im_lengths = [] - self._im_positions = [] + self._image_lengths = [] + self._image_positions = [] images_seen = 0 # TODO Soham: verify correctness, reshaping into width, height? for n_images in self._n_images: - self._im_lengths.append( + self._image_lengths.append( np.frombuffer( self._index_bin_buffer, dtype=np.int32, @@ -121,8 +121,8 @@ def _init( offset=offset + self._n_images.nbytes + 2 * images_seen * np.dtype(np.int32).itemsize, ).reshape(-1, 2) ) - self._n_pixels += self._im_lengths[-1].prod(axis=1, initial=3).sum() - self._im_positions.append( + self._num_pixels += self._image_lengths[-1].prod(axis=1, initial=3).sum() + self._image_positions.append( np.frombuffer( self._index_bin_buffer, dtype=np.int32, @@ -140,14 +140,14 @@ def _init( # TODO Soham: fix num_tokens to include images. Get total number of image pixels from index file and assign # self._num_tokens = div(self._bin_buffer_mmap.size - n_pixels, np.dtype(self._dtype).itemsize) - self._num_tokens = div(self._bin_buffer_mmap.size - self._n_pixels, np.dtype(self._dtype).itemsize) + self._num_tokens = div(self._bin_buffer_mmap.size - self._num_pixels, np.dtype(self._dtype).itemsize) if num_pixels is not None: - assert self._n_pixels == num_pixels + assert self._num_pixels == num_pixels if num_tokens is not None: assert self._num_tokens == num_tokens def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: - return (self._name, self._prefix, self._num_documents, self._num_tokens) + return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) def __setstate__(self, state: tuple[str, pathlib.Path, int | None, int | None]): self._init(*state) @@ -169,7 +169,7 @@ def get( use_loss_masking_spans: bool = False, # , patch_size: tuple(int), max_height: int, max_width: int ): - # TODO Soham: Handle truncations? + # TODO Soham: handle spans # if self._has_images: # doc_size = self._document_sizes[idx] # n_images = self._n_images[idx] @@ -188,34 +188,42 @@ def get( offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) if self._has_images: - image_positions = self._im_positions[idx] - images = np.frombuffer( + image_positions = self._image_positions[idx] + pixels = np.frombuffer( self._bin_buffer, - dtype=np.dtype(np.uint8).itemsize, - count=self._image_lengths[idx][0] * self._image_lengths[idx][1] * 3, + dtype=np.dtype(np.uint8), + count=self._image_lengths[idx].prod(initial=3), offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, ) + images = [] + start = 0 + for image_length in self._image_lengths[idx]: + # TODO Soham: verify reshape dimension order + n_pixels = image_length.prod(initial=3) + images.append(pixels[start : start + n_pixels].reshape(image_length[0], image_length[1], 3)) + start += n_pixels + # TODO Soham: return loss_masking_spans return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) - def get( - self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False - ) -> GPTSample: - token_ids = np.frombuffer( - self._bin_buffer, - dtype=self._dtype, - count=self._document_sizes[idx] - offset if length is None else length, - offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, - ) - sample_spans = None - if use_loss_masking_spans and self._spans is not None: - sample_spans = self._spans[idx] - # adjust the spans for the offset and length - sample_spans = sample_spans[ - (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) - ] - sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset - sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - return GPTSample(token_ids=token_ids, loss_masking_spans=sample_spans) + # def get( + # self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False + # ) -> GPTSample: + # token_ids = np.frombuffer( + # self._bin_buffer, + # dtype=self._dtype, + # count=self._document_sizes[idx] - offset if length is None else length, + # offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, + # ) + # sample_spans = None + # if use_loss_masking_spans and self._spans is not None: + # sample_spans = self._spans[idx] + # # adjust the spans for the offset and length + # sample_spans = sample_spans[ + # (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) + # ] + # sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset + # sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset + # return GPTSample(token_ids=token_ids, loss_masking_spans=sample_spans) @property def name(self) -> str: @@ -233,20 +241,21 @@ def has_images(self) -> bool: return self._has_images # TODO: image sizes - def get_document_sizes(self) -> np.ndarray: + def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes, self._im_lengths + return self._document_sizes, self._image_lengths def get_document_size(self, index: int, patch_size: list[int]) -> int: - return self._document_sizes[index].item() + ( - sum((h // patch_size[0]) * (w // patch_size[1]) for h, w in self._image_lengths[index]) - if self._has_images - else 0 - ) + # return self._document_sizes[index].item() + ( + # sum((h // patch_size[0]) * (w // patch_size[1]) for h, w in self._image_lengths[index]) + # if self._has_images + # else 0 + # ) + return self._document_sizes[index].item(), self._image_lengths[index] if self._has_images else [] @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): @@ -255,7 +264,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP num_documents = 0 doc_lengths = [] n_images = [] - im_lengths = [] + image_lengths = [] im_positions = [] total_images = 0 pointers = [] @@ -288,7 +297,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # assume 3 channels (RGB) for all images with PIL.Image.open(io.BytesIO(image["bytes"])) as img: pixels = np.array(img) - im_lengths.append(np.array(pixels.shape[:2])) + image_lengths.append(np.array(pixels.shape[:2])) bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size im_positions.append(document.image_positions) @@ -316,7 +325,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP n_images = np.array(n_images, dtype=np.int32) else: n_images = np.array([]) - im_lengths = np.stack(im_lengths, dtype=np.int32) + image_lengths = np.stack(image_lengths, dtype=np.int32) im_positions = np.array(im_positions, dtype=np.int32) # Write the index file (.idx) @@ -347,7 +356,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Number of images per document idx_stream.write(n_images.tobytes(order="C")) # n_pixels * 3 per image - idx_stream.write(im_lengths.tobytes(order="C")) + idx_stream.write(image_lengths.tobytes(order="C")) # Position of each image in the document idx_stream.write(im_positions.tobytes(order="C")) # Document indices, unused but needed for compatibility with Megatron-LM diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 288018b12..8acbf9ee6 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -91,11 +91,14 @@ def __init__( self._num_samples = sampling.num_samples self._sequence_length = sampling.sequence_length self._patch_size = sampling.patch_size + self._image_height = sampling.image_height + self._image_width = sampling.image_width self._cross_document_attention = sampling.cross_document_attention self._config = sampling.config self._truncate_documents = sampling.truncate_documents self._device = torch.device("cuda" if self._config.gpu else "cpu") + # TODO Soham: use something else for this check, introducing has_images for just this check might be unnecessary. if self._indexed_dataset.has_images and self._truncate_documents: raise RuntimeError( "Truncating documents with images is not supported. Please turn off truncation to use images." @@ -137,8 +140,9 @@ def _sample(self) -> None: document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) image_token_sizes = torch.zeros_like(document_sizes).to(self._device) + # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): - image_token_sizes[i] = sum(sizes[0, :] // self._patch_size[0]) * (sizes[:, 1] // self._patch_size[1]) + image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size[0]) * (sizes[:, 1] // self._patch_size[1])) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() @@ -387,15 +391,26 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size = self._indexed_dataset.get_document_size(document_index, self._patch_size) + document_size, image_lengths = self._indexed_dataset.get_document_size(document_index, self._patch_size) + + image_sizes = [ + ImageProcessor.get_num_patches_from_dims( + *ImageProcessor.get_resize_dims( + *image_length, self._image_height, self._image_width, self._patch_size + ), + self._patch_size, + ) + for image_length in image_lengths + ] + image_tokens = sum(image_sizes) if not self._truncate_documents: - if document_size > self._sequence_length + 1: + if document_size + image_tokens > self._sequence_length + 1: # Document too long, ignore document_sampling_index += 1 continue tokens_in_sample = token_count % (self._sequence_length + 1) - if document_size + tokens_in_sample > self._sequence_length + 1: + if document_size + image_tokens + tokens_in_sample > self._sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._sequence_length + 1 - tokens_in_sample if token_count > token_start: @@ -408,7 +423,7 @@ def __getitem__(self, index: int) -> typing.Any: token_count += padding_size # Determine if the document belongs to the requested sample. - if token_count + document_size >= token_start: + if token_count + document_size + image_tokens >= token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) @@ -422,7 +437,7 @@ def __getitem__(self, index: int) -> typing.Any: for idx, im_position in enumerate(sample.image_positions): # image_positions.append(im_positions + len(token_ids) + image_tokens_added) image_positions.append(im_position + len(token_ids) + image_tokens_added) - image_tokens_added += ImageProcessor.get_num_patches(sample.images[idx]) + image_tokens_added += image_tokens images.append(sample.images) token_ids.append(sample.token_ids) if self._config.use_loss_masking_spans: @@ -433,7 +448,7 @@ def __getitem__(self, index: int) -> typing.Any: # Go to the next document. document_sampling_index += 1 - token_count += document_size + token_count += document_size + image_tokens sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) @@ -447,7 +462,8 @@ def __getitem__(self, index: int) -> typing.Any: if self._config.use_loss_masking_spans else None ) - images = [im for img_list in images for im in img_list] + images = [im for img_list in images for im in img_list] if images else None + image_positions = np.array(image_positions) if image_positions else None Assert.eq(len(token_ids) + image_tokens_added, self._sequence_length + 1) return GPTSample( diff --git a/fast_llm/data/image_processor.py b/fast_llm/data/image_processor.py index 567c81469..edfeceb95 100644 --- a/fast_llm/data/image_processor.py +++ b/fast_llm/data/image_processor.py @@ -33,7 +33,7 @@ def resize(self, image): # TODO Soham: move to utils @classmethod - def get_resize_dims(height, width, max_height, max_width, patch_size: list[int]): + def get_resize_dims(self, height, width, max_height, max_width, patch_size: list[int]): ratio = max(height / max_height, width / max_width) return ( (math.ceil(height / ratio), math.ceil(width / ratio)) @@ -47,5 +47,9 @@ def normalize(self, image: torch.Tensor) -> torch.Tensor: @classmethod # TODO Soham: move to utils - def get_num_patches(image: torch.Tensor, patch_size: list[int]) -> torch.Tensor: - return (image.size[0] // patch_size[0]) * (image.size[1] // patch_size[1]) + def get_num_patches(self, image: torch.Tensor, patch_size: list[int]) -> torch.Tensor: + return (image.shape[0] // patch_size[0]) * (image.shape[1] // patch_size[1]) + + @classmethod + def get_num_patches_from_dims(self, height: int, width: int, patch_size: list[int]) -> torch.Tensor: + return (height // patch_size[0]) * (width // patch_size[1]) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 83d3d51a3..16cfaf713 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -55,6 +55,21 @@ class BatchConfig(Config): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) + patch_size: list[int] | None = Field( + default=None, + desc="Patch size for each image token", + hint=FieldHint.optional, + ) + max_image_height: int | None = Field( + default=None, + desc="Maximum image height for each image token", + hint=FieldHint.optional, + ) + max_image_width: int | None = Field( + default=None, + desc="Maximum image width for each image token", + hint=FieldHint.optional, + ) num_micro_sequences: int = Field( init=False, desc="Number of micro-sequences to split each sample (= seqence length / micro-sequence length).", diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 75c5418bb..0175296c5 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -125,6 +125,11 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_class = LanguageModelArchitectureConfig transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) + vision_encoder: VisionArchitectureConfig | None = Field( + default=None, + desc="Configuration for the vision encoder that transforms images into embeddings.", + hint=FieldHint.optional, + ) init_method_std_embed: float = Field( default=None, desc="Initialization scale for the vocabulary embedding and output weights (logits).", @@ -200,8 +205,14 @@ def _validate(self) -> None: Assert.leq(self.init_method_min_embed, self.init_method_max_embed) super()._validate() + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + super().setup_tensor_space(tensor_space) + + if self.vision_encoder is not None: + self.vision_encoder.setup_tensor_space(tensor_space) + -class MultiModalBaseConfig: +class MultiModalBaseConfig(BaseModelConfig): language_model: LanguageModelBaseConfig = Field( default_factory=LanguageModelBaseConfig, desc="Configuration for the language model.", diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 5a21368fa..c90da81b3 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -48,6 +48,10 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mixtral" +class PixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "pixtral" + + @config_class() class GPTArchitectureConfig(LanguageModelArchitectureConfig): _abstract = False diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 30ae80416..30f54f06d 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -32,6 +32,7 @@ LlamaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, + PixtralGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) @@ -163,54 +164,65 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig def _create_weight_converters( self, + hf_base_prefix: str = "", + fast_llm_offset: int = 0, ) -> list[WeightConverter]: converters = [] num_layers = self._model.config.base_model.transformer.num_layers # Embeddings - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + converters.append( + WeightConverter( + f"layers.{fast_llm_offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight" + ) + ) - converters += self._create_lm_head_converters() + converters += self._create_lm_head_converters(hf_base_prefix, fast_llm_offset) for i in range(num_layers): - converters += self._create_transformer_layer_converters(i) + converters += self._create_transformer_layer_converters(i, hf_base_prefix, fast_llm_offset) return converters - def _create_transformer_layer_converters(self, i: int, ignore_export: bool = False) -> list[WeightConverter]: + def _create_transformer_layer_converters( + self, i: int, ignore_export: bool = False, hf_base_prefix: str = "", fast_llm_offset: int = 1 + ) -> list[WeightConverter]: transformer_config: TransformerConfig = self._model.config.base_model.transformer norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm converters = [] names_bias_cls = [ # Self-attn ( - f"layers.{i+1}.self_attn.query", - f"model.layers.{i}.self_attn.q_proj", + f"layers.{i+fast_llm_offset}.self_attn.query", + f"{hf_base_prefix}model.layers.{i}.self_attn.q_proj", transformer_config.add_attn_qkv_bias, QueryWeightConverter, ), ( - f"layers.{i+1}.self_attn.key_value", - (f"model.layers.{i}.self_attn.k_proj", f"model.layers.{i}.self_attn.v_proj"), + f"layers.{i+fast_llm_offset}.self_attn.key_value", + ( + f"{hf_base_prefix}model.layers.{i}.self_attn.k_proj", + f"{hf_base_prefix}model.layers.{i}.self_attn.v_proj", + ), transformer_config.add_attn_qkv_bias, KeyValueWeightConverter, ), ( - f"layers.{i+1}.self_attn.dense", - f"model.layers.{i}.self_attn.o_proj", + f"layers.{i+fast_llm_offset}.self_attn.dense", + f"{hf_base_prefix}model.layers.{i}.self_attn.o_proj", transformer_config.add_attn_dense_bias, WeightConverter, ), # Norm ( - f"layers.{i+1}.norm_1", - f"model.layers.{i}.input_layernorm", + f"layers.{i+fast_llm_offset}.norm_1", + f"{hf_base_prefix}model.layers.{i}.input_layernorm", norm_bias, WeightConverter, ), ( - f"layers.{i+1}.norm_2", - f"model.layers.{i}.post_attention_layernorm", + f"layers.{i+fast_llm_offset}.norm_2", + f"{hf_base_prefix}model.layers.{i}.post_attention_layernorm", norm_bias, WeightConverter, ), @@ -226,17 +238,23 @@ def _create_transformer_layer_converters(self, i: int, ignore_export: bool = Fal # MLP if ignore_export: converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mlp.layer_1", (), transformer_config.add_mlp_bias, cls=IgnoreExportWeightConverter + f"layers.{i+fast_llm_offset}.mlp.layer_1", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mlp.layer_2", (), transformer_config.add_mlp_bias, cls=IgnoreExportWeightConverter + f"layers.{i+fast_llm_offset}.mlp.layer_2", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, ) - converters += [IgnoreExportWeightConverter(f"layers.{i+1}.mlp.router.weight", ())] + converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] else: - converters += self._get_mlp_converters(f"layers.{i+1}", f"model.layers.{i}") + converters += self._get_mlp_converters(f"layers.{i+fast_llm_offset}", f"{hf_base_prefix}model.layers.{i}") return converters - def _create_lm_head_converters(self) -> list[WeightConverter]: + def _create_lm_head_converters(self, hf_base_prefix: str, fast_llm_offset: int = 1) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm @@ -245,15 +263,20 @@ def _create_lm_head_converters(self) -> list[WeightConverter]: # Next-token prediction head # Final norm converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias + f"layers.{num_layers + fast_llm_offset}.final_norm", f"{hf_base_prefix}model.norm", norm_bias ) # Output weights if self._model.config.base_model.tie_word_embeddings: - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) + converters.append(IgnoreImportWeightConverter((), f"{hf_base_prefix}lm_head.weight")) else: - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + converters.append( + WeightConverter( + f"layers.{num_layers + fast_llm_offset}.output_weights", f"{hf_base_prefix}lm_head.weight" + ) + ) # MTP-heads > 0 are thrown away + # TODO Soham: handle offset with MTP for i in range(1, prediction_heads): logger.warning( f"The model weights for the multi-token prediction head {i} are discarded during conversion." @@ -531,6 +554,196 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] +class PixtralHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + lm_converters = super()._create_config_converters() + for converter in lm_converters: + if converter.fast_llm_names[0][0] == "transformer": + converter.export_names[0] = ("text_config", *converter.export_names[0]) + return lm_converters + [ + # Multimodal adapter + RenameParamConverter( + fast_llm_names=(("vision_encoder", "adapter_size"),), + export_names=( + ( + "text_config", + "hidden_size", + ) + ), + ), + # Image processing and conv layer + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "image_size"),), + export_names=( + ( + "vision_config", + "image_size", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "patch_size"),), + export_names=( + ( + "vision_config", + "patch_size", + ) + ), + ), + # Vision Transformer + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "num_hidden_layers"),), + export_names=( + ( + "vision_config", + "num_hidden_layers", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "hidden_size"),), + export_names=( + ( + "vision_config", + "hidden_size", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "num_attention_heads"),), + export_names=( + ( + "vision_config", + "num_attention_heads", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "intermediate_size"),), + export_names=( + ( + "vision_config", + "intermediate_size", + ) + ), + ), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "activation_type"),), + export_names=( + ( + "vision_config", + "hidden_act", + ) + ), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "num_channels"),), + export_names=( + ( + "vision_config", + "num_channels", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "attention_dropout"),), + export_names=( + ( + "vision_config", + "attention_dropout", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "rope_theta"),), + export_names=(("vision_config", "rope_theta"),), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "initializer_range"),), + export_names=(("vision_config", "initializer_range"),), + ), + ] + + def _create_vision_transformer_converters(self) -> list[WeightConverter]: + num_layers = self._model.config.base_model.vision_encoder.encoder.num_hidden_layers + vision_transformer_converters = [] + for i in range(num_layers): + vision_transformer_converters += [ + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.k_proj.weight", + f"vision_tower.transformer.layers.{i}.attention.k_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.v_proj.weight", + f"vision_tower.transformer.layers.{i}.attention.v_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.q_proj.weight", + f"vision_tower.transformer.layers.{i}.attention.q_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.o_proj.weight", + f"vision_tower.transformer.layers.{i}.attention.o_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention_norm.weight", + f"vision_tower.transformer.layers.{i}.attention_norm.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.down_proj.weight", + f"vision_tower.transformer.layers.{i}.feed_forward.down_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", + f"vision_tower.transformer.layers.{i}.feed_forward.gate_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.up_proj.weight", + f"vision_tower.transformer.layers.{i}.feed_forward.up_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.ffn_norm.weight", + f"vision_tower.transformer.layers.{i}.ffn_norm.weight", + ), + ] + + return vision_transformer_converters + + def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: + patch_conv_converter = WeightConverter( + "layers.0._vision_encoder.patch_conv.weight", + "vision_tower.patch_conv.weight", + ) + vision_transformer_converters = self._create_vision_transformer_converters() + adapter_converters = [ + WeightConverter( + "layers.0._vision_encoder._adapter.layer_1.weight", + "multi_modal_projector.linear_1.weight", + ), + WeightConverter( + "layers.0._vision_encoder._adapter.layer_1.bias", + "multi_modal_projector.linear_1.bias", + ), + WeightConverter( + "layers.0._vision_encoder._adapter.layer_2.weight", + "multi_modal_projector.linear_2.weight", + ), + WeightConverter( + "layers.0._vision_encoder._adapter.layer_2.bias", + "multi_modal_projector.linear_2.bias", + ), + ] + return [patch_conv_converter] + vision_transformer_converters + adapter_converters + + def _create_weight_converters(self) -> list[WeightConverter]: + vision_encoder_converter = self._create_vision_encoder_weight_converters() + lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=1) + return vision_encoder_converter + lm_converters + + class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat @@ -580,4 +793,5 @@ class AutoGPTHuggingfaceCheckpointHandler( Qwen2GPTHuggingfaceCheckpointFormat.name: Qwen2HuggingfaceCheckpointHandler, MistralGPTHuggingfaceCheckpointFormat.name: MistralHuggingfaceCheckpointHandler, MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, + PixtralGPTHuggingfaceCheckpointFormat.name: PixtralHuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index e878530cf..674116413 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -26,6 +26,7 @@ RotaryEmbeddingPreprocessor, ) from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.vision_encoder.encoder import VisionEncoder from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -100,7 +101,10 @@ def get_layers(self) -> list[Layer]: LanguageModelEmbedding(self._config, self._tensor_space), LanguageModelHead(self._config, self._tensor_space, 0), ] - return [ + return ( + [VisionEncoder(self._config, self._tensor_space)] if self._config.vision_encoder is not None else [] + ) + [ + # return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ TransformerLayer( @@ -312,11 +316,11 @@ def preprocess( @property def embedding(self) -> LanguageModelEmbedding: - return self.layers[0] + return self.layers[self._config.vision_encoder is not None] @property def transformer_layers(self) -> list[TransformerLayer]: - return self.layers[1:-1] + return self.layers[(self._config.vision_encoder is not None) + 1 : -1] @property def model_head(self) -> LanguageModelHead: @@ -331,7 +335,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: return { WORD_EMBEDDINGS_WEIGHT: ( self.embedding.word_embeddings_weight, - (0, *self.model_head_indices), + (self._config.vision_encoder is not None, *self.model_head_indices), ) } elif self._config.prediction_heads > 1: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 376d8b840..b801fbd3d 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -21,6 +21,9 @@ def _get_data(self) -> GPTData: vocab_size=self._config.model.base_model.vocab_size, max_sequence_length=self._config.batch.sequence_length, cross_document_attention=self._config.batch.cross_document_attention, + patch_size=self._config.batch.patch_size, + max_image_height=self._config.batch.max_image_height, + max_image_width=self._config.batch.max_image_width, ) def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: From 6521e41920fe8b17f207b32f58c43978bfcc8a46 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 24 Apr 2025 19:00:23 +0000 Subject: [PATCH 06/97] vision model --- fast_llm/layers/vision_encoder/adapter.py | 44 ++++++++ fast_llm/layers/vision_encoder/config.py | 128 ++++++++++++++++++++++ fast_llm/layers/vision_encoder/encoder.py | 89 +++++++++++++++ 3 files changed, 261 insertions(+) create mode 100644 fast_llm/layers/vision_encoder/adapter.py create mode 100644 fast_llm/layers/vision_encoder/config.py create mode 100644 fast_llm/layers/vision_encoder/encoder.py diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py new file mode 100644 index 000000000..234c451a9 --- /dev/null +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -0,0 +1,44 @@ +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.common.linear import LinearBase +from fast_llm.layers.transformer.config import TransformerDimNames +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames +from fast_llm.tensor import init_normal_ + + +class VisionAdapter(Layer): + """ + Vision adapter layer for the LLM. + """ + + def __init__(self, intermediate_size: int, tensor_space: TensorSpace, name: str = "vision_adapter"): + super().__init__() + self._name = name + input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + self.layer_1 = LinearBase( + input_dim, + tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), + bias=True, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + ) + self.layer_2 = LinearBase( + tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), + tensor_space.get_tensor_dim(TransformerDimNames.hidden), + bias=True, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + ) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ): + return self.layer_2(self.layer_1(input_)) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py new file mode 100644 index 000000000..d410f92dc --- /dev/null +++ b/fast_llm/layers/vision_encoder/config.py @@ -0,0 +1,128 @@ +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.engine.base_model.config import BaseModelArchitectureConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.functional.config import ActivationType + + +class VisionEncoderDimNames: + out_channels = "vision_out_channels" + intermediate_size = "vision_intermediate_size" + patch_height = "vision_patch_height" + patch_width = "vision_patch_width" + + +@config_class() +class PatchConvConfig(BaseModelArchitectureConfig): + _abstract = False + """ + Configuration class for the convolution layers to apply on the image patches + """ + in_channels: int = Field( + default=3, + desc="Number of input channels for the convolution layers. Typically 3 for RGB images.", + hint=FieldHint.optional, + ) + bias: bool = Field( + default=False, desc="Whether to use a bias term in the convolution layers.", hint=FieldHint.optional + ) + height: int = Field( + default=16, + desc="Height of the image patches considered as tokens", + ) + width: int | None = Field( + default=16, + desc="Width of the image patches considered as tokens", + ) + + +@config_class() +class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): + _abstract = False + """ + Configuration class for the vision encoder, which transforms images into embeddings + """ + path: str | None = Field(default=None, desc="Path to a pretrained vision encoder model.", hint=FieldHint.optional) + hidden_size: int = Field( + default=1024, desc="The size of the hidden layers in the transformer model.", hint=FieldHint.optional + ) + intermediate_size: int = Field( + default=4096, + desc="The size of the intermediate (feed-forward) layers in the transformer model.", + hint=FieldHint.optional, + ) + num_hidden_layers: int = Field( + default=24, desc="The number of hidden layers in the transformer model.", hint=FieldHint.optional + ) + num_attention_heads: int = Field( + default=16, desc="The number of attention heads for the multi-head attention layers.", hint=FieldHint.optional + ) + num_channels: int = Field( + default=3, desc="Number of channels in the input image, typically 3 for RGB.", hint=FieldHint.optional + ) + image_size: int = Field( + default=1024, desc="The size of the input images (assumed square).", hint=FieldHint.optional + ) + patch_size: int = Field(default=16, desc="The size of the image patches to be encoded.", hint=FieldHint.optional) + hidden_act: str = Field( + default="gelu", desc="The activation function used in the hidden layers.", hint=FieldHint.optional + ) + attention_dropout: float = Field( + default=0.0, desc="The dropout probability for attention layers.", hint=FieldHint.optional + ) + rope_theta: float = Field( + default=10000.0, desc="The base value for rotary position embeddings.", hint=FieldHint.optional + ) + initializer_range: float = Field( + default=0.02, desc="The standard deviation of the normal initializer.", hint=FieldHint.optional + ) + + +@config_class() +class VisionArchitectureConfig(BaseModelArchitectureConfig): + _abstract = False + + encoder: VisionEncoderArchitectureConfig = Field( + default_factory=VisionEncoderArchitectureConfig, + desc="Configuration for the vision encoder that transforms images into embeddings.", + hint=FieldHint.optional, + ) + adapter_size: int = Field( + default=5120, + desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", + hint=FieldHint.optional, + ) + adapter_activation_type: ActivationType = Field( + default=ActivationType.gelu, + desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", + hint=FieldHint.core, + ) + + def setup_tensor_space(self, tensor_space: TensorSpace): + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.encoder.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.intermediate_size, self.adapter_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_height, self.encoder.patch_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_width, self.encoder.patch_size)) + # tensor_space.add_tensor_dim( + # CompositeTensorDim(VisionEncoderDimNames.) + # ) + + # patch_convolution: PatchConvConfig = Field( + # default_factory=PatchConvConfig, + # desc="Configuration for the convolution layers applied to the image patches.", + # hint=FieldHint.optional + # ) + # normalization: NormalizationArchitectureConfig = Field( + # default_factory=NormalizationArchitectureConfig, + # desc="Configuration for the normalization layers applied to the image patches.", + # hint=FieldHint.optional + # ) + # transformer: TransformerArchitectureConfig = Field( + # default_factory=TransformerArchitectureConfig, + # desc="Configuration for the transformer layers applied to the image patches.", + # hint=FieldHint.optional + # ) + # patch_rotary: RotaryArchitectureConfig = Field( + # default_factory=RotaryArchitectureConfig, + # desc="Configuration for the rotary positional embeddings applied to the image patches.", + # hint=FieldHint.optional + # ) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py new file mode 100644 index 000000000..2ea5c1e4f --- /dev/null +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -0,0 +1,89 @@ +import functools +import typing + +import torch +from transformers import PixtralVisionConfig, PixtralVisionModel + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.language_model.config import LanguageModelBaseConfig +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.vision_encoder.adapter import VisionAdapter +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames +from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ + + +class VisionEncoder(Layer): + """ + A vision encoder layer for creating token embeddings from vision model + """ + + def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): + super().__init__() + + self._config = config.vision_encoder + self._distributed_config = tensor_space.distributed_config + with torch.device("meta"): + if self._config.encoder.path: + self._vision_encoder = PixtralVisionModel.from_pretrained( + self._config.encoder.path, torch_dtype=self._distributed_config.training_dtype.torch + ) + else: + self._vision_encoder = PixtralVisionModel( + PixtralVisionConfig( + hidden_size=self._config.hidden_size, + intermediate_size=self._config.intermediate_size, + num_hidden_layers=self._config.num_hidden_layers, + num_attention_heads=self._config.num_attention_heads, + num_channels=self._config.num_channels, + image_size=self._config.image_size, + patch_size=self._config.patch_size, + hidden_act=self._config.hidden_act, + attention_dropout=self._config.attention_dropout, + rope_theta=self._config.rope_theta, + initializer_range=self._config.initializer_range, + ) + ) + param_names = [] + # gather all names first. PyTorch complains if we do it in the loop + for name, param in self._vision_encoder.named_parameters(): + param_names.append(name) + for name in param_names: + # exclude .weight/.bias + *module_path, stem = name.split(".")[:-1] + module = functools.reduce(getattr, module_path, self._vision_encoder) + param = self._vision_encoder.get_parameter(name) + setattr( + module, + stem, + ParameterMeta.from_dims( + tuple(TensorDim(f"{name}_{idx}", size) for idx, size in enumerate(param.shape)), + init_method=init_normal_(), + ), + # ParameterMeta( + # param, + # tensor_name=name, + # dims=(TensorDim(f"{name}_{idx}", size) for idx, size in enumerate(param.shape)), + # init_method=init_normal_(), + # allow_no_grad=True, + # ), + ) + self._adapter = VisionAdapter( + intermediate_size=tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), + tensor_space=tensor_space, + ) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Vision Output", + dtype=self._distributed_config.training_dtype.torch, + ) + return self._vision_encoder(input_) From daf586f8d6a428398674771bce71a61a7e32cdbf Mon Sep 17 00:00:00 2001 From: root Date: Thu, 24 Apr 2025 21:49:28 +0000 Subject: [PATCH 07/97] wip --- fast_llm/models/gpt/config.py | 5 ++-- fast_llm/models/gpt/conversion.py | 43 ++++++++++++++++--------------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index c90da81b3..ca73b879e 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -48,8 +48,8 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mixtral" -class PixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "pixtral" +class LlavaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "llava" @config_class() @@ -109,6 +109,7 @@ class GPTModelConfig(FastLLMModelConfig): Qwen2GPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, + LlavaGPTHuggingfaceCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 30f54f06d..ad74ad53e 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -30,9 +30,9 @@ GPTArchitectureConfig, GPTModelConfig, LlamaGPTHuggingfaceCheckpointFormat, + LlavaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, - PixtralGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) @@ -367,7 +367,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), RenameParamConverter( fast_llm_names=(("transformer", "kv_channels"),), - export_names=(("head_dim"),), + export_names=(("head_dim",),), ), ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), @@ -554,23 +554,24 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] -class PixtralHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): +class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat + @classmethod def _create_config_converters(cls) -> list[ParamConverter]: lm_converters = super()._create_config_converters() for converter in lm_converters: - if converter.fast_llm_names[0][0] == "transformer": - converter.export_names[0] = ("text_config", *converter.export_names[0]) + if isinstance(converter, (RenameParamConverter, MappedConfigParamConverter, RopeScalingParamConverter)): + # Llava uses a different name for the text config + # if converter.fast_llm_names[0][0] == "transformer": + converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) + # if converter.fast_llm_names[0][0] == "transformer": + # converter.export_names[0] = ("text_config", *converter.export_names[0]) return lm_converters + [ # Multimodal adapter RenameParamConverter( fast_llm_names=(("vision_encoder", "adapter_size"),), - export_names=( - ( - "text_config", - "hidden_size", - ) - ), + export_names=(("text_config", "hidden_size"),), ), # Image processing and conv layer RenameParamConverter( @@ -579,7 +580,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "image_size", - ) + ), ), ), RenameParamConverter( @@ -588,7 +589,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "patch_size", - ) + ), ), ), # Vision Transformer @@ -598,7 +599,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "num_hidden_layers", - ) + ), ), ), RenameParamConverter( @@ -607,7 +608,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "hidden_size", - ) + ), ), ), RenameParamConverter( @@ -616,7 +617,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "num_attention_heads", - ) + ), ), ), RenameParamConverter( @@ -625,7 +626,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "intermediate_size", - ) + ), ), ), MappedConfigParamConverter( @@ -634,7 +635,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "hidden_act", - ) + ), ), fast_llm_value=ActivationType.from_hf_name, export_value=lambda activation_type: activation_type.hf_name, @@ -645,7 +646,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "num_channels", - ) + ), ), ), RenameParamConverter( @@ -654,7 +655,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "attention_dropout", - ) + ), ), ), RenameParamConverter( @@ -793,5 +794,5 @@ class AutoGPTHuggingfaceCheckpointHandler( Qwen2GPTHuggingfaceCheckpointFormat.name: Qwen2HuggingfaceCheckpointHandler, MistralGPTHuggingfaceCheckpointFormat.name: MistralHuggingfaceCheckpointHandler, MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, - PixtralGPTHuggingfaceCheckpointFormat.name: PixtralHuggingfaceCheckpointHandler, + LlavaGPTHuggingfaceCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler, } From ef4488d4f94b9c19b04f409917b3091b8e8601e8 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 25 Apr 2025 22:47:08 +0000 Subject: [PATCH 08/97] wip --- fast_llm/layers/language_model/config.py | 7 ++- fast_llm/layers/vision_encoder/config.py | 11 +++++ fast_llm/layers/vision_encoder/encoder.py | 35 +++++++-------- fast_llm/models/gpt/conversion.py | 54 ++++++++++++++++------- 4 files changed, 69 insertions(+), 38 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 0175296c5..ec80a9334 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -44,6 +44,11 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.core, ) + vision_encoder: None | VisionArchitectureConfig = Field( + default=None, + desc="Configuration for the vision encoder that transforms images into embeddings.", + hint=FieldHint.optional, + ) max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", @@ -125,7 +130,7 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_class = LanguageModelArchitectureConfig transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) - vision_encoder: VisionArchitectureConfig | None = Field( + vision_encoder: None | VisionArchitectureConfig = FieldUpdate( default=None, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index d410f92dc..76af3d371 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -2,6 +2,7 @@ from fast_llm.engine.base_model.config import BaseModelArchitectureConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.config import NormalizationType class VisionEncoderDimNames: @@ -42,6 +43,11 @@ class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): Configuration class for the vision encoder, which transforms images into embeddings """ path: str | None = Field(default=None, desc="Path to a pretrained vision encoder model.", hint=FieldHint.optional) + pre_norm: NormalizationType = Field( + default=NormalizationType.rms_norm, + desc="The type of normalization to use before the transformer layers.", + hint=FieldHint.optional, + ) hidden_size: int = Field( default=1024, desc="The size of the hidden layers in the transformer model.", hint=FieldHint.optional ) @@ -75,6 +81,11 @@ class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): initializer_range: float = Field( default=0.02, desc="The standard deviation of the normal initializer.", hint=FieldHint.optional ) + activation_type: ActivationType = Field( + default=ActivationType.silu, + desc="The activation function used in the hidden layers. Default: SiLU.", + hint=FieldHint.optional, + ) @config_class() diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 2ea5c1e4f..88064b51a 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -31,17 +31,17 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): else: self._vision_encoder = PixtralVisionModel( PixtralVisionConfig( - hidden_size=self._config.hidden_size, - intermediate_size=self._config.intermediate_size, - num_hidden_layers=self._config.num_hidden_layers, - num_attention_heads=self._config.num_attention_heads, - num_channels=self._config.num_channels, - image_size=self._config.image_size, - patch_size=self._config.patch_size, - hidden_act=self._config.hidden_act, - attention_dropout=self._config.attention_dropout, - rope_theta=self._config.rope_theta, - initializer_range=self._config.initializer_range, + hidden_size=self._config.encoder.hidden_size, + intermediate_size=self._config.encoder.intermediate_size, + num_hidden_layers=self._config.encoder.num_hidden_layers, + num_attention_heads=self._config.encoder.num_attention_heads, + num_channels=self._config.encoder.num_channels, + image_size=self._config.encoder.image_size, + patch_size=self._config.encoder.patch_size, + hidden_act=self._config.encoder.hidden_act, + attention_dropout=self._config.encoder.attention_dropout, + rope_theta=self._config.encoder.rope_theta, + initializer_range=self._config.encoder.initializer_range, ) ) param_names = [] @@ -49,8 +49,7 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): for name, param in self._vision_encoder.named_parameters(): param_names.append(name) for name in param_names: - # exclude .weight/.bias - *module_path, stem = name.split(".")[:-1] + *module_path, stem = name.split(".") module = functools.reduce(getattr, module_path, self._vision_encoder) param = self._vision_encoder.get_parameter(name) setattr( @@ -60,14 +59,10 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): tuple(TensorDim(f"{name}_{idx}", size) for idx, size in enumerate(param.shape)), init_method=init_normal_(), ), - # ParameterMeta( - # param, - # tensor_name=name, - # dims=(TensorDim(f"{name}_{idx}", size) for idx, size in enumerate(param.shape)), - # init_method=init_normal_(), - # allow_no_grad=True, - # ), ) + none_params = [key for key, value in module._parameters.items() if value is None] + for key in none_params: + module._parameters.pop(key) self._adapter = VisionAdapter( intermediate_size=tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), tensor_space=tensor_space, diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index ad74ad53e..f730d79c6 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -173,14 +173,16 @@ def _create_weight_converters( # Embeddings converters.append( WeightConverter( - f"layers.{fast_llm_offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight" + f"layers.{fast_llm_offset - 1}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight" ) ) converters += self._create_lm_head_converters(hf_base_prefix, fast_llm_offset) for i in range(num_layers): - converters += self._create_transformer_layer_converters(i, hf_base_prefix, fast_llm_offset) + converters += self._create_transformer_layer_converters( + i, hf_base_prefix=hf_base_prefix, fast_llm_offset=fast_llm_offset + ) return converters @@ -560,6 +562,9 @@ class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: lm_converters = super()._create_config_converters() + lm_converters[-2] = ConstantExportParamConverter( + export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] + ) for converter in lm_converters: if isinstance(converter, (RenameParamConverter, MappedConfigParamConverter, RopeScalingParamConverter)): # Llava uses a different name for the text config @@ -674,39 +679,39 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: for i in range(num_layers): vision_transformer_converters += [ WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.k_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.attention.k_proj.weight", f"vision_tower.transformer.layers.{i}.attention.k_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.v_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.attention.v_proj.weight", f"vision_tower.transformer.layers.{i}.attention.v_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.q_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.attention.q_proj.weight", f"vision_tower.transformer.layers.{i}.attention.q_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.o_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.attention.o_proj.weight", f"vision_tower.transformer.layers.{i}.attention.o_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention_norm.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.attention_norm.weight", f"vision_tower.transformer.layers.{i}.attention_norm.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.down_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.down_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.down_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.gate_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.up_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.up_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.up_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.ffn_norm.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.ffn_norm.weight", f"vision_tower.transformer.layers.{i}.ffn_norm.weight", ), ] @@ -718,30 +723,45 @@ def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: "layers.0._vision_encoder.patch_conv.weight", "vision_tower.patch_conv.weight", ) + # TODO Soham: use _get_weight_and_bias_converters? + layer_norm_converter = WeightConverter( + "layers.0._vision_encoder.ln_pre.weight", + "vision_tower.ln_pre.weight", + ) + if self._model.config.base_model.vision_encoder.encoder.pre_norm == NormalizationType.layer_norm: + layer_norm_bias_converter = WeightConverter( + "layers.0._vision_encoder.ln_pre.bias", + "vision_tower.ln_pre.bias", + ) vision_transformer_converters = self._create_vision_transformer_converters() adapter_converters = [ WeightConverter( - "layers.0._vision_encoder._adapter.layer_1.weight", + "layers.0._adapter.layer_1.weight", "multi_modal_projector.linear_1.weight", ), WeightConverter( - "layers.0._vision_encoder._adapter.layer_1.bias", + "layers.0._adapter.layer_1.bias", "multi_modal_projector.linear_1.bias", ), + # TODO Soham: conditionally add bias WeightConverter( - "layers.0._vision_encoder._adapter.layer_2.weight", + "layers.0._adapter.layer_2.weight", "multi_modal_projector.linear_2.weight", ), WeightConverter( - "layers.0._vision_encoder._adapter.layer_2.bias", + "layers.0._adapter.layer_2.bias", "multi_modal_projector.linear_2.bias", ), ] - return [patch_conv_converter] + vision_transformer_converters + adapter_converters + return ( + [patch_conv_converter, layer_norm_converter, layer_norm_bias_converter] + + vision_transformer_converters + + adapter_converters + ) def _create_weight_converters(self) -> list[WeightConverter]: vision_encoder_converter = self._create_vision_encoder_weight_converters() - lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=1) + lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=2) return vision_encoder_converter + lm_converters From 6d9d595b921bc5139a910fe843d6ece3403445fb Mon Sep 17 00:00:00 2001 From: root Date: Mon, 28 Apr 2025 15:23:50 +0000 Subject: [PATCH 09/97] missing files --- fast_llm/data/dataset/gpt/memmap.py | 6 +- fast_llm/engine/multi_stage/stage_base.py | 3 + fast_llm/layers/multi_modal/embedding.py | 83 +++++++++++++++++++ fast_llm/layers/vision_encoder/config.py | 57 ++++++++++++- fast_llm/layers/vision_encoder/encoder.py | 26 +++--- .../layers/vision_encoder/preprocessing.py | 74 +++++++++++++++++ fast_llm/models/gpt/conversion.py | 44 +++++----- fast_llm/models/gpt/model.py | 59 +++++++++++-- setup.cfg | 2 +- 9 files changed, 309 insertions(+), 45 deletions(-) create mode 100644 fast_llm/layers/multi_modal/embedding.py create mode 100644 fast_llm/layers/vision_encoder/preprocessing.py diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 43fba843c..99bfbfa42 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -200,7 +200,7 @@ def get( for image_length in self._image_lengths[idx]: # TODO Soham: verify reshape dimension order n_pixels = image_length.prod(initial=3) - images.append(pixels[start : start + n_pixels].reshape(image_length[0], image_length[1], 3)) + images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) start += n_pixels # TODO Soham: return loss_masking_spans return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) @@ -296,8 +296,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP for image in document.images: # assume 3 channels (RGB) for all images with PIL.Image.open(io.BytesIO(image["bytes"])) as img: - pixels = np.array(img) - image_lengths.append(np.array(pixels.shape[:2])) + pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW + image_lengths.append(np.array(pixels.shape[1:])) bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size im_positions.append(document.image_positions) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 0f83c862d..e97ef0410 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -161,6 +161,9 @@ def _replace(module: torch.nn.Module): nonlocal i for key in module._parameters: meta = typing.cast(ParameterMeta, module._parameters[key]) + # TODO Soham: clean way to get around check? + if meta is None: + continue module._parameters[key] = self.get_parameter_buffer(meta.tensor_name) i += 1 diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py new file mode 100644 index 000000000..a92fdc4e5 --- /dev/null +++ b/fast_llm/layers/multi_modal/embedding.py @@ -0,0 +1,83 @@ +import typing + +import torch + +from fast_llm.core.distributed import set_generator +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs +from fast_llm.layers.language_model.embedding import LanguageModelEmbedding +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionModelKwargs +from fast_llm.layers.vision_encoder.encoder import VisionEncoder +from fast_llm.tensor import TensorMeta + + +class MultiModalEmbedding(LanguageModelEmbedding): + """ + Multi-modal embedding layer to combine embeddings from text, image and more modalities. + """ + + def __init__( + self, + config: LanguageModelBaseConfig, + tensor_space: TensorSpace, + ): + super().__init__(config, tensor_space) + self.vision_encoder = VisionEncoder(config, tensor_space) + + def _forward( + self, + input_: torch.Tensor, + position_ids: torch.Tensor | None, + images: torch.Tensor | None, + image_sizes: torch.Tensor | None, + image_positions: list[torch.Tensor] | None, + ) -> torch.Tensor: + image_embeddings = self.vision_encoder(images, kwargs={VisionModelKwargs.image_sizes: image_sizes}) + # TODO Soham: offset position ids + img_tokens_seen = 0 + image_idx = 0 + text_embeddings = super()._forward(input_, position_ids) + embeddings = [] + for sample_idx, positions in enumerate(image_positions): + embedding_parts = [] + for position in positions: + image = images[image_idx] + image_tokens = (image.size[1] // self._config.vision_encoder.encoder.patch_size) * ( + image.size[2] // self._config.vision_encoder.encoder.patch_size + ) + image_idx += 1 + img_tokens_seen += image_tokens + embedding_parts.append(text_embeddings[sample_idx, :position]) + embedding_parts.append(image_embeddings[img_tokens_seen : img_tokens_seen + image_tokens]) + embedding_parts.append(text_embeddings[sample_idx, position + image_tokens :]) + embeddings.append(torch.cat(embedding_parts, dim=0)) + embeddings = torch.stack(embeddings, dim=0) + with set_generator( + self._tensor_space.distributed.tp_generator + if self._sequence_parallel + else self._tensor_space.distributed.pp_generator + ): + embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + return embeddings.to(self._residual_dtype) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Embedding output", + dtype=self._residual_dtype, + ) + return self._forward( + input_, + kwargs.get(LanguageModelKwargs.position_ids), + kwargs.get(VisionModelKwargs.images), + kwargs.get(VisionModelKwargs.image_sizes), + kwargs.get(VisionModelKwargs.image_positions), + ) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 76af3d371..5e4722513 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,4 +1,4 @@ -from fast_llm.config import Field, FieldHint, config_class +from fast_llm.config import Config, Field, FieldHint, config_class from fast_llm.engine.base_model.config import BaseModelArchitectureConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType @@ -12,6 +12,17 @@ class VisionEncoderDimNames: patch_width = "vision_patch_width" +class VisionModelKwargs: + images = "images" + image_positions = "image_positions" + image_height = "image_height" + image_width = "image_width" + image_sizes = "image_sizes" + image_mean = "image_normalization_mean" + image_std = "image_normalization_std" + image_rescale_factor = "image_rescale_factor" + + @config_class() class PatchConvConfig(BaseModelArchitectureConfig): _abstract = False @@ -88,6 +99,45 @@ class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): ) +@config_class() +class ImageNormalizationConfig(Config): + mean_r: float = Field( + default=0.48145466, + desc="Mean value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_g: float = Field( + default=0.4578275, + desc="Mean value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_b: float = Field( + default=0.40821073, + desc="Mean value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_r: float = Field( + default=0.26862954, + desc="Standard deviation value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_g: float = Field( + default=0.26130258, + desc="Standard deviation value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_b: float = Field( + default=0.27577711, + desc="Standard deviation value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + rescale_factor: float = Field( + default=255.0, + desc="Rescale factor for the image normalization process.", + hint=FieldHint.optional, + ) + + @config_class() class VisionArchitectureConfig(BaseModelArchitectureConfig): _abstract = False @@ -107,6 +157,11 @@ class VisionArchitectureConfig(BaseModelArchitectureConfig): desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", hint=FieldHint.core, ) + normalization: ImageNormalizationConfig = Field( + default_factory=ImageNormalizationConfig, + desc="Configuration for the normalization layers applied to the image patches.", + hint=FieldHint.optional, + ) def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.encoder.hidden_size)) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 88064b51a..b028fa1fa 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -9,10 +9,11 @@ from fast_llm.layers.language_model.config import LanguageModelBaseConfig from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.adapter import VisionAdapter -from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionModelKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ +# TODO Soham: should this just be nn.Module? class VisionEncoder(Layer): """ A vision encoder layer for creating token embeddings from vision model @@ -25,11 +26,14 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): self._distributed_config = tensor_space.distributed_config with torch.device("meta"): if self._config.encoder.path: - self._vision_encoder = PixtralVisionModel.from_pretrained( + self.vision_encoder = PixtralVisionModel.from_pretrained( self._config.encoder.path, torch_dtype=self._distributed_config.training_dtype.torch ) else: - self._vision_encoder = PixtralVisionModel( + # TODO Soham options to fix rotary: + # 1. load PixtralTransformer instead of PixtralVisionModel. Required to implement conv2d, ln_pre separately and store positional embeddings in kwargs_meta + # 2. set self.vision_encoder.position_embeddings = PixtralRotaryEmbedding(config) outside of meta scope + self.vision_encoder = PixtralVisionModel( PixtralVisionConfig( hidden_size=self._config.encoder.hidden_size, intermediate_size=self._config.encoder.intermediate_size, @@ -46,12 +50,12 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): ) param_names = [] # gather all names first. PyTorch complains if we do it in the loop - for name, param in self._vision_encoder.named_parameters(): + for name, param in self.vision_encoder.named_parameters(): param_names.append(name) for name in param_names: *module_path, stem = name.split(".") - module = functools.reduce(getattr, module_path, self._vision_encoder) - param = self._vision_encoder.get_parameter(name) + module = functools.reduce(getattr, module_path, self.vision_encoder) + param = self.vision_encoder.get_parameter(name) setattr( module, stem, @@ -60,10 +64,10 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): init_method=init_normal_(), ), ) - none_params = [key for key, value in module._parameters.items() if value is None] - for key in none_params: - module._parameters.pop(key) - self._adapter = VisionAdapter( + # none_params = [key for key, value in module._parameters.items() if value is None] + # for key in none_params: + # module._parameters.pop(key) + self.adapter = VisionAdapter( intermediate_size=tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), tensor_space=tensor_space, ) @@ -81,4 +85,4 @@ def forward( tensor_name="Vision Output", dtype=self._distributed_config.training_dtype.torch, ) - return self._vision_encoder(input_) + return self.adapter(self.vision_encoder(input_, kwargs[VisionModelKwargs.image_sizes])) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py new file mode 100644 index 000000000..7ebfd3d7d --- /dev/null +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -0,0 +1,74 @@ +import typing + +import torch +import torchvision.transforms.v2.functional as F + +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.vision_encoder.config import VisionArchitectureConfig, VisionModelKwargs + + +def get_resize_dims(height: int, width: int, max_height: int, max_width: int) -> tuple[int, int]: + """ + Calculate the new dimensions for resizing an image while maintaining the aspect ratio. + If the image is larger than the max dimensions, it will be resized to fit within them. + If the image is smaller, it will be resized to the nearest multiple of the patch size. + """ + ratio = max(height / max_height, width / max_width) + return ( + (int(height / ratio), int(width / ratio)) + if ratio > 1 + else (max_height * (height // max_height), max_width * (width // max_width)) + ) + + +def resize(image: torch.Tensor, max_height: int, max_width: int) -> tuple[int, int]: + resize_dims = get_resize_dims(image.size(1), image.size(2), max_height, max_width) + # TODO: options for interpolation mode? + return F.resize(image, size=resize_dims, interpolation=F.InterpolationMode.BICUBIC) + + +def normalize(image: torch.Tensor, mean: list[float], std: list[float]) -> torch.Tensor: + """ + Normalize the image using the specified mean and standard deviation. + """ + return F.normalize(image, mean=mean, std=std) + + +def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: + """ + Pad images on the right and bottom with 0s untitl max_height and max_width + """ + width_padding = max(0, max_height - image.size(1)) + depth_padding = max(0, max_width - image.size(2)) + return F.pad(image, (0, 0, width_padding, depth_padding), 0) + + +class VisionPreprocessor: + def __init__(self, config: VisionArchitectureConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + images = kwargs.get("images") + im_height = kwargs.get(VisionModelKwargs.image_height) + im_width = kwargs.get(VisionModelKwargs.image_width) + kwargs[VisionModelKwargs.image_sizes] = [(im.size(1), im.size(2)) for im in images] + images = [ + pad( + normalize( + resize(image, im_height, im_width) / kwargs[VisionModelKwargs.image_rescale_factor], + mean=kwargs[VisionModelKwargs.image_mean], + std=kwargs[VisionModelKwargs.image_std], + ), + max_height=im_height, + max_width=im_width, + ) + for image in images + ] + images = torch.stack(images, dim=0).to( + # TODO Soham: is this needed? + device=self._tensor_space.distributed.device, + dtype=self._distributed_config.training_dtype.torch, + ) + kwargs[VisionModelKwargs.images] = images diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index f730d79c6..3caaee5ad 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -679,39 +679,39 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: for i in range(num_layers): vision_transformer_converters += [ WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.attention.k_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.k_proj.weight", f"vision_tower.transformer.layers.{i}.attention.k_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.attention.v_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.v_proj.weight", f"vision_tower.transformer.layers.{i}.attention.v_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.attention.q_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.q_proj.weight", f"vision_tower.transformer.layers.{i}.attention.q_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.attention.o_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.o_proj.weight", f"vision_tower.transformer.layers.{i}.attention.o_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.attention_norm.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention_norm.weight", f"vision_tower.transformer.layers.{i}.attention_norm.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.down_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.down_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.down_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.gate_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.up_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.up_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.up_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.ffn_norm.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.ffn_norm.weight", f"vision_tower.transformer.layers.{i}.ffn_norm.weight", ), ] @@ -720,48 +720,48 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: patch_conv_converter = WeightConverter( - "layers.0._vision_encoder.patch_conv.weight", + "layers.0.vision_encoder.vision_encoder.patch_conv.weight", "vision_tower.patch_conv.weight", ) # TODO Soham: use _get_weight_and_bias_converters? + layernorm_converters = [] layer_norm_converter = WeightConverter( - "layers.0._vision_encoder.ln_pre.weight", + "layers.0.vision_encoder.vision_encoder.ln_pre.weight", "vision_tower.ln_pre.weight", ) + layernorm_converters.append(layer_norm_converter) + layer_norm_converter if self._model.config.base_model.vision_encoder.encoder.pre_norm == NormalizationType.layer_norm: layer_norm_bias_converter = WeightConverter( - "layers.0._vision_encoder.ln_pre.bias", + "layers.0.vision_encoder.vision_encoder.ln_pre.bias", "vision_tower.ln_pre.bias", ) + layernorm_converters.append(layer_norm_bias_converter) vision_transformer_converters = self._create_vision_transformer_converters() adapter_converters = [ WeightConverter( - "layers.0._adapter.layer_1.weight", + "layers.0.vision_encoder.adapter.layer_1.weight", "multi_modal_projector.linear_1.weight", ), WeightConverter( - "layers.0._adapter.layer_1.bias", + "layers.0.vision_encoder.adapter.layer_1.bias", "multi_modal_projector.linear_1.bias", ), # TODO Soham: conditionally add bias WeightConverter( - "layers.0._adapter.layer_2.weight", + "layers.0.vision_encoder.adapter.layer_2.weight", "multi_modal_projector.linear_2.weight", ), WeightConverter( - "layers.0._adapter.layer_2.bias", + "layers.0.vision_encoder.adapter.layer_2.bias", "multi_modal_projector.linear_2.bias", ), ] - return ( - [patch_conv_converter, layer_norm_converter, layer_norm_bias_converter] - + vision_transformer_converters - + adapter_converters - ) + return [patch_conv_converter] + layernorm_converters + vision_transformer_converters + adapter_converters def _create_weight_converters(self) -> list[WeightConverter]: vision_encoder_converter = self._create_vision_encoder_weight_converters() - lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=2) + lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=1) return vision_encoder_converter + lm_converters diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 674116413..0890051ea 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -14,6 +14,7 @@ from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor +from fast_llm.layers.multi_modal.embedding import MultiModalEmbedding from fast_llm.layers.transformer.config import ( RoutingType, TransformerDimNames, @@ -26,7 +27,8 @@ RotaryEmbeddingPreprocessor, ) from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.encoder import VisionEncoder +from fast_llm.layers.vision_encoder.config import VisionModelKwargs +from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -72,6 +74,9 @@ def __init__( else: self._flash_varlen_preprocessor = FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space) + if self._config.vision_encoder: + self._vision_preprocessor = VisionPreprocessor(self._config.vision_encoder, self._tensor_space) + def get_output_layers(self) -> list[Layer]: return [ layer @@ -98,14 +103,19 @@ def get_layers(self) -> list[Layer]: if self._config.transformer.num_layers == 0: Assert.eq(self._config.prediction_heads, 1) return [ - LanguageModelEmbedding(self._config, self._tensor_space), + ( + LanguageModelEmbedding(self._config, self._tensor_space) + if self._config.vision_encoder is None + else MultiModalEmbedding(self._config, self._tensor_space) + ), LanguageModelHead(self._config, self._tensor_space, 0), ] - return ( - [VisionEncoder(self._config, self._tensor_space)] if self._config.vision_encoder is not None else [] - ) + [ - # return [ - LanguageModelEmbedding(self._config, self._tensor_space), + return [ + ( + LanguageModelEmbedding(self._config, self._tensor_space) + if self._config.vision_encoder is None + else MultiModalEmbedding(self._config, self._tensor_space) + ), *[ TransformerLayer( self._config.transformer, @@ -139,6 +149,30 @@ def preprocess_meta( sequence_length -= 1 micro_sequence_length = sequence_length + if self._config.vision_encoder: + image_height = batch_meta.max_image_height + image_width = batch_meta.max_image_width + image_mean = [ + self._config.vision_encoder.normalization.mean_r, + self._config.vision_encoder.normalization.mean_g, + self._config.vision_encoder.normalization.mean_b, + ] + image_std = [ + self._config.vision_encoder.normalization.std_r, + self._config.vision_encoder.normalization.std_g, + self._config.vision_encoder.normalization.std_b, + ] + image_rescale_factor = self._config.vision_encoder.normalization.rescale_factor + vision_kwargs = { + VisionModelKwargs.image_height: image_height, + VisionModelKwargs.image_width: image_width, + VisionModelKwargs.image_mean: image_mean, + VisionModelKwargs.image_std: image_std, + VisionModelKwargs.image_rescale_factor: image_rescale_factor, + } + else: + vision_kwargs = {} + batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) @@ -189,6 +223,7 @@ def preprocess_meta( TransformerKwargs.sequence_length: sequence_length, TransformerKwargs.sequence_q_dim: sequence_q_dim, } + common_kwargs.update(vision_kwargs) preprocessed_meta = [] for sequence_k_past in range( @@ -271,6 +306,16 @@ def preprocess( if self._use_flash_attention: self._flash_varlen_preprocessor.preprocess(kwargs_meta) + if batch.images is not None: + kwargs_meta[VisionModelKwargs.images] = [ + img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + for images in batch.images + for img in images + ] + kwargs_meta[VisionModelKwargs.image_positions] = batch.image_positions + if self._config.vision_encoder: + self._vision_preprocessor.preprocess(kwargs_meta) + # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. pasts = presents diff --git a/setup.cfg b/setup.cfg index 57913f83d..52676c799 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ CORE = # Required for some optional features and tools. OPTIONAL = # Huggingface tools - transformers>=4.44.2 + transformers>=4.48.3 hf-transfer>=0.1.8 datasets>=3.1.0 huggingface-hub>=0.28.1 From 6cb8f5d0e85e8b1bd24470e387b4c0d259124201 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 30 Apr 2025 16:31:46 +0000 Subject: [PATCH 10/97] make it work, barely --- Dockerfile | 1 + fast_llm/data/data/gpt/data.py | 6 +- fast_llm/data/dataset/gpt/sampled.py | 17 ++- fast_llm/layers/multi_modal/embedding.py | 76 +++++----- fast_llm/layers/vision_encoder/adapter.py | 19 +-- fast_llm/layers/vision_encoder/config.py | 17 ++- fast_llm/layers/vision_encoder/encoder.py | 134 ++++++++++++++---- .../layers/vision_encoder/preprocessing.py | 31 +++- fast_llm/models/gpt/conversion.py | 30 ++-- fast_llm/models/gpt/model.py | 26 ++-- 10 files changed, 240 insertions(+), 117 deletions(-) diff --git a/Dockerfile b/Dockerfile index 8c2efa85e..b8e1f8887 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,6 +4,7 @@ FROM nvcr.io/nvidia/pytorch:24.11-py3 # Install dependencies. RUN apt-get update \ && apt-get install --no-install-recommends -y acl git-lfs \ + && apt-get install --no-install-recommends -y libtiff5-dev \ && rm -rf /var/lib/apt/lists/* \ && git lfs install diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 22e4730c9..cffaa734f 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -49,10 +49,12 @@ def gpt_data_collate_fn( stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] if not cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] + has_images = False batch_images = [] for sample in batch: if sample.images is not None: batch_images.append([torch.from_numpy(image) for image in sample.images]) + has_images = True else: batch_images.append(None) batch_image_positions = [] @@ -65,8 +67,8 @@ def gpt_data_collate_fn( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, - images=batch_images if any(batch_images) else None, - image_positions=batch_image_positions if any(batch_image_positions) else None, + images=batch_images if has_images else None, + image_positions=batch_image_positions if has_images else None, ) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 8acbf9ee6..973c1db53 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -433,13 +433,22 @@ def __getitem__(self, index: int) -> typing.Any: length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._config.use_loss_masking_spans, ) - # TODO Soham: handle images with loss masking spans + start_pos = 0 for idx, im_position in enumerate(sample.image_positions): # image_positions.append(im_positions + len(token_ids) + image_tokens_added) + # Add placeholders for image tokens + token_ids.append(sample.token_ids[start_pos:im_position]) + token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) image_positions.append(im_position + len(token_ids) + image_tokens_added) image_tokens_added += image_tokens + start_pos = im_position + token_ids.append(sample.token_ids[start_pos:]) + # TODO Soham: remove this + # if len(sample.images) == 1: + # sample.images.append(sample.images[0]) + # sample.image_positions = np.concatenate([sample.image_positions, sample.image_positions]) images.append(sample.images) - token_ids.append(sample.token_ids) + # TODO Soham: add offsets for loss masking spans if self._config.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: span = np.clip(loss_masking_span + token_count - token_start, 0, self._sequence_length + 1) @@ -452,7 +461,7 @@ def __getitem__(self, index: int) -> typing.Any: sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) - + np.array([ImageProcessor.get_num_patches(image) for image in images[idx] for idx in range(len(images))]) + # + np.array([ImageProcessor.get_num_patches(image) for image in images[idx] for idx in range(len(images))]) if not self._cross_document_attention else None ) @@ -464,7 +473,7 @@ def __getitem__(self, index: int) -> typing.Any: ) images = [im for img_list in images for im in img_list] if images else None image_positions = np.array(image_positions) if image_positions else None - Assert.eq(len(token_ids) + image_tokens_added, self._sequence_length + 1) + Assert.eq(len(token_ids), self._sequence_length + 1) return GPTSample( token_ids=token_ids, diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index a92fdc4e5..3b62c60b7 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -25,59 +25,59 @@ def __init__( super().__init__(config, tensor_space) self.vision_encoder = VisionEncoder(config, tensor_space) - def _forward( + def forward( self, input_: torch.Tensor, - position_ids: torch.Tensor | None, - images: torch.Tensor | None, - image_sizes: torch.Tensor | None, - image_positions: list[torch.Tensor] | None, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, ) -> torch.Tensor: - image_embeddings = self.vision_encoder(images, kwargs={VisionModelKwargs.image_sizes: image_sizes}) + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Embedding output", + dtype=self._residual_dtype, + ) + # return self._forward( + # input_, + # kwargs.get(LanguageModelKwargs.position_ids), + # kwargs.get(VisionModelKwargs.images), + # kwargs.get(VisionModelKwargs.image_sizes), + # kwargs.get(VisionModelKwargs.image_positions), + # ) # TODO Soham: offset position ids + images = kwargs.pop(VisionModelKwargs.images)[:1] + position_ids = kwargs.get(LanguageModelKwargs.position_ids) + image_positions = kwargs.get(VisionModelKwargs.image_positions)[:1] + image_embeddings = self.vision_encoder(images, kwargs) + embeddings = super()._forward(input_, position_ids) img_tokens_seen = 0 image_idx = 0 - text_embeddings = super()._forward(input_, position_ids) - embeddings = [] for sample_idx, positions in enumerate(image_positions): - embedding_parts = [] - for position in positions: + # embedding_parts = [] + for position in positions[:1]: image = images[image_idx] - image_tokens = (image.size[1] // self._config.vision_encoder.encoder.patch_size) * ( - image.size[2] // self._config.vision_encoder.encoder.patch_size + image_tokens = (image.size(1) // self._config.vision_encoder.encoder.patch_size) * ( + image.size(2) // self._config.vision_encoder.encoder.patch_size ) + embeddings[sample_idx, position : position + image_tokens] = image_embeddings[ + sample_idx, img_tokens_seen : img_tokens_seen + image_tokens + ] + # embedding_parts.append(text_embeddings[sample_idx, :position]) + # embedding_parts.append(image_embeddings[sample_idx, img_tokens_seen : img_tokens_seen + image_tokens]) image_idx += 1 img_tokens_seen += image_tokens - embedding_parts.append(text_embeddings[sample_idx, :position]) - embedding_parts.append(image_embeddings[img_tokens_seen : img_tokens_seen + image_tokens]) - embedding_parts.append(text_embeddings[sample_idx, position + image_tokens :]) - embeddings.append(torch.cat(embedding_parts, dim=0)) - embeddings = torch.stack(embeddings, dim=0) + # embedding_parts.append(text_embeddings[sample_idx, position:]) + # TODO Soham: debug from here + # embeddings.append(torch.cat(embedding_parts, dim=0)) + # embeddings = torch.stack(embeddings, dim=0) with set_generator( self._tensor_space.distributed.tp_generator if self._sequence_parallel else self._tensor_space.distributed.pp_generator ): embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + # assert embeddings.size(1) == 8192 + del image_embeddings + del images return embeddings.to(self._residual_dtype) - - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict | None = None, - ) -> torch.Tensor: - if isinstance(input_, TensorMeta): - return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims], - tensor_name="Embedding output", - dtype=self._residual_dtype, - ) - return self._forward( - input_, - kwargs.get(LanguageModelKwargs.position_ids), - kwargs.get(VisionModelKwargs.images), - kwargs.get(VisionModelKwargs.image_sizes), - kwargs.get(VisionModelKwargs.image_positions), - ) diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py index 234c451a9..b8436f72e 100644 --- a/fast_llm/layers/vision_encoder/adapter.py +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -1,16 +1,13 @@ -import typing - import torch -from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.common.linear import LinearBase +from fast_llm.layers.common.linear import Linear from fast_llm.layers.transformer.config import TransformerDimNames from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames from fast_llm.tensor import init_normal_ -class VisionAdapter(Layer): +class VisionAdapter(torch.nn.Module): """ Vision adapter layer for the LLM. """ @@ -19,14 +16,14 @@ def __init__(self, intermediate_size: int, tensor_space: TensorSpace, name: str super().__init__() self._name = name input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) - self.layer_1 = LinearBase( + self.layer_1 = Linear( input_dim, tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), bias=True, weight_init_method=init_normal_(), bias_init_method=init_normal_(), ) - self.layer_2 = LinearBase( + self.layer_2 = Linear( tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), tensor_space.get_tensor_dim(TransformerDimNames.hidden), bias=True, @@ -34,11 +31,5 @@ def __init__(self, intermediate_size: int, tensor_space: TensorSpace, name: str bias_init_method=init_normal_(), ) - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict[str, typing.Any] | None = None, - ): + def forward(self, input_: torch.Tensor): return self.layer_2(self.layer_1(input_)) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 5e4722513..65ae8e502 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -2,7 +2,7 @@ from fast_llm.engine.base_model.config import BaseModelArchitectureConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import NormalizationType +from fast_llm.layers.common.config import NormalizationConfig class VisionEncoderDimNames: @@ -10,9 +10,11 @@ class VisionEncoderDimNames: intermediate_size = "vision_intermediate_size" patch_height = "vision_patch_height" patch_width = "vision_patch_width" + kv_channels = "vision_kv_channels" class VisionModelKwargs: + patch_size = "patch_size" images = "images" image_positions = "image_positions" image_height = "image_height" @@ -21,6 +23,9 @@ class VisionModelKwargs: image_mean = "image_normalization_mean" image_std = "image_normalization_std" image_rescale_factor = "image_rescale_factor" + rope_theta = "vit_rope_theta" + rotary_inv_freq = "vit_rotary_inv_freq" + kv_channels = "vit_kv_channels" @config_class() @@ -54,10 +59,8 @@ class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): Configuration class for the vision encoder, which transforms images into embeddings """ path: str | None = Field(default=None, desc="Path to a pretrained vision encoder model.", hint=FieldHint.optional) - pre_norm: NormalizationType = Field( - default=NormalizationType.rms_norm, - desc="The type of normalization to use before the transformer layers.", - hint=FieldHint.optional, + pre_norm: NormalizationConfig = Field( + default_factory=NormalizationConfig, ) hidden_size: int = Field( default=1024, desc="The size of the hidden layers in the transformer model.", hint=FieldHint.optional @@ -168,6 +171,10 @@ def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.intermediate_size, self.adapter_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_height, self.encoder.patch_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_width, self.encoder.patch_size)) + # TODO Soham: add a check for kv channels + tensor_space.add_tensor_dim( + TensorDim(VisionEncoderDimNames.kv_channels, self.encoder.hidden_size // self.encoder.num_attention_heads) + ) # tensor_space.add_tensor_dim( # CompositeTensorDim(VisionEncoderDimNames.) # ) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index b028fa1fa..bbcebf251 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -2,7 +2,8 @@ import typing import torch -from transformers import PixtralVisionConfig, PixtralVisionModel +from transformers import PixtralVisionConfig +from transformers.models.pixtral.modeling_pixtral import PixtralTransformer from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace @@ -13,6 +14,33 @@ from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ +def position_ids_in_meshgrid(patch_embeddings_list, max_width): + positions = [] + for patch in patch_embeddings_list: + height, width = patch.shape[-2:] + mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_width + v_grid + positions.append(ids[:, 0]) + return torch.cat(positions) + + +def generate_block_attention_mask(patch_embeds_list, tensor): + dtype = tensor.dtype + device = tensor.device + seq_len = tensor.shape[1] + d_min = torch.finfo(dtype).min + causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device) + + block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1) + block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1) + for start, end in zip(block_start_idx, block_end_idx): + causal_mask[start:end, start:end] = 0 + + causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1) + return causal_mask + + # TODO Soham: should this just be nn.Module? class VisionEncoder(Layer): """ @@ -25,37 +53,49 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): self._config = config.vision_encoder self._distributed_config = tensor_space.distributed_config with torch.device("meta"): - if self._config.encoder.path: - self.vision_encoder = PixtralVisionModel.from_pretrained( - self._config.encoder.path, torch_dtype=self._distributed_config.training_dtype.torch - ) - else: - # TODO Soham options to fix rotary: - # 1. load PixtralTransformer instead of PixtralVisionModel. Required to implement conv2d, ln_pre separately and store positional embeddings in kwargs_meta - # 2. set self.vision_encoder.position_embeddings = PixtralRotaryEmbedding(config) outside of meta scope - self.vision_encoder = PixtralVisionModel( - PixtralVisionConfig( - hidden_size=self._config.encoder.hidden_size, - intermediate_size=self._config.encoder.intermediate_size, - num_hidden_layers=self._config.encoder.num_hidden_layers, - num_attention_heads=self._config.encoder.num_attention_heads, - num_channels=self._config.encoder.num_channels, - image_size=self._config.encoder.image_size, - patch_size=self._config.encoder.patch_size, - hidden_act=self._config.encoder.hidden_act, - attention_dropout=self._config.encoder.attention_dropout, - rope_theta=self._config.encoder.rope_theta, - initializer_range=self._config.encoder.initializer_range, - ) - ) + # TODO Soham options to fix rotary: + # 1. load PixtralTransformer instead of PixtralVisionModel. Required to implement conv2d, ln_pre separately and store positional embeddings in kwargs_meta + # 2. set self.vision_encoder.position_embeddings = PixtralRotaryEmbedding(config) outside of meta scope + config = PixtralVisionConfig( + hidden_size=self._config.encoder.hidden_size, + intermediate_size=self._config.encoder.intermediate_size, + num_hidden_layers=self._config.encoder.num_hidden_layers, + num_attention_heads=self._config.encoder.num_attention_heads, + num_channels=self._config.encoder.num_channels, + image_size=self._config.encoder.image_size, + patch_size=self._config.encoder.patch_size, + hidden_act=self._config.encoder.hidden_act, + attention_dropout=self._config.encoder.attention_dropout, + rope_theta=self._config.encoder.rope_theta, + initializer_range=self._config.encoder.initializer_range, + ) + self.patch_conv = torch.nn.Conv2d( + in_channels=3, + out_channels=self._config.encoder.hidden_size, + kernel_size=self._config.encoder.patch_size, + stride=self._config.encoder.patch_size, + bias=False, + ) + self.patch_conv.weight = ParameterMeta.from_dims( + tuple( + TensorDim(f"patch_conv_weight_{idx}", size) + for idx, size in enumerate(self.patch_conv.weight.shape) + ), + init_method=init_normal_(), + ) + self.norm = self._config.encoder.pre_norm.get_layer( + tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + ) + self.vision_transformer = PixtralTransformer(config) + # self.vision_encoder = PixtralVisionModel(config) param_names = [] # gather all names first. PyTorch complains if we do it in the loop - for name, param in self.vision_encoder.named_parameters(): + for name, param in self.vision_transformer.named_parameters(): param_names.append(name) for name in param_names: *module_path, stem = name.split(".") - module = functools.reduce(getattr, module_path, self.vision_encoder) - param = self.vision_encoder.get_parameter(name) + module = functools.reduce(getattr, module_path, self.vision_transformer) + param = self.vision_transformer.get_parameter(name) setattr( module, stem, @@ -72,6 +112,38 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): tensor_space=tensor_space, ) + def _forward( + self, input_: torch.Tensor, image_sizes: torch.Tensor, inv_freq: torch.Tensor, image_width: int + ) -> torch.Tensor: + patch_embeddings = self.patch_conv(input_) + patch_embeddings_list = [ + embedding[..., : image_size[0], : image_size[1]] + for embedding, image_size in zip(patch_embeddings, image_sizes) + ] + patch_embeddings = torch.cat([p.flatten(1).T for p in patch_embeddings_list], dim=0).unsqueeze(0) + patch_embeddings = self.norm(patch_embeddings) + position_ids = position_ids_in_meshgrid(patch_embeddings_list, image_width // self._config.encoder.patch_size) + freqs = inv_freq[position_ids] + with torch.autocast(device_type=input_.device.type): + cos = freqs.cos() + sin = freqs.sin() + cos = cos.to(dtype=input_.dtype) + sin = sin.to(dtype=input_.dtype) + + attention_mask = generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeddings_list], patch_embeddings + ) + + (out,) = self.vision_transformer( + patch_embeddings, + attention_mask=attention_mask, + position_embeddings=(cos, sin), + output_attentions=False, + return_dict=False, + ) + + return self.adapter(out) + def forward( self, input_: torch.Tensor, @@ -85,4 +157,10 @@ def forward( tensor_name="Vision Output", dtype=self._distributed_config.training_dtype.torch, ) - return self.adapter(self.vision_encoder(input_, kwargs[VisionModelKwargs.image_sizes])) + return self._forward( + input_, + kwargs[VisionModelKwargs.image_sizes][:1], + kwargs[VisionModelKwargs.rotary_inv_freq], + image_width=kwargs[VisionModelKwargs.image_width], + ) + # return self.adapter(self.vision_encoder(input_, kwargs[VisionModelKwargs.image_sizes])) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 7ebfd3d7d..57ee3a0b2 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -40,7 +40,27 @@ def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: """ width_padding = max(0, max_height - image.size(1)) depth_padding = max(0, max_width - image.size(2)) - return F.pad(image, (0, 0, width_padding, depth_padding), 0) + return F.pad(image, (0, 0, depth_padding, width_padding), 0) + + +def create_inv_freqs(rope_theta: int, kv_channels: int, image_size: int, patch_size: int) -> torch.Tensor: + freqs = 1.0 / (rope_theta ** (torch.arange(0, kv_channels, 2).float() / kv_channels)) + max_patches_per_side = image_size // patch_size + + h = torch.arange(max_patches_per_side) + w = torch.arange(max_patches_per_side) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape(-1, kv_channels // 2) + + return torch.cat((inv_freq, inv_freq), dim=-1) class VisionPreprocessor: @@ -53,7 +73,8 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: images = kwargs.get("images") im_height = kwargs.get(VisionModelKwargs.image_height) im_width = kwargs.get(VisionModelKwargs.image_width) - kwargs[VisionModelKwargs.image_sizes] = [(im.size(1), im.size(2)) for im in images] + image_sizes = [get_resize_dims(im.size(1), im.size(2), im_height, im_width) for im in images] + kwargs[VisionModelKwargs.image_sizes] = image_sizes images = [ pad( normalize( @@ -72,3 +93,9 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: dtype=self._distributed_config.training_dtype.torch, ) kwargs[VisionModelKwargs.images] = images + kwargs[VisionModelKwargs.rotary_inv_freq] = create_inv_freqs( + kwargs[VisionModelKwargs.rope_theta], + kwargs[VisionModelKwargs.kv_channels], + im_height, + kwargs[VisionModelKwargs.patch_size], + ).to(device=self._tensor_space.distributed.device) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 3caaee5ad..bd7da7979 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -597,6 +597,10 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ), ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "pre_norm", "type"),), + fast_llm_value=NormalizationType.rms_norm, + ), # Vision Transformer RenameParamConverter( fast_llm_names=(("vision_encoder", "encoder", "num_hidden_layers"),), @@ -679,39 +683,39 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: for i in range(num_layers): vision_transformer_converters += [ WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.k_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.k_proj.weight", f"vision_tower.transformer.layers.{i}.attention.k_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.v_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.v_proj.weight", f"vision_tower.transformer.layers.{i}.attention.v_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.q_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.q_proj.weight", f"vision_tower.transformer.layers.{i}.attention.q_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.o_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.o_proj.weight", f"vision_tower.transformer.layers.{i}.attention.o_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention_norm.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention_norm.weight", f"vision_tower.transformer.layers.{i}.attention_norm.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.down_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.down_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.down_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.gate_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.gate_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.up_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.up_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.up_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.ffn_norm.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.ffn_norm.weight", f"vision_tower.transformer.layers.{i}.ffn_norm.weight", ), ] @@ -720,20 +724,20 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: patch_conv_converter = WeightConverter( - "layers.0.vision_encoder.vision_encoder.patch_conv.weight", + "layers.0.vision_encoder.patch_conv.weight", "vision_tower.patch_conv.weight", ) # TODO Soham: use _get_weight_and_bias_converters? layernorm_converters = [] layer_norm_converter = WeightConverter( - "layers.0.vision_encoder.vision_encoder.ln_pre.weight", + "layers.0.vision_encoder.norm.weight", "vision_tower.ln_pre.weight", ) layernorm_converters.append(layer_norm_converter) layer_norm_converter - if self._model.config.base_model.vision_encoder.encoder.pre_norm == NormalizationType.layer_norm: + if self._model.config.base_model.vision_encoder.encoder.pre_norm.type == NormalizationType.layer_norm: layer_norm_bias_converter = WeightConverter( - "layers.0.vision_encoder.vision_encoder.ln_pre.bias", + "layers.0.vision_encoder.norm.bias", "vision_tower.ln_pre.bias", ) layernorm_converters.append(layer_norm_bias_converter) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 0890051ea..ffbd22816 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -27,7 +27,7 @@ RotaryEmbeddingPreprocessor, ) from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.config import VisionModelKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionModelKwargs from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -164,11 +164,16 @@ def preprocess_meta( ] image_rescale_factor = self._config.vision_encoder.normalization.rescale_factor vision_kwargs = { + VisionModelKwargs.patch_size: self._config.vision_encoder.encoder.patch_size, VisionModelKwargs.image_height: image_height, VisionModelKwargs.image_width: image_width, VisionModelKwargs.image_mean: image_mean, VisionModelKwargs.image_std: image_std, VisionModelKwargs.image_rescale_factor: image_rescale_factor, + VisionModelKwargs.rope_theta: self._config.vision_encoder.encoder.rope_theta, + VisionModelKwargs.kv_channels: self._tensor_space.get_tensor_dim( + VisionEncoderDimNames.kv_channels + ).size, } else: vision_kwargs = {} @@ -306,16 +311,6 @@ def preprocess( if self._use_flash_attention: self._flash_varlen_preprocessor.preprocess(kwargs_meta) - if batch.images is not None: - kwargs_meta[VisionModelKwargs.images] = [ - img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) - for images in batch.images - for img in images - ] - kwargs_meta[VisionModelKwargs.image_positions] = batch.image_positions - if self._config.vision_encoder: - self._vision_preprocessor.preprocess(kwargs_meta) - # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. pasts = presents @@ -349,6 +344,15 @@ def preprocess( else: labels[i, start : end + 1] = -100 kwargs[LanguageModelKwargs.labels] = labels + if batch.images is not None: + kwargs[VisionModelKwargs.images] = [ + img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + for images in batch.images + for img in images + ] + kwargs[VisionModelKwargs.image_positions] = batch.image_positions + if self._config.vision_encoder: + self._vision_preprocessor.preprocess(kwargs) if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.preprocess(kwargs) if self._config.transformer.rotary.enabled: From 5761a2d52cf4e7e5fcfd38ec19750be48cb06f8e Mon Sep 17 00:00:00 2001 From: root Date: Wed, 30 Apr 2025 18:24:54 +0000 Subject: [PATCH 11/97] fix --- fast_llm/data/dataset/gpt/memmap.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 99bfbfa42..54bf6826a 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -49,11 +49,14 @@ def _init( with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: self._has_spans = struct.unpack("= 3: + self._has_preference_spans = struct.unpack("= 4: self._has_images = struct.unpack("= 3: + if self._has_images and self._version >= 4: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) @@ -333,10 +336,12 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP idx_stream.write(MEMMAP_INDEX_HEADER) # Indicates the version # Version 2 onwards optionally add loss-masking spans - # Version 3 onwards optionally add images - idx_stream.write(struct.pack(" 0 else 0)) + # Placeholder flag for preference spans + idx_stream.write(struct.pack(" 0 else 0)) # Data type From d45d60061068b316c3e49d633ea0e8adbc2d52ef Mon Sep 17 00:00:00 2001 From: root Date: Thu, 1 May 2025 05:43:50 +0000 Subject: [PATCH 12/97] fixes --- fast_llm/data/config.py | 57 ------------------- fast_llm/data/data/gpt/data.py | 9 +-- fast_llm/data/dataset/gpt/config.py | 5 +- fast_llm/data/dataset/gpt/memmap.py | 15 +---- fast_llm/data/dataset/gpt/sampled.py | 18 ++---- fast_llm/data/image_processor.py | 55 ------------------ fast_llm/engine/schedule/config.py | 11 +--- fast_llm/layers/vision_encoder/config.py | 57 +------------------ fast_llm/layers/vision_encoder/encoder.py | 2 +- .../layers/vision_encoder/preprocessing.py | 30 +++++++--- fast_llm/models/gpt/model.py | 6 +- fast_llm/models/gpt/trainer.py | 3 +- 12 files changed, 44 insertions(+), 224 deletions(-) delete mode 100644 fast_llm/data/image_processor.py diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index f1a0fd58a..1586d370d 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -34,60 +34,3 @@ class TokenizerConfig(Config): desc="Path to the tokenizer file.", hint=FieldHint.core, ) - - -@config_class() -class ImageProcessorConfig(Config): - """ - Configuration for the image processor - """ - - # Defaults taken from [pixtral](https://github.com/huggingface/transformers/blob/794fde7b1c3d041519fc28ea3e1461b0cfcad4e7/src/transformers/models/pixtral/image_processing_pixtral.py#L201) - # patch_size: list[int] = Field( - # default_factory=lambda: [16, 16], - # desc="Size for each path extracted from the image. Each patch corresponds to a token for the transformer", - # hint=FieldHint.optional, - # ) - # max_height: int = Field( - # default=1024, - # desc="Maximum height of the image. Image will be resized if larger", - # hint=FieldHint.optional, - # ) - # max_width: int = Field( - # default=1024, - # desc="Maximum width of the image. Image will be resized if larger", - # hint=FieldHint.optional, - # ) - # mean: list[float] = Field( - # default_factory=lambda: [0.48145466, 0.4578275, 0.40821073], - # desc="Mean RGB values for pixel normalization", - # hint=FieldHint.optional, - # ) - # std: list[float] = Field( - # default_factory=lambda: [0.26862954, 0.26130258, 0.27577711], - # desc="Standard deviation RGB values for pixel normalization", - # hint=FieldHint.optional, - # ) - # rescale_factor: float = Field( - # default=255.0, - # desc="Diminisher factor for pixel normalization", - # hint=FieldHint.optional, - # ) - - -@config_class() -class MultiModalProcessorConfig(Config): - """ - Wrapper config that stores the `ImageProcessorConfig` and `TokenizerConfig` - """ - - tokenizer: TokenizerConfig = Field( - default_factory=TokenizerConfig, - desc="Configuration for the tokenizer.", - hint=FieldHint.core, - ) - image_processor: ImageProcessorConfig = Field( - default_factory=ImageProcessorConfig, - desc="Configuration for the image processor.", - hint=FieldHint.core, - ) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index cffaa734f..34b86f213 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -91,8 +91,7 @@ def __init__( max_sequence_length: int, cross_document_attention: bool = True, patch_size: list[int] | None = None, - max_image_height: int | None = None, - max_image_width: int | None = None, + max_image_size: int | None = None, ): """ Create the data and gather some basic information on the dataset(s). @@ -103,8 +102,7 @@ def __init__( self._max_sequence_length = max_sequence_length self._cross_document_attention = cross_document_attention self._patch_size = patch_size - self._max_image_height = max_image_height - self._max_image_width = max_image_width + self._max_image_size = max_image_size def setup( self, @@ -153,8 +151,7 @@ def setup( truncate_documents=self._config.truncate_documents, cross_document_attention=self._cross_document_attention, patch_size=self._patch_size, - image_height=self._max_image_height, - image_width=self._max_image_width, + image_size=self._max_image_size, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 8022a05f7..65adf0bda 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -73,9 +73,8 @@ class GPTSamplingData(SamplingData): tokenizer: "Tokenizer" truncate_documents: bool = True cross_document_attention: bool = True - patch_size: list[int] | None = None - image_height: int | None = None - image_width: int | None = None + patch_size: int | None = None + image_size: int | None = None @config_class() diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 54bf6826a..8651b8fcd 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -170,20 +170,8 @@ def get( offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False, - # , patch_size: tuple(int), max_height: int, max_width: int ): # TODO Soham: handle spans - # if self._has_images: - # doc_size = self._document_sizes[idx] - # n_images = self._n_images[idx] - # image_positions = self._im_positions[idx] - # image_lengths = self._im_lengths[idx] - # image_tokens_seen = 0 - # for idx in range(n_images): - # height, width = ImageProcessor.get_resize_dims(image_lengths[0], image_lengths[1], max_height, max_width) - # n_image_tokens = (height // patch_size[0]) * (width // patch_size[1]) - # if (image_positions[idx] > offset + length) or (image_positions[idx] + n_tokens < offset): - # continue token_ids = np.frombuffer( self._bin_buffer, dtype=self._dtype, @@ -299,6 +287,9 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP for image in document.images: # assume 3 channels (RGB) for all images with PIL.Image.open(io.BytesIO(image["bytes"])) as img: + if img.mode == "L": + # Convert grayscale to RGB + img = img.convert("RGB") pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW image_lengths.append(np.array(pixels.shape[1:])) bin_stream.write(pixels.tobytes(order="C")) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 973c1db53..0ba3f0e13 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -12,9 +12,9 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset -from fast_llm.data.image_processor import ImageProcessor from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank +from fast_llm.layers.vision_encoder.preprocessing import get_num_patches, get_resize_dims from fast_llm.utils import Assert try: @@ -91,8 +91,7 @@ def __init__( self._num_samples = sampling.num_samples self._sequence_length = sampling.sequence_length self._patch_size = sampling.patch_size - self._image_height = sampling.image_height - self._image_width = sampling.image_width + self._image_size = sampling.image_size self._cross_document_attention = sampling.cross_document_attention self._config = sampling.config self._truncate_documents = sampling.truncate_documents @@ -142,7 +141,7 @@ def _sample(self) -> None: image_token_sizes = torch.zeros_like(document_sizes).to(self._device) # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): - image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size[0]) * (sizes[:, 1] // self._patch_size[1])) + image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size) * (sizes[:, 1] // self._patch_size)) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() @@ -394,10 +393,8 @@ def __getitem__(self, index: int) -> typing.Any: document_size, image_lengths = self._indexed_dataset.get_document_size(document_index, self._patch_size) image_sizes = [ - ImageProcessor.get_num_patches_from_dims( - *ImageProcessor.get_resize_dims( - *image_length, self._image_height, self._image_width, self._patch_size - ), + get_num_patches( + *get_resize_dims(*image_length, self._image_size, self._image_size, self._patch_size), self._patch_size, ) for image_length in image_lengths @@ -443,10 +440,6 @@ def __getitem__(self, index: int) -> typing.Any: image_tokens_added += image_tokens start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) - # TODO Soham: remove this - # if len(sample.images) == 1: - # sample.images.append(sample.images[0]) - # sample.image_positions = np.concatenate([sample.image_positions, sample.image_positions]) images.append(sample.images) # TODO Soham: add offsets for loss masking spans if self._config.use_loss_masking_spans: @@ -461,7 +454,6 @@ def __getitem__(self, index: int) -> typing.Any: sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) - # + np.array([ImageProcessor.get_num_patches(image) for image in images[idx] for idx in range(len(images))]) if not self._cross_document_attention else None ) diff --git a/fast_llm/data/image_processor.py b/fast_llm/data/image_processor.py deleted file mode 100644 index edfeceb95..000000000 --- a/fast_llm/data/image_processor.py +++ /dev/null @@ -1,55 +0,0 @@ -import math - -import torch -from torchvision.transforms.v2 import functional as F - -from fast_llm.data.config import ImageProcessorConfig - - -class ImageProcessor: - def __init__(self, config: ImageProcessorConfig): - self.patch_size = config.patch_size - self.mean = [x / config.rescale_factor for x in config.mean] - self.std = [x / config.rescale_factor for x in config.std] - self.max_height = config.max_height - self.max_width = config.max_width - assert ( - self.max_height % self.patch_size[0] == 0 - ), "max_height must be divisible by patch_size[0]. Found {max_height} and {self.patch_size[0]}" - assert ( - self.max_width % self.patch_size[1] == 0 - ), "max_width must be divisible by patch_size[1]. Found {max_width} and {self.patch_size[1]}" - - def resize(self, image): - # Resize the image to the specified size - # TODO Soham: resize for patches only during train? - # TODO Soham: convert all images to tensor? - # height = image.shape[0] - # width = image.shape[1] - height, width = self.get_resize_dims(image.shape[0], image.shape[1], self.max_height, self.max_width) - - # TODO: options for interpolation mode - return F.resize(image, size=(height, width), interpolation=F.InterpolationMode.BICUBIC) - - # TODO Soham: move to utils - @classmethod - def get_resize_dims(self, height, width, max_height, max_width, patch_size: list[int]): - ratio = max(height / max_height, width / max_width) - return ( - (math.ceil(height / ratio), math.ceil(width / ratio)) - if ratio > 1 - else (patch_size[0] * math.ceil(height / patch_size[0]), patch_size[1] * math.ceil(width / patch_size[1])) - ) - - def normalize(self, image: torch.Tensor) -> torch.Tensor: - # Normalize the image using the mean and std - return F.normalize(image, mean=self.mean, std=self.std) - - @classmethod - # TODO Soham: move to utils - def get_num_patches(self, image: torch.Tensor, patch_size: list[int]) -> torch.Tensor: - return (image.shape[0] // patch_size[0]) * (image.shape[1] // patch_size[1]) - - @classmethod - def get_num_patches_from_dims(self, height: int, width: int, patch_size: list[int]) -> torch.Tensor: - return (height // patch_size[0]) * (width // patch_size[1]) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 16cfaf713..9cf8f8b57 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -55,19 +55,14 @@ class BatchConfig(Config): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) - patch_size: list[int] | None = Field( + patch_size: int | None = Field( default=None, desc="Patch size for each image token", hint=FieldHint.optional, ) - max_image_height: int | None = Field( + max_image_size: int | None = Field( default=None, - desc="Maximum image height for each image token", - hint=FieldHint.optional, - ) - max_image_width: int | None = Field( - default=None, - desc="Maximum image width for each image token", + desc="Maximum image height and width", hint=FieldHint.optional, ) num_micro_sequences: int = Field( diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 65ae8e502..b83a118b5 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -8,8 +8,7 @@ class VisionEncoderDimNames: out_channels = "vision_out_channels" intermediate_size = "vision_intermediate_size" - patch_height = "vision_patch_height" - patch_width = "vision_patch_width" + patch_size = "vision_patch_size" kv_channels = "vision_kv_channels" @@ -17,8 +16,7 @@ class VisionModelKwargs: patch_size = "patch_size" images = "images" image_positions = "image_positions" - image_height = "image_height" - image_width = "image_width" + image_size = "image_size" image_sizes = "image_sizes" image_mean = "image_normalization_mean" image_std = "image_normalization_std" @@ -28,30 +26,6 @@ class VisionModelKwargs: kv_channels = "vit_kv_channels" -@config_class() -class PatchConvConfig(BaseModelArchitectureConfig): - _abstract = False - """ - Configuration class for the convolution layers to apply on the image patches - """ - in_channels: int = Field( - default=3, - desc="Number of input channels for the convolution layers. Typically 3 for RGB images.", - hint=FieldHint.optional, - ) - bias: bool = Field( - default=False, desc="Whether to use a bias term in the convolution layers.", hint=FieldHint.optional - ) - height: int = Field( - default=16, - desc="Height of the image patches considered as tokens", - ) - width: int | None = Field( - default=16, - desc="Width of the image patches considered as tokens", - ) - - @config_class() class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): _abstract = False @@ -169,33 +143,8 @@ class VisionArchitectureConfig(BaseModelArchitectureConfig): def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.encoder.hidden_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.intermediate_size, self.adapter_size)) - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_height, self.encoder.patch_size)) - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_width, self.encoder.patch_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.encoder.patch_size)) # TODO Soham: add a check for kv channels tensor_space.add_tensor_dim( TensorDim(VisionEncoderDimNames.kv_channels, self.encoder.hidden_size // self.encoder.num_attention_heads) ) - # tensor_space.add_tensor_dim( - # CompositeTensorDim(VisionEncoderDimNames.) - # ) - - # patch_convolution: PatchConvConfig = Field( - # default_factory=PatchConvConfig, - # desc="Configuration for the convolution layers applied to the image patches.", - # hint=FieldHint.optional - # ) - # normalization: NormalizationArchitectureConfig = Field( - # default_factory=NormalizationArchitectureConfig, - # desc="Configuration for the normalization layers applied to the image patches.", - # hint=FieldHint.optional - # ) - # transformer: TransformerArchitectureConfig = Field( - # default_factory=TransformerArchitectureConfig, - # desc="Configuration for the transformer layers applied to the image patches.", - # hint=FieldHint.optional - # ) - # patch_rotary: RotaryArchitectureConfig = Field( - # default_factory=RotaryArchitectureConfig, - # desc="Configuration for the rotary positional embeddings applied to the image patches.", - # hint=FieldHint.optional - # ) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index bbcebf251..8c694d28a 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -161,6 +161,6 @@ def forward( input_, kwargs[VisionModelKwargs.image_sizes][:1], kwargs[VisionModelKwargs.rotary_inv_freq], - image_width=kwargs[VisionModelKwargs.image_width], + image_width=kwargs[VisionModelKwargs.image_size], ) # return self.adapter(self.vision_encoder(input_, kwargs[VisionModelKwargs.image_sizes])) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 57ee3a0b2..154c1a16d 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -1,3 +1,4 @@ +import math import typing import torch @@ -5,9 +6,17 @@ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.vision_encoder.config import VisionArchitectureConfig, VisionModelKwargs +from fast_llm.utils import div -def get_resize_dims(height: int, width: int, max_height: int, max_width: int) -> tuple[int, int]: +def get_num_patches(height: int, width: int, patch_size: int) -> tuple[int, int]: + """ + Calculate the number of patches in height and width dimensions. + """ + return div(height, patch_size) * div(width, patch_size) + + +def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: """ Calculate the new dimensions for resizing an image while maintaining the aspect ratio. If the image is larger than the max dimensions, it will be resized to fit within them. @@ -17,12 +26,12 @@ def get_resize_dims(height: int, width: int, max_height: int, max_width: int) -> return ( (int(height / ratio), int(width / ratio)) if ratio > 1 - else (max_height * (height // max_height), max_width * (width // max_width)) + else (patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size)) ) -def resize(image: torch.Tensor, max_height: int, max_width: int) -> tuple[int, int]: - resize_dims = get_resize_dims(image.size(1), image.size(2), max_height, max_width) +def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: + resize_dims = get_resize_dims(image.size(1), image.size(2), max_height, max_width, patch_size=patch_size) # TODO: options for interpolation mode? return F.resize(image, size=resize_dims, interpolation=F.InterpolationMode.BICUBIC) @@ -71,14 +80,17 @@ def __init__(self, config: VisionArchitectureConfig, tensor_space: TensorSpace): def preprocess(self, kwargs: dict[str, typing.Any]) -> None: images = kwargs.get("images") - im_height = kwargs.get(VisionModelKwargs.image_height) - im_width = kwargs.get(VisionModelKwargs.image_width) - image_sizes = [get_resize_dims(im.size(1), im.size(2), im_height, im_width) for im in images] + im_height = kwargs.get(VisionModelKwargs.image_size) + im_width = kwargs.get(VisionModelKwargs.image_size) + patch_size = kwargs[VisionModelKwargs.patch_size] + image_sizes = [ + get_resize_dims(im.size(1), im.size(2), im_height, im_width, patch_size=patch_size) for im in images + ] kwargs[VisionModelKwargs.image_sizes] = image_sizes images = [ pad( normalize( - resize(image, im_height, im_width) / kwargs[VisionModelKwargs.image_rescale_factor], + resize(image, im_height, im_width, patch_size) / kwargs[VisionModelKwargs.image_rescale_factor], mean=kwargs[VisionModelKwargs.image_mean], std=kwargs[VisionModelKwargs.image_std], ), @@ -97,5 +109,5 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: kwargs[VisionModelKwargs.rope_theta], kwargs[VisionModelKwargs.kv_channels], im_height, - kwargs[VisionModelKwargs.patch_size], + patch_size, ).to(device=self._tensor_space.distributed.device) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index ffbd22816..c273f09b1 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -150,8 +150,7 @@ def preprocess_meta( micro_sequence_length = sequence_length if self._config.vision_encoder: - image_height = batch_meta.max_image_height - image_width = batch_meta.max_image_width + image_size = batch_meta.max_image_size image_mean = [ self._config.vision_encoder.normalization.mean_r, self._config.vision_encoder.normalization.mean_g, @@ -165,8 +164,7 @@ def preprocess_meta( image_rescale_factor = self._config.vision_encoder.normalization.rescale_factor vision_kwargs = { VisionModelKwargs.patch_size: self._config.vision_encoder.encoder.patch_size, - VisionModelKwargs.image_height: image_height, - VisionModelKwargs.image_width: image_width, + VisionModelKwargs.image_size: image_size, VisionModelKwargs.image_mean: image_mean, VisionModelKwargs.image_std: image_std, VisionModelKwargs.image_rescale_factor: image_rescale_factor, diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index b801fbd3d..bc16829b3 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -22,8 +22,7 @@ def _get_data(self) -> GPTData: max_sequence_length=self._config.batch.sequence_length, cross_document_attention=self._config.batch.cross_document_attention, patch_size=self._config.batch.patch_size, - max_image_height=self._config.batch.max_image_height, - max_image_width=self._config.batch.max_image_width, + max_image_size=self._config.batch.max_image_size, ) def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: From 74a99b8ec047e31acd514a48237196ed9da761be Mon Sep 17 00:00:00 2001 From: root Date: Tue, 6 May 2025 17:44:50 +0000 Subject: [PATCH 13/97] changes --- fast_llm/engine/schedule/config.py | 21 +- fast_llm/functional/config.py | 2 + fast_llm/layers/language_model/config.py | 20 +- fast_llm/layers/multi_modal/embedding.py | 52 +-- fast_llm/layers/transformer/attention.py | 109 +++--- fast_llm/layers/transformer/config.py | 96 +++-- fast_llm/layers/transformer/mlp.py | 22 +- fast_llm/layers/transformer/preprocessing.py | 139 ++++++-- fast_llm/layers/transformer/transformer.py | 18 +- fast_llm/layers/vision_encoder/adapter.py | 39 ++- fast_llm/layers/vision_encoder/config.py | 178 ++++++---- fast_llm/layers/vision_encoder/encoder.py | 141 ++------ .../layers/vision_encoder/preprocessing.py | 153 ++++++-- fast_llm/models/gpt/conversion.py | 330 +++++++++++------- fast_llm/models/gpt/model.py | 113 ++++-- fast_llm/tools/cli.py | 1 - 16 files changed, 886 insertions(+), 548 deletions(-) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 9cf8f8b57..517a9cff5 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -55,16 +55,6 @@ class BatchConfig(Config): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) - patch_size: int | None = Field( - default=None, - desc="Patch size for each image token", - hint=FieldHint.optional, - ) - max_image_size: int | None = Field( - default=None, - desc="Maximum image height and width", - hint=FieldHint.optional, - ) num_micro_sequences: int = Field( init=False, desc="Number of micro-sequences to split each sample (= seqence length / micro-sequence length).", @@ -81,6 +71,17 @@ class BatchConfig(Config): desc="Pointer to a distributed configuration, required to know the data-parallel split of the batch.", hint=FieldHint.setup, ) + # Image inputs + patch_size: int | None = Field( + default=None, + desc="Patch size for each image token", + hint=FieldHint.optional, + ) + max_image_size: int | None = Field( + default=None, + desc="Maximum image height and width", + hint=FieldHint.optional, + ) def setup(self, distributed_config: DistributedConfig) -> None: self._distributed = distributed_config diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 9f1fe005e..c5da0f9b1 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -82,6 +82,8 @@ def _set_activation_fn_map() -> None: ActivationType.squared_relu: "relu2", } _ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()} +_ACTIVATION_HF_NAMES_INV["gelu"] = ActivationType.gelu + MAX_DROPLESS_BLOCK_SIZE_ROW = 128 diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index ec80a9334..887952d7a 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -6,7 +6,7 @@ from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig -from fast_llm.layers.vision_encoder.config import VisionArchitectureConfig +from fast_llm.layers.vision_encoder.config import VisionEncoderArchitectureConfig, VisionEncoderConfig from fast_llm.utils import Assert @@ -34,6 +34,7 @@ class LanguageModelKwargs: position_ids = "position_ids" # TODO: These are generic labels = "labels" + tokens = "tokens" phase = "phase" @@ -44,7 +45,7 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.core, ) - vision_encoder: None | VisionArchitectureConfig = Field( + vision_encoder: None | VisionEncoderArchitectureConfig = Field( default=None, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, @@ -130,7 +131,7 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_class = LanguageModelArchitectureConfig transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) - vision_encoder: None | VisionArchitectureConfig = FieldUpdate( + vision_encoder: None | VisionEncoderConfig = FieldUpdate( default=None, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, @@ -215,16 +216,3 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: if self.vision_encoder is not None: self.vision_encoder.setup_tensor_space(tensor_space) - - -class MultiModalBaseConfig(BaseModelConfig): - language_model: LanguageModelBaseConfig = Field( - default_factory=LanguageModelBaseConfig, - desc="Configuration for the language model.", - hint=FieldHint.core, - ) - vision_model: VisionArchitectureConfig = Field( - default_factory=VisionArchitectureConfig, - desc="Configuration for the vision inputs.", - hint=FieldHint.core, - ) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 3b62c60b7..a3abe7813 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -7,8 +7,8 @@ from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.layers.vision_encoder.config import VisionModelKwargs -from fast_llm.layers.vision_encoder.encoder import VisionEncoder +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs +from fast_llm.layers.vision_encoder.preprocessing import get_num_patches from fast_llm.tensor import TensorMeta @@ -23,7 +23,6 @@ def __init__( tensor_space: TensorSpace, ): super().__init__(config, tensor_space) - self.vision_encoder = VisionEncoder(config, tensor_space) def forward( self, @@ -38,46 +37,29 @@ def forward( tensor_name="Embedding output", dtype=self._residual_dtype, ) - # return self._forward( - # input_, - # kwargs.get(LanguageModelKwargs.position_ids), - # kwargs.get(VisionModelKwargs.images), - # kwargs.get(VisionModelKwargs.image_sizes), - # kwargs.get(VisionModelKwargs.image_positions), - # ) - # TODO Soham: offset position ids - images = kwargs.pop(VisionModelKwargs.images)[:1] + # image_embeddings = kwargs.pop(VisionEncoderKwargs.patch_embeddings) position_ids = kwargs.get(LanguageModelKwargs.position_ids) - image_positions = kwargs.get(VisionModelKwargs.image_positions)[:1] - image_embeddings = self.vision_encoder(images, kwargs) - embeddings = super()._forward(input_, position_ids) - img_tokens_seen = 0 + image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes) + image_positions = kwargs.get(VisionEncoderKwargs.image_positions) + tokens = kwargs.get(LanguageModelKwargs.tokens) + # get text embeddings + embeddings = super()._forward(tokens, position_ids) image_idx = 0 - for sample_idx, positions in enumerate(image_positions): - # embedding_parts = [] - for position in positions[:1]: - image = images[image_idx] - image_tokens = (image.size(1) // self._config.vision_encoder.encoder.patch_size) * ( - image.size(2) // self._config.vision_encoder.encoder.patch_size - ) - embeddings[sample_idx, position : position + image_tokens] = image_embeddings[ - sample_idx, img_tokens_seen : img_tokens_seen + image_tokens + for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): + image_embedding_offset = 0 + for position, size in zip(positions, sizes): + num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) + embeddings[sample_idx, position : position + num_image_tokens] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens ] - # embedding_parts.append(text_embeddings[sample_idx, :position]) - # embedding_parts.append(image_embeddings[sample_idx, img_tokens_seen : img_tokens_seen + image_tokens]) + image_embedding_offset += num_image_tokens image_idx += 1 - img_tokens_seen += image_tokens - # embedding_parts.append(text_embeddings[sample_idx, position:]) - # TODO Soham: debug from here - # embeddings.append(torch.cat(embedding_parts, dim=0)) - # embeddings = torch.stack(embeddings, dim=0) + with set_generator( self._tensor_space.distributed.tp_generator if self._sequence_parallel else self._tensor_space.distributed.pp_generator ): embeddings = torch.dropout(embeddings, self._dropout_p, self.training) - # assert embeddings.size(1) == 8192 - del image_embeddings - del images + return embeddings.to(self._residual_dtype) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index c7ae55c5c..3a3f40239 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -14,7 +14,9 @@ TransformerDimNames, TransformerKwargs, TransformerSubLayerName, + VisionTransformerConfig, ) +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -57,24 +59,6 @@ class Attention(torch.nn.Module): A self-attention layer. """ - _QUERY_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_heads, - TransformerDimNames.kv_channels, - ) - _KV_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.group_heads, - TransformerDimNames.kv_channels, - ) - _CONTEXT_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_dense, - ) - def __init__( self, config: TransformerConfig, @@ -82,12 +66,19 @@ def __init__( layer_index, ): super().__init__() + if isinstance(config, VisionTransformerConfig): + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + elif isinstance(config, TransformerConfig): + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs self._config = config self._tensor_space = tensor_space - Assert.in_range_incl(layer_index, 1, self._config.num_layers) + # Assert.in_range_incl(layer_index, 1, self._config.num_layers) self._layer_index = layer_index self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel self._debug_transformer = self._config.debug_transformer + self._causal = self._config.causal self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -101,19 +92,19 @@ def __init__( max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels).size - self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size - self._local_head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).size - self._local_heads_per_group = self._tensor_space.get_tensor_dim(TransformerDimNames.group_heads).size + self._kv_channels = self._tensor_space.get_tensor_dim(self._transformer_dim_names.kv_channels).size + self._head_groups = self._tensor_space.get_tensor_dim(self._transformer_dim_names.head_groups).global_size + self._local_head_groups = self._tensor_space.get_tensor_dim(self._transformer_dim_names.head_groups).size + self._local_heads_per_group = self._tensor_space.get_tensor_dim(self._transformer_dim_names.group_heads).size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query), + self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_query), bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -122,7 +113,7 @@ def __init__( ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value), + self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_key_value), bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -133,7 +124,7 @@ def __init__( # Output. self.dense = InputParallelLinear( - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense), + self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_dense), hidden_dim, bias=self._config.add_attn_dense_bias, weight_init_method=init_method_std_attn_proj, @@ -199,7 +190,7 @@ def _attn_fused( def _get_meta( self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] ) -> TensorMeta: - hidden_dims = {dim.name: dim for dim in kwargs[TransformerKwargs.hidden_dims]} + hidden_dims = {dim.name: dim for dim in kwargs[self._transformer_kwargs.hidden_dims]} return TensorMeta.from_dims( tuple( hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) @@ -209,6 +200,32 @@ def _get_meta( dtype=input_.dtype, ) + @property + def query_dims(self): + return ( + self._transformer_dim_names.batch, + self._transformer_dim_names.sequence_q, + self._transformer_dim_names.composite_heads, + self._transformer_dim_names.kv_channels, + ) + + @property + def kv_dims(self): + return ( + self._transformer_dim_names.batch, + self._transformer_dim_names.sequence_q, + self._transformer_dim_names.group_heads, + self._transformer_dim_names.kv_channels, + ) + + @property + def context_dims(self): + return ( + self._transformer_dim_names.batch, + self._transformer_dim_names.sequence_q, + self._transformer_dim_names.composite_dense, + ) + def _debug_log( self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] ) -> None: @@ -307,12 +324,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # TODO: Move the rest to function. - if (past_key_values := kwargs.get(TransformerKwargs.past_key_values)) is not None: + if (past_key_values := kwargs.get(self._transformer_kwargs.past_key_values)) is not None: assert sequence_first # Clear the lists so tensors can be de-allocated key_value = torch.cat((past_key_values.pop(0), key_value), dim=0) - if (presents := kwargs.get(TransformerKwargs.presents)) is not None: + if (presents := kwargs.get(self._transformer_kwargs.presents)) is not None: # Return the presents as a leaf tensors so the gradients from later micro-sequences # don't propagate to this one. presents.append(present := key_value.detach().requires_grad_()) @@ -339,23 +356,23 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._config.rotary.enabled: if self._debug_transformer: - self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) + self._debug_log(query, "query_rotary_input", self.query_dims, kwargs) self._debug_log( key, "key_rotary_input", - self._KV_DIMS, + self.kv_dims, kwargs, ) rotary_fn = triton_rotary_autograd_ if self._config.rotary.triton else apply_rotary_embeddings - query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q]) - key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k]) + query = rotary_fn(query, kwargs[self._transformer_kwargs.rotary_freq_q]) + key = rotary_fn(key, kwargs[self._transformer_kwargs.rotary_freq_k]) window_size = self._decide_window_size() if self._use_flash_attention: assert _flash_available with set_generator(self._tensor_space.distributed.tp_generator): - if (cu_seqlens_q := kwargs.get(TransformerKwargs.cu_seqlens_q, None)) is not None: + if (cu_seqlens_q := kwargs.get(self._transformer_kwargs.cu_seqlens_q, None)) is not None: out_dims = query.size() query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) @@ -365,12 +382,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key, value, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=kwargs.get(TransformerKwargs.cu_seqlens_k), - max_seqlen_q=kwargs.get(TransformerKwargs.max_seqlen_q), - max_seqlen_k=kwargs.get(TransformerKwargs.max_seqlen_k), + cu_seqlens_k=kwargs.get(self._transformer_kwargs.cu_seqlens_k), + max_seqlen_q=kwargs.get(self._transformer_kwargs.max_seqlen_q), + max_seqlen_k=kwargs.get(self._transformer_kwargs.max_seqlen_k), dropout_p=self._config.attention_dropout if self.training else 0.0, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), - causal=True, + causal=self._causal, softmax_scale=self._softmax_scale, ).view(*out_dims) else: @@ -380,7 +397,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ value, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), dropout_p=self._config.attention_dropout if self.training else 0.0, - causal=True, + causal=self._causal, softmax_scale=self._softmax_scale, ) input_ = input_.flatten(-2) @@ -390,25 +407,25 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ query.flatten(-2), key.flatten(-2), value.flatten(-2), - kwargs[TransformerKwargs.attention_mask], - kwargs[TransformerKwargs.attention_mask_value], + kwargs[self._transformer_kwargs.attention_mask], + kwargs[self._transformer_kwargs.attention_mask_value], ) if self._debug_transformer: - self._debug_log(query, "query", self._QUERY_DIMS, kwargs) + self._debug_log(query, "query", self.query_dims, kwargs) self._debug_log( key, "key", - self._KV_DIMS, + self.kv_dims, kwargs, ) self._debug_log( value, "value", - self._KV_DIMS, + self.kv_dims, kwargs, ) - self._debug_log(input_, "context", self._CONTEXT_DIMS, kwargs) + self._debug_log(input_, "context", self.context_dims, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 4806e37ec..6b0d7ad68 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -84,6 +84,7 @@ class TransformerKwargs: sequence_q_dim = "sequence_q_dim" sequence_k_dim = "sequence_k_dim" sequence_length = "sequence_length" + micro_batch_size = "micro_batch_size" # TODO: Move grad_output = "grad_output" @@ -98,6 +99,8 @@ class RotaryEmbeddingType(str, enum.Enum): default = "default" llama3 = "llama3" yarn = "yarn" + # TODO Soham: generic name? + pixtral = "pixtral" @config_class() @@ -166,6 +169,15 @@ class RotaryConfig(RotaryArchitectureConfig, BaseModelConfig): pass +@config_class() +class VisionRotaryConfig(RotaryConfig): + type: RotaryEmbeddingType = Field( + default=RotaryEmbeddingType.pixtral, + desc="The type of rotary embedding to use. Choices: none, default, llama3, yarn, pixtral.", + hint=FieldHint.feature, + ) + + class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" everywhere = "everywhere" @@ -398,63 +410,73 @@ def _from_dict( cls._handle_renamed_field(default, "triton_rotary", ("rotary", "triton")) return super()._from_dict(default, strict, flat) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + def setup_tensor_space(self, tensor_space: TensorSpace, type: str | None = None) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + if type == "vision": + # TODO Soham: better way to get around circular imports? + from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames + + transformer_dim_names = VisionTransformerDimNames + else: + transformer_dim_names = TransformerDimNames + # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.hidden, self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.hidden, self.hidden_size)) # Self-attention dimensions tensor_space.add_tensor_dim( head_groups := TensorDim( - TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - TransformerDimNames.group_heads, + transformer_dim_names.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(key_and_value := TensorDim(transformer_dim_names.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(transformer_dim_names.kv_channels, self.kv_channels)) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) + CompositeTensorDim(transformer_dim_names.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) ) # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(TransformerDimNames.mlp, self.ffn_hidden_size, tensor)) - tensor_space.add_tensor_dim(gate_and_up := TensorDim(TransformerDimNames.gate_and_up, 2 if self.gated else 1)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(TransformerDimNames.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_expert_mlp, (experts, mlp))) + tensor_space.add_tensor_dim(mlp := TensorDim(transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + gate_and_up := TensorDim(transformer_dim_names.gate_and_up, 2 if self.gated else 1) ) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.unshared_experts, self.num_unshared_experts)) + tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp))) + tensor_space.add_tensor_dim(experts := TensorDim(transformer_dim_names.experts, self.num_experts)) + tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_expert_mlp, (experts, mlp))) + tensor_space.add_tensor_dim( + CompositeTensorDim(transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + ) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.top_experts, self.num_experts_per_token)) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.unshared_experts, self.num_unshared_experts)) # shared_experts if self.num_shared_experts: tensor_space.add_tensor_dim( - shared_experts := TensorDim(TransformerDimNames.shared_experts, self.num_shared_experts) + shared_experts := TensorDim(transformer_dim_names.shared_experts, self.num_shared_experts) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) + CompositeTensorDim(transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - TransformerDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) + transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) ) ) @@ -656,6 +678,11 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) + causal: bool = Field( + default=True, + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", + hint=FieldHint.feature, + ) def _validate(self) -> None: if self.init_method_std is None: @@ -718,3 +745,30 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: Assert.is_(self.window_size, None) return use_flash_attention + + +@config_class() +class VisionRotaryConfig(RotaryConfig): + type: RotaryEmbeddingType = Field( + default=RotaryEmbeddingType.pixtral, + desc="The type of rotary embedding to use. Choices: none, default, llama3, yarn, pixtral.", + hint=FieldHint.feature, + ) + + +@config_class() +class VisionTransformerConfig(TransformerConfig): + """ + Configuration for the Vision Transformer (ViT) model. + """ + + causal: bool = FieldUpdate( + default=False, + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", + hint=FieldHint.feature, + ) + rotary: VisionRotaryConfig = FieldUpdate( + default_factory=VisionRotaryConfig, + desc="Configuration for the rotary positional embeddings.", + hint=FieldHint.feature, + ) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index 9b90beffb..1b494fc0b 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -8,7 +8,14 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerName +from fast_llm.layers.transformer.config import ( + TransformerConfig, + TransformerDimNames, + TransformerKwargs, + TransformerSubLayerName, + VisionTransformerConfig, +) +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -18,6 +25,13 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s super().__init__() self._name = name + if isinstance(config, VisionTransformerConfig): + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + elif isinstance(config, TransformerConfig): + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs + init_method_1 = init_normal_( std=config.init_method_std_mlp_1, min_val=config.init_method_min_mlp_1, @@ -29,8 +43,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self._intermediate_dim = tensor_space.get_tensor_dim(TransformerDimNames.composite_expert_mlp) + hidden_dim = tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) + self._intermediate_dim = tensor_space.get_tensor_dim(self._transformer_dim_names.composite_expert_mlp) self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._recompute_level = config.mlp_recompute_level @@ -41,7 +55,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), + tensor_space.get_tensor_dim(self._transformer_dim_names.composite_gated_expert_mlp), bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index cbafe6c97..542b4d42e 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -12,7 +12,9 @@ TransformerConfig, TransformerDimNames, TransformerKwargs, + VisionTransformerConfig, ) +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -129,63 +131,122 @@ def get_rotary_frequencies( return frequencies +def get_2d_rotary_frequencies( + config: RotaryConfig, + height, + width, + kv_channels, + *, + device="cuda", +) -> torch.Tensor: + # Calculate complex frequencies by using alternating channels for width and height + height_positions = torch.arange(height, device=device, dtype=torch.float64) + width_positions = torch.arange(width, device=device, dtype=torch.float64) + frequencies = config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) + # TODO Soham: apply scaling + angles_h = torch.outer(height_positions, frequencies[::2]) + angles_w = torch.outer(width_positions, frequencies[1::2]) + angles = torch.cat( + [ + angles_h[:, None, :].repeat(1, width, 1), + angles_w[None, :, :].repeat(height, 1, 1), + ], + dim=-1, + ).reshape(-1, kv_channels // 2) + + frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) + if not config.complex_format: + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), kv_channels, 3 + ).contiguous() + + return frequencies + + class RotaryEmbeddingPreprocessor: _scalar_dim: TensorDim - _kv_channels_dim: TensorDim - _rotary_embedding_frequencies: torch.Tensor _mask: torch.Tensor _mask_value: torch.Tensor - _tensor_cache_max_sequence_length: int = -1 def __init__( self, config: RotaryConfig, tensor_space: TensorSpace, ): + # if isinstance(config, TransformerConfig): + # self._transformer_dim_names = TransformerDimNames + # self._transformer_kwargs = TransformerKwargs + # elif isinstance(config, VisionTransformerConfig): + # self._transformer_dim_names = VisionTransformerDimNames + # self._transformer_kwargs = VisionTransformerKwargs + # TODO Soham: better way to do this? + if config.type == RotaryEmbeddingType.pixtral: + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + else: + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs self._config = config assert self._config.enabled self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._kv_channels_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.kv_channels) + self._tensor_cache_max_sequence_length: int = -1 - def create_tensors(self, sequence_length: int) -> None: + def create_tensors(self, sequence_length: int, num_patches: None | int = None) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length - self._rotary_embedding_frequencies = get_rotary_frequencies( - self._config, - sequence_length, - self._kv_channels_dim.global_size, - device=self._tensor_space.distributed.device, - ) + if self._config.type == RotaryEmbeddingType.pixtral: + self._rotary_embedding_frequencies = get_2d_rotary_frequencies( + self._config, + num_patches, + num_patches, + self._kv_channels_dim.global_size, + device=self._tensor_space.distributed.device, + ) + else: + self._rotary_embedding_frequencies = get_rotary_frequencies( + self._config, + sequence_length, + self._kv_channels_dim.global_size, + device=self._tensor_space.distributed.device, + ) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k - ] - kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] + sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + if self._config.type == RotaryEmbeddingType.pixtral: + position_ids = kwargs[self._transformer_kwargs.patch_position_ids] + # TODO Soham: use position_ids_q and position_ids_k for sequence_data_parallelism + kwargs[self._transformer_kwargs.rotary_freq_q] = self._rotary_embedding_frequencies[:, position_ids] + kwargs[self._transformer_kwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, position_ids] + else: + kwargs[self._transformer_kwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ + :, sequence_k - sequence_q : sequence_k + ] + kwargs[self._transformer_kwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.rotary_freq_q] = TensorMeta.from_dims( ( self._scalar_dim, kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_q, + tensor_name=self._transformer_kwargs.rotary_freq_q, ) - kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.rotary_freq_k] = TensorMeta.from_dims( ( self._scalar_dim, kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_k, + tensor_name=self._transformer_kwargs.rotary_freq_k, ) @@ -202,6 +263,12 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, ): + if isinstance(config, VisionTransformerConfig): + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + elif isinstance(config, TransformerConfig): + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config @@ -231,22 +298,22 @@ def create_tensors(self, sequence_length: int) -> None: def preprocess(self, kwargs: dict[str, typing.Any]) -> None: sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - kwargs[TransformerKwargs.attention_mask] = self._mask[ + kwargs[self._transformer_kwargs.attention_mask] = self._mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] - if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths, None)) is not None: + if (sequence_lengths := kwargs.get(self._transformer_kwargs.sequence_lengths, None)) is not None: seq_ids = torch.stack( [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) - kwargs[TransformerKwargs.attention_mask] = ( - kwargs[TransformerKwargs.attention_mask] + kwargs[self._transformer_kwargs.attention_mask] = ( + kwargs[self._transformer_kwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] ) - kwargs[TransformerKwargs.attention_mask_value] = self._mask_value + kwargs[self._transformer_kwargs.attention_mask_value] = self._mask_value def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.attention_mask] = TensorMeta.from_dims( ( self._scalar_dim, self._scalar_dim, @@ -254,12 +321,12 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: self._scalar_dim, kwargs[TransformerKwargs.sequence_k_dim], ), - tensor_name=TransformerKwargs.attention_mask, + tensor_name=self._transformer_kwargs.attention_mask, dtype=torch.bool, ) - kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.attention_mask_value] = TensorMeta.from_dims( (self._scalar_dim,), - tensor_name=TransformerKwargs.attention_mask_value, + tensor_name=self._transformer_kwargs.attention_mask_value, dtype=self._tensor_space.distributed_config.training_dtype.torch, ) @@ -270,6 +337,12 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert self._config.do_use_flash_attention(self._distributed_config) + if isinstance(config, VisionTransformerConfig): + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + elif isinstance(config, TransformerConfig): + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs def preprocess(self, kwargs: dict[str, typing.Any]) -> None: """ @@ -281,7 +354,7 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: also contain previous tokens from the first document in micro-sequence. We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. """ - sequence_lengths = kwargs.get(TransformerKwargs.sequence_lengths) + sequence_lengths = kwargs.get(self._transformer_kwargs.sequence_lengths) sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size if sequence_q < kwargs[TransformerKwargs.sequence_length]: @@ -318,17 +391,17 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: else: seqlens_q = torch.cat(sequence_lengths) seqlens_k = torch.cat(sequence_lengths) - kwargs[TransformerKwargs.cu_seqlens_q] = torch.cat( + kwargs[self._transformer_kwargs.cu_seqlens_q] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.cu_seqlens_k] = torch.cat( + kwargs[self._transformer_kwargs.cu_seqlens_k] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max() - kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max() + kwargs[self._transformer_kwargs.max_seqlen_q] = seqlens_q.max() + kwargs[self._transformer_kwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 311403fc9..ba4e5139f 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,9 +8,15 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import ( + TransformerConfig, + TransformerDimNames, + TransformerKwargs, + VisionTransformerConfig, +) from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -30,6 +36,12 @@ def __init__( return_input: bool = False, ): super().__init__() + if isinstance(config, VisionTransformerConfig): + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + elif isinstance(config, TransformerConfig): + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs self._config = config self._tensor_space = tensor_space self._dropout_p = self._config.hidden_dropout @@ -39,7 +51,7 @@ def __init__( self._layer_index = layer_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) @@ -66,7 +78,7 @@ def name(self) -> str: return f"Transformer layer {self._layer_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): - dims = kwargs[TransformerKwargs.hidden_dims] + dims = kwargs[self._transformer_kwargs.hidden_dims] if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py index b8436f72e..bf5f3f1aa 100644 --- a/fast_llm/layers/vision_encoder/adapter.py +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -1,35 +1,54 @@ +import typing + import torch +from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.functional.triton.mlp import torch_mlp_activation from fast_llm.layers.common.linear import Linear -from fast_llm.layers.transformer.config import TransformerDimNames -from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames -from fast_llm.tensor import init_normal_ +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames +from fast_llm.tensor import TensorMeta, init_normal_ -class VisionAdapter(torch.nn.Module): +class VisionAdapter(Layer): """ Vision adapter layer for the LLM. """ - def __init__(self, intermediate_size: int, tensor_space: TensorSpace, name: str = "vision_adapter"): + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() - self._name = name input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + self._activation_type = config.adapter_activation_type + # TODO Soham: Make them OutputParallelLinear instead? How would this work with parallelism? self.layer_1 = Linear( input_dim, - tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), + tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), bias=True, weight_init_method=init_normal_(), bias_init_method=init_normal_(), ) self.layer_2 = Linear( - tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), + tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), tensor_space.get_tensor_dim(TransformerDimNames.hidden), bias=True, weight_init_method=init_normal_(), bias_init_method=init_normal_(), ) - def forward(self, input_: torch.Tensor): - return self.layer_2(self.layer_1(input_)) + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Vision adapter output", + dtype=input_.dtype, + ) + return self.layer_2( + torch_mlp_activation(input_=self.layer_1(input_), gated=False, activation_type=self._activation_type) + ) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index b83a118b5..7c650bf93 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,20 +1,55 @@ -from fast_llm.config import Config, Field, FieldHint, config_class -from fast_llm.engine.base_model.config import BaseModelArchitectureConfig +from fast_llm.config import Config, Field, FieldHint, FieldUpdate, config_class +from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationConfig +from fast_llm.layers.transformer.config import TransformerArchitectureConfig, VisionTransformerConfig class VisionEncoderDimNames: out_channels = "vision_out_channels" - intermediate_size = "vision_intermediate_size" + adapter_size = "vision_adapter_size" patch_size = "vision_patch_size" kv_channels = "vision_kv_channels" -class VisionModelKwargs: +class VisionTransformerDimNames: + # A set of common tensor dim names packed into a namespace. + # Input dimensions (variable) + # TODO: Does batch belong here? + batch = "vit_batch" + # TODO: Distinguish micro-sequence? + sequence_q = "vit_sequence_q" + sequence_q_tp = "vit_sequence_q_tp" + sequence_k = "vit_sequence_k" + hidden = "vit_hidden" + # Self-attention dimensions + head_groups = "vit_head_groups" + group_heads = "vit_group_heads" + key_and_value = "vit_key_value" + kv_channels = "vit_kv_channels" + composite_heads = "vit_composite_heads" + composite_query = "vit_composite_query" + composite_key_value = "vit_composite_key_value" + composite_dense = "vit_composite_dense" + # MLP dimensions + mlp = "vit_mlp" + gate_and_up = "vit_gate_and_up" + composite_gated_mlp = "vit_composite_gated_mlp" + experts = "vit_experts" + top_experts = "vit_top_experts" + shared_experts = "vit_shared_experts" + unshared_experts = "vit_unshared_experts" + composite_expert_mlp = "vit_composite_expert_mlp" + composite_gated_expert_mlp = "vit_composite_gated_expert_mlp" + composite_shared_expert_mlp = "vit_composite_shared_expert_mlp" + composite_gated_shared_expert_mlp = "vit_composite_gated_shared_expert_mlp" + + +class VisionEncoderKwargs: patch_size = "patch_size" images = "images" + image_patches = "image_patches" image_positions = "image_positions" image_size = "image_size" image_sizes = "image_sizes" @@ -24,56 +59,34 @@ class VisionModelKwargs: rope_theta = "vit_rope_theta" rotary_inv_freq = "vit_rotary_inv_freq" kv_channels = "vit_kv_channels" + max_image_tokens = "max_image_tokens" + patch_embeddings = "patch_embeddings" + hidden_dims = "vit_hidden_dims" -@config_class() -class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): - _abstract = False - """ - Configuration class for the vision encoder, which transforms images into embeddings - """ - path: str | None = Field(default=None, desc="Path to a pretrained vision encoder model.", hint=FieldHint.optional) - pre_norm: NormalizationConfig = Field( - default_factory=NormalizationConfig, - ) - hidden_size: int = Field( - default=1024, desc="The size of the hidden layers in the transformer model.", hint=FieldHint.optional - ) - intermediate_size: int = Field( - default=4096, - desc="The size of the intermediate (feed-forward) layers in the transformer model.", - hint=FieldHint.optional, - ) - num_hidden_layers: int = Field( - default=24, desc="The number of hidden layers in the transformer model.", hint=FieldHint.optional - ) - num_attention_heads: int = Field( - default=16, desc="The number of attention heads for the multi-head attention layers.", hint=FieldHint.optional - ) - num_channels: int = Field( - default=3, desc="Number of channels in the input image, typically 3 for RGB.", hint=FieldHint.optional - ) - image_size: int = Field( - default=1024, desc="The size of the input images (assumed square).", hint=FieldHint.optional - ) - patch_size: int = Field(default=16, desc="The size of the image patches to be encoded.", hint=FieldHint.optional) - hidden_act: str = Field( - default="gelu", desc="The activation function used in the hidden layers.", hint=FieldHint.optional - ) - attention_dropout: float = Field( - default=0.0, desc="The dropout probability for attention layers.", hint=FieldHint.optional - ) - rope_theta: float = Field( - default=10000.0, desc="The base value for rotary position embeddings.", hint=FieldHint.optional - ) - initializer_range: float = Field( - default=0.02, desc="The standard deviation of the normal initializer.", hint=FieldHint.optional - ) - activation_type: ActivationType = Field( - default=ActivationType.silu, - desc="The activation function used in the hidden layers. Default: SiLU.", - hint=FieldHint.optional, - ) +# TODO Soham: do we need all of them? +class VisionTransformerKwargs: + rotary_freq_q = "vit_rotary_freq_q" + rotary_freq_k = "vit_rotary_freq_k" + attention_mask = "vit_attention_mask" + attention_mask_value = "vit_attention_mask_value" + sequence_lengths = "vit_sequence_lengths" + cu_seqlens_q = "vit_cu_seqlens_q" + cu_seqlens_k = "vit_cu_seqlens_k" + max_seqlen_q = "vit_max_seqlen_q" + max_seqlen_k = "vit_max_seqlen_k" + # TODO: Review these + presents = "vit_presents" + past_key_values = "vit_past_key_values" + sequence_first = "vit_sequence_first" + hidden_dims = "vit_hidden_dims" + sequence_q_dim = "vit_sequence_q_dim" + sequence_k_dim = "vit_sequence_k_dim" + sequence_length = "vit_sequence_length" + micro_batch_size = "vit_micro_batch_size" + # TODO: Move + grad_output = "vit_grad_output" + patch_position_ids = "patch_position_ids" @config_class() @@ -116,35 +129,70 @@ class ImageNormalizationConfig(Config): @config_class() -class VisionArchitectureConfig(BaseModelArchitectureConfig): +class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): _abstract = False - encoder: VisionEncoderArchitectureConfig = Field( - default_factory=VisionEncoderArchitectureConfig, - desc="Configuration for the vision encoder that transforms images into embeddings.", + transformer: TransformerArchitectureConfig = Field( + default_factory=TransformerArchitectureConfig, + desc="Configuration for the vision transformer architecture.", + hint=FieldHint.core, + ) + patch_size: int = Field( + default=16, + desc="Patch size for the image encoder.", + hint=FieldHint.core, + ) + patch_norm: NormalizationConfig = Field( + default_factory=NormalizationConfig, + desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, ) adapter_size: int = Field( default=5120, desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", - hint=FieldHint.optional, + hint=FieldHint.core, ) adapter_activation_type: ActivationType = Field( default=ActivationType.gelu, desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", hint=FieldHint.core, ) - normalization: ImageNormalizationConfig = Field( + + def setup_tensor_space(self, tensor_space: TensorSpace): + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.adapter_size, self.adapter_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.patch_size)) + # TODO Soham: add a check for presence of kv channels parameter (head_dim) + tensor_space.add_tensor_dim( + TensorDim( + VisionEncoderDimNames.kv_channels, self.transformer.hidden_size // self.transformer.num_attention_heads + ) + ) + self.transformer.setup_tensor_space(tensor_space, type="vision") + + +@config_class() +class VisionEncoderConfig(VisionEncoderArchitectureConfig, BaseModelConfig): + transformer: VisionTransformerConfig = FieldUpdate( + default_factory=VisionTransformerConfig, + desc="Configuration for the transformer architecture.", + hint=FieldHint.core, + ) + patch_norm: NormalizationConfig = FieldUpdate( + default_factory=NormalizationConfig, + desc="Configuration for the normalization layers applied to the image patches.", + hint=FieldHint.optional, + ) + image_normalization: ImageNormalizationConfig = Field( default_factory=ImageNormalizationConfig, desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, ) + adapter_activation_type: ActivationType = FieldUpdate( + default=ActivationType.gelu, + desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", + hint=FieldHint.core, + ) def setup_tensor_space(self, tensor_space: TensorSpace): - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.encoder.hidden_size)) - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.intermediate_size, self.adapter_size)) - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.encoder.patch_size)) - # TODO Soham: add a check for kv channels - tensor_space.add_tensor_dim( - TensorDim(VisionEncoderDimNames.kv_channels, self.encoder.hidden_size // self.encoder.num_attention_heads) - ) + super().setup_tensor_space(tensor_space) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 8c694d28a..9369037d4 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -1,26 +1,20 @@ -import functools import typing import torch -from transformers import PixtralVisionConfig -from transformers.models.pixtral.modeling_pixtral import PixtralTransformer from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.language_model.config import LanguageModelBaseConfig -from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.layers.vision_encoder.adapter import VisionAdapter -from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionModelKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ -def position_ids_in_meshgrid(patch_embeddings_list, max_width): +def position_ids_in_meshgrid(patch_embeddings_list, max_size): positions = [] for patch in patch_embeddings_list: height, width = patch.shape[-2:] mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) - ids = h_grid * max_width + v_grid + ids = h_grid * max_size + v_grid positions.append(ids[:, 0]) return torch.cat(positions) @@ -41,108 +35,24 @@ def generate_block_attention_mask(patch_embeds_list, tensor): return causal_mask -# TODO Soham: should this just be nn.Module? -class VisionEncoder(Layer): - """ - A vision encoder layer for creating token embeddings from vision model - """ - - def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): +class PatchConv(Layer): + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() - - self._config = config.vision_encoder - self._distributed_config = tensor_space.distributed_config + # TODO Soham: device=meta with torch.device("meta"): - # TODO Soham options to fix rotary: - # 1. load PixtralTransformer instead of PixtralVisionModel. Required to implement conv2d, ln_pre separately and store positional embeddings in kwargs_meta - # 2. set self.vision_encoder.position_embeddings = PixtralRotaryEmbedding(config) outside of meta scope - config = PixtralVisionConfig( - hidden_size=self._config.encoder.hidden_size, - intermediate_size=self._config.encoder.intermediate_size, - num_hidden_layers=self._config.encoder.num_hidden_layers, - num_attention_heads=self._config.encoder.num_attention_heads, - num_channels=self._config.encoder.num_channels, - image_size=self._config.encoder.image_size, - patch_size=self._config.encoder.patch_size, - hidden_act=self._config.encoder.hidden_act, - attention_dropout=self._config.encoder.attention_dropout, - rope_theta=self._config.encoder.rope_theta, - initializer_range=self._config.encoder.initializer_range, - ) - self.patch_conv = torch.nn.Conv2d( + self.conv = torch.nn.Conv2d( in_channels=3, - out_channels=self._config.encoder.hidden_size, - kernel_size=self._config.encoder.patch_size, - stride=self._config.encoder.patch_size, + out_channels=config.transformer.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, bias=False, + dtype=tensor_space.distributed_config.training_dtype.torch, ) - self.patch_conv.weight = ParameterMeta.from_dims( - tuple( - TensorDim(f"patch_conv_weight_{idx}", size) - for idx, size in enumerate(self.patch_conv.weight.shape) - ), + self.conv.weight = ParameterMeta.from_dims( + tuple(TensorDim(f"patch_conv_weight_{idx}", size) for idx, size in enumerate(self.conv.weight.shape)), init_method=init_normal_(), ) - self.norm = self._config.encoder.pre_norm.get_layer( - tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) - ) - self.vision_transformer = PixtralTransformer(config) - # self.vision_encoder = PixtralVisionModel(config) - param_names = [] - # gather all names first. PyTorch complains if we do it in the loop - for name, param in self.vision_transformer.named_parameters(): - param_names.append(name) - for name in param_names: - *module_path, stem = name.split(".") - module = functools.reduce(getattr, module_path, self.vision_transformer) - param = self.vision_transformer.get_parameter(name) - setattr( - module, - stem, - ParameterMeta.from_dims( - tuple(TensorDim(f"{name}_{idx}", size) for idx, size in enumerate(param.shape)), - init_method=init_normal_(), - ), - ) - # none_params = [key for key, value in module._parameters.items() if value is None] - # for key in none_params: - # module._parameters.pop(key) - self.adapter = VisionAdapter( - intermediate_size=tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), - tensor_space=tensor_space, - ) - - def _forward( - self, input_: torch.Tensor, image_sizes: torch.Tensor, inv_freq: torch.Tensor, image_width: int - ) -> torch.Tensor: - patch_embeddings = self.patch_conv(input_) - patch_embeddings_list = [ - embedding[..., : image_size[0], : image_size[1]] - for embedding, image_size in zip(patch_embeddings, image_sizes) - ] - patch_embeddings = torch.cat([p.flatten(1).T for p in patch_embeddings_list], dim=0).unsqueeze(0) - patch_embeddings = self.norm(patch_embeddings) - position_ids = position_ids_in_meshgrid(patch_embeddings_list, image_width // self._config.encoder.patch_size) - freqs = inv_freq[position_ids] - with torch.autocast(device_type=input_.device.type): - cos = freqs.cos() - sin = freqs.sin() - cos = cos.to(dtype=input_.dtype) - sin = sin.to(dtype=input_.dtype) - - attention_mask = generate_block_attention_mask( - [p.shape[-2] * p.shape[-1] for p in patch_embeddings_list], patch_embeddings - ) - - (out,) = self.vision_transformer( - patch_embeddings, - attention_mask=attention_mask, - position_embeddings=(cos, sin), - output_attentions=False, - return_dict=False, - ) - - return self.adapter(out) + self.norm = config.patch_norm.get_layer(tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels)) def forward( self, @@ -150,17 +60,14 @@ def forward( kwargs: dict[str, typing.Any], losses: dict[str, typing.Any] | None = None, metrics: dict | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> torch.Tensor: + hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): - return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims], - tensor_name="Vision Output", - dtype=self._distributed_config.training_dtype.torch, - ) - return self._forward( - input_, - kwargs[VisionModelKwargs.image_sizes][:1], - kwargs[VisionModelKwargs.rotary_inv_freq], - image_width=kwargs[VisionModelKwargs.image_size], - ) - # return self.adapter(self.vision_encoder(input_, kwargs[VisionModelKwargs.image_sizes])) + return TensorMeta.from_dims(hidden_dims) + # we don't need images after this point + # image_patches = kwargs.pop(VisionEncoderKwargs.image_patches) + patch_embeddings = self.norm(self.conv(input_)) + patch_embeddings = patch_embeddings.reshape(*(x.size for x in hidden_dims)) + # Hack to pass patch embeddings to the next layer + # kwargs[VisionEncoderKwargs.patch_embeddings] = patch_embeddings + return patch_embeddings diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 154c1a16d..abae6f11a 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -5,7 +5,12 @@ import torchvision.transforms.v2.functional as F from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.vision_encoder.config import VisionArchitectureConfig, VisionModelKwargs +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.vision_encoder.config import ( + VisionEncoderArchitectureConfig, + VisionEncoderKwargs, + VisionTransformerKwargs, +) from fast_llm.utils import div @@ -23,11 +28,11 @@ def get_resize_dims(height: int, width: int, max_height: int, max_width: int, pa If the image is smaller, it will be resized to the nearest multiple of the patch size. """ ratio = max(height / max_height, width / max_width) - return ( - (int(height / ratio), int(width / ratio)) - if ratio > 1 - else (patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size)) - ) + if ratio > 1: + # Resize to fit within max dimensions + height = int(height / ratio) + width = int(width / ratio) + return patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size) def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: @@ -72,42 +77,128 @@ def create_inv_freqs(rope_theta: int, kv_channels: int, image_size: int, patch_s return torch.cat((inv_freq, inv_freq), dim=-1) +def position_ids_in_meshgrid(image_sizes: list[torch.Tensor], max_size: int, patch_size: int) -> torch.Tensor: + positions = [] + for h, w in image_sizes: + patch_height = h // patch_size + patch_width = w // patch_size + mesh = torch.meshgrid(torch.arange(patch_height), torch.arange(patch_width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_size + v_grid + positions.append(ids[:, 0]) + return positions + + +def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tensor: + patch_height = height // patch_size + patch_width = width // patch_size + mesh = torch.meshgrid(torch.arange(patch_height), torch.arange(patch_width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_size + v_grid + return ids[:, 0] + + class VisionPreprocessor: - def __init__(self, config: VisionArchitectureConfig, tensor_space: TensorSpace): + def __init__(self, config: VisionEncoderArchitectureConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config def preprocess(self, kwargs: dict[str, typing.Any]) -> None: - images = kwargs.get("images") - im_height = kwargs.get(VisionModelKwargs.image_size) - im_width = kwargs.get(VisionModelKwargs.image_size) - patch_size = kwargs[VisionModelKwargs.patch_size] + images = kwargs.get(VisionEncoderKwargs.images) + im_height = kwargs.get(VisionEncoderKwargs.image_size) + im_width = kwargs.get(VisionEncoderKwargs.image_size) + patch_size = kwargs[VisionEncoderKwargs.patch_size] image_sizes = [ - get_resize_dims(im.size(1), im.size(2), im_height, im_width, patch_size=patch_size) for im in images + [get_resize_dims(im.size(1), im.size(2), im_height, im_width, patch_size=patch_size) for im in ims] + for ims in images ] - kwargs[VisionModelKwargs.image_sizes] = image_sizes + kwargs[VisionEncoderKwargs.image_sizes] = image_sizes images = [ - pad( + [ normalize( - resize(image, im_height, im_width, patch_size) / kwargs[VisionModelKwargs.image_rescale_factor], - mean=kwargs[VisionModelKwargs.image_mean], - std=kwargs[VisionModelKwargs.image_std], - ), - max_height=im_height, - max_width=im_width, - ) - for image in images + resize(image, im_height, im_width, patch_size).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch + ) + / kwargs[VisionEncoderKwargs.image_rescale_factor], + mean=kwargs[VisionEncoderKwargs.image_mean], + std=kwargs[VisionEncoderKwargs.image_std], + ) + for image in imgs + ] + for imgs in images ] - images = torch.stack(images, dim=0).to( - # TODO Soham: is this needed? - device=self._tensor_space.distributed.device, - dtype=self._distributed_config.training_dtype.torch, - ) - kwargs[VisionModelKwargs.images] = images - kwargs[VisionModelKwargs.rotary_inv_freq] = create_inv_freqs( - kwargs[VisionModelKwargs.rope_theta], - kwargs[VisionModelKwargs.kv_channels], + # position_ids = position_ids_in_meshgrid(image_sizes, im_height, patch_size) + patches = [] + patch_position_ids = [] + cu_seqlens = [0] + max_seqlen = -1 + for imgs, sizes in zip(images, image_sizes): + # TODO Soham: should this be micro_sequence_length? + # sum( + # get_num_patches(*size, patch_size) for size in sizes + # ) + seq_patches = [] + for image, size in zip(imgs, sizes): + seqlen = get_num_patches(*size, patch_size) + if seqlen > max_seqlen: + max_seqlen = seqlen + cu_seqlens.append(cu_seqlens[-1] + seqlen) + seq_patches.append( + torch.cat( + [ + torch.nn.functional.unfold(image, kernel_size=patch_size, stride=patch_size).T.reshape( + -1, 3, patch_size, patch_size + ), + ] + ) + ) + padding_size = kwargs[TransformerKwargs.sequence_length] - cu_seqlens[-1] + if padding_size > max_seqlen: + max_seqlen = padding_size + cu_seqlens.append(kwargs[TransformerKwargs.sequence_length]) + patches.append( + torch.cat( + [ + *seq_patches, + torch.zeros(padding_size, 3, patch_size, patch_size).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ), + ] + ) + ) + position_ids = torch.cat( + [position_ids_in_meshgrid(*size, im_height // patch_size, patch_size) for size in sizes] + ).to(device=self._tensor_space.distributed.device) + # We pad at the end instead of padding at the position in meshgrid because flash attention does not support custom attention masks + patch_position_ids.append( + torch.cat( + [ + position_ids, + torch.full((padding_size,), 0).to(device=self._tensor_space.distributed.device), + ] + ) + ) + # TODO Soham: remove + assert patches[-1].size(0) == kwargs[TransformerKwargs.sequence_length] + patches = torch.cat(patches) + patch_position_ids = torch.cat(patch_position_ids) + kwargs[VisionEncoderKwargs.image_patches] = patches + kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( + kwargs[VisionEncoderKwargs.rope_theta], + kwargs[VisionEncoderKwargs.kv_channels], im_height, patch_size, ).to(device=self._tensor_space.distributed.device) + kwargs[VisionEncoderKwargs.max_image_tokens] = div(im_height * im_width, patch_size**2) + kwargs[VisionTransformerKwargs.patch_position_ids] = patch_position_ids + # TODO Soham: handle sequence data parallel + kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( + cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 + ) + kwargs[VisionTransformerKwargs.cu_seqlens_k] = torch.tensor( + cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 + ) + kwargs[VisionTransformerKwargs.max_seqlen_q] = max_seqlen + kwargs[VisionTransformerKwargs.max_seqlen_k] = max_seqlen diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index bd7da7979..d599a1148 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -165,7 +165,7 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig def _create_weight_converters( self, hf_base_prefix: str = "", - fast_llm_offset: int = 0, + fast_llm_offset: int = 1, ) -> list[WeightConverter]: converters = [] num_layers = self._model.config.base_model.transformer.num_layers @@ -187,9 +187,18 @@ def _create_weight_converters( return converters def _create_transformer_layer_converters( - self, i: int, ignore_export: bool = False, hf_base_prefix: str = "", fast_llm_offset: int = 1 + self, + i: int, + ignore_export: bool = False, + hf_base_prefix: str = "", + fast_llm_offset: int = 1, + type: str | None = None, ) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.transformer + if type is not None: + if type == "vision": + transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer + else: + transformer_config: TransformerConfig = self._model.config.base_model.transformer norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm converters = [] names_bias_cls = [ @@ -565,6 +574,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: lm_converters[-2] = ConstantExportParamConverter( export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] ) + # TODO Soham: cleaner way to get language model config converters for converter in lm_converters: if isinstance(converter, (RenameParamConverter, MappedConfigParamConverter, RopeScalingParamConverter)): # Llava uses a different name for the text config @@ -579,31 +589,36 @@ def _create_config_converters(cls) -> list[ParamConverter]: export_names=(("text_config", "hidden_size"),), ), # Image processing and conv layer - RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "image_size"),), - export_names=( - ( - "vision_config", - "image_size", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "patch_size"),), - export_names=( - ( - "vision_config", - "patch_size", - ), - ), + # TODO Soham: these options are not in the fast-llm model config. They're read from BatchConfig currently + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "encoder", "image_size"),), + # export_names=( + # ( + # "vision_config", + # "image_size", + # ), + # ), + # ), + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "encoder", "patch_size"),), + # export_names=( + # ( + # "vision_config", + # "patch_size", + # ), + # ), + # ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "patch_norm", "type"),), + fast_llm_value=NormalizationType.rms_norm, ), ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "pre_norm", "type"),), + fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm, ), # Vision Transformer RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "num_hidden_layers"),), + fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), export_names=( ( "vision_config", @@ -612,7 +627,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ), RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "hidden_size"),), + fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), export_names=( ( "vision_config", @@ -621,7 +636,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ), RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "num_attention_heads"),), + fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), export_names=( ( "vision_config", @@ -630,144 +645,213 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ), RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "intermediate_size"),), + fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), export_names=( ( "vision_config", - "intermediate_size", + "num_key_value_heads", ), ), ), - MappedConfigParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "activation_type"),), - export_names=( - ( - "vision_config", - "hidden_act", - ), - ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "num_channels"),), + fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), export_names=( ( "vision_config", - "num_channels", + "intermediate_size", ), ), ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "attention_dropout"),), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), export_names=( ( "vision_config", - "attention_dropout", + "hidden_act", ), ), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "rope_theta"),), - export_names=(("vision_config", "rope_theta"),), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True ), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "adapter_activation_type"),), + export_names=(("projector_hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False + ), + # TODO Soham: add this config param for completeness? + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "encoder", "num_channels"),), + # export_names=( + # ( + # "vision_config", + # "num_channels", + # ), + # ), + # ), + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "transformer", "attention_dropout"),), + # export_names=( + # ( + # "vision_config", + # "attention_dropout", + # ), + # ), + # ), RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "initializer_range"),), - export_names=(("vision_config", "initializer_range"),), + fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), + export_names=(("vision_config", "rope_theta"),), ), + # TODO Soham: add this config param in vision encoder for completeness? + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "transformer", "initializer_range"),), + # export_names=(("vision_config", "initializer_range"),), + # ), ] def _create_vision_transformer_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.vision_encoder.encoder.num_hidden_layers + num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers vision_transformer_converters = [] - for i in range(num_layers): - vision_transformer_converters += [ - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.k_proj.weight", - f"vision_tower.transformer.layers.{i}.attention.k_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.v_proj.weight", - f"vision_tower.transformer.layers.{i}.attention.v_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.q_proj.weight", - f"vision_tower.transformer.layers.{i}.attention.q_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.o_proj.weight", - f"vision_tower.transformer.layers.{i}.attention.o_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention_norm.weight", - f"vision_tower.transformer.layers.{i}.attention_norm.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.down_proj.weight", - f"vision_tower.transformer.layers.{i}.feed_forward.down_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.gate_proj.weight", - f"vision_tower.transformer.layers.{i}.feed_forward.gate_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.up_proj.weight", - f"vision_tower.transformer.layers.{i}.feed_forward.up_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.ffn_norm.weight", - f"vision_tower.transformer.layers.{i}.ffn_norm.weight", - ), - ] + for layer in range(num_layers): + # TODO Soham: check if args are correct + vision_transformer_converters.extend( + self._create_vision_transformer_layer_converters( + layer, + ignore_export=False, + hf_base_prefix="vision_tower.transformer.layers.", + fast_llm_offset=1, + type="vision", + ) + ) return vision_transformer_converters def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: - patch_conv_converter = WeightConverter( - "layers.0.vision_encoder.patch_conv.weight", - "vision_tower.patch_conv.weight", - ) - # TODO Soham: use _get_weight_and_bias_converters? - layernorm_converters = [] - layer_norm_converter = WeightConverter( - "layers.0.vision_encoder.norm.weight", - "vision_tower.ln_pre.weight", - ) - layernorm_converters.append(layer_norm_converter) - layer_norm_converter - if self._model.config.base_model.vision_encoder.encoder.pre_norm.type == NormalizationType.layer_norm: - layer_norm_bias_converter = WeightConverter( - "layers.0.vision_encoder.norm.bias", - "vision_tower.ln_pre.bias", - ) - layernorm_converters.append(layer_norm_bias_converter) + patch_conv_converter = WeightConverter("layers.0.conv.weight", "vision_tower.patch_conv.weight") + layernorm_converters = [ + WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), + ] + if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: + layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) + vision_transformer_converters = self._create_vision_transformer_converters() + offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 adapter_converters = [ - WeightConverter( - "layers.0.vision_encoder.adapter.layer_1.weight", - "multi_modal_projector.linear_1.weight", - ), - WeightConverter( - "layers.0.vision_encoder.adapter.layer_1.bias", - "multi_modal_projector.linear_1.bias", - ), - # TODO Soham: conditionally add bias - WeightConverter( - "layers.0.vision_encoder.adapter.layer_2.weight", - "multi_modal_projector.linear_2.weight", - ), - WeightConverter( - "layers.0.vision_encoder.adapter.layer_2.bias", - "multi_modal_projector.linear_2.bias", - ), + WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), + WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), + # TODO Soham: add bias based on config + WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), + WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), ] + return [patch_conv_converter] + layernorm_converters + vision_transformer_converters + adapter_converters def _create_weight_converters(self) -> list[WeightConverter]: vision_encoder_converter = self._create_vision_encoder_weight_converters() - lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=1) + offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 + lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) return vision_encoder_converter + lm_converters + def _create_vision_transformer_layer_converters( + self, + i: int, + ignore_export: bool = False, + hf_base_prefix: str = "", + fast_llm_offset: int = 1, + type: str | None = None, + ) -> list[WeightConverter]: + if type is not None: + if type == "vision": + transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer + else: + transformer_config: TransformerConfig = self._model.config.base_model.transformer + norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + converters = [] + names_bias_cls = [ + # Self-attn + ( + f"layers.{i+fast_llm_offset}.self_attn.query", + f"vision_tower.transformer.layers.{i}.attention.q_proj", + transformer_config.add_attn_qkv_bias, + QueryWeightConverter, + ), + ( + f"layers.{i+fast_llm_offset}.self_attn.key_value", + ( + f"vision_tower.transformer.layers.{i}.attention.k_proj", + f"vision_tower.transformer.layers.{i}.attention.v_proj", + ), + transformer_config.add_attn_qkv_bias, + KeyValueWeightConverter, + ), + ( + f"layers.{i+fast_llm_offset}.self_attn.dense", + f"vision_tower.transformer.layers.{i}.attention.o_proj", + transformer_config.add_attn_dense_bias, + WeightConverter, + ), + # Norm + ( + f"layers.{i+fast_llm_offset}.norm_1", + f"vision_tower.transformer.layers.{i}.attention_norm", + norm_bias, + WeightConverter, + ), + ( + f"layers.{i+fast_llm_offset}.norm_2", + f"vision_tower.transformer.layers.{i}.ffn_norm", + norm_bias, + WeightConverter, + ), + ] + for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: + converters += self._get_weight_and_bias_converters( + fast_llm_prefix, + () if ignore_export else hf_prefix, + use_bias, + cls=IgnoreExportWeightConverter if ignore_export else cls, + ) + + # MLP + if ignore_export: + converters += self._get_weight_and_bias_converters( + f"layers.{i+fast_llm_offset}.mlp.layer_1", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+fast_llm_offset}.mlp.layer_2", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, + ) + converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] + else: + converters += self._get_vision_transformer_mlp_converters( + f"layers.{i+fast_llm_offset}", f"vision_tower.transformer.layers.{i}" + ) + return converters + + def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + return [ + SplitWeightConverter( + f"{fast_llm_prefix}.mlp.layer_1.weight", + (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), + ), + MLPLayer2Converter( + f"{fast_llm_prefix}.mlp.layer_2.weight", + f"{hf_prefix}.feed_forward.down_proj.weight", + self._model.config.base_model, + ), + ] + class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index c273f09b1..6aef273f6 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -27,7 +27,10 @@ RotaryEmbeddingPreprocessor, ) from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionModelKwargs +from fast_llm.layers.transformer.vision_transformer import VisionTransformerLayer +from fast_llm.layers.vision_encoder.adapter import VisionAdapter +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs, VisionTransformerDimNames +from fast_llm.layers.vision_encoder.encoder import PatchConv from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -76,6 +79,10 @@ def __init__( if self._config.vision_encoder: self._vision_preprocessor = VisionPreprocessor(self._config.vision_encoder, self._tensor_space) + if self._config.vision_encoder.transformer.rotary.enabled: + self._vision_rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor( + self._config.vision_encoder.transformer.rotary, self._tensor_space + ) def get_output_layers(self) -> list[Layer]: return [ @@ -99,22 +106,35 @@ def get_output_layers(self) -> list[Layer]: ] ] + def get_vision_layers(self) -> list[Layer]: + patch_conv = PatchConv(self._config.vision_encoder, self._tensor_space) + vit_layers = [ + VisionTransformerLayer(self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1) + for idx in range(self._config.vision_encoder.transformer.num_layers) + ] + return [ + patch_conv, + *vit_layers, + VisionAdapter(self._config.vision_encoder, self._tensor_space), + MultiModalEmbedding(self._config, self._tensor_space), + ] + def get_layers(self) -> list[Layer]: if self._config.transformer.num_layers == 0: Assert.eq(self._config.prediction_heads, 1) return [ - ( - LanguageModelEmbedding(self._config, self._tensor_space) + *( + [LanguageModelEmbedding(self._config, self._tensor_space)] if self._config.vision_encoder is None - else MultiModalEmbedding(self._config, self._tensor_space) + else self.get_vision_layers(self._config, self._tensor_space) ), LanguageModelHead(self._config, self._tensor_space, 0), ] return [ - ( - LanguageModelEmbedding(self._config, self._tensor_space) + *( + [LanguageModelEmbedding(self._config, self._tensor_space)] if self._config.vision_encoder is None - else MultiModalEmbedding(self._config, self._tensor_space) + else self.get_vision_layers() ), *[ TransformerLayer( @@ -152,24 +172,24 @@ def preprocess_meta( if self._config.vision_encoder: image_size = batch_meta.max_image_size image_mean = [ - self._config.vision_encoder.normalization.mean_r, - self._config.vision_encoder.normalization.mean_g, - self._config.vision_encoder.normalization.mean_b, + self._config.vision_encoder.image_normalization.mean_r, + self._config.vision_encoder.image_normalization.mean_g, + self._config.vision_encoder.image_normalization.mean_b, ] image_std = [ - self._config.vision_encoder.normalization.std_r, - self._config.vision_encoder.normalization.std_g, - self._config.vision_encoder.normalization.std_b, + self._config.vision_encoder.image_normalization.std_r, + self._config.vision_encoder.image_normalization.std_g, + self._config.vision_encoder.image_normalization.std_b, ] - image_rescale_factor = self._config.vision_encoder.normalization.rescale_factor + image_rescale_factor = self._config.vision_encoder.image_normalization.rescale_factor vision_kwargs = { - VisionModelKwargs.patch_size: self._config.vision_encoder.encoder.patch_size, - VisionModelKwargs.image_size: image_size, - VisionModelKwargs.image_mean: image_mean, - VisionModelKwargs.image_std: image_std, - VisionModelKwargs.image_rescale_factor: image_rescale_factor, - VisionModelKwargs.rope_theta: self._config.vision_encoder.encoder.rope_theta, - VisionModelKwargs.kv_channels: self._tensor_space.get_tensor_dim( + VisionEncoderKwargs.patch_size: batch_meta.patch_size, + VisionEncoderKwargs.image_size: image_size, + VisionEncoderKwargs.image_mean: image_mean, + VisionEncoderKwargs.image_std: image_std, + VisionEncoderKwargs.image_rescale_factor: image_rescale_factor, + VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, + VisionEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( VisionEncoderDimNames.kv_channels ).size, } @@ -218,6 +238,18 @@ def preprocess_meta( if sequence_first else (batch_dim, hidden_sequence_q_dim, hidden_dim) ) + if self._config.vision_encoder: + vision_hidden_dim = self._tensor_space.get_tensor_dim(VisionTransformerDimNames.hidden) + vision_hidden_dims = ( + (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) + if sequence_first + else (batch_dim, hidden_sequence_q_dim, vision_hidden_dim) + ) + vision_kwargs.update( + { + VisionEncoderKwargs.hidden_dims: vision_hidden_dims, + } + ) common_kwargs = { LanguageModelKwargs.phase: phase, @@ -225,6 +257,7 @@ def preprocess_meta( TransformerKwargs.hidden_dims: hidden_dims, TransformerKwargs.sequence_length: sequence_length, TransformerKwargs.sequence_q_dim: sequence_q_dim, + TransformerKwargs.micro_batch_size: micro_batch_size, } common_kwargs.update(vision_kwargs) @@ -253,6 +286,9 @@ def preprocess_meta( self._position_embedding_preprocessor.preprocess_meta(kwargs) if self._config.transformer.rotary.enabled: self._rotary_embedding_preprocessor.preprocess_meta(kwargs) + if self._config.vision_encoder: + if self._config.vision_encoder.transformer.rotary.enabled: + self._vision_rotary_embedding_preprocessor.preprocess_meta(kwargs) if not self._use_flash_attention: self._backup_attention_preprocessor.preprocess_meta(kwargs) preprocessed_meta.append((tokens, kwargs)) @@ -294,6 +330,11 @@ def preprocess( self._rotary_embedding_preprocessor.create_tensors(sequence_length) if not self._use_flash_attention: self._backup_attention_preprocessor.create_tensors(sequence_length) + if self._config.vision_encoder and self._config.vision_encoder.transformer.rotary.enabled: + max_num_patches = ( + common_kwargs[VisionEncoderKwargs.image_size] // common_kwargs[VisionEncoderKwargs.patch_size] + ) + self._vision_rotary_embedding_preprocessor.create_tensors(sequence_length, max_num_patches) preprocessed = [] presents = None @@ -342,32 +383,38 @@ def preprocess( else: labels[i, start : end + 1] = -100 kwargs[LanguageModelKwargs.labels] = labels - if batch.images is not None: - kwargs[VisionModelKwargs.images] = [ - img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) - for images in batch.images - for img in images - ] - kwargs[VisionModelKwargs.image_positions] = batch.image_positions - if self._config.vision_encoder: - self._vision_preprocessor.preprocess(kwargs) if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.preprocess(kwargs) if self._config.transformer.rotary.enabled: self._rotary_embedding_preprocessor.preprocess(kwargs) if not self._use_flash_attention: self._backup_attention_preprocessor.preprocess(kwargs) - preprocessed.append((tokens, kwargs)) + if batch.images is not None: + kwargs[VisionEncoderKwargs.images] = [ + [ + img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + for img in images + ] + for images in batch.images + ] + kwargs[VisionEncoderKwargs.image_positions] = batch.image_positions + if self._config.vision_encoder: + self._vision_preprocessor.preprocess(kwargs) + self._vision_rotary_embedding_preprocessor.preprocess(kwargs) + kwargs[LanguageModelKwargs.tokens] = tokens + preprocessed.append((kwargs[VisionEncoderKwargs.image_patches], kwargs)) + else: + preprocessed.append((tokens, kwargs)) return preprocessed @property def embedding(self) -> LanguageModelEmbedding: - return self.layers[self._config.vision_encoder is not None] + return self.layers[self._config.vision_encoder.transformer.num_layers + 2] @property def transformer_layers(self) -> list[TransformerLayer]: - return self.layers[(self._config.vision_encoder is not None) + 1 : -1] + return self.layers[self._config.vision_encoder.transformer.num_layers + 3 : -1] @property def model_head(self) -> LanguageModelHead: diff --git a/fast_llm/tools/cli.py b/fast_llm/tools/cli.py index e9df18ed2..b1f14ccc5 100644 --- a/fast_llm/tools/cli.py +++ b/fast_llm/tools/cli.py @@ -32,7 +32,6 @@ def fast_llm(args=None): sys.exit(1) except Exception: # noqa logger.critical(traceback.format_exc()) - sys.exit(1) if __name__ == "__main__": From 99ad5d9bda84eea74e377c8cc75f7184bb0dcc76 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 May 2025 18:34:24 +0000 Subject: [PATCH 14/97] patches and fixes --- fast_llm/layers/language_model/config.py | 10 ++++++---- fast_llm/layers/vision_encoder/config.py | 2 ++ fast_llm/layers/vision_encoder/encoder.py | 2 +- .../layers/vision_encoder/preprocessing.py | 20 ++++++++++++++++++- fast_llm/models/gpt/model.py | 10 +++++++--- 5 files changed, 35 insertions(+), 9 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 887952d7a..ef0e7a5cc 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -45,8 +45,9 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.core, ) - vision_encoder: None | VisionEncoderArchitectureConfig = Field( - default=None, + # TODO Soham: make this None by default. Need to figure out how to handle this in the config (see ) + vision_encoder: VisionEncoderArchitectureConfig = Field( + default_factory=VisionEncoderArchitectureConfig, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, ) @@ -131,8 +132,9 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_class = LanguageModelArchitectureConfig transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) - vision_encoder: None | VisionEncoderConfig = FieldUpdate( - default=None, + # TODO Soham: make this None by default. Need to figure out how to handle this in the config + vision_encoder: VisionEncoderConfig = FieldUpdate( + default_factory=VisionEncoderConfig, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, ) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 7c650bf93..283513727 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -7,6 +7,7 @@ class VisionEncoderDimNames: + in_channels = "vision_in_channels" out_channels = "vision_out_channels" adapter_size = "vision_adapter_size" patch_size = "vision_patch_size" @@ -62,6 +63,7 @@ class VisionEncoderKwargs: max_image_tokens = "max_image_tokens" patch_embeddings = "patch_embeddings" hidden_dims = "vit_hidden_dims" + image_patches_meta = "vit_image_patches_meta" # TODO Soham: do we need all of them? diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 9369037d4..ed6fbc92a 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -63,7 +63,7 @@ def forward( ) -> torch.Tensor: hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): - return TensorMeta.from_dims(hidden_dims) + return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) # we don't need images after this point # image_patches = kwargs.pop(VisionEncoderKwargs.image_patches) patch_embeddings = self.norm(self.conv(input_)) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index abae6f11a..c087cf6d0 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -4,13 +4,16 @@ import torch import torchvision.transforms.v2.functional as F -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.config import ( VisionEncoderArchitectureConfig, + VisionEncoderDimNames, VisionEncoderKwargs, + VisionTransformerDimNames, VisionTransformerKwargs, ) +from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -104,6 +107,21 @@ def __init__(self, config: VisionEncoderArchitectureConfig, tensor_space: Tensor self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + # kwargs[VisionEncoderDimNames] + kwargs[VisionEncoderKwargs.image_patches_meta] = TensorMeta.from_dims( + ( + TensorDim( + VisionTransformerDimNames.batch, + kwargs[TransformerKwargs.micro_batch_size] * kwargs[TransformerKwargs.sequence_q_dim].size, + ), + TensorDim(VisionEncoderDimNames.in_channels, 3), + TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + ), + dtype=self._distributed_config.training_dtype.torch, + ) + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: images = kwargs.get(VisionEncoderKwargs.images) im_height = kwargs.get(VisionEncoderKwargs.image_size) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 6aef273f6..5425a1e13 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -286,12 +286,16 @@ def preprocess_meta( self._position_embedding_preprocessor.preprocess_meta(kwargs) if self._config.transformer.rotary.enabled: self._rotary_embedding_preprocessor.preprocess_meta(kwargs) + if not self._use_flash_attention: + self._backup_attention_preprocessor.preprocess_meta(kwargs) if self._config.vision_encoder: + self._vision_preprocessor.preprocess_meta(kwargs) if self._config.vision_encoder.transformer.rotary.enabled: self._vision_rotary_embedding_preprocessor.preprocess_meta(kwargs) - if not self._use_flash_attention: - self._backup_attention_preprocessor.preprocess_meta(kwargs) - preprocessed_meta.append((tokens, kwargs)) + # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size + preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) + else: + preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta From bcb557aca291afcbb2e19969d2e7e1da16a93612 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 May 2025 20:44:40 +0000 Subject: [PATCH 15/97] fix dependency --- Dockerfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index b8e1f8887..149a498e0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,8 +3,7 @@ FROM nvcr.io/nvidia/pytorch:24.11-py3 # Install dependencies. RUN apt-get update \ - && apt-get install --no-install-recommends -y acl git-lfs \ - && apt-get install --no-install-recommends -y libtiff5-dev \ + && apt-get install --no-install-recommends -y acl git-lfs libtiff5-dev \ && rm -rf /var/lib/apt/lists/* \ && git lfs install From a6f5364d33c8d80ff46ea592612362fd03f85f30 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 May 2025 20:49:53 +0000 Subject: [PATCH 16/97] remove for testing --- Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 149a498e0..b7e42d4dc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,8 @@ FROM nvcr.io/nvidia/pytorch:24.11-py3 # Install dependencies. RUN apt-get update \ - && apt-get install --no-install-recommends -y acl git-lfs libtiff5-dev \ + && apt-get install --no-install-recommends -y acl git-lfs \ + # && apt-get install --no-install-recommends -y acl git-lfs libtiff5-dev \ && rm -rf /var/lib/apt/lists/* \ && git lfs install From 73b431b22d0c4b54d41d25a4dcf0738c5a1b1711 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 May 2025 21:57:17 +0000 Subject: [PATCH 17/97] mising --- .../layers/transformer/vision_transformer.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 fast_llm/layers/transformer/vision_transformer.py diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py new file mode 100644 index 000000000..94a9c70af --- /dev/null +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -0,0 +1,55 @@ +import torch + +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.tensor import TensorMeta + + +class VisionTransformerLayer(TransformerLayer): + """ + A vision transformer layer to encode image patches + """ + + def __init__( + self, + config: TransformerConfig, + tensor_space: TensorSpace, + layer_index: int, + return_input: bool = False, + ): + super().__init__(config, tensor_space, layer_index, return_input) + + hidden_dim = self._tensor_space.get_tensor_dim(VisionTransformerDimNames.hidden) + self.norm_1 = self._config.normalization.get_layer(hidden_dim) + self.norm_2 = self._config.normalization.get_layer(hidden_dim) + + self.norm_1 = self._config.peft.apply_other(self.norm_1) + self.norm_2 = self._config.peft.apply_other(self.norm_2) + + @property + def name(self) -> str: + return f"Vision transformer layer {self._layer_index}" + + def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): + dims = kwargs[VisionTransformerKwargs.hidden_dims] + if self._return_input: + dims = (TensorDim("stacked_input_output", 2),) + dims + return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) + + # TODO Soham: remove this since we only need to call the parent method + # def forward( + # self, + # input_: torch.Tensor, + # kwargs: dict[str, typing.Any], + # losses: dict[str, typing.Any] | None = None, + # metrics: dict[str, typing.Any] | None = None, + # ) -> torch.Tensor: + # if isinstance(input_, TensorMeta): + # return self._get_meta(input_, "output", kwargs) + # # Hack for now to compute the patch embeddings + # kwargs[VisionTransformerKwargs.patch_embeddings] = super().forward( + # kwargs.pop(VisionTransformerKwargs.patch_embeddings), kwargs, losses, metrics + # ) + # return input_ From 6d6567673450e3e97ae07879957a55875ec80caf Mon Sep 17 00:00:00 2001 From: root Date: Thu, 8 May 2025 06:11:54 +0000 Subject: [PATCH 18/97] fix --- fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/layers/multi_modal/embedding.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 0ba3f0e13..2f80ee77d 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -485,7 +485,7 @@ def _lazy_load(self): def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - if unshuffled_tokens := data.get("unshuffled_tokens") is not None: + if (unshuffled_tokens := data.get("unshuffled_tokens")) is not None: self._unshuffled_tokens = unshuffled_tokens else: self._unshuffled_tokens = data["unshuffled_epochs"] * data["dataset"]["tokens_per_epoch"] diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index a3abe7813..b7d79dd37 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -43,7 +43,8 @@ def forward( image_positions = kwargs.get(VisionEncoderKwargs.image_positions) tokens = kwargs.get(LanguageModelKwargs.tokens) # get text embeddings - embeddings = super()._forward(tokens, position_ids) + # TODO Soham: cloning to avoid pytorch complaint about in-place operation. Can we do better? + embeddings = super()._forward(tokens, position_ids).clone() image_idx = 0 for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 From 66e708170d98bd476e679fcbaf6fbf761b284388 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 9 May 2025 18:39:55 +0000 Subject: [PATCH 19/97] fixes --- fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/layers/language_model/config.py | 6 +----- fast_llm/layers/vision_encoder/config.py | 6 +++--- fast_llm/layers/vision_encoder/preprocessing.py | 7 ++++--- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index cb6d6c8d4..54564a212 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -278,7 +278,7 @@ def _sample(self) -> None: # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` if unshuffled_epochs > 0: token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( - document_sizes + image_token_sizes + document_sizes + image_token_sizes, offset=0, # TODO: Allowing for max 100% extra tokens for padding, is that enough? dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 451044207..ab5707804 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,12 +5,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl -<<<<<<< HEAD -from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig -from fast_llm.layers.vision_encoder.config import VisionEncoderArchitectureConfig, VisionEncoderConfig -======= from fast_llm.layers.transformer.config import TransformerConfig ->>>>>>> main +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig from fast_llm.utils import Assert diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index b15f90bdb..345b118ed 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,9 +1,9 @@ -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, config_class -from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig +from fast_llm.config import Config, Field, FieldHint, config_class +from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationConfig -from fast_llm.layers.transformer.config import TransformerArchitectureConfig, VisionTransformerConfig +from fast_llm.layers.transformer.config import VisionTransformerConfig class VisionEncoderDimNames: diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index c087cf6d0..7bd8a2aa1 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -4,10 +4,11 @@ import torch import torchvision.transforms.v2.functional as F +from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.config import ( - VisionEncoderArchitectureConfig, + VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs, VisionTransformerDimNames, @@ -101,8 +102,8 @@ def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tenso return ids[:, 0] -class VisionPreprocessor: - def __init__(self, config: VisionEncoderArchitectureConfig, tensor_space: TensorSpace): +class VisionPreprocessor(Preprocessor): + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config From 7f86a7f1889065ca06dade517d0cc69ef8b83215 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 12 May 2025 06:18:03 +0000 Subject: [PATCH 20/97] fix --- fast_llm/data/dataset/gpt/sampled.py | 19 ++- fast_llm/data/tokenizer.py | 153 ++++++++++++------ fast_llm/engine/schedule/config.py | 2 +- fast_llm/layers/language_model/config.py | 2 +- fast_llm/layers/transformer/config.py | 43 ++--- fast_llm/layers/vision_encoder/config.py | 1 - .../layers/vision_encoder/preprocessing.py | 2 +- fast_llm/models/gpt/model.py | 19 ++- 8 files changed, 153 insertions(+), 88 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 54564a212..f99a9d3ef 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -138,7 +138,9 @@ def _sample(self) -> None: image_token_sizes = torch.zeros_like(document_sizes).to(self._device) # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): - image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size) * (sizes[:, 1] // self._patch_size)) + image_token_sizes[i] = sum( + (sizes[:, 0] // self._parameters.patch_size) * (sizes[:, 1] // self._parameters.patch_size) + ) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() @@ -195,7 +197,7 @@ def _sample(self) -> None: "num_samples": self._parameters.num_samples, "unshuffled_epochs": unshuffled_epochs, "sequence_length": self._parameters.sequence_length, - "patch_size": self._patch_size, + "patch_size": self._parameters.patch_size, "truncate_documents": self._truncate_documents, "config": self._config.to_dict(), } @@ -405,12 +407,19 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size, image_lengths = self._indexed_dataset.get_document_size(document_index, self._patch_size) + document_size, image_lengths = self._indexed_dataset.get_document_size( + document_index, self._parameters.patch_size + ) image_sizes = [ get_num_patches( - *get_resize_dims(*image_length, self._image_size, self._image_size, self._patch_size), - self._patch_size, + *get_resize_dims( + *image_length, + self._parameters.image_size, + self._parameters.image_size, + self._parameters.patch_size, + ), + self._parameters.patch_size, ) for image_length in image_lengths ] diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 0e7d54709..10b8b2c64 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -42,64 +42,119 @@ def _tokenize(self, text: str, begin=True, end=True) -> list[int]: + ([self.eod_id] if end else []) ) - def tokenize(self, text, image_positions=None): + def tokenize(self, text: str, image_positions=None, char_spans=None) -> tuple[list[int], list[tuple[int, int]]]: + """ + Tokenize the input text and return the tokenized input_ids along with token spans. + """ + # if not image_positions and not char_spans: + # return self._tokenize(text), [], [] if not image_positions: - return self._tokenize(text), [], [] + image_positions = [] + if not char_spans: + char_spans = [] + image_idx = 0 char_pos = 0 token_ids = [] image_token_positions = [] beginning_of_text = True - while image_idx < len(image_positions): - if image_positions[image_idx] > len(text): - raise ValueError( - f"Image position {image_positions[image_idx]} is greater than text length {len(text)}" - ) - curr_text = text[char_pos : image_positions[image_idx]] - tokenized_text = self._tokenize( - curr_text, begin=beginning_of_text, end=image_positions[image_idx] >= len(text) - ) - beginning_of_text = False - token_ids.extend(tokenized_text) - image_token_positions = len(token_ids) - char_pos = image_positions[image_idx] - image_idx += 1 - if char_pos < len(text): - curr_text = text[char_pos:] - tokenized_text = self._tokenize(curr_text, begin=beginning_of_text, end=True) - token_ids.extend(tokenized_text) - return token_ids, image_token_positions - - def tokenize_with_spans( - self, text: str, char_spans: list[tuple[int, int]] - ) -> tuple[list[int], list[tuple[int, int]]]: - """ - Perform span-aware tokenization and return the tokenized input_ids along with token spans. - """ - input_ids = [] - token_spans = [] - char_pos = 0 - beginning_of_text = True for start, end in char_spans: + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + while image_position <= start: + tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) + beginning_of_text = False + image_token_positions.append(len(token_ids)) + token_ids.extend(tokenized_text) + image_idx += 1 + char_pos = image_position + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") if char_pos < start: - curr_text = text[char_pos:start] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) + self._tokenize(text[char_pos:start], begin=beginning_of_text, end=False) beginning_of_text = False - input_ids.extend(tokenized_text) - curr_text = text[start : end + 1] - if end >= len(text) - 1: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - else: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) - input_ids.extend(tokenized_text) - char_pos = end + 1 - if char_pos < len(text): - curr_text = text[char_pos:] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - input_ids.extend(tokenized_text) - return input_ids, token_spans + token_ids.extend(tokenized_text) + char_pos = start + len(token_ids) + span_length = 0 + while image_position <= end: + tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) + beginning_of_text = False + image_token_positions.append(len(token_ids)) + token_ids.extend(tokenized_text) + span_length += len(tokenized_text) + char_pos = image_position + image_idx += 1 + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + if char_pos < end: + if end >= len(text) - 1: + tokenized_text = self._tokenize(text[char_pos : end + 1], begin=beginning_of_text, end=True) + beginning_of_text = False + token_ids.extend(tokenized_text) + span_length += len(tokenized_text) + char_pos = end + 1 + else: + tokenized_text = self._tokenize(text[char_pos : end + 1], begin=beginning_of_text, end=False) + beginning_of_text = False + token_ids.extend(tokenized_text) + span_length += len(tokenized_text) + + # def tokenize(self, text, image_positions=None): + # if not image_positions: + # return self._tokenize(text), [], [] + # image_idx = 0 + # char_pos = 0 + # token_ids = [] + # image_token_positions = [] + # beginning_of_text = True + # while image_idx < len(image_positions): + # if image_positions[image_idx] > len(text): + # raise ValueError( + # f"Image position {image_positions[image_idx]} is greater than text length {len(text)}" + # ) + # curr_text = text[char_pos : image_positions[image_idx]] + # tokenized_text = self._tokenize( + # curr_text, begin=beginning_of_text, end=image_positions[image_idx] >= len(text) + # ) + # beginning_of_text = False + # token_ids.extend(tokenized_text) + # image_token_positions = len(token_ids) + # char_pos = image_positions[image_idx] + # image_idx += 1 + # if char_pos < len(text): + # curr_text = text[char_pos:] + # tokenized_text = self._tokenize(curr_text, begin=beginning_of_text, end=True) + # token_ids.extend(tokenized_text) + # return token_ids, image_token_positions + + # def tokenize_with_spans( + # self, text: str, char_spans: list[tuple[int, int]] + # ) -> tuple[list[int], list[tuple[int, int]]]: + # """ + # Perform span-aware tokenization and return the tokenized input_ids along with token spans. + # """ + # input_ids = [] + # token_spans = [] + # char_pos = 0 + # beginning_of_text = True + # for start, end in char_spans: + # if char_pos < start: + # curr_text = text[char_pos:start] + # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) + # beginning_of_text = False + # input_ids.extend(tokenized_text) + # curr_text = text[start : end + 1] + # if end >= len(text) - 1: + # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) + # else: + # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) + # beginning_of_text = False + # token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) + # input_ids.extend(tokenized_text) + # char_pos = end + 1 + # if char_pos < len(text): + # curr_text = text[char_pos:] + # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) + # input_ids.extend(tokenized_text) + # return input_ids, token_spans def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 10f87835b..48daf0e69 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -55,7 +55,7 @@ class BatchConfig(Config): desc="Patch size for each image token", hint=FieldHint.optional, ) - max_image_size: int | None = Field( + image_size: int | None = Field( default=None, desc="Maximum image height and width", hint=FieldHint.optional, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index ab5707804..78de218f1 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -167,7 +167,7 @@ def _validate(self) -> None: raise NotImplementedError("Multi-token prediction not supported with distillation.") def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - self.transformer.setup_tensor_space(tensor_space, type="vision" if self.vision_encoder is not None else None) + self.transformer.setup_tensor_space(tensor_space) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Embedding dimensions diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 55320a1b5..38dc9ec48 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -169,6 +169,7 @@ class VisionRotaryConfig(RotaryConfig): hint=FieldHint.feature, ) + class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" everywhere = "everywhere" @@ -668,59 +669,61 @@ def setup_tensor_space(self, tensor_space: TensorSpace, type: str | None = None) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.hidden, self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.hidden, self.hidden_size)) # Self-attention dimensions tensor_space.add_tensor_dim( head_groups := TensorDim( - TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - TransformerDimNames.group_heads, + transformer_dim_names.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(key_and_value := TensorDim(transformer_dim_names.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(transformer_dim_names.kv_channels, self.kv_channels)) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) + CompositeTensorDim(transformer_dim_names.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) ) # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(TransformerDimNames.mlp, self.ffn_hidden_size, tensor)) - tensor_space.add_tensor_dim(gate_and_up := TensorDim(TransformerDimNames.gate_and_up, 2 if self.gated else 1)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(TransformerDimNames.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_expert_mlp, (experts, mlp))) + tensor_space.add_tensor_dim(mlp := TensorDim(transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) + tensor_space.add_tensor_dim( + gate_and_up := TensorDim(transformer_dim_names.gate_and_up, 2 if self.gated else 1) + ) + tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp))) + tensor_space.add_tensor_dim(experts := TensorDim(transformer_dim_names.experts, self.num_experts)) + tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_expert_mlp, (experts, mlp))) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + CompositeTensorDim(transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) ) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.unshared_experts, self.num_unshared_experts)) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.top_experts, self.num_experts_per_token)) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.unshared_experts, self.num_unshared_experts)) # shared_experts if self.num_shared_experts: tensor_space.add_tensor_dim( - shared_experts := TensorDim(TransformerDimNames.shared_experts, self.num_shared_experts) + shared_experts := TensorDim(transformer_dim_names.shared_experts, self.num_shared_experts) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) + CompositeTensorDim(transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - TransformerDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) + transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) ) ) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 345b118ed..4dde28bee 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -176,4 +176,3 @@ def setup_tensor_space(self, tensor_space: TensorSpace): ) ) self.transformer.setup_tensor_space(tensor_space, type="vision") - super().setup_tensor_space(tensor_space) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 7bd8a2aa1..46bf0ab3f 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -123,7 +123,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: dtype=self._distributed_config.training_dtype.torch, ) - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: images = kwargs.get(VisionEncoderKwargs.images) im_height = kwargs.get(VisionEncoderKwargs.image_size) im_width = kwargs.get(VisionEncoderKwargs.image_size) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index c80c05f94..b832f1b04 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -77,14 +77,10 @@ def __init__( self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) if self._config.vision_encoder: - self._preprocessors.append( - VisionPreprocessor(self._config.vision_encoder, self._tensor_space) - ) + self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) if self._config.vision_encoder.transformer.rotary.enabled: self._preprocessors.append( - RotaryEmbeddingPreprocessor( - self._config.vision_encoder.transformer.rotary, self._tensor_space - ) + RotaryEmbeddingPreprocessor(self._config.vision_encoder.transformer.rotary, self._tensor_space) ) # self._vision_preprocessor = VisionPreprocessor(self._config.vision_encoder, self._tensor_space) # if self._config.vision_encoder.transformer.rotary.enabled: @@ -167,7 +163,7 @@ def preprocess_meta( micro_sequence_length = sequence_length if self._config.vision_encoder: - image_size = batch_meta.max_image_size + image_size = batch_meta.image_size image_mean = [ self._config.vision_encoder.image_normalization.mean_r, self._config.vision_encoder.image_normalization.mean_g, @@ -411,8 +407,6 @@ def preprocess( kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) - for preprocessor in self._preprocessors: - preprocessor.preprocess(tokens, kwargs) if batch.images is not None: kwargs[VisionEncoderKwargs.images] = [ [ @@ -423,7 +417,12 @@ def preprocess( ] kwargs[VisionEncoderKwargs.image_positions] = batch.image_positions kwargs[LanguageModelKwargs.tokens] = tokens - preprocessed.append((kwargs[VisionEncoderKwargs.image_patches], kwargs)) + + for preprocessor in self._preprocessors: + preprocessor.preprocess(tokens, kwargs) + image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) + if image_patches is not None: + preprocessed.append((image_patches, kwargs)) else: preprocessed.append((tokens, kwargs)) From 3a8a99d62c559f97f35d37dc4c2133d5e0a77a73 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 12 May 2025 15:23:33 +0000 Subject: [PATCH 21/97] more fixes after merge --- fast_llm/layers/transformer/preprocessing.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 01b953976..870463df2 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -15,7 +15,11 @@ TransformerKwargs, VisionTransformerConfig, ) -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import ( + VisionEncoderKwargs, + VisionTransformerDimNames, + VisionTransformerKwargs, +) from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -163,6 +167,7 @@ def get_2d_rotary_frequencies( return frequencies + class RotaryEmbeddingPreprocessor(Preprocessor): _scalar_dim: TensorDim _mask: torch.Tensor @@ -216,7 +221,11 @@ def _create_tensors(self, sequence_length: int, num_patches: None | int = None) ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) + if self._config.type == RotaryEmbeddingType.pixtral: + max_num_patches = kwargs[VisionEncoderKwargs.image_size] // kwargs[VisionEncoderKwargs.patch_size] + self._create_tensors(kwargs[TransformerKwargs.sequence_length], max_num_patches) + else: + self._create_tensors(kwargs[TransformerKwargs.sequence_length]) sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size if self._config.type == RotaryEmbeddingType.pixtral: From d16284ee0b96598e63e74c27b6b09e7e70d9d367 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 12 May 2025 19:32:32 +0000 Subject: [PATCH 22/97] conv cleanup --- fast_llm/data/dataset/gpt/memmap.py | 1 - fast_llm/data/preparator/gpt_memmap/config.py | 1 - fast_llm/layers/vision_encoder/config.py | 6 +++ fast_llm/layers/vision_encoder/encoder.py | 39 ++++++++++--------- fast_llm/models/gpt/conversion.py | 6 ++- setup.cfg | 1 - 6 files changed, 30 insertions(+), 24 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 8651b8fcd..5d3df5983 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -163,7 +163,6 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap - # TODO Soham: get images def get( self, idx: int, diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 89fe904cd..38d90ed42 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -180,7 +180,6 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.optional, ) - # TODO Soham: move tokenizer validation to MultiModalDataProcessor def _validate(self) -> None: assert self.tokenizer.path is not None if self.dataset.data_type is not None: diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 4dde28bee..be3fb38cb 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -144,6 +144,11 @@ class VisionEncoderConfig(BaseModelConfig): desc="Patch size for the image encoder.", hint=FieldHint.core, ) + conv_bias: bool = Field( + default=False, + desc="Whether to use bias in the convolutional layer.", + hint=FieldHint.optional, + ) patch_norm: NormalizationConfig = Field( default_factory=NormalizationConfig, desc="Configuration for the normalization layers applied to the image patches.", @@ -169,6 +174,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.adapter_size, self.adapter_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.patch_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.in_channels, 3)) # TODO Soham: add a check for presence of kv channels parameter (head_dim) tensor_space.add_tensor_dim( TensorDim( diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index ed6fbc92a..59212c58f 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -3,7 +3,7 @@ import torch from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ @@ -38,21 +38,25 @@ def generate_block_attention_mask(patch_embeds_list, tensor): class PatchConv(Layer): def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() - # TODO Soham: device=meta - with torch.device("meta"): - self.conv = torch.nn.Conv2d( - in_channels=3, - out_channels=config.transformer.hidden_size, - kernel_size=config.patch_size, - stride=config.patch_size, - bias=False, - dtype=tensor_space.distributed_config.training_dtype.torch, - ) - self.conv.weight = ParameterMeta.from_dims( - tuple(TensorDim(f"patch_conv_weight_{idx}", size) for idx, size in enumerate(self.conv.weight.shape)), - init_method=init_normal_(), + self._tensor_space = tensor_space + # TODO Soham: lr_scale + self.weight = ParameterMeta.from_dims( + ( + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels), + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.in_channels), + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size), + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size), + ), + init_method=init_normal_(), + ) + if config.conv_bias: + self.bias = ParameterMeta.from_dims( + (self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels),) ) + else: + self.bias = None self.norm = config.patch_norm.get_layer(tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels)) + self.stride = config.patch_size def forward( self, @@ -64,10 +68,7 @@ def forward( hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) - # we don't need images after this point - # image_patches = kwargs.pop(VisionEncoderKwargs.image_patches) - patch_embeddings = self.norm(self.conv(input_)) + input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) + patch_embeddings = self.norm(input_.flatten(1)) patch_embeddings = patch_embeddings.reshape(*(x.size for x in hidden_dims)) - # Hack to pass patch embeddings to the next layer - # kwargs[VisionEncoderKwargs.patch_embeddings] = patch_embeddings return patch_embeddings diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 4b08d564a..6aa3aaf1f 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -728,7 +728,9 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: return vision_transformer_converters def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: - patch_conv_converter = WeightConverter("layers.0.conv.weight", "vision_tower.patch_conv.weight") + patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] + if self._model.config.base_model.vision_encoder.conv_bias: + patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) layernorm_converters = [ WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), ] @@ -745,7 +747,7 @@ def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), ] - return [patch_conv_converter] + layernorm_converters + vision_transformer_converters + adapter_converters + return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters def _create_weight_converters(self) -> list[WeightConverter]: vision_encoder_converter = self._create_vision_encoder_weight_converters() diff --git a/setup.cfg b/setup.cfg index 3b5eea402..25f8af8bd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,7 +45,6 @@ OPTIONAL = requests>=2.32.3 tqdm>=4.66.3 # Vision Tools - # TODO Soham: use pillow-simd instead of pillow? webp>=0.4.0 pillow-simd>=9.5.0 torchvision>=0.20.0 From b3134aade1428641c47ea831d50701a43ee222ca Mon Sep 17 00:00:00 2001 From: root Date: Mon, 12 May 2025 19:35:17 +0000 Subject: [PATCH 23/97] more conv cleanup --- fast_llm/engine/multi_stage/stage_base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 4d9cd8488..fd50f55c5 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -162,9 +162,6 @@ def _replace(module: torch.nn.Module): nonlocal i for key in module._parameters: meta = typing.cast(ParameterMeta, module._parameters[key]) - # TODO Soham: clean way to get around check? - if meta is None: - continue module._parameters[key] = self.get_parameter_buffer(meta.tensor_name) i += 1 From c8aa66ec3793e222e0412afa0b142869f513e431 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 06:08:16 +0000 Subject: [PATCH 24/97] images + loss-masks --- fast_llm/data/dataset/gpt/memmap.py | 94 +++++++++++++------ .../data/preparator/gpt_memmap/prepare.py | 9 +- fast_llm/data/tokenizer.py | 81 ++++------------ 3 files changed, 92 insertions(+), 92 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 5d3df5983..73fb3903a 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -10,6 +10,7 @@ from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.layers.vision_encoder.preprocessing import get_num_patches, get_resize_dims from fast_llm.utils import Assert, div @@ -114,7 +115,6 @@ def _init( self._image_lengths = [] self._image_positions = [] images_seen = 0 - # TODO Soham: verify correctness, reshaping into width, height? for n_images in self._n_images: self._image_lengths.append( np.frombuffer( @@ -141,8 +141,6 @@ def _init( self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) - # TODO Soham: fix num_tokens to include images. Get total number of image pixels from index file and assign - # self._num_tokens = div(self._bin_buffer_mmap.size - n_pixels, np.dtype(self._dtype).itemsize) self._num_tokens = div(self._bin_buffer_mmap.size - self._num_pixels, np.dtype(self._dtype).itemsize) if num_pixels is not None: assert self._num_pixels == num_pixels @@ -163,21 +161,54 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap + # def get( + # self, + # idx: int, + # offset: int = 0, + # image_offset: int = 0, + # length: int | None = None, + # use_loss_masking_spans: bool = False, + # ): + # token_ids = np.frombuffer( + # self._bin_buffer, + # dtype=self._dtype, + # count=self._document_sizes[idx] - offset if length is None else length, + # offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, + # ) + # if self._has_images: + # image_positions = self._image_positions[idx] + # pixels = np.frombuffer( + # self._bin_buffer, + # dtype=np.dtype(np.uint8), + # count=self._image_lengths[idx].prod(initial=3), + # offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, + # ) + # images = [] + # start = 0 + # for image_length in self._image_lengths[idx]: + # n_pixels = image_length.prod(initial=3) + # images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) + # start += n_pixels + # return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) + def get( self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False, - ): - # TODO Soham: handle spans + patch_size: int | None = None, + image_size: int | None = None, + ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, dtype=self._dtype, count=self._document_sizes[idx] - offset if length is None else length, offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) + images = None if self._has_images: + # Truncations with images are not yet supported image_positions = self._image_positions[idx] pixels = np.frombuffer( self._bin_buffer, @@ -188,32 +219,39 @@ def get( images = [] start = 0 for image_length in self._image_lengths[idx]: - # TODO Soham: verify reshape dimension order n_pixels = image_length.prod(initial=3) images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) start += n_pixels - # TODO Soham: return loss_masking_spans - return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) - - # def get( - # self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False - # ) -> GPTSample: - # token_ids = np.frombuffer( - # self._bin_buffer, - # dtype=self._dtype, - # count=self._document_sizes[idx] - offset if length is None else length, - # offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, - # ) - # sample_spans = None - # if use_loss_masking_spans and self._spans is not None: - # sample_spans = self._spans[idx] - # # adjust the spans for the offset and length - # sample_spans = sample_spans[ - # (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) - # ] - # sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset - # sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - # return GPTSample(token_ids=token_ids, loss_masking_spans=sample_spans) + sample_spans = None + if use_loss_masking_spans and self._spans is not None: + sample_spans = self._spans[idx] + sample_spans = sample_spans[ + (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) + ] + sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset + sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset + if images: + image_idx = 0 + for span in sample_spans: + additional_tokens = 0 + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + while image_position >= span[0] and image_position <= span[1]: + image_tokens = get_num_patches( + get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), + patch_size, + ) + additional_tokens += image_tokens + image_idx += 1 + image_position = ( + image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + ) + span[1] += additional_tokens + return GPTSample( + token_ids=token_ids, + images=images, + image_positions=image_positions, + loss_masking_spans=sample_spans, + ) @property def name(self) -> str: diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 2a3778df6..b6d817730 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -50,21 +50,24 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ # np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) # for text in batch[self._config.dataset.field] # ] - input_ids, image_token_positions = map( + input_ids, token_spans, image_token_positions = map( list, zip( *[ ( np.array(input_ids, dtype=self._data_type.numpy), + np.array(token_spans, dtype=np.int32).reshape(-1, 2), np.array(image_token_positions, dtype=np.int32), ) - for input_ids, image_token_positions in [ + for input_ids, token_spans, image_token_positions in [ self._tokenizer.tokenize( text, + loss_mask_spans, im_char_positions, ) - for text, im_char_positions in zip( + for text, loss_mask_spans, im_char_positions in zip( batch[self._config.dataset.field], + batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), batch.get(self._config.dataset.image_positions, itertools.repeat(None)), ) ] diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 10b8b2c64..c44715d80 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -42,7 +42,7 @@ def _tokenize(self, text: str, begin=True, end=True) -> list[int]: + ([self.eod_id] if end else []) ) - def tokenize(self, text: str, image_positions=None, char_spans=None) -> tuple[list[int], list[tuple[int, int]]]: + def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[list[int], list[tuple[int, int]]]: """ Tokenize the input text and return the tokenized input_ids along with token spans. """ @@ -57,14 +57,15 @@ def tokenize(self, text: str, image_positions=None, char_spans=None) -> tuple[li char_pos = 0 token_ids = [] image_token_positions = [] + token_spans = [] beginning_of_text = True + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") for start, end in char_spans: - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") while image_position <= start: tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) beginning_of_text = False - image_token_positions.append(len(token_ids)) token_ids.extend(tokenized_text) + image_token_positions.append(len(token_ids)) image_idx += 1 char_pos = image_position image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") @@ -75,11 +76,12 @@ def tokenize(self, text: str, image_positions=None, char_spans=None) -> tuple[li char_pos = start len(token_ids) span_length = 0 + token_start = len(token_ids) while image_position <= end: tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) beginning_of_text = False - image_token_positions.append(len(token_ids)) token_ids.extend(tokenized_text) + image_token_positions.append(len(token_ids)) span_length += len(tokenized_text) char_pos = image_position image_idx += 1 @@ -96,65 +98,22 @@ def tokenize(self, text: str, image_positions=None, char_spans=None) -> tuple[li beginning_of_text = False token_ids.extend(tokenized_text) span_length += len(tokenized_text) + char_pos = end + 1 + token_spans.append((token_start, token_start + span_length - 1)) - # def tokenize(self, text, image_positions=None): - # if not image_positions: - # return self._tokenize(text), [], [] - # image_idx = 0 - # char_pos = 0 - # token_ids = [] - # image_token_positions = [] - # beginning_of_text = True - # while image_idx < len(image_positions): - # if image_positions[image_idx] > len(text): - # raise ValueError( - # f"Image position {image_positions[image_idx]} is greater than text length {len(text)}" - # ) - # curr_text = text[char_pos : image_positions[image_idx]] - # tokenized_text = self._tokenize( - # curr_text, begin=beginning_of_text, end=image_positions[image_idx] >= len(text) - # ) - # beginning_of_text = False - # token_ids.extend(tokenized_text) - # image_token_positions = len(token_ids) - # char_pos = image_positions[image_idx] - # image_idx += 1 - # if char_pos < len(text): - # curr_text = text[char_pos:] - # tokenized_text = self._tokenize(curr_text, begin=beginning_of_text, end=True) - # token_ids.extend(tokenized_text) - # return token_ids, image_token_positions + while image_position <= len(text): + image_position = image_positions[image_idx] + tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) + beginning_of_text = False + token_ids.extend(tokenized_text) + image_token_positions.append(len(token_ids)) + char_pos = image_position + image_idx += 1 + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + tokenized_text = self._tokenize(text[char_pos:], begin=beginning_of_text, end=True) + token_ids.extend(tokenized_text) - # def tokenize_with_spans( - # self, text: str, char_spans: list[tuple[int, int]] - # ) -> tuple[list[int], list[tuple[int, int]]]: - # """ - # Perform span-aware tokenization and return the tokenized input_ids along with token spans. - # """ - # input_ids = [] - # token_spans = [] - # char_pos = 0 - # beginning_of_text = True - # for start, end in char_spans: - # if char_pos < start: - # curr_text = text[char_pos:start] - # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - # beginning_of_text = False - # input_ids.extend(tokenized_text) - # curr_text = text[start : end + 1] - # if end >= len(text) - 1: - # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - # else: - # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - # beginning_of_text = False - # token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) - # input_ids.extend(tokenized_text) - # char_pos = end + 1 - # if char_pos < len(text): - # curr_text = text[char_pos:] - # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - # input_ids.extend(tokenized_text) - # return input_ids, token_spans + return token_ids, token_spans, image_token_positions def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) From 0baae59dc9c4d7401a98b253b03fb41323219910 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 06:21:39 +0000 Subject: [PATCH 25/97] minor fixes --- fast_llm/data/dataset/gpt/indexed.py | 4 ++-- fast_llm/data/dataset/gpt/memmap.py | 8 +------- fast_llm/data/dataset/gpt/sampled.py | 6 ++---- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 209c6e317..f8260413d 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -48,8 +48,8 @@ def get_document_sizes(self) -> np.ndarray: doc_sizes, im_sizes = self._dataset.get_document_sizes() return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] - def get_document_size(self, index: int, patch_size: list[int]) -> int: - return self._dataset.get_document_size(self._begin + index, patch_size) + def get_document_size(self, index: int) -> int: + return self._dataset.get_document_size(self._begin + index) @property def has_images(self) -> bool: diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 73fb3903a..af632d5b4 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -268,7 +268,6 @@ def num_tokens(self) -> int: def has_images(self) -> bool: return self._has_images - # TODO: image sizes def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: """ The size of each document in the dataset. @@ -277,12 +276,7 @@ def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: """ return self._document_sizes, self._image_lengths - def get_document_size(self, index: int, patch_size: list[int]) -> int: - # return self._document_sizes[index].item() + ( - # sum((h // patch_size[0]) * (w // patch_size[1]) for h, w in self._image_lengths[index]) - # if self._has_images - # else 0 - # ) + def get_document_size(self, index: int) -> int: return self._document_sizes[index].item(), self._image_lengths[index] if self._has_images else [] @classmethod diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f99a9d3ef..2a1df4430 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -407,9 +407,7 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size, image_lengths = self._indexed_dataset.get_document_size( - document_index, self._parameters.patch_size - ) + document_size, image_lengths = self._indexed_dataset.get_document_size(document_index) image_sizes = [ get_num_patches( @@ -582,7 +580,7 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ logger.info(f" > Sampling dataset {self._indexed_dataset.name} ...") - document_sizes = self._indexed_dataset.get_document_sizes() + document_sizes, _ = self._indexed_dataset.get_document_sizes() num_documents = len(document_sizes) num_tokens = document_sizes.sum() np_rng = np.random.RandomState(seed=self._config.seed) From 48855be3c9413298a38af9a94ee25eb56167815f Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 06:30:55 +0000 Subject: [PATCH 26/97] cleanup --- fast_llm/data/dataset/gpt/sampled.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 2a1df4430..a8ad574c1 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -135,12 +135,24 @@ def _sample(self) -> None: # TODO Soham: verify numpy correctness document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) - image_token_sizes = torch.zeros_like(document_sizes).to(self._device) + image_token_sizes = [] # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): - image_token_sizes[i] = sum( - (sizes[:, 0] // self._parameters.patch_size) * (sizes[:, 1] // self._parameters.patch_size) + image_token_sizes.append( + sum( + get_num_patches( + *get_resize_dims( + *size, + self._parameters.image_size, + self._parameters.image_size, + self._parameters.patch_size, + ), + self._parameters.patch_size, + ) + for size in sizes + ) ) + image_token_sizes = image_token_sizes.to(self._device) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() From f35e003d82b05e4787bc791928e1955262d4ba6a Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 06:34:37 +0000 Subject: [PATCH 27/97] cleanup --- fast_llm/data/dataset/gpt/sampled.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index a8ad574c1..ce92d1c1f 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -434,14 +434,15 @@ def __getitem__(self, index: int) -> typing.Any: for image_length in image_lengths ] image_tokens = sum(image_sizes) + document_size += image_tokens if not self._truncate_documents: - if document_size + image_tokens > self._parameters.sequence_length + 1: + if document_size > self._parameters.sequence_length + 1: # Document too long, ignore document_sampling_index += 1 continue tokens_in_sample = token_count % (self._parameters.sequence_length + 1) - if document_size + image_tokens + tokens_in_sample > self._parameters.sequence_length + 1: + if document_size + tokens_in_sample > self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample if token_count > token_start: @@ -454,7 +455,7 @@ def __getitem__(self, index: int) -> typing.Any: token_count += padding_size # Determine if the document belongs to the requested sample. - if token_count + document_size + image_tokens >= token_start: + if token_count + document_size >= token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) @@ -488,7 +489,7 @@ def __getitem__(self, index: int) -> typing.Any: # Go to the next document. document_sampling_index += 1 - token_count += document_size + image_tokens + token_count += document_size sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) From 4eb34cb0c4a4be901d079aaf0997e048035dbce6 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 06:41:02 +0000 Subject: [PATCH 28/97] cleanup --- fast_llm/data/dataset/gpt/sampled.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index ce92d1c1f..01459fa0a 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -96,7 +96,7 @@ def __init__( # TODO Soham: use something else for this check, introducing has_images for just this check might be unnecessary. if self._indexed_dataset.has_images and self._truncate_documents: raise RuntimeError( - "Truncating documents with images is not supported. Please turn off truncation to use images." + "Truncating documents with images is not yet supported. Please turn off truncation to use images." ) if sampling.cache_directory is None: @@ -132,11 +132,9 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. - # TODO Soham: verify numpy correctness document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) image_token_sizes = [] - # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): image_token_sizes.append( sum( @@ -476,7 +474,6 @@ def __getitem__(self, index: int) -> typing.Any: start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) images.append(sample.images) - # TODO Soham: add offsets for loss masking spans if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: span = np.clip( From ebb9e276a3b97b3571e26c346a986be67d8e87cc Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 15:13:09 +0000 Subject: [PATCH 29/97] cleanup --- fast_llm/data/dataset/gpt/indexed.py | 1 - .../layers/transformer/vision_transformer.py | 16 ---------------- 2 files changed, 17 deletions(-) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index f8260413d..6e9bef96d 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -11,7 +11,6 @@ class GPTIndexedDataset(IndexedDataset): - # TODO Soham: should we change this to include images? @abc.abstractmethod def get_document_sizes(self) -> np.ndarray: """ diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py index 94a9c70af..3588956c7 100644 --- a/fast_llm/layers/transformer/vision_transformer.py +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -37,19 +37,3 @@ def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) - - # TODO Soham: remove this since we only need to call the parent method - # def forward( - # self, - # input_: torch.Tensor, - # kwargs: dict[str, typing.Any], - # losses: dict[str, typing.Any] | None = None, - # metrics: dict[str, typing.Any] | None = None, - # ) -> torch.Tensor: - # if isinstance(input_, TensorMeta): - # return self._get_meta(input_, "output", kwargs) - # # Hack for now to compute the patch embeddings - # kwargs[VisionTransformerKwargs.patch_embeddings] = super().forward( - # kwargs.pop(VisionTransformerKwargs.patch_embeddings), kwargs, losses, metrics - # ) - # return input_ From 51098ef106b72a0528c71558aba1405993d96aa0 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 15:45:07 +0000 Subject: [PATCH 30/97] fix --- fast_llm/data/dataset/gpt/sampled.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 01459fa0a..fc2ddb6a0 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -150,7 +150,7 @@ def _sample(self) -> None: for size in sizes ) ) - image_token_sizes = image_token_sizes.to(self._device) + image_token_sizes = torch.tensor(image_token_sizes).to(self._device) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() @@ -417,7 +417,7 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size, image_lengths = self._indexed_dataset.get_document_size(document_index) + text_size, image_lengths = self._indexed_dataset.get_document_size(document_index) image_sizes = [ get_num_patches( @@ -432,7 +432,7 @@ def __getitem__(self, index: int) -> typing.Any: for image_length in image_lengths ] image_tokens = sum(image_sizes) - document_size += image_tokens + document_size = text_size + image_tokens if not self._truncate_documents: if document_size > self._parameters.sequence_length + 1: @@ -456,7 +456,7 @@ def __getitem__(self, index: int) -> typing.Any: if token_count + document_size >= token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) - token_end_index_in_document = min(token_end - token_count, document_size) + token_end_index_in_document = min(token_end - token_count, text_size) sample = self._indexed_dataset.get( document_index, offset=token_start_index_in_document, From 60b87fa766a77a183a4aa998ae914a2d22b1e195 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 16:39:46 +0000 Subject: [PATCH 31/97] prepare cleanup --- fast_llm/data/dataset/gpt/memmap.py | 2 + fast_llm/data/dataset/gpt/sampled.py | 1 - fast_llm/data/preparator/gpt_memmap/config.py | 5 +++ .../data/preparator/gpt_memmap/prepare.py | 44 ++++++++++--------- fast_llm/data/tokenizer.py | 2 - 5 files changed, 31 insertions(+), 23 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index af632d5b4..e1297b14a 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -108,6 +108,8 @@ def _init( + sum([x.nbytes for x in self._spans]) ) self._num_pixels = 0 + self._image_lengths = None + self._image_positions = None if self._has_images and self._version >= 4: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index fc2ddb6a0..91f8ca8fa 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -93,7 +93,6 @@ def __init__( self._truncate_documents = sampling.truncate_documents self._device = torch.device("cuda" if self._config.gpu else "cpu") - # TODO Soham: use something else for this check, introducing has_images for just this check might be unnecessary. if self._indexed_dataset.has_images and self._truncate_documents: raise RuntimeError( "Truncating documents with images is not yet supported. Please turn off truncation to use images." diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 38d90ed42..53f8e4688 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -173,6 +173,11 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Tokenizer configuration.", hint=FieldHint.feature, ) + image_patch_size: int = Field( + default=16, + desc="Patch size for images. This is used solely for computing the number of tokens in an image to get an even split.", + hint=FieldHint.optional, + ) splits: dict[str, float] | None = Field( default=None, desc="Split the output dataset into multiple ones (ex, train/valid/test) with the specified ratios." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index b6d817730..c5a1b339c 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -44,12 +44,7 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D def _process_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: pass - # TODO Soham: can we merged tokenize_batch and tokenize_batch_with_spans? def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - # input_ids = [ - # np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) - # for text in batch[self._config.dataset.field] - # ] input_ids, token_spans, image_token_positions = map( list, zip( @@ -85,6 +80,7 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ return { "input_ids": input_ids, "image_positions": image_token_positions, + "token_spans": token_spans, "num_tokens": num_tokens, "num_pixels": num_pixels, } @@ -282,12 +278,7 @@ def run(self) -> None: ) if self._config.dataset.field not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.") - if self._config.dataset.loss_masking_spans is not None: - if self._config.dataset.loss_masking_spans not in dataset.column_names: - raise ValueError(f"Dataset does not have spans field '{self._config.dataset.loss_masking_spans}'.") - tokenize_fn = self._tokenize_batch_with_spans - else: - tokenize_fn = self._tokenize_batch + tokenize_fn = self._tokenize_batch # Avoid decoding bytes to images unless asked if self._config.dataset.images is not None: dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) @@ -336,7 +327,7 @@ def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDa # Create the config file(s) on rank 0 if self._config.splits: for split_name, split_config in self._split_and_blend_dataset_configs( - dataset_configs, self._config.splits, self._config.output_path + dataset_configs, self._config.splits, self._config.output_path, self._config.image_patch_size ).items(): self._save_dataset_config( split_config, self._config.output_path / f"fast_llm_config_{split_name}.yaml" @@ -376,7 +367,11 @@ def _blend_dataset_configs(cls, dataset_configs: list[GPTMemmapDatasetConfig]) - @classmethod def _split_and_blend_dataset_configs( - cls, dataset_configs: list[GPTMemmapDatasetConfig], splits: dict[str, int | float], output_path: pathlib.Path + cls, + dataset_configs: list[GPTMemmapDatasetConfig], + splits: dict[str, int | float], + output_path: pathlib.Path, + image_patch_size: int, ) -> dict[str, GPTSampledDatasetConfig]: split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] @@ -406,11 +401,20 @@ def _split_and_blend_dataset_configs( # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() - # TODO Soham: handle pixels (could still work with number of tokens?) - sizes_cumsum = dataset.get_document_sizes()[0].cumsum() - Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) - begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) - end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) + text_sizes, image_sizes = dataset.get_document_sizes() + tokens_cumsum = text_sizes.cumsum() + Assert.eq(tokens_cumsum[-1], dataset_config.num_tokens) + if image_sizes: + num_pixels_cumsum = np.cumsum([x.prod(axis=1).sum() for x in image_sizes]) + # We use the patch sizes only for the purposes of even splitting and blending weights. + # We can always use a different patch size for training without any significant impact + # Unless the patch size used at training time is significantly different from the one used here + image_tokens_cumsum = num_pixels_cumsum // (image_patch_size**2) + tokens_cumsum += image_tokens_cumsum + num_pixels_cumsum = num_pixels_cumsum * 3 + Assert.eq(num_pixels_cumsum[-1], dataset_config.num_pixels) + begin_index = _get_nearest_split(tokens_cumsum, split_begin_in_dataset * tokens_cumsum[-1]) + end_index = _get_nearest_split(tokens_cumsum, split_end_in_dataset * tokens_cumsum[-1]) if end_index > begin_index: datasets_in_split.append( GPTDatasetSliceConfig.from_dict( @@ -423,8 +427,8 @@ def _split_and_blend_dataset_configs( ) ) dataset_tokens_in_split.append( - sizes_cumsum[end_index - 1].item() - - (sizes_cumsum[begin_index - 1].item() if begin_index > 0 else 0) + tokens_cumsum[end_index - 1].item() + - (tokens_cumsum[begin_index - 1].item() if begin_index > 0 else 0) ) # [else] None of the dataset belongs to the split. diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index c44715d80..0acb65e47 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -46,8 +46,6 @@ def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[li """ Tokenize the input text and return the tokenized input_ids along with token spans. """ - # if not image_positions and not char_spans: - # return self._tokenize(text), [], [] if not image_positions: image_positions = [] if not char_spans: From f8a5532f16df73794bbed793721a2b507bb8b280 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 22:21:27 +0000 Subject: [PATCH 32/97] slightly better conversion --- fast_llm/models/gpt/conversion.py | 328 +++++++++++++----------------- 1 file changed, 146 insertions(+), 182 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 6aa3aaf1f..4363c96c6 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -167,20 +167,16 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig def _create_weight_converters( self, - hf_base_prefix: str = "", - fast_llm_offset: int = 1, ) -> list[WeightConverter]: converters = [] num_layers = self._model.config.base_model.transformer.num_layers # Embeddings converters.append( - WeightConverter( - f"layers.{fast_llm_offset - 1}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight" - ) + WeightConverter(f"layers.{num_layers - 1}.word_embeddings_weight", f"model.embed_tokens.weight") ) - converters += self._create_lm_head_converters(hf_base_prefix, fast_llm_offset) + converters += self._create_lm_head_converters() for i in range(num_layers): converters += self._create_transformer_layer_converters(f"layers.{i+1}", f"model.layers.{i}") @@ -565,196 +561,111 @@ class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + # lm_converters = super()._create_config_converters() lm_converters = super()._create_config_converters() - lm_converters[-2] = ConstantExportParamConverter( - export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] - ) - # TODO Soham: cleaner way to get language model config converters - for converter in lm_converters: - if isinstance(converter, (RenameParamConverter, MappedConfigParamConverter, RopeScalingParamConverter)): - # Llava uses a different name for the text config - # if converter.fast_llm_names[0][0] == "transformer": + for idx, converter in enumerate(lm_converters): + if converter.export_names == (("model_type",),): + continue + elif converter.export_names == (("architectures",),): + ignore_index = idx + if converter.export_names: converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) - # if converter.fast_llm_names[0][0] == "transformer": - # converter.export_names[0] = ("text_config", *converter.export_names[0]) - return lm_converters + [ - # Multimodal adapter - RenameParamConverter( - fast_llm_names=(("vision_encoder", "adapter_size"),), - export_names=(("text_config", "hidden_size"),), - ), - # Image processing and conv layer - # TODO Soham: these options are not in the fast-llm model config. They're read from BatchConfig currently - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "encoder", "image_size"),), - # export_names=( - # ( - # "vision_config", - # "image_size", - # ), - # ), - # ), - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "encoder", "patch_size"),), - # export_names=( - # ( - # "vision_config", - # "patch_size", - # ), - # ), - # ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "patch_norm", "type"),), - fast_llm_value=NormalizationType.rms_norm, - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), - fast_llm_value=NormalizationType.rms_norm, - ), - # Vision Transformer - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), - export_names=( - ( - "vision_config", - "num_hidden_layers", + + return ( + lm_converters[:ignore_index] + + lm_converters[ignore_index + 1 :] + + [ + ConstantExportParamConverter( + export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] + ), + # Vision Adapter + RenameParamConverter( + fast_llm_names=(("vision_encoder", "adapter_size"),), + export_names=(("text_config", "hidden_size"),), + ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "patch_norm", "type"),), + fast_llm_value=NormalizationType.rms_norm, + ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), + fast_llm_value=NormalizationType.rms_norm, + ), + # Vision Transformer + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), + export_names=( + ( + "vision_config", + "num_hidden_layers", + ), ), ), - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), - export_names=( - ( - "vision_config", - "hidden_size", + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), + export_names=( + ( + "vision_config", + "hidden_size", + ), ), ), - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), - export_names=( - ( - "vision_config", - "num_attention_heads", + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), + export_names=( + ( + "vision_config", + "num_attention_heads", + ), ), ), - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), - export_names=( - ( - "vision_config", - "num_key_value_heads", + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), + export_names=( + ( + "vision_config", + "num_key_value_heads", + ), ), ), - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), - export_names=( - ( - "vision_config", - "intermediate_size", + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), + export_names=( + ( + "vision_config", + "intermediate_size", + ), ), ), - ), - MappedConfigParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), - export_names=( - ( - "vision_config", - "hidden_act", + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), + export_names=( + ( + "vision_config", + "hidden_act", + ), ), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True - ), - MappedConfigParamConverter( - fast_llm_names=(("vision_encoder", "adapter_activation_type"),), - export_names=(("projector_hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False - ), - # TODO Soham: add this config param for completeness? - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "encoder", "num_channels"),), - # export_names=( - # ( - # "vision_config", - # "num_channels", - # ), - # ), - # ), - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "transformer", "attention_dropout"),), - # export_names=( - # ( - # "vision_config", - # "attention_dropout", - # ), - # ), - # ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), - export_names=(("vision_config", "rope_theta"),), - ), - # TODO Soham: add this config param in vision encoder for completeness? - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "transformer", "initializer_range"),), - # export_names=(("vision_config", "initializer_range"),), - # ), - ] - - def _create_vision_transformer_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers - vision_transformer_converters = [] - for layer in range(num_layers): - # TODO Soham: check if args are correct - vision_transformer_converters.extend( - self._create_vision_transformer_layer_converters( - layer, - ignore_export=False, - hf_base_prefix="vision_tower.transformer.layers.", - fast_llm_offset=1, - type="vision", - ) - ) - - return vision_transformer_converters - - def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: - patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] - if self._model.config.base_model.vision_encoder.conv_bias: - patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) - layernorm_converters = [ - WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), - ] - if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: - layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) - - vision_transformer_converters = self._create_vision_transformer_converters() - offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 - adapter_converters = [ - WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), - WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), - # TODO Soham: add bias based on config - WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), - WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), - ] - - return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters - - def _create_weight_converters(self) -> list[WeightConverter]: - vision_encoder_converter = self._create_vision_encoder_weight_converters() - offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 - # TODO Soham: call _create_transformer_layer_converters with llava's custom offset - lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) - return vision_encoder_converter + lm_converters + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True + ), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "adapter_activation_type"),), + export_names=(("projector_hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), + export_names=(("vision_config", "rope_theta"),), + ), + ] + ) def _create_vision_transformer_layer_converters( self, @@ -850,6 +761,59 @@ def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix ), ] + def _create_vision_transformer_converters(self) -> list[WeightConverter]: + num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers + vision_transformer_converters = [] + for layer in range(num_layers): + # TODO Soham: check if args are correct + vision_transformer_converters.extend( + self._create_vision_transformer_layer_converters( + layer, + ignore_export=False, + hf_base_prefix="vision_tower.transformer.layers.", + fast_llm_offset=1, + type="vision", + ) + ) + + return vision_transformer_converters + + def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: + patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] + if self._model.config.base_model.vision_encoder.conv_bias: + patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) + layernorm_converters = [ + WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), + ] + if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: + layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) + + vision_transformer_converters = self._create_vision_transformer_converters() + offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 + adapter_converters = [ + WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), + WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), + # TODO Soham: add bias based on config + WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), + WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), + ] + + return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters + + def _create_weight_converters(self) -> list[WeightConverter]: + vision_encoder_converter = self._create_vision_encoder_weight_converters() + offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 + # Embeddings + lm_converters = [ + WeightConverter(f"layers.{offset - 1}.word_embeddings_weight", f"language_model.model.embed_tokens.weight") + ] + for i in range(self._model.config.base_model.transformer.num_layers): + lm_converters += self._create_transformer_layer_converters( + fast_llm_layer_name=f"layers.{i + offset}", hf_layer_name=f"language_model.model.layers.{i}" + ) + lm_converters += self._create_lm_head_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) + return vision_encoder_converter + lm_converters + class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat From 490651e4b074073e60e36910c6d6d0ed1fa46c21 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 May 2025 06:51:31 +0000 Subject: [PATCH 33/97] cleanup, sequence parallelism --- fast_llm/data/dataset/gpt/indexed.py | 2 +- fast_llm/data/dataset/gpt/memmap.py | 1 + fast_llm/layers/language_model/config.py | 2 +- fast_llm/layers/multi_modal/embedding.py | 96 +++++++++++++++---- fast_llm/layers/vision_encoder/config.py | 16 ++++ fast_llm/layers/vision_encoder/encoder.py | 8 +- .../layers/vision_encoder/preprocessing.py | 2 +- fast_llm/models/gpt/conversion.py | 10 +- fast_llm/models/gpt/model.py | 26 +++-- 9 files changed, 126 insertions(+), 37 deletions(-) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 6e9bef96d..cbe77ff0a 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -45,7 +45,7 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. doc_sizes, im_sizes = self._dataset.get_document_sizes() - return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] + return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else None def get_document_size(self, index: int) -> int: return self._dataset.get_document_size(self._begin + index) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index e1297b14a..1efc312e8 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -209,6 +209,7 @@ def get( offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) images = None + image_positions = None if self._has_images: # Truncations with images are not yet supported image_positions = self._image_positions[idx] diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 78de218f1..e46e104c2 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -175,7 +175,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: # TODO: Need both? tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab, self.vocab_size)) tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) - if self.vision_encoder is not None: + if self.vision_encoder.enabled: self.vision_encoder.setup_tensor_space(tensor_space) @property diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index b7d79dd37..52eaaac34 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -3,6 +3,7 @@ import torch from fast_llm.core.distributed import set_generator +from fast_llm.core.ops import gather, reduce_forward, split from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.layers.language_model.embedding import LanguageModelEmbedding @@ -10,6 +11,7 @@ from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs from fast_llm.layers.vision_encoder.preprocessing import get_num_patches from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert class MultiModalEmbedding(LanguageModelEmbedding): @@ -24,6 +26,78 @@ def __init__( ): super().__init__(config, tensor_space) + @torch.compile + def _forward( + self, + input_: torch.Tensor, + tokens: torch.Tensor, + position_ids: torch.Tensor | None, + image_positions: list[torch.Tensor] | None, + image_sizes: list[list[tuple[int, int]]] | None, + ) -> torch.Tensor: + """ + Forward pass for the multi-modal embedding layer. + Args: + input_: The input tensor (image embeddings). + tokens: The tokenized text input. + position_ids: The position ids for the text input. + image_positions: The positions of the image tokens in the input. + image_sizes: The sizes of the images in the input. + Returns: + The combined embeddings for text and images. + """ + Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) + group = self._tensor_space.distributed.tensor_group + if self._parallel_embeddings: + token_mask = (tokens >= self._vocab_start_index) * (tokens < self._vocab_end_index) + masked_tokens = (tokens - self._vocab_start_index) * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa + embeddings = reduce_forward(embeddings, group) + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + # TODO Soham: avoid cloning? + embeddings = embeddings.clone() + input_ = gather(input_, group, dim=0) + for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): + image_embedding_offset = 0 + for position, size in zip(positions, sizes): + num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) + embeddings[sample_idx, position : position + num_image_tokens] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens + ] + image_embedding_offset += num_image_tokens + if self._sequence_parallel: + embeddings = split(embeddings, group=group, dim=0) + else: + if self._sequence_parallel: + tokens = split(tokens, group=group, dim=0) + if self._use_absolute_position_embeddings: + position_ids = split(position_ids, group=group, dim=0) + # TODO Soham: get image positions for current split. Maybe in preprocessing? + # for positions in image_positions: + # if positions > self._distributed_config.tensor_rank + embeddings = torch.embedding(self.word_embeddings_weight, tokens) + # TODO Soham: avoid cloning? + embeddings = embeddings.clone() + for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): + image_embedding_offset = 0 + for position, size in zip(positions, sizes): + num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) + embeddings[sample_idx, position : position + num_image_tokens] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens + ] + image_embedding_offset += num_image_tokens + + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + with set_generator( + self._tensor_space.distributed.tp_generator + if self._sequence_parallel + else self._tensor_space.distributed.pp_generator + ): + embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + return embeddings.to(dtype=self._residual_dtype) + def forward( self, input_: torch.Tensor, @@ -42,25 +116,5 @@ def forward( image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes) image_positions = kwargs.get(VisionEncoderKwargs.image_positions) tokens = kwargs.get(LanguageModelKwargs.tokens) - # get text embeddings - # TODO Soham: cloning to avoid pytorch complaint about in-place operation. Can we do better? - embeddings = super()._forward(tokens, position_ids).clone() - image_idx = 0 - for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): - image_embedding_offset = 0 - for position, size in zip(positions, sizes): - num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) - embeddings[sample_idx, position : position + num_image_tokens] = input_[ - sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens - ] - image_embedding_offset += num_image_tokens - image_idx += 1 - - with set_generator( - self._tensor_space.distributed.tp_generator - if self._sequence_parallel - else self._tensor_space.distributed.pp_generator - ): - embeddings = torch.dropout(embeddings, self._dropout_p, self.training) - return embeddings.to(self._residual_dtype) + return self._forward(input_, tokens, position_ids, image_positions, image_sizes) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index be3fb38cb..e9bfd7d1c 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,3 +1,5 @@ +import enum + from fast_llm.config import Config, Field, FieldHint, config_class from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace @@ -130,10 +132,20 @@ class ImageNormalizationConfig(Config): ) +class VisionEncoderType(str, enum.Enum): + none = "none" + pixtral = "pixtral" + + @config_class() class VisionEncoderConfig(BaseModelConfig): _abstract = False + type: VisionEncoderType = Field( + default=VisionEncoderType.none, + desc="Type of the vision encoder. Choices: none, pixtral.", + hint=FieldHint.architecture, + ) transformer: VisionTransformerConfig = Field( default_factory=VisionTransformerConfig, desc="Configuration for the vision transformer architecture.", @@ -182,3 +194,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace): ) ) self.transformer.setup_tensor_space(tensor_space, type="vision") + + @property + def enabled(self) -> bool: + return self.type != VisionEncoderType.none diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 59212c58f..a67053d56 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -2,6 +2,7 @@ import torch +from fast_llm.core.ops import split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs @@ -39,6 +40,8 @@ class PatchConv(Layer): def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() self._tensor_space = tensor_space + self._distributed_config = tensor_space.distributed_config + self._sequence_parallel = self._distributed_config.sequence_tensor_parallel # TODO Soham: lr_scale self.weight = ParameterMeta.from_dims( ( @@ -68,7 +71,10 @@ def forward( hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) + group = self._tensor_space.distributed.tensor_group input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) patch_embeddings = self.norm(input_.flatten(1)) - patch_embeddings = patch_embeddings.reshape(*(x.size for x in hidden_dims)) + patch_embeddings = patch_embeddings.reshape(*(x.global_size for x in hidden_dims)) + if self._sequence_parallel: + patch_embeddings = split(patch_embeddings, group=group, dim=0) return patch_embeddings diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 46bf0ab3f..db726e24f 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -153,7 +153,6 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: cu_seqlens = [0] max_seqlen = -1 for imgs, sizes in zip(images, image_sizes): - # TODO Soham: should this be micro_sequence_length? # sum( # get_num_patches(*size, patch_size) for size in sizes # ) @@ -172,6 +171,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] ) ) + # TODO Soham: should this be micro_sequence_length? padding_size = kwargs[TransformerKwargs.sequence_length] - cu_seqlens[-1] if padding_size > max_seqlen: max_seqlen = padding_size diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 4363c96c6..ad4df7378 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -27,6 +27,7 @@ from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.layers.common.config import NormalizationType from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig +from fast_llm.layers.vision_encoder.config import VisionEncoderType from fast_llm.models.gpt.config import ( GPTBaseModelConfig, GPTModelConfig, @@ -172,9 +173,7 @@ def _create_weight_converters( num_layers = self._model.config.base_model.transformer.num_layers # Embeddings - converters.append( - WeightConverter(f"layers.{num_layers - 1}.word_embeddings_weight", f"model.embed_tokens.weight") - ) + converters.append(WeightConverter(f"layers.0.word_embeddings_weight", f"model.embed_tokens.weight")) converters += self._create_lm_head_converters() @@ -250,7 +249,7 @@ def _create_transformer_layer_converters( converters += self._get_mlp_converters(f"{fast_llm_layer_name}", f"{hf_layer_name}") return converters - def _create_lm_head_converters(self, hf_base_prefix: str, fast_llm_offset: int = 1) -> list[WeightConverter]: + def _create_lm_head_converters(self, hf_base_prefix: str = "", fast_llm_offset: int = 1) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm @@ -575,6 +574,9 @@ def _create_config_converters(cls) -> list[ParamConverter]: lm_converters[:ignore_index] + lm_converters[ignore_index + 1 :] + [ + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "type"),), fast_llm_value=VisionEncoderType.pixtral + ), ConstantExportParamConverter( export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] ), diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index b832f1b04..4219ac324 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -76,7 +76,7 @@ def __init__( else: self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) - if self._config.vision_encoder: + if self._config.vision_encoder.enabled: self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) if self._config.vision_encoder.transformer.rotary.enabled: self._preprocessors.append( @@ -129,7 +129,7 @@ def get_layers(self) -> list[Layer]: return [ *( [LanguageModelEmbedding(self._config, self._tensor_space)] - if self._config.vision_encoder is None + if not self._config.vision_encoder.enabled else self.get_vision_layers() ), *[ @@ -162,7 +162,7 @@ def preprocess_meta( sequence_length -= self._config.prediction_heads micro_sequence_length = sequence_length - if self._config.vision_encoder: + if self._config.vision_encoder.enabled: image_size = batch_meta.image_size image_mean = [ self._config.vision_encoder.image_normalization.mean_r, @@ -231,7 +231,7 @@ def preprocess_meta( if sequence_first else (batch_dim, hidden_sequence_q_dim, hidden_dim) ) - if self._config.vision_encoder: + if self._config.vision_encoder.enabled: vision_hidden_dim = self._tensor_space.get_tensor_dim(VisionTransformerDimNames.hidden) vision_hidden_dims = ( (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) @@ -298,7 +298,7 @@ def preprocess_meta( reference_kwargs[name] = reference_kwargs_ kwargs["reference_models"] = reference_kwargs - if self._config.vision_encoder: + if self._config.vision_encoder.enabled: # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) else: @@ -430,11 +430,17 @@ def preprocess( @property def embedding(self) -> LanguageModelEmbedding: - return self.layers[self._config.vision_encoder.transformer.num_layers + 2] + if self._config.vision_encoder.enabled: + return self.layers[self._config.vision_encoder.transformer.num_layers + 2] + else: + return self.layers[0] @property def transformer_layers(self) -> list[TransformerLayer]: - return self.layers[self._config.vision_encoder.transformer.num_layers + 3 : -1] + if self._config.vision_encoder.enabled: + return self.layers[self._config.vision_encoder.transformer.num_layers + 3 : -1] + else: + return self.layers[1:-1] @property def model_head(self) -> LanguageModelHead: @@ -449,7 +455,11 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: return { WORD_EMBEDDINGS_WEIGHT: ( self.embedding.word_embeddings_weight, - (self._config.vision_encoder is not None, *self.model_head_indices), + # TODO Soham: make embedding layer index a property + ( + self._config.vision_encoder.enabled * (self._config.vision_encoder.transformer.num_layers + 2), + *self.model_head_indices, + ), ) } elif self._config.prediction_heads > 1: From 24e1b83f15c0ec89cb866b5438283533218bc005 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 May 2025 07:19:49 +0000 Subject: [PATCH 34/97] fix conv --- fast_llm/layers/vision_encoder/encoder.py | 28 +++++++++++++++++------ 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index a67053d56..cff874793 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -61,6 +61,25 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self.norm = config.patch_norm.get_layer(tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels)) self.stride = config.patch_size + @torch.compile + def _forward( + self, + input_: torch.Tensor, + hidden_dims: tuple[TensorMeta, ...], + ): + group = self._tensor_space.distributed.tensor_group + input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) + patch_embeddings = self.norm(input_.flatten(1)) + batch_dim, sequence_q_dim, hidden_dim = hidden_dims + if self._sequence_parallel: + patch_embeddings = patch_embeddings.reshape( + sequence_q_dim.global_size, batch_dim.size, hidden_dim.global_size + ) + patch_embeddings = split(patch_embeddings, group=group, dim=0) + else: + patch_embeddings = patch_embeddings.reshape(batch_dim.size, sequence_q_dim.size, hidden_dim.size) + return patch_embeddings + def forward( self, input_: torch.Tensor, @@ -71,10 +90,5 @@ def forward( hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) - group = self._tensor_space.distributed.tensor_group - input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) - patch_embeddings = self.norm(input_.flatten(1)) - patch_embeddings = patch_embeddings.reshape(*(x.global_size for x in hidden_dims)) - if self._sequence_parallel: - patch_embeddings = split(patch_embeddings, group=group, dim=0) - return patch_embeddings + + return self._forward(input_, hidden_dims) From 0f1612a63c84b355c45b282cb10f174c6a9a7da3 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 May 2025 14:57:47 +0000 Subject: [PATCH 35/97] wip fixes --- fast_llm/data/dataset/gpt/indexed.py | 2 +- fast_llm/data/dataset/gpt/sampled.py | 53 ++++++++++++++++------------ 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index cbe77ff0a..56c4c8927 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -45,7 +45,7 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. doc_sizes, im_sizes = self._dataset.get_document_sizes() - return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else None + return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else [] def get_document_size(self, index: int) -> int: return self._dataset.get_document_size(self._begin + index) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 91f8ca8fa..9fbb218ee 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -133,23 +133,26 @@ def _sample(self) -> None: # Get the document sizes, the main information needed for sampling. document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) - image_token_sizes = [] - for i, sizes in enumerate(image_sizes): - image_token_sizes.append( - sum( - get_num_patches( - *get_resize_dims( - *size, - self._parameters.image_size, - self._parameters.image_size, + if image_sizes: + image_token_sizes = [] + for i, sizes in enumerate(image_sizes): + image_token_sizes.append( + sum( + get_num_patches( + *get_resize_dims( + *size, + self._parameters.image_size, + self._parameters.image_size, + self._parameters.patch_size, + ), self._parameters.patch_size, - ), - self._parameters.patch_size, + ) + for size in sizes ) - for size in sizes ) - ) - image_token_sizes = torch.tensor(image_token_sizes).to(self._device) + image_token_sizes = torch.tensor(image_token_sizes).to(self._device) + else: + image_token_sizes = torch.zeros_like(document_sizes) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() @@ -463,16 +466,20 @@ def __getitem__(self, index: int) -> typing.Any: use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) start_pos = 0 - for idx, im_position in enumerate(sample.image_positions): - # image_positions.append(im_positions + len(token_ids) + image_tokens_added) - # Add placeholders for image tokens - token_ids.append(sample.token_ids[start_pos:im_position]) - token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) - image_positions.append(im_position + len(token_ids) + image_tokens_added) - image_tokens_added += image_tokens - start_pos = im_position + if sample.image_positions: + for idx, im_position in enumerate(sample.image_positions): + # image_positions.append(im_positions + len(token_ids) + image_tokens_added) + # Add placeholders for image tokens + token_ids.append(sample.token_ids[start_pos:im_position]) + token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) + image_positions.append(im_position + len(token_ids) + image_tokens_added) + image_tokens_added += image_tokens + start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) - images.append(sample.images) + if sample.images: + images.append(sample.images) + else: + images.append([]) if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: span = np.clip( From 2e48c5f282e4e5b1e460e96efdc9e42b2c0743db Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 May 2025 22:26:10 +0000 Subject: [PATCH 36/97] fix --- fast_llm/layers/multi_modal/embedding.py | 11 ++++-- fast_llm/layers/vision_encoder/config.py | 1 + fast_llm/layers/vision_encoder/encoder.py | 34 +++++++------------ .../layers/vision_encoder/preprocessing.py | 5 ++- fast_llm/models/gpt/model.py | 3 ++ 5 files changed, 29 insertions(+), 25 deletions(-) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 52eaaac34..9a035d8fd 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -62,9 +62,14 @@ def _forward( image_embedding_offset = 0 for position, size in zip(positions, sizes): num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) - embeddings[sample_idx, position : position + num_image_tokens] = input_[ - sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens - ] + if self._sequence_parallel: + embeddings[position : position + num_image_tokens, sample_idx] = input_[ + image_embedding_offset : image_embedding_offset + num_image_tokens, sample_idx + ] + else: + embeddings[sample_idx, position : position + num_image_tokens] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens + ] image_embedding_offset += num_image_tokens if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index e9bfd7d1c..fdbe2726f 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -66,6 +66,7 @@ class VisionEncoderKwargs: patch_embeddings = "patch_embeddings" hidden_dims = "vit_hidden_dims" image_patches_meta = "vit_image_patches_meta" + out_channels = "vit_out_channels" # TODO Soham: do we need all of them? diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index cff874793..1df7f889c 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -5,6 +5,7 @@ from fast_llm.core.ops import split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ @@ -61,25 +62,6 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self.norm = config.patch_norm.get_layer(tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels)) self.stride = config.patch_size - @torch.compile - def _forward( - self, - input_: torch.Tensor, - hidden_dims: tuple[TensorMeta, ...], - ): - group = self._tensor_space.distributed.tensor_group - input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) - patch_embeddings = self.norm(input_.flatten(1)) - batch_dim, sequence_q_dim, hidden_dim = hidden_dims - if self._sequence_parallel: - patch_embeddings = patch_embeddings.reshape( - sequence_q_dim.global_size, batch_dim.size, hidden_dim.global_size - ) - patch_embeddings = split(patch_embeddings, group=group, dim=0) - else: - patch_embeddings = patch_embeddings.reshape(batch_dim.size, sequence_q_dim.size, hidden_dim.size) - return patch_embeddings - def forward( self, input_: torch.Tensor, @@ -90,5 +72,15 @@ def forward( hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) - - return self._forward(input_, hidden_dims) + micro_batch_size = kwargs[TransformerKwargs.micro_batch_size] + sequence_length = kwargs[TransformerKwargs.sequence_length] + out_channels = kwargs[VisionEncoderKwargs.out_channels] + reshape_dims = (micro_batch_size, sequence_length, out_channels) + group = self._tensor_space.distributed.tensor_group + input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) + patch_embeddings = self.norm(input_.flatten(1)) + patch_embeddings = patch_embeddings.view(reshape_dims) + if self._sequence_parallel: + patch_embeddings = patch_embeddings.permute(1, 0, 2).contiguous() + patch_embeddings = split(patch_embeddings, group=group, dim=0) + return patch_embeddings diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index db726e24f..7ebfb5228 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -152,16 +152,19 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: patch_position_ids = [] cu_seqlens = [0] max_seqlen = -1 + sequence_first = kwargs.get(TransformerKwargs.sequence_first) for imgs, sizes in zip(images, image_sizes): # sum( # get_num_patches(*size, patch_size) for size in sizes # ) seq_patches = [] + sample_cu_seqlen = 0 for image, size in zip(imgs, sizes): seqlen = get_num_patches(*size, patch_size) if seqlen > max_seqlen: max_seqlen = seqlen cu_seqlens.append(cu_seqlens[-1] + seqlen) + sample_cu_seqlen += seqlen seq_patches.append( torch.cat( [ @@ -172,7 +175,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ) ) # TODO Soham: should this be micro_sequence_length? - padding_size = kwargs[TransformerKwargs.sequence_length] - cu_seqlens[-1] + padding_size = kwargs[TransformerKwargs.sequence_length] - sample_cu_seqlen if padding_size > max_seqlen: max_seqlen = padding_size cu_seqlens.append(kwargs[TransformerKwargs.sequence_length]) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 4219ac324..9fff50bc7 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -185,6 +185,9 @@ def preprocess_meta( VisionEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( VisionEncoderDimNames.kv_channels ).size, + VisionEncoderKwargs.out_channels: self._tensor_space.get_tensor_dim( + VisionEncoderDimNames.out_channels + ).size, } else: vision_kwargs = {} From d529d37d881849afff40e57609ef4d10a916b742 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 17 May 2025 17:42:24 +0000 Subject: [PATCH 37/97] fix image position --- fast_llm/data/dataset/gpt/sampled.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 9fbb218ee..780b18878 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -412,6 +412,7 @@ def __getitem__(self, index: int) -> typing.Any: images = [] image_positions = [] image_tokens_added = 0 + text_tokens_added = 0 while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -471,11 +472,13 @@ def __getitem__(self, index: int) -> typing.Any: # image_positions.append(im_positions + len(token_ids) + image_tokens_added) # Add placeholders for image tokens token_ids.append(sample.token_ids[start_pos:im_position]) + text_tokens_added += len(token_ids[-1]) token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) image_positions.append(im_position + len(token_ids) + image_tokens_added) image_tokens_added += image_tokens start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) + text_tokens_added += len(token_ids[-1]) if sample.images: images.append(sample.images) else: From 3c22ddafc27e02a6f5af31ad7022a6d315cb3f03 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 17 May 2025 17:45:04 +0000 Subject: [PATCH 38/97] cleanup --- .../layers/transformer/vision_transformer.py | 25 ++----------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py index 3588956c7..72bd95ddd 100644 --- a/fast_llm/layers/transformer/vision_transformer.py +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -1,33 +1,12 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionTransformerKwargs from fast_llm.tensor import TensorMeta class VisionTransformerLayer(TransformerLayer): - """ - A vision transformer layer to encode image patches - """ - - def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index: int, - return_input: bool = False, - ): - super().__init__(config, tensor_space, layer_index, return_input) - - hidden_dim = self._tensor_space.get_tensor_dim(VisionTransformerDimNames.hidden) - self.norm_1 = self._config.normalization.get_layer(hidden_dim) - self.norm_2 = self._config.normalization.get_layer(hidden_dim) - - self.norm_1 = self._config.peft.apply_other(self.norm_1) - self.norm_2 = self._config.peft.apply_other(self.norm_2) - @property def name(self) -> str: return f"Vision transformer layer {self._layer_index}" From f0c8d830da9c4ea43df478a6cafbbb48bf910111 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 20 May 2025 07:05:01 +0000 Subject: [PATCH 39/97] cleanup --- fast_llm/layers/language_model/config.py | 1 - fast_llm/layers/transformer/attention.py | 17 +- fast_llm/layers/transformer/config.py | 259 ++++++++++++------ fast_llm/layers/transformer/mlp.py | 17 +- fast_llm/layers/transformer/preprocessing.py | 58 ++-- fast_llm/layers/transformer/transformer.py | 24 +- .../layers/transformer/vision_transformer.py | 12 +- fast_llm/layers/vision_encoder/config.py | 60 +--- fast_llm/models/gpt/model.py | 19 +- fast_llm/utils.py | 7 + 10 files changed, 239 insertions(+), 235 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index e46e104c2..cdb27d9ef 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -46,7 +46,6 @@ class LanguageModelBaseConfig(BaseModelConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - # TODO Soham: make this None by default. Need to figure out how to handle this in the config (see ) vision_encoder: VisionEncoderConfig = Field( default_factory=VisionEncoderConfig, desc="Configuration for the vision encoder that transforms images into embeddings.", diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index b16f17405..3180b6cb8 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -9,14 +9,7 @@ from fast_llm.functional.rotary import apply_rotary_embeddings from fast_llm.functional.triton.rotary import triton_rotary_autograd_ from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - TransformerSubLayerName, - VisionTransformerConfig, -) -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs, TransformerSubLayerName from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -66,12 +59,8 @@ def __init__( layer_index, ): super().__init__() - if isinstance(config, VisionTransformerConfig): - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - elif isinstance(config, TransformerConfig): - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config = config self._tensor_space = tensor_space # TODO Soham: fix assert diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 38dc9ec48..9a6bec07d 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -28,60 +28,109 @@ class RoutingType(str, enum.Enum): sinkhorn = "sinkhorn" -class TransformerDimNames: - # A set of common tensor dim names packed into a namespace. - # Input dimensions (variable) - # TODO: Does batch belong here? - batch = "batch" - # TODO: Distinguish micro-sequence? - sequence_q = "sequence_q" - sequence_q_tp = "sequence_q_tp" - sequence_k = "sequence_k" - hidden = "hidden" - # Self-attention dimensions - head_groups = "head_groups" - group_heads = "group_heads" - key_and_value = "key_value" - kv_channels = "kv_channels" - composite_heads = "composite_heads" - composite_query = "composite_query" - composite_key_value = "composite_key_value" - composite_dense = "composite_dense" - # MLP dimensions - mlp = "mlp" - gate_and_up = "gate_and_up" - composite_gated_mlp = "composite_gated_mlp" - experts = "experts" - top_experts = "top_experts" - shared_experts = "shared_experts" - unshared_experts = "unshared_experts" - composite_expert_mlp = "composite_expert_mlp" - composite_gated_expert_mlp = "composite_gated_expert_mlp" - composite_shared_expert_mlp = "composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" - - -class TransformerKwargs: - rotary_freq_q = "rotary_freq_q" - rotary_freq_k = "rotary_freq_k" - attention_mask = "attention_mask" - attention_mask_value = "attention_mask_value" - sequence_lengths = "sequence_lengths" - cu_seqlens_q = "cu_seqlens_q" - cu_seqlens_k = "cu_seqlens_k" - max_seqlen_q = "max_seqlen_q" - max_seqlen_k = "max_seqlen_k" - # TODO: Review these - presents = "presents" - past_key_values = "past_key_values" - sequence_first = "sequence_first" - hidden_dims = "hidden_dims" - sequence_q_dim = "sequence_q_dim" - sequence_k_dim = "sequence_k_dim" - sequence_length = "sequence_length" - micro_batch_size = "micro_batch_size" - # TODO: Move - grad_output = "grad_output" +class BaseTransformerDimNames: + _kwargs_attributes = { + "batch": "batch", + "sequence_q": "sequence_q", + "sequence_q_tp": "sequence_q_tp", + "sequence_k": "sequence_k", + "hidden": "hidden", + "head_groups": "head_groups", + "group_heads": "group_heads", + "key_and_value": "key_value", + "kv_channels": "kv_channels", + "composite_heads": "composite_heads", + "composite_query": "composite_query", + "composite_key_value": "composite_key_value", + "composite_dense": "composite_dense", + "mlp": "mlp", + "gate_and_up": "gate_and_up", + "composite_gated_mlp": "composite_gated_mlp", + "experts": "experts", + "top_experts": "top_experts", + "shared_experts": "shared_experts", + "unshared_experts": "unshared_experts", + "composite_expert_mlp": "composite_expert_mlp", + "composite_gated_expert_mlp": "composite_gated_expert_mlp", + "composite_shared_expert_mlp": "composite_shared_expert_mlp", + "composite_gated_shared_expert_mlp": "composite_gated_shared_expert_mlp", + } + + def __init_subclass__(cls, prefix="", **kwargs): + super().__init_subclass__(**kwargs) + cls._prefix = prefix + for attr, value in BaseTransformerDimNames._kwargs_attributes.items(): + setattr(cls, value, f"{cls._prefix}_{value}") + + +class TransformerDimNames(BaseTransformerDimNames, prefix=""): + pass + + +class VisionTransformerDimNames(BaseTransformerDimNames, prefix="image_encoder"): + pass + + +class BaseTransformerKwargs: + _kwargs_attributes = { + "rotary_freq_q": "rotary_freq_q", + "rotary_freq_k": "rotary_freq_k", + "attention_mask": "attention_mask", + "attention_mask_value": "attention_mask_value", + "sequence_lengths": "sequence_lengths", + "cu_seqlens_q": "cu_seqlens_q", + "cu_seqlens_k": "cu_seqlens_k", + "max_seqlen_q": "max_seqlen_q", + "max_seqlen_k": "max_seqlen_k", + "presents": "presents", + "past_key_values": "past_key_values", + "sequence_first": "sequence_first", + "hidden_dims": "hidden_dims", + "sequence_q_dim": "sequence_q_dim", + "sequence_k_dim": "sequence_k_dim", + "sequence_length": "sequence_length", + "micro_batch_size": "micro_batch_size", + "grad_output": "grad_output", + } + + _prefix = "" + + def __init_subclass__(cls, prefix="", **kwargs): + super().__init_subclass__(**kwargs) + cls._prefix = prefix + for attr, value in BaseTransformerKwargs._kwargs_attributes.items(): + setattr(cls, value, f"{cls._prefix}_{value}") + + +class TransformerKwargs(BaseTransformerKwargs, prefix=""): + pass + + +class VisionTransformerKwargs(BaseTransformerKwargs, prefix="image_encoder"): + patch_position_ids = "patch_position_ids" + + +# class TransformerKwargs: +# rotary_freq_q = "rotary_freq_q" +# rotary_freq_k = "rotary_freq_k" +# attention_mask = "attention_mask" +# attention_mask_value = "attention_mask_value" +# sequence_lengths = "sequence_lengths" +# cu_seqlens_q = "cu_seqlens_q" +# cu_seqlens_k = "cu_seqlens_k" +# max_seqlen_q = "max_seqlen_q" +# max_seqlen_k = "max_seqlen_k" +# # TODO: Review these +# presents = "presents" +# past_key_values = "past_key_values" +# sequence_first = "sequence_first" +# hidden_dims = "hidden_dims" +# sequence_q_dim = "sequence_q_dim" +# sequence_k_dim = "sequence_k_dim" +# sequence_length = "sequence_length" +# micro_batch_size = "micro_batch_size" +# # TODO: Move +# grad_output = "grad_output" class TransformerLossNames: @@ -98,6 +147,11 @@ class RotaryEmbeddingType(str, enum.Enum): pixtral = "pixtral" +class TransformerType(str, enum.Enum): + lm_decoder = "lm_decoder" + image_encoder = "image_encoder" + + @config_class() class RotaryConfig(BaseModelConfig): _abstract = False @@ -160,6 +214,14 @@ def _validate(self) -> None: if self.triton and not TritonConfig.TRITON_ENABLED: warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") + @property + def _transformer_dim_names(self) -> TransformerDimNames: + return TransformerDimNames + + @property + def _transformer_kwargs(self) -> TransformerKwargs: + return TransformerKwargs + @config_class() class VisionRotaryConfig(RotaryConfig): @@ -169,6 +231,14 @@ class VisionRotaryConfig(RotaryConfig): hint=FieldHint.feature, ) + @property + def _transformer_dim_names(self) -> VisionTransformerDimNames: + return VisionTransformerDimNames + + @property + def _transformer_kwargs(self) -> VisionTransformerKwargs: + return VisionTransformerKwargs + class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" @@ -259,6 +329,11 @@ def _validate(self) -> None: @config_class() class TransformerConfig(BaseModelConfig): _abstract = False + transformer_type: TransformerType = Field( + default=TransformerType.lm_decoder, + desc="Type of the transformer. Choices: lm_decoder, image_encoder.", + hint=FieldHint.architecture, + ) normalization: NormalizationConfig = Field( default_factory=NormalizationConfig, desc="Configuration for the normalization layers architecture.", @@ -658,72 +733,71 @@ def _from_dict( cls._handle_renamed_field(default, "triton_rotary", ("rotary", "triton")) return super()._from_dict(default, strict, flat) - def setup_tensor_space(self, tensor_space: TensorSpace, type: str | None = None) -> None: - if type == "vision": - # TODO Soham: better way to get around circular imports? Maybe add a type class variable to TransformerConfig? - from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames - - transformer_dim_names = VisionTransformerDimNames - else: - transformer_dim_names = TransformerDimNames + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.hidden, self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.hidden, self.hidden_size)) # Self-attention dimensions tensor_space.add_tensor_dim( head_groups := TensorDim( - transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + self.transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - transformer_dim_names.group_heads, + self.transformer_dim_names.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(transformer_dim_names.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(transformer_dim_names.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(key_and_value := TensorDim(self.transformer_dim_names.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(self.transformer_dim_names.kv_channels, self.kv_channels)) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_heads, (head_groups, group_heads)) + CompositeTensorDim(self.transformer_dim_names.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(self.transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels)) + CompositeTensorDim( + self.transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels) + ) ) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(self.transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) ) # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) + tensor_space.add_tensor_dim(mlp := TensorDim(self.transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) tensor_space.add_tensor_dim( - gate_and_up := TensorDim(transformer_dim_names.gate_and_up, 2 if self.gated else 1) + gate_and_up := TensorDim(self.transformer_dim_names.gate_and_up, 2 if self.gated else 1) ) - tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(transformer_dim_names.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_expert_mlp, (experts, mlp))) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + CompositeTensorDim(self.transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp)) ) - tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.unshared_experts, self.num_unshared_experts)) + tensor_space.add_tensor_dim(experts := TensorDim(self.transformer_dim_names.experts, self.num_experts)) + tensor_space.add_tensor_dim( + CompositeTensorDim(self.transformer_dim_names.composite_expert_mlp, (experts, mlp)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(self.transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + ) + tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.top_experts, self.num_experts_per_token)) + tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.unshared_experts, self.num_unshared_experts)) # shared_experts if self.num_shared_experts: tensor_space.add_tensor_dim( - shared_experts := TensorDim(transformer_dim_names.shared_experts, self.num_shared_experts) + shared_experts := TensorDim(self.transformer_dim_names.shared_experts, self.num_shared_experts) ) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) + CompositeTensorDim(self.transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) + self.transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) ) ) @@ -739,6 +813,14 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: return use_flash_attention + @property + def _transformer_kwargs(self) -> TransformerKwargs: + return TransformerKwargs + + @property + def _transformer_dim_names(self) -> TransformerDimNames: + return TransformerDimNames + @config_class() class VisionRotaryConfig(RotaryConfig): @@ -755,6 +837,11 @@ class VisionTransformerConfig(TransformerConfig): Configuration for the Vision Transformer (ViT) model. """ + transformer_type: TransformerType = FieldUpdate( + default=TransformerType.image_encoder, + desc="Type of the transformer. Choices: lm_decoder, image_encoder.", + hint=FieldHint.architecture, + ) causal: bool = FieldUpdate( default=False, desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", @@ -765,3 +852,11 @@ class VisionTransformerConfig(TransformerConfig): desc="Configuration for the rotary positional embeddings.", hint=FieldHint.feature, ) + + @property + def _transformer_kwargs(self) -> VisionTransformerKwargs: + return VisionTransformerKwargs + + @property + def _transformer_dim_names(self) -> VisionTransformerDimNames: + return VisionTransformerDimNames diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index dcea463a8..42393a413 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -8,14 +8,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - TransformerSubLayerName, - VisionTransformerConfig, -) -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerSubLayerName from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -25,12 +18,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s super().__init__() self._name = name - if isinstance(config, VisionTransformerConfig): - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - elif isinstance(config, TransformerConfig): - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs init_method_1 = init_normal_( std=config.init_method_std_mlp_1, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 870463df2..97c6c0f3f 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -7,19 +7,8 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.rotary import convert_rotary_complex_to_real -from fast_llm.layers.transformer.config import ( - RotaryConfig, - RotaryEmbeddingType, - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - VisionTransformerConfig, -) -from fast_llm.layers.vision_encoder.config import ( - VisionEncoderKwargs, - VisionTransformerDimNames, - VisionTransformerKwargs, -) +from fast_llm.layers.transformer.config import RotaryConfig, RotaryEmbeddingType, TransformerConfig, TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -178,19 +167,8 @@ def __init__( config: RotaryConfig, tensor_space: TensorSpace, ): - # if isinstance(config, TransformerConfig): - # self._transformer_dim_names = TransformerDimNames - # self._transformer_kwargs = TransformerKwargs - # elif isinstance(config, VisionTransformerConfig): - # self._transformer_dim_names = VisionTransformerDimNames - # self._transformer_kwargs = VisionTransformerKwargs - # TODO Soham: better way to do this? - if config.type == RotaryEmbeddingType.pixtral: - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - else: - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config = config assert self._config.enabled self._tensor_space = tensor_space @@ -273,12 +251,14 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, ): - if isinstance(config, VisionTransformerConfig): - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - elif isinstance(config, TransformerConfig): - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + # if isinstance(config, VisionTransformerConfig): + # self._transformer_dim_names = VisionTransformerDimNames + # self._transformer_kwargs = VisionTransformerKwargs + # elif isinstance(config, TransformerConfig): + # self._transformer_dim_names = TransformerDimNames + # self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config @@ -348,12 +328,14 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert self._config.do_use_flash_attention(self._distributed_config) - if isinstance(config, VisionTransformerConfig): - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - elif isinstance(config, TransformerConfig): - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + # if isinstance(config, VisionTransformerConfig): + # self._transformer_dim_names = VisionTransformerDimNames + # self._transformer_kwargs = VisionTransformerKwargs + # elif isinstance(config, TransformerConfig): + # self._transformer_dim_names = TransformerDimNames + # self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: """ diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 5590be322..8bd1394e1 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -9,15 +9,9 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - VisionTransformerConfig, -) +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -35,12 +29,8 @@ def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): super().__init__() - if isinstance(config, VisionTransformerConfig): - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - elif isinstance(config, TransformerConfig): - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config: TransformerConfig = config self._tensor_space: TensorSpace = tensor_space self._dropout_p: float = self._config.hidden_dropout @@ -80,6 +70,14 @@ def _bias_dropout_add( def name(self) -> str: return f"{self._name} {self._layer_index}" + @property + def _transformer_kwargs(self) -> TransformerKwargs: + return TransformerKwargs + + @property + def _transformer_dim_names(self) -> TransformerDimNames: + return TransformerDimNames + def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[self._transformer_kwargs.hidden_dims] if self._return_input: diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py index 72bd95ddd..7f39f9cff 100644 --- a/fast_llm/layers/transformer/vision_transformer.py +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -2,14 +2,20 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.config import VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.tensor import TensorMeta class VisionTransformerLayer(TransformerLayer): + _name: str = "Vision transformer layer" + + @property + def _transformer_kwargs(self) -> VisionTransformerKwargs: + return VisionTransformerKwargs + @property - def name(self) -> str: - return f"Vision transformer layer {self._layer_index}" + def _transformer_dim_names(self) -> VisionTransformerDimNames: + return VisionTransformerDimNames def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[VisionTransformerKwargs.hidden_dims] diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index fdbe2726f..70504901b 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -16,39 +16,6 @@ class VisionEncoderDimNames: kv_channels = "vision_kv_channels" -class VisionTransformerDimNames: - # A set of common tensor dim names packed into a namespace. - # Input dimensions (variable) - # TODO: Does batch belong here? - batch = "vit_batch" - # TODO: Distinguish micro-sequence? - sequence_q = "vit_sequence_q" - sequence_q_tp = "vit_sequence_q_tp" - sequence_k = "vit_sequence_k" - hidden = "vit_hidden" - # Self-attention dimensions - head_groups = "vit_head_groups" - group_heads = "vit_group_heads" - key_and_value = "vit_key_value" - kv_channels = "vit_kv_channels" - composite_heads = "vit_composite_heads" - composite_query = "vit_composite_query" - composite_key_value = "vit_composite_key_value" - composite_dense = "vit_composite_dense" - # MLP dimensions - mlp = "vit_mlp" - gate_and_up = "vit_gate_and_up" - composite_gated_mlp = "vit_composite_gated_mlp" - experts = "vit_experts" - top_experts = "vit_top_experts" - shared_experts = "vit_shared_experts" - unshared_experts = "vit_unshared_experts" - composite_expert_mlp = "vit_composite_expert_mlp" - composite_gated_expert_mlp = "vit_composite_gated_expert_mlp" - composite_shared_expert_mlp = "vit_composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "vit_composite_gated_shared_expert_mlp" - - class VisionEncoderKwargs: patch_size = "patch_size" images = "images" @@ -69,31 +36,6 @@ class VisionEncoderKwargs: out_channels = "vit_out_channels" -# TODO Soham: do we need all of them? -class VisionTransformerKwargs: - rotary_freq_q = "vit_rotary_freq_q" - rotary_freq_k = "vit_rotary_freq_k" - attention_mask = "vit_attention_mask" - attention_mask_value = "vit_attention_mask_value" - sequence_lengths = "vit_sequence_lengths" - cu_seqlens_q = "vit_cu_seqlens_q" - cu_seqlens_k = "vit_cu_seqlens_k" - max_seqlen_q = "vit_max_seqlen_q" - max_seqlen_k = "vit_max_seqlen_k" - # TODO: Review these - presents = "vit_presents" - past_key_values = "vit_past_key_values" - sequence_first = "vit_sequence_first" - hidden_dims = "vit_hidden_dims" - sequence_q_dim = "vit_sequence_q_dim" - sequence_k_dim = "vit_sequence_k_dim" - sequence_length = "vit_sequence_length" - micro_batch_size = "vit_micro_batch_size" - # TODO: Move - grad_output = "vit_grad_output" - patch_position_ids = "patch_position_ids" - - @config_class() class ImageNormalizationConfig(Config): mean_r: float = Field( @@ -194,7 +136,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace): VisionEncoderDimNames.kv_channels, self.transformer.hidden_size // self.transformer.num_attention_heads ) ) - self.transformer.setup_tensor_space(tensor_space, type="vision") + self.transformer.setup_tensor_space(tensor_space) @property def enabled(self) -> bool: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 9fff50bc7..c1d9df90f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -433,17 +433,18 @@ def preprocess( @property def embedding(self) -> LanguageModelEmbedding: - if self._config.vision_encoder.enabled: - return self.layers[self._config.vision_encoder.transformer.num_layers + 2] - else: - return self.layers[0] + return self.layers[self.embedding_layer_index] @property def transformer_layers(self) -> list[TransformerLayer]: + return self.layers[self.embedding_layer_index + 1 : -1] + + @property + def embedding_layer_index(self) -> int: if self._config.vision_encoder.enabled: - return self.layers[self._config.vision_encoder.transformer.num_layers + 3 : -1] + return self._config.vision_encoder.transformer.num_layers + 2 else: - return self.layers[1:-1] + return 0 @property def model_head(self) -> LanguageModelHead: @@ -458,11 +459,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: return { WORD_EMBEDDINGS_WEIGHT: ( self.embedding.word_embeddings_weight, - # TODO Soham: make embedding layer index a property - ( - self._config.vision_encoder.enabled * (self._config.vision_encoder.transformer.num_layers + 2), - *self.model_head_indices, - ), + (self.embedding_layer_index, *self.model_head_indices), ) } elif self._config.prediction_heads > 1: diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 51e0eee59..c5b7f07ae 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -336,3 +336,10 @@ def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple def check_equal_nested(config_a, config_b): if errors := compare_nested(config_a, config_b): raise ValueError("\n".join(errors)) + + +def prefix_class_vars(cls, prefix: str, base_cls: type): + for attr, value in vars(base_cls).items(): + if not attr.startswith("__") and isinstance(value, str) and not hasattr(cls, attr): + setattr(cls, attr, prefix + value) + return cls From ca33ee83b22bea5c45a946a13209572b6aa73680 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 21 May 2025 20:59:14 +0000 Subject: [PATCH 40/97] cleaner, extensible multimodal config --- fast_llm/layers/transformer/config.py | 44 +- fast_llm/layers/transformer/preprocessing.py | 12 - fast_llm/layers/transformer/transformer.py | 18 +- .../layers/transformer/vision_transformer.py | 14 +- fast_llm/layers/vision_encoder/config.py | 5 + .../layers/vision_encoder/preprocessing.py | 12 +- fast_llm/models/gpt/config.py | 30 + fast_llm/models/gpt/conversion.py | 984 ++++++++++++------ fast_llm/models/gpt/model.py | 3 +- 9 files changed, 740 insertions(+), 382 deletions(-) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 9a6bec07d..a634bc3c8 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -60,7 +60,7 @@ def __init_subclass__(cls, prefix="", **kwargs): super().__init_subclass__(**kwargs) cls._prefix = prefix for attr, value in BaseTransformerDimNames._kwargs_attributes.items(): - setattr(cls, value, f"{cls._prefix}_{value}") + setattr(cls, attr, f"{cls._prefix}_{value}") class TransformerDimNames(BaseTransformerDimNames, prefix=""): @@ -737,67 +737,69 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.hidden, self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.hidden, self.hidden_size)) # Self-attention dimensions tensor_space.add_tensor_dim( head_groups := TensorDim( - self.transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + self._transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - self.transformer_dim_names.group_heads, + self._transformer_dim_names.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(self.transformer_dim_names.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(self.transformer_dim_names.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(key_and_value := TensorDim(self._transformer_dim_names.key_and_value, 2)) + tensor_space.add_tensor_dim( + kv_channels := TensorDim(self._transformer_dim_names.kv_channels, self.kv_channels) + ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_heads, (head_groups, group_heads)) + CompositeTensorDim(self._transformer_dim_names.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(self._transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - self.transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels) + self._transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels) ) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(self._transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) ) # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(self.transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) + tensor_space.add_tensor_dim(mlp := TensorDim(self._transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) tensor_space.add_tensor_dim( - gate_and_up := TensorDim(self.transformer_dim_names.gate_and_up, 2 if self.gated else 1) + gate_and_up := TensorDim(self._transformer_dim_names.gate_and_up, 2 if self.gated else 1) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp)) + CompositeTensorDim(self._transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp)) ) - tensor_space.add_tensor_dim(experts := TensorDim(self.transformer_dim_names.experts, self.num_experts)) + tensor_space.add_tensor_dim(experts := TensorDim(self._transformer_dim_names.experts, self.num_experts)) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_expert_mlp, (experts, mlp)) + CompositeTensorDim(self._transformer_dim_names.composite_expert_mlp, (experts, mlp)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + CompositeTensorDim(self._transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) ) - tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.unshared_experts, self.num_unshared_experts)) + tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.top_experts, self.num_experts_per_token)) + tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.unshared_experts, self.num_unshared_experts)) # shared_experts if self.num_shared_experts: tensor_space.add_tensor_dim( - shared_experts := TensorDim(self.transformer_dim_names.shared_experts, self.num_shared_experts) + shared_experts := TensorDim(self._transformer_dim_names.shared_experts, self.num_shared_experts) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) + CompositeTensorDim(self._transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - self.transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) + self._transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) ) ) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 97c6c0f3f..af1a53f68 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -251,12 +251,6 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, ): - # if isinstance(config, VisionTransformerConfig): - # self._transformer_dim_names = VisionTransformerDimNames - # self._transformer_kwargs = VisionTransformerKwargs - # elif isinstance(config, TransformerConfig): - # self._transformer_dim_names = TransformerDimNames - # self._transformer_kwargs = TransformerKwargs self._transformer_dim_names = config._transformer_dim_names self._transformer_kwargs = config._transformer_kwargs self._config = config @@ -328,12 +322,6 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert self._config.do_use_flash_attention(self._distributed_config) - # if isinstance(config, VisionTransformerConfig): - # self._transformer_dim_names = VisionTransformerDimNames - # self._transformer_kwargs = VisionTransformerKwargs - # elif isinstance(config, TransformerConfig): - # self._transformer_dim_names = TransformerDimNames - # self._transformer_kwargs = TransformerKwargs self._transformer_dim_names = config._transformer_dim_names self._transformer_kwargs = config._transformer_kwargs diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 8bd1394e1..2c79883b3 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -9,7 +9,7 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage @@ -70,14 +70,6 @@ def _bias_dropout_add( def name(self) -> str: return f"{self._name} {self._layer_index}" - @property - def _transformer_kwargs(self) -> TransformerKwargs: - return TransformerKwargs - - @property - def _transformer_dim_names(self) -> TransformerDimNames: - return TransformerDimNames - def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[self._transformer_kwargs.hidden_dims] if self._return_input: @@ -157,3 +149,11 @@ def __init__( def _create_mixer(self): self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + + # @property + # def _transformer_kwargs(self) -> TransformerKwargs: + # return TransformerKwargs + + # @property + # def _transformer_dim_names(self) -> TransformerDimNames: + # return TransformerDimNames diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py index 7f39f9cff..c2cfe9f23 100644 --- a/fast_llm/layers/transformer/vision_transformer.py +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -1,21 +1,21 @@ import torch from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.layers.transformer.config import VisionTransformerKwargs from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.tensor import TensorMeta class VisionTransformerLayer(TransformerLayer): _name: str = "Vision transformer layer" - @property - def _transformer_kwargs(self) -> VisionTransformerKwargs: - return VisionTransformerKwargs + # @property + # def _transformer_kwargs(self) -> VisionTransformerKwargs: + # return VisionTransformerKwargs - @property - def _transformer_dim_names(self) -> VisionTransformerDimNames: - return VisionTransformerDimNames + # @property + # def _transformer_dim_names(self) -> VisionTransformerDimNames: + # return VisionTransformerDimNames def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[VisionTransformerKwargs.hidden_dims] diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 70504901b..6932c8fc0 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -119,6 +119,11 @@ class VisionEncoderConfig(BaseModelConfig): desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", hint=FieldHint.core, ) + adapter_bias: bool = Field( + default=True, + desc="Whether to use bias in the adapter linear layer.", + hint=FieldHint.optional, + ) image_normalization: ImageNormalizationConfig = Field( default_factory=ImageNormalizationConfig, desc="Configuration for the normalization layers applied to the image patches.", diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 7ebfb5228..5009123f0 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -6,14 +6,8 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.layers.vision_encoder.config import ( - VisionEncoderConfig, - VisionEncoderDimNames, - VisionEncoderKwargs, - VisionTransformerDimNames, - VisionTransformerKwargs, -) +from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -152,7 +146,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: patch_position_ids = [] cu_seqlens = [0] max_seqlen = -1 - sequence_first = kwargs.get(TransformerKwargs.sequence_first) + kwargs.get(TransformerKwargs.sequence_first) for imgs, sizes in zip(images, image_sizes): # sum( # get_num_patches(*size, patch_size) for size in sizes diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 162015768..d7d32221d 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -51,12 +51,22 @@ class MistralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mixtral" + class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mtp_llama" trust_remote_code: typing.ClassVar[bool] = True + class LlavaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "llava" + # Using default values for vision and text models. Can be overridden in the config + vision_name: typing.ClassVar[str] = "pixtral" + text_name: typing.ClassVar[str] = "mistral" + + +class PixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "pixtral" + @config_class() class GPTBatchConfig(BatchConfig): @@ -140,6 +150,7 @@ class GPTModelConfig(FastLLMModelConfig): MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, LlavaGPTHuggingfaceCheckpointFormat, + PixtralGPTHuggingfaceCheckpointFormat, ) @classmethod @@ -154,6 +165,25 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: return HuggingfaceGPTModelForCausalLM + @classmethod + def get_checkpoint_format(cls, format: type[CheckpointFormat]) -> type[CheckpointFormat]: + if isinstance(format, type) and issubclass(format, CheckpointFormat): + format_ = cls.get_checkpoint_format(format.name) + Assert.is_(format, format_) + return format_ + elif isinstance(format, dict): + for format_ in cls.checkpoint_formats: + if format_.name == format["name"]: + if (vision_name := format.get("vision_name")) is not None: + format_.vision_name = vision_name + if (text_name := format.get("text_name")) is not None: + format_.text_name = text_name + return format_ + for format_ in cls.checkpoint_formats: + if format_.name == format: + return format_ + raise ValueError(f"Checkpoint format {format} not supported for model {cls.model_name}") + @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index ad4df7378..0b0796ed2 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -6,8 +6,10 @@ import torch from transformers.configuration_utils import PretrainedConfig -from fast_llm.config import DEFAULT, MISSING -from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm import __version__ +from fast_llm.config import DEFAULT, MISSING, get_nested_dict_value, set_nested_dict_value +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointLoadMetadataConfig from fast_llm.engine.checkpoint.external import ( AutoStateDictCheckpointHandler, ConstantExportParamConverter, @@ -22,7 +24,7 @@ WeightConverter, ) from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler -from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.layers.common.config import NormalizationType @@ -36,6 +38,7 @@ MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, + PixtralGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) @@ -112,73 +115,70 @@ def import_weight( return (merged_weight.t().contiguous(),) -class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): - _model: GPTModel - _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig - """ - Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) - """ - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), - RenameParamConverter( - fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) - ), - MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), - export_names=(("hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), - export_names=(("num_hidden_layers",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), - export_names=(("hidden_size",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "num_attention_heads"),), - export_names=(("num_attention_heads",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "head_groups"),), - export_names=(("num_key_value_heads",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "ffn_hidden_size"),), - export_names=(("intermediate_size",),), - ), - RenameParamConverter( - fast_llm_names=(("vocab_size",),), - export_names=(("vocab_size",),), - ), - RenameParamConverter( - fast_llm_names=(("tie_word_embeddings",),), - export_names=(("tie_word_embeddings",),), - ), - ] - - @abc.abstractmethod - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - pass +class TransformerWeightConverterMixin: - def _create_weight_converters( + def _get_weight_and_bias_converters( self, + fast_llm_prefix: str | tuple[str, ...], + hf_prefix: str | tuple[str, ...], + use_bias: bool, + cls=WeightConverter, ) -> list[WeightConverter]: - converters = [] - num_layers = self._model.config.base_model.transformer.num_layers + if isinstance(fast_llm_prefix, str): + fast_llm_prefix = (fast_llm_prefix,) + if isinstance(hf_prefix, str): + hf_prefix = (hf_prefix,) + converters = [ + cls( + tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), + tuple(f"{prefix}.weight" for prefix in hf_prefix), + self._model.config.base_model, + ) + ] + if use_bias: + converters.append( + cls( + tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), + tuple(f"{prefix}.bias" for prefix in hf_prefix), + self._model.config.base_model, + ) + ) + return converters - # Embeddings - converters.append(WeightConverter(f"layers.0.word_embeddings_weight", f"model.embed_tokens.weight")) + def _create_lm_head_converters(self, hf_base_prefix: str = "", offset: int = 0) -> list[WeightConverter]: + num_layers = self._model.config.base_model.transformer.num_layers + prediction_heads = self._model.config.base_model.prediction_heads + norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + converters = [] - converters += self._create_lm_head_converters() + # Next-token prediction head + # Final norm + converters += self._get_weight_and_bias_converters( + f"layers.{num_layers + offset + 1}.final_norm", f"{hf_base_prefix}model.norm", norm_bias + ) + # Output weights + if self._model.config.base_model.tie_word_embeddings: + converters.append(IgnoreImportWeightConverter((), f"{hf_base_prefix}lm_head.weight")) + else: + converters.append( + WeightConverter(f"layers.{num_layers + offset + 1}.output_weights", f"{hf_base_prefix}lm_head.weight") + ) - for i in range(num_layers): - converters += self._create_transformer_layer_converters(f"layers.{i+1}", f"model.layers.{i}") + # MTP-heads > 0 are thrown away + # TODO Soham: handle offset with MTP + for i in range(1, prediction_heads): + logger.warning( + f"The model weights for the multi-token prediction head {i} are discarded during conversion." + ) + mtp_transformer_layer_index = num_layers - 1 + 2 * i + # MTP transformer layer + converters += self._create_transformer_layer_converters( + f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True + ) + # MTP output norm + converters += self._get_weight_and_bias_converters( + f"layers.{mtp_transformer_layer_index + 2}.final_norm", (), norm_bias, IgnoreExportWeightConverter + ) return converters @@ -249,71 +249,81 @@ def _create_transformer_layer_converters( converters += self._get_mlp_converters(f"{fast_llm_layer_name}", f"{hf_layer_name}") return converters - def _create_lm_head_converters(self, hf_base_prefix: str = "", fast_llm_offset: int = 1) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers - prediction_heads = self._model.config.base_model.prediction_heads - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm - converters = [] - # Next-token prediction head - # Final norm - converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + fast_llm_offset}.final_norm", f"{hf_base_prefix}model.norm", norm_bias - ) - # Output weights - if self._model.config.base_model.tie_word_embeddings: - converters.append(IgnoreImportWeightConverter((), f"{hf_base_prefix}lm_head.weight")) - else: - converters.append( - WeightConverter( - f"layers.{num_layers + fast_llm_offset}.output_weights", f"{hf_base_prefix}lm_head.weight" - ) - ) +class CommonHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): + _model: GPTModel + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + """ + Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) + """ - # MTP-heads > 0 are thrown away - # TODO Soham: handle offset with MTP - for i in range(1, prediction_heads): - logger.warning( - f"The model weights for the multi-token prediction head {i} are discarded during conversion." - ) - mtp_transformer_layer_index = num_layers - 1 + 2 * i - # MTP transformer layer - converters += self._create_transformer_layer_converters( - f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True - ) - # MTP output norm - converters += self._get_weight_and_bias_converters( - f"layers.{mtp_transformer_layer_index + 2}.final_norm", (), norm_bias, IgnoreExportWeightConverter - ) + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_layers"),), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "hidden_size"),), + export_names=(("hidden_size",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_attention_heads"),), + export_names=(("num_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "head_groups"),), + export_names=(("num_key_value_heads",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "ffn_hidden_size"),), + export_names=(("intermediate_size",),), + ), + RenameParamConverter( + fast_llm_names=(("vocab_size",),), + export_names=(("vocab_size",),), + ), + RenameParamConverter( + fast_llm_names=(("tie_word_embeddings",),), + export_names=(("tie_word_embeddings",),), + ), + ] - return converters + @abc.abstractmethod + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + pass - def _get_weight_and_bias_converters( + def _create_weight_converters( self, - fast_llm_prefix: str | tuple[str, ...], - hf_prefix: str | tuple[str, ...], - use_bias: bool, - cls=WeightConverter, + hf_base_prefix: str = "", + offset: int = 0, ) -> list[WeightConverter]: - if isinstance(fast_llm_prefix, str): - fast_llm_prefix = (fast_llm_prefix,) - if isinstance(hf_prefix, str): - hf_prefix = (hf_prefix,) - converters = [ - cls( - tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), - tuple(f"{prefix}.weight" for prefix in hf_prefix), - self._model.config.base_model, - ) - ] - if use_bias: - converters.append( - cls( - tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), - tuple(f"{prefix}.bias" for prefix in hf_prefix), - self._model.config.base_model, - ) + converters = [] + num_layers = self._model.config.base_model.transformer.num_layers + + # Embeddings + converters.append( + WeightConverter(f"layers.{offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight") + ) + + converters += self._create_lm_head_converters(hf_base_prefix, offset=offset) + + for i in range(num_layers): + converters += self._create_transformer_layer_converters( + f"layers.{i+offset+1}", f"{hf_base_prefix}model.layers.{i}" ) + return converters @@ -555,266 +565,592 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] -class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): - format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat +class PixtralHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - # lm_converters = super()._create_config_converters() - lm_converters = super()._create_config_converters() - for idx, converter in enumerate(lm_converters): - if converter.export_names == (("model_type",),): - continue - elif converter.export_names == (("architectures",),): - ignore_index = idx - if converter.export_names: - converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) - - return ( - lm_converters[:ignore_index] - + lm_converters[ignore_index + 1 :] - + [ - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "type"),), fast_llm_value=VisionEncoderType.pixtral - ), - ConstantExportParamConverter( - export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] - ), - # Vision Adapter - RenameParamConverter( - fast_llm_names=(("vision_encoder", "adapter_size"),), - export_names=(("text_config", "hidden_size"),), - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "patch_norm", "type"),), - fast_llm_value=NormalizationType.rms_norm, - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), - fast_llm_value=NormalizationType.rms_norm, - ), - # Vision Transformer - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), - export_names=( - ( - "vision_config", - "num_hidden_layers", - ), + return super()._create_config_converters() + [ + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["PixtralVisionModel"]), + ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value=VisionEncoderType.pixtral), + ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_layers", ), ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), - export_names=( - ( - "vision_config", - "hidden_size", - ), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "hidden_size", ), ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), - export_names=( - ( - "vision_config", - "num_attention_heads", - ), + export_names=(("hidden_size",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_attention_heads", ), ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), - export_names=( - ( - "vision_config", - "num_key_value_heads", - ), + export_names=(("num_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "head_groups", ), ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), - export_names=( - ( - "vision_config", - "intermediate_size", - ), + export_names=(("num_key_value_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "ffn_hidden_size", ), ), - MappedConfigParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), - export_names=( - ( - "vision_config", - "hidden_act", - ), + export_names=(("intermediate_size",),), + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "kv_channels", ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True - ), - MappedConfigParamConverter( - fast_llm_names=(("vision_encoder", "adapter_activation_type"),), - export_names=(("projector_hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), - export_names=(("vision_config", "rope_theta"),), + export_names=(("head_dim",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "rotary", + "theta", + ), ), - ] - ) + export_names=(("rope_theta",),), + ), + RenameParamConverter(fast_llm_names=(("patch_size",),), export_names=(("patch_size",),)), + ] + + def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + return [ + SplitWeightConverter( + f"{fast_llm_prefix}.mlp.layer_1.weight", + (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), + ), + MLPLayer2Converter( + f"{fast_llm_prefix}.mlp.layer_2.weight", + f"{hf_prefix}.feed_forward.down_proj.weight", + self._model.config.base_model, + ), + ] def _create_vision_transformer_layer_converters( - self, - i: int, - ignore_export: bool = False, - hf_base_prefix: str = "", - fast_llm_offset: int = 1, - type: str | None = None, + self, transformer_layer_index: int, fast_llm_offset: int = 1, hf_base_prefix: str = "" ) -> list[WeightConverter]: - if type is not None: - if type == "vision": - transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer - else: - transformer_config: TransformerConfig = self._model.config.base_model.transformer - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm - converters = [] - names_bias_cls = [ + # Vision transformer layer + transformer_config = self._model.config.base_model.vision_encoder.transformer + norm_bias: bool = transformer_config.normalization.type == NormalizationType.layer_norm + name_bias_cls = [ # Self-attn ( - f"layers.{i+fast_llm_offset}.self_attn.query", - f"vision_tower.transformer.layers.{i}.attention.q_proj", + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.query", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.q_proj", transformer_config.add_attn_qkv_bias, QueryWeightConverter, ), ( - f"layers.{i+fast_llm_offset}.self_attn.key_value", + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.key_value", ( - f"vision_tower.transformer.layers.{i}.attention.k_proj", - f"vision_tower.transformer.layers.{i}.attention.v_proj", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.k_proj", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.v_proj", ), transformer_config.add_attn_qkv_bias, KeyValueWeightConverter, ), ( - f"layers.{i+fast_llm_offset}.self_attn.dense", - f"vision_tower.transformer.layers.{i}.attention.o_proj", + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.dense", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.o_proj", transformer_config.add_attn_dense_bias, WeightConverter, ), # Norm ( - f"layers.{i+fast_llm_offset}.norm_1", - f"vision_tower.transformer.layers.{i}.attention_norm", + f"layers.{fast_llm_offset + transformer_layer_index}.norm_1", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention_norm", norm_bias, WeightConverter, ), ( - f"layers.{i+fast_llm_offset}.norm_2", - f"vision_tower.transformer.layers.{i}.ffn_norm", + f"layers.{fast_llm_offset + transformer_layer_index}.norm_2", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.ffn_norm", norm_bias, WeightConverter, ), ] - for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: + converters = [] + for fast_llm_prefix, hf_prefix, use_bias, cls in name_bias_cls: converters += self._get_weight_and_bias_converters( fast_llm_prefix, - () if ignore_export else hf_prefix, + hf_prefix, use_bias, - cls=IgnoreExportWeightConverter if ignore_export else cls, + cls, ) - # MLP - if ignore_export: - converters += self._get_weight_and_bias_converters( - f"layers.{i+fast_llm_offset}.mlp.layer_1", - (), - transformer_config.add_mlp_bias, - cls=IgnoreExportWeightConverter, - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+fast_llm_offset}.mlp.layer_2", - (), - transformer_config.add_mlp_bias, - cls=IgnoreExportWeightConverter, - ) - converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] - else: - converters += self._get_vision_transformer_mlp_converters( - f"layers.{i+fast_llm_offset}", f"vision_tower.transformer.layers.{i}" - ) + converters += self._get_transformer_mlp_converters( + f"layers.{fast_llm_offset + transformer_layer_index}", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}", + ) return converters - def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - return [ - SplitWeightConverter( - f"{fast_llm_prefix}.mlp.layer_1.weight", - (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), - ), - MLPLayer2Converter( - f"{fast_llm_prefix}.mlp.layer_2.weight", - f"{hf_prefix}.feed_forward.down_proj.weight", - self._model.config.base_model, - ), - ] + def _create_weight_converters(self, offset: int = 0, hf_base_prefix: str = "") -> list[WeightConverter]: + converters = [] + converters.append(WeightConverter(f"layers.{offset}.weight", f"{hf_base_prefix}patch_conv.weight")) + if self._model.config.base_model.vision_encoder.conv_bias: + converters.append(WeightConverter(f"layers.{offset}.bias", f"{hf_base_prefix}patch_conv.bias")) + converters.append(WeightConverter(f"layers.{offset}.norm.weight", f"{hf_base_prefix}ln_pre.weight")) + if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: + converters.append(WeightConverter(f"layers.{offset}.norm.bias", f"{hf_base_prefix}ln_pre.bias")) - def _create_vision_transformer_converters(self) -> list[WeightConverter]: num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers - vision_transformer_converters = [] - for layer in range(num_layers): - # TODO Soham: check if args are correct - vision_transformer_converters.extend( - self._create_vision_transformer_layer_converters( - layer, - ignore_export=False, - hf_base_prefix="vision_tower.transformer.layers.", - fast_llm_offset=1, - type="vision", - ) + for i in range(num_layers): + converters += self._create_vision_transformer_layer_converters(i, offset + 1, hf_base_prefix) + + converters.extend( + [ + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_1.weight", "multi_modal_projector.linear_1.weight" + ), + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_2.weight", "multi_modal_projector.linear_2.weight" + ), + ] + ) + if self._model.config.base_model.vision_encoder.adapter_bias: + converters.extend( + [ + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_1.bias", "multi_modal_projector.linear_1.bias" + ), + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_2.bias", "multi_modal_projector.linear_2.bias" + ), + ] ) - return vision_transformer_converters + return converters - def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: - patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] - if self._model.config.base_model.vision_encoder.conv_bias: - patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) - layernorm_converters = [ - WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), - ] - if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: - layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) - - vision_transformer_converters = self._create_vision_transformer_converters() - offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 - adapter_converters = [ - WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), - WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), - # TODO Soham: add bias based on config - WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), - WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), - ] + @property + def num_layers(self) -> int: + # +2 for projector and conv layers + return self._model.config.base_model.vision_encoder.transformer.num_layers + 2 - return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters - def _create_weight_converters(self) -> list[WeightConverter]: - vision_encoder_converter = self._create_vision_encoder_weight_converters() - offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 - # Embeddings - lm_converters = [ - WeightConverter(f"layers.{offset - 1}.word_embeddings_weight", f"language_model.model.embed_tokens.weight") - ] - for i in range(self._model.config.base_model.transformer.num_layers): - lm_converters += self._create_transformer_layer_converters( - fast_llm_layer_name=f"layers.{i + offset}", hf_layer_name=f"language_model.model.layers.{i}" +class LlavaHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + + @classmethod + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + cfg_dict = cls._load_config(config.path) + kwargs = {} + if "text_config" in cfg_dict: + text_kwargs = cls._import_config(cfg_dict["text_config"]) + kwargs.update(text_kwargs) + if "vision_config" in cfg_dict: + vision_kwargs = cls._import_config(cfg_dict["vision_config"]) + vision_kwargs = {tuple(["vision_encoder"] + list(key)): value for key, value in vision_kwargs.items()} + kwargs.update(vision_kwargs) + kwargs.update( + cls._import_config( + {key: value for key, value in cfg_dict.items() if key not in ("text_config", "vision_config")} ) - lm_converters += self._create_lm_head_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) - return vision_encoder_converter + lm_converters + ) + imported_model_config = cls._model_class.get_base_model_config_class().from_dict({}, kwargs) + return CheckpointMetadata( + fast_llm_version=__version__, + model=cls._model_class, + format=config.format, + config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}), + shards=["weights"], + ) + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] + ), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "adapter_activation_type"),), + export_names=(("projector_hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + ] + + @classmethod + def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: + handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) + kwargs = {} + for converter in handler_cls._create_config_converters(): + try: + values = () + for export_name in converter.export_names: + try: + value = get_nested_dict_value(config, export_name) + except KeyError: + value = MISSING + values = values + (value,) + values = converter.import_params(values) + for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True): + if value is MISSING: + raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}") + if fast_llm_name in kwargs: + raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}") + kwargs[fast_llm_name] = value + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return kwargs + + @classmethod + def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: + exported_config = {} + vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) + for converter in vision_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, ("vision_encoder",) + fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("vision_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in text_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("text_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return exported_config + + def _create_weight_converters(self): + vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.vision_name) + vision_handler = vision_handler_cls(self._model) + converters = vision_handler._create_weight_converters(hf_base_prefix="vision_tower.", offset=0) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.text_name) + text_handler = text_handler_cls(self._model) + converters.extend( + text_handler._create_weight_converters(hf_base_prefix="language_model.", offset=vision_handler.num_layers) + ) + return converters + + +# class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): +# format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat + +# @classmethod +# def _create_config_converters(cls) -> list[ParamConverter]: +# # lm_converters = super()._create_config_converters() +# lm_converters = super()._create_config_converters() +# for idx, converter in enumerate(lm_converters): +# if converter.export_names == (("model_type",),): +# continue +# elif converter.export_names == (("architectures",),): +# ignore_index = idx +# if converter.export_names: +# converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) + +# return ( +# lm_converters[:ignore_index] +# + lm_converters[ignore_index + 1 :] +# + [ +# ConstantImportParamConverter( +# fast_llm_names=(("vision_encoder", "type"),), fast_llm_value=VisionEncoderType.pixtral +# ), +# ConstantExportParamConverter( +# export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] +# ), +# # Vision Adapter +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "adapter_size"),), +# export_names=(("text_config", "hidden_size"),), +# ), +# ConstantImportParamConverter( +# fast_llm_names=(("vision_encoder", "patch_norm", "type"),), +# fast_llm_value=NormalizationType.rms_norm, +# ), +# ConstantImportParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), +# fast_llm_value=NormalizationType.rms_norm, +# ), +# # Vision Transformer +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), +# export_names=( +# ( +# "vision_config", +# "num_hidden_layers", +# ), +# ), +# ), +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), +# export_names=( +# ( +# "vision_config", +# "hidden_size", +# ), +# ), +# ), +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), +# export_names=( +# ( +# "vision_config", +# "num_attention_heads", +# ), +# ), +# ), +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), +# export_names=( +# ( +# "vision_config", +# "num_key_value_heads", +# ), +# ), +# ), +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), +# export_names=( +# ( +# "vision_config", +# "intermediate_size", +# ), +# ), +# ), +# MappedConfigParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), +# export_names=( +# ( +# "vision_config", +# "hidden_act", +# ), +# ), +# fast_llm_value=ActivationType.from_hf_name, +# export_value=lambda activation_type: activation_type.hf_name, +# ), +# ConstantImportParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True +# ), +# MappedConfigParamConverter( +# fast_llm_names=(("vision_encoder", "adapter_activation_type"),), +# export_names=(("projector_hidden_act",),), +# fast_llm_value=ActivationType.from_hf_name, +# export_value=lambda activation_type: activation_type.hf_name, +# ), +# ConstantImportParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False +# ), +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), +# export_names=(("vision_config", "rope_theta"),), +# ), +# ] +# ) + +# def _create_vision_transformer_layer_converters( +# self, +# i: int, +# ignore_export: bool = False, +# hf_base_prefix: str = "", +# fast_llm_offset: int = 1, +# type: str | None = None, +# ) -> list[WeightConverter]: +# if type is not None: +# if type == "vision": +# transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer +# else: +# transformer_config: TransformerConfig = self._model.config.base_model.transformer +# norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm +# converters = [] +# names_bias_cls = [ +# # Self-attn +# ( +# f"layers.{i+fast_llm_offset}.self_attn.query", +# f"vision_tower.transformer.layers.{i}.attention.q_proj", +# transformer_config.add_attn_qkv_bias, +# QueryWeightConverter, +# ), +# ( +# f"layers.{i+fast_llm_offset}.self_attn.key_value", +# ( +# f"vision_tower.transformer.layers.{i}.attention.k_proj", +# f"vision_tower.transformer.layers.{i}.attention.v_proj", +# ), +# transformer_config.add_attn_qkv_bias, +# KeyValueWeightConverter, +# ), +# ( +# f"layers.{i+fast_llm_offset}.self_attn.dense", +# f"vision_tower.transformer.layers.{i}.attention.o_proj", +# transformer_config.add_attn_dense_bias, +# WeightConverter, +# ), +# # Norm +# ( +# f"layers.{i+fast_llm_offset}.norm_1", +# f"vision_tower.transformer.layers.{i}.attention_norm", +# norm_bias, +# WeightConverter, +# ), +# ( +# f"layers.{i+fast_llm_offset}.norm_2", +# f"vision_tower.transformer.layers.{i}.ffn_norm", +# norm_bias, +# WeightConverter, +# ), +# ] +# for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: +# converters += self._get_weight_and_bias_converters( +# fast_llm_prefix, +# () if ignore_export else hf_prefix, +# use_bias, +# cls=IgnoreExportWeightConverter if ignore_export else cls, +# ) + +# # MLP +# if ignore_export: +# converters += self._get_weight_and_bias_converters( +# f"layers.{i+fast_llm_offset}.mlp.layer_1", +# (), +# transformer_config.add_mlp_bias, +# cls=IgnoreExportWeightConverter, +# ) +# converters += self._get_weight_and_bias_converters( +# f"layers.{i+fast_llm_offset}.mlp.layer_2", +# (), +# transformer_config.add_mlp_bias, +# cls=IgnoreExportWeightConverter, +# ) +# converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] +# else: +# converters += self._get_vision_transformer_mlp_converters( +# f"layers.{i+fast_llm_offset}", f"vision_tower.transformer.layers.{i}" +# ) +# return converters + +# def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: +# return [ +# SplitWeightConverter( +# f"{fast_llm_prefix}.mlp.layer_1.weight", +# (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), +# ), +# MLPLayer2Converter( +# f"{fast_llm_prefix}.mlp.layer_2.weight", +# f"{hf_prefix}.feed_forward.down_proj.weight", +# self._model.config.base_model, +# ), +# ] + +# def _create_vision_transformer_converters(self) -> list[WeightConverter]: +# num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers +# vision_transformer_converters = [] +# for layer in range(num_layers): +# # TODO Soham: check if args are correct +# vision_transformer_converters.extend( +# self._create_vision_transformer_layer_converters( +# layer, +# ignore_export=False, +# hf_base_prefix="vision_tower.transformer.layers.", +# fast_llm_offset=1, +# type="vision", +# ) +# ) + +# return vision_transformer_converters + +# def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: +# patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] +# if self._model.config.base_model.vision_encoder.conv_bias: +# patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) +# layernorm_converters = [ +# WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), +# ] +# if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: +# layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) + +# vision_transformer_converters = self._create_vision_transformer_converters() +# offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 +# adapter_converters = [ +# WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), +# WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), +# # TODO Soham: add bias based on config +# WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), +# WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), +# ] + +# return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters + +# def _create_weight_converters(self) -> list[WeightConverter]: +# vision_encoder_converter = self._create_vision_encoder_weight_converters() +# offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 +# # Embeddings +# lm_converters = [ +# WeightConverter(f"layers.{offset - 1}.word_embeddings_weight", f"language_model.model.embed_tokens.weight") +# ] +# for i in range(self._model.config.base_model.transformer.num_layers): +# lm_converters += self._create_transformer_layer_converters( +# fast_llm_layer_name=f"layers.{i + offset}", hf_layer_name=f"language_model.model.layers.{i}" +# ) +# lm_converters += self._create_lm_head_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) +# return vision_encoder_converter + lm_converters class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): @@ -950,4 +1286,6 @@ class AutoGPTHuggingfaceCheckpointHandler( MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, MTPLlamaGPTHuggingfaceCheckpointFormat.name: MTPLlamaHuggingfaceCheckpointHandler, LlavaGPTHuggingfaceCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler, + PixtralGPTHuggingfaceCheckpointFormat.name: PixtralHuggingfaceCheckpointHandler, + # MultiModalGPTHuggingfaceCheckpointFormat.name: MultiModalHuggingfaceCheckpointHandler } diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index c1d9df90f..72ff1b887 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -20,6 +20,7 @@ TransformerDimNames, TransformerKwargs, TransformerLossNames, + VisionTransformerDimNames, ) from fast_llm.layers.transformer.preprocessing import ( BackupAttentionPreprocessor, @@ -29,7 +30,7 @@ from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.layers.transformer.vision_transformer import VisionTransformerLayer from fast_llm.layers.vision_encoder.adapter import VisionAdapter -from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs, VisionTransformerDimNames +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs from fast_llm.layers.vision_encoder.encoder import PatchConv from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig From f3a4a74a086f5cb81da86195a00d6549cf66844b Mon Sep 17 00:00:00 2001 From: root Date: Wed, 21 May 2025 21:00:11 +0000 Subject: [PATCH 41/97] cleanup --- fast_llm/models/gpt/conversion.py | 262 ------------------------------ 1 file changed, 262 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 0b0796ed2..356525471 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -891,268 +891,6 @@ def _create_weight_converters(self): return converters -# class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): -# format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat - -# @classmethod -# def _create_config_converters(cls) -> list[ParamConverter]: -# # lm_converters = super()._create_config_converters() -# lm_converters = super()._create_config_converters() -# for idx, converter in enumerate(lm_converters): -# if converter.export_names == (("model_type",),): -# continue -# elif converter.export_names == (("architectures",),): -# ignore_index = idx -# if converter.export_names: -# converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) - -# return ( -# lm_converters[:ignore_index] -# + lm_converters[ignore_index + 1 :] -# + [ -# ConstantImportParamConverter( -# fast_llm_names=(("vision_encoder", "type"),), fast_llm_value=VisionEncoderType.pixtral -# ), -# ConstantExportParamConverter( -# export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] -# ), -# # Vision Adapter -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "adapter_size"),), -# export_names=(("text_config", "hidden_size"),), -# ), -# ConstantImportParamConverter( -# fast_llm_names=(("vision_encoder", "patch_norm", "type"),), -# fast_llm_value=NormalizationType.rms_norm, -# ), -# ConstantImportParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), -# fast_llm_value=NormalizationType.rms_norm, -# ), -# # Vision Transformer -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), -# export_names=( -# ( -# "vision_config", -# "num_hidden_layers", -# ), -# ), -# ), -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), -# export_names=( -# ( -# "vision_config", -# "hidden_size", -# ), -# ), -# ), -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), -# export_names=( -# ( -# "vision_config", -# "num_attention_heads", -# ), -# ), -# ), -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), -# export_names=( -# ( -# "vision_config", -# "num_key_value_heads", -# ), -# ), -# ), -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), -# export_names=( -# ( -# "vision_config", -# "intermediate_size", -# ), -# ), -# ), -# MappedConfigParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), -# export_names=( -# ( -# "vision_config", -# "hidden_act", -# ), -# ), -# fast_llm_value=ActivationType.from_hf_name, -# export_value=lambda activation_type: activation_type.hf_name, -# ), -# ConstantImportParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True -# ), -# MappedConfigParamConverter( -# fast_llm_names=(("vision_encoder", "adapter_activation_type"),), -# export_names=(("projector_hidden_act",),), -# fast_llm_value=ActivationType.from_hf_name, -# export_value=lambda activation_type: activation_type.hf_name, -# ), -# ConstantImportParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False -# ), -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), -# export_names=(("vision_config", "rope_theta"),), -# ), -# ] -# ) - -# def _create_vision_transformer_layer_converters( -# self, -# i: int, -# ignore_export: bool = False, -# hf_base_prefix: str = "", -# fast_llm_offset: int = 1, -# type: str | None = None, -# ) -> list[WeightConverter]: -# if type is not None: -# if type == "vision": -# transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer -# else: -# transformer_config: TransformerConfig = self._model.config.base_model.transformer -# norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm -# converters = [] -# names_bias_cls = [ -# # Self-attn -# ( -# f"layers.{i+fast_llm_offset}.self_attn.query", -# f"vision_tower.transformer.layers.{i}.attention.q_proj", -# transformer_config.add_attn_qkv_bias, -# QueryWeightConverter, -# ), -# ( -# f"layers.{i+fast_llm_offset}.self_attn.key_value", -# ( -# f"vision_tower.transformer.layers.{i}.attention.k_proj", -# f"vision_tower.transformer.layers.{i}.attention.v_proj", -# ), -# transformer_config.add_attn_qkv_bias, -# KeyValueWeightConverter, -# ), -# ( -# f"layers.{i+fast_llm_offset}.self_attn.dense", -# f"vision_tower.transformer.layers.{i}.attention.o_proj", -# transformer_config.add_attn_dense_bias, -# WeightConverter, -# ), -# # Norm -# ( -# f"layers.{i+fast_llm_offset}.norm_1", -# f"vision_tower.transformer.layers.{i}.attention_norm", -# norm_bias, -# WeightConverter, -# ), -# ( -# f"layers.{i+fast_llm_offset}.norm_2", -# f"vision_tower.transformer.layers.{i}.ffn_norm", -# norm_bias, -# WeightConverter, -# ), -# ] -# for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: -# converters += self._get_weight_and_bias_converters( -# fast_llm_prefix, -# () if ignore_export else hf_prefix, -# use_bias, -# cls=IgnoreExportWeightConverter if ignore_export else cls, -# ) - -# # MLP -# if ignore_export: -# converters += self._get_weight_and_bias_converters( -# f"layers.{i+fast_llm_offset}.mlp.layer_1", -# (), -# transformer_config.add_mlp_bias, -# cls=IgnoreExportWeightConverter, -# ) -# converters += self._get_weight_and_bias_converters( -# f"layers.{i+fast_llm_offset}.mlp.layer_2", -# (), -# transformer_config.add_mlp_bias, -# cls=IgnoreExportWeightConverter, -# ) -# converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] -# else: -# converters += self._get_vision_transformer_mlp_converters( -# f"layers.{i+fast_llm_offset}", f"vision_tower.transformer.layers.{i}" -# ) -# return converters - -# def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: -# return [ -# SplitWeightConverter( -# f"{fast_llm_prefix}.mlp.layer_1.weight", -# (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), -# ), -# MLPLayer2Converter( -# f"{fast_llm_prefix}.mlp.layer_2.weight", -# f"{hf_prefix}.feed_forward.down_proj.weight", -# self._model.config.base_model, -# ), -# ] - -# def _create_vision_transformer_converters(self) -> list[WeightConverter]: -# num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers -# vision_transformer_converters = [] -# for layer in range(num_layers): -# # TODO Soham: check if args are correct -# vision_transformer_converters.extend( -# self._create_vision_transformer_layer_converters( -# layer, -# ignore_export=False, -# hf_base_prefix="vision_tower.transformer.layers.", -# fast_llm_offset=1, -# type="vision", -# ) -# ) - -# return vision_transformer_converters - -# def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: -# patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] -# if self._model.config.base_model.vision_encoder.conv_bias: -# patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) -# layernorm_converters = [ -# WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), -# ] -# if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: -# layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) - -# vision_transformer_converters = self._create_vision_transformer_converters() -# offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 -# adapter_converters = [ -# WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), -# WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), -# # TODO Soham: add bias based on config -# WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), -# WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), -# ] - -# return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters - -# def _create_weight_converters(self) -> list[WeightConverter]: -# vision_encoder_converter = self._create_vision_encoder_weight_converters() -# offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 -# # Embeddings -# lm_converters = [ -# WeightConverter(f"layers.{offset - 1}.word_embeddings_weight", f"language_model.model.embed_tokens.weight") -# ] -# for i in range(self._model.config.base_model.transformer.num_layers): -# lm_converters += self._create_transformer_layer_converters( -# fast_llm_layer_name=f"layers.{i + offset}", hf_layer_name=f"language_model.model.layers.{i}" -# ) -# lm_converters += self._create_lm_head_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) -# return vision_encoder_converter + lm_converters - - class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat From 3b955b1600ba09c5b7844113b6fc55ee3916f261 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 21 May 2025 22:23:07 +0000 Subject: [PATCH 42/97] fixes for pixtral --- fast_llm/models/gpt/conversion.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 356525471..b7f9f7733 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -572,6 +572,12 @@ class PixtralHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, Huggi @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantImportParamConverter( + fast_llm_names=(("patch_norm", "type"),), fast_llm_value=NormalizationType.rms_norm + ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + ), ConstantExportParamConverter(export_names=(("architectures",),), export_value=["PixtralVisionModel"]), ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value=VisionEncoderType.pixtral), ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), @@ -646,6 +652,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: export_names=(("rope_theta",),), ), RenameParamConverter(fast_llm_names=(("patch_size",),), export_names=(("patch_size",),)), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), ] def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: @@ -803,6 +811,10 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_value=ActivationType.from_hf_name, export_value=lambda activation_type: activation_type.hf_name, ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "adapter_size"),), + export_names=(("projector_intermediate_size",),), + ), ] @classmethod From 49daf581600175c884265c00df4aaf04a9dc0f74 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 21 May 2025 23:27:52 +0000 Subject: [PATCH 43/97] model fixes --- fast_llm/layers/multi_modal/embedding.py | 2 -- fast_llm/layers/transformer/config.py | 11 +---------- fast_llm/layers/vision_encoder/encoder.py | 4 ++-- fast_llm/models/gpt/model.py | 5 +++-- 4 files changed, 6 insertions(+), 16 deletions(-) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 9a035d8fd..c67f82b41 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -55,7 +55,6 @@ def _forward( embeddings = reduce_forward(embeddings, group) if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) - # TODO Soham: avoid cloning? embeddings = embeddings.clone() input_ = gather(input_, group, dim=0) for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): @@ -82,7 +81,6 @@ def _forward( # for positions in image_positions: # if positions > self._distributed_config.tensor_rank embeddings = torch.embedding(self.word_embeddings_weight, tokens) - # TODO Soham: avoid cloning? embeddings = embeddings.clone() for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index a634bc3c8..49babb06b 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -99,7 +99,7 @@ def __init_subclass__(cls, prefix="", **kwargs): super().__init_subclass__(**kwargs) cls._prefix = prefix for attr, value in BaseTransformerKwargs._kwargs_attributes.items(): - setattr(cls, value, f"{cls._prefix}_{value}") + setattr(cls, value, f"{cls._prefix}_{value}" if cls._prefix else value) class TransformerKwargs(BaseTransformerKwargs, prefix=""): @@ -824,15 +824,6 @@ def _transformer_dim_names(self) -> TransformerDimNames: return TransformerDimNames -@config_class() -class VisionRotaryConfig(RotaryConfig): - type: RotaryEmbeddingType = Field( - default=RotaryEmbeddingType.pixtral, - desc="The type of rotary embedding to use. Choices: none, default, llama3, yarn, pixtral.", - hint=FieldHint.feature, - ) - - @config_class() class VisionTransformerConfig(TransformerConfig): """ diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 1df7f889c..20749af48 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -5,7 +5,7 @@ from fast_llm.core.ops import split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerKwargs from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ @@ -69,7 +69,7 @@ def forward( losses: dict[str, typing.Any] | None = None, metrics: dict | None = None, ) -> torch.Tensor: - hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] + hidden_dims = kwargs[VisionTransformerKwargs.hidden_dims] if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) micro_batch_size = kwargs[TransformerKwargs.micro_batch_size] diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 72ff1b887..cbce66f2e 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -21,6 +21,7 @@ TransformerKwargs, TransformerLossNames, VisionTransformerDimNames, + VisionTransformerKwargs, ) from fast_llm.layers.transformer.preprocessing import ( BackupAttentionPreprocessor, @@ -30,7 +31,7 @@ from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.layers.transformer.vision_transformer import VisionTransformerLayer from fast_llm.layers.vision_encoder.adapter import VisionAdapter -from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.layers.vision_encoder.encoder import PatchConv from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig @@ -244,7 +245,7 @@ def preprocess_meta( ) vision_kwargs.update( { - VisionEncoderKwargs.hidden_dims: vision_hidden_dims, + VisionTransformerKwargs.hidden_dims: vision_hidden_dims, } ) From b5ed9f4f6fdd6205225f730a136edb2f211c9f95 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 22 May 2025 19:07:00 +0000 Subject: [PATCH 44/97] more cleanup --- fast_llm/data/data/gpt/data.py | 4 +- fast_llm/data/preparator/gpt_memmap/config.py | 8 +- .../data/preparator/gpt_memmap/prepare.py | 13 +- fast_llm/data/tokenizer.py | 2 +- fast_llm/engine/schedule/config.py | 5 - fast_llm/functional/config.py | 8 +- fast_llm/layers/multi_modal/embedding.py | 1 - fast_llm/layers/transformer/attention.py | 18 +- fast_llm/layers/transformer/config.py | 23 --- fast_llm/layers/transformer/transformer.py | 8 - .../layers/transformer/vision_transformer.py | 8 - fast_llm/layers/vision_encoder/adapter.py | 1 - fast_llm/layers/vision_encoder/config.py | 21 ++- .../{encoder.py => patch_conv.py} | 6 +- .../layers/vision_encoder/preprocessing.py | 5 - fast_llm/models/gpt/conversion.py | 161 +++++++++--------- fast_llm/models/gpt/model.py | 11 +- fast_llm/models/gpt/trainer.py | 9 +- fast_llm/tools/cli.py | 1 + fast_llm/utils.py | 7 - 20 files changed, 129 insertions(+), 191 deletions(-) rename fast_llm/layers/vision_encoder/{encoder.py => patch_conv.py} (95%) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 4fcd42ae1..31a19e148 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -51,13 +51,13 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling batch_images.append([torch.from_numpy(image) for image in sample.images]) has_images = True else: - batch_images.append(None) + batch_images.append([]) batch_image_positions = [] for sample in batch: if sample.image_positions is not None: batch_image_positions.append(torch.from_numpy(sample.image_positions)) else: - batch_image_positions.append(None) + batch_image_positions.append([]) return GPTBatch( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 53f8e4688..2e9243807 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -151,12 +151,6 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.optional, valid=check_field(Assert.geq, 1), ) - tokenize_batch_size: int = Field( - default=1000, - desc="Batch size for tokenization.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 1), - ) saving_workers: int = Field( default=1, desc="Number of processes for saving the data.", @@ -170,7 +164,7 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): ) tokenizer: TokenizerConfig = Field( default_factory=TokenizerConfig, - desc="Tokenizer configuration.", + desc="Configuration for the tokenizer.", hint=FieldHint.feature, ) image_patch_size: int = Field( diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index c5a1b339c..fa46ee92e 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -138,20 +138,9 @@ def _document_generator(): if self._config.dataset.loss_masking_spans else None ), - # [np.array(Image.open(pathlib.Path(self._config.dataset.path) / path)) for path in item["image_paths"]] if self._config.dataset.image_paths else None, - # [np.array(im) for im in item["images"]] if self._config.dataset.images else None, item["images"] if self._config.dataset.images else None, item["image_positions"] if self._config.dataset.image_positions else None, ) - # if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: - # for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - # yield GPTSample( - # np.array(item["input_ids"], dtype=self._data_type.numpy), - # np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), - # ) - # else: - # for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - # yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) @@ -279,7 +268,7 @@ def run(self) -> None: if self._config.dataset.field not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.") tokenize_fn = self._tokenize_batch - # Avoid decoding bytes to images unless asked + # decoding bytes to images is slow and should be done only when needed if self._config.dataset.images is not None: dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 0acb65e47..1cbc1ec56 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -44,7 +44,7 @@ def _tokenize(self, text: str, begin=True, end=True) -> list[int]: def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[list[int], list[tuple[int, int]]]: """ - Tokenize the input text and return the tokenized input_ids along with token spans. + Tokenize the input text and return the tokenized input_ids and if provided, token spans and image positions. """ if not image_positions: image_positions = [] diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 48daf0e69..204abdf1c 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -50,11 +50,6 @@ class BatchConfig(Config): hint=FieldHint.setup, ) # Image inputs - patch_size: int | None = Field( - default=None, - desc="Patch size for each image token", - hint=FieldHint.optional, - ) image_size: int | None = Field( default=None, desc="Maximum image height and width", diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 233ea339d..480fa067e 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -40,6 +40,7 @@ class ActivationType(enum.StrEnum): """ gelu = "gelu" + gelu_pytorch_tanh = "gelu_pytorch_tanh" silu = "silu" relu = "relu" squared_relu = "squared_relu" @@ -67,7 +68,8 @@ def _set_activation_fn_map() -> None: global _ACTIVATION_FN_MAP _ACTIVATION_FN_MAP = { - ActivationType.gelu: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.gelu: torch.nn.functional.gelu, + ActivationType.gelu_pytorch_tanh: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), ActivationType.silu: torch.nn.functional.silu, ActivationType.relu: torch.nn.functional.relu, ActivationType.squared_relu: lambda x: torch.pow(torch.nn.functional.relu(x), 2), @@ -78,14 +80,14 @@ def _set_activation_fn_map() -> None: _ACTIVATION_FN_MAP: dict[ActivationType, typing.Callable[["torch.Tensor"], "torch.Tensor"]] = {} _ACTIVATION_HF_NAMES = { - ActivationType.gelu: "gelu_pytorch_tanh", + ActivationType.gelu: "gelu", + ActivationType.gelu_pytorch_tanh: "gelu_pytorch_tanh", ActivationType.silu: "silu", ActivationType.relu: "relu", ActivationType.squared_relu: "relu2", ActivationType.identity: "identity", } _ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()} -_ACTIVATION_HF_NAMES_INV["gelu"] = ActivationType.gelu MAX_DROPLESS_BLOCK_SIZE_ROW = 128 diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index c67f82b41..8c541e983 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -114,7 +114,6 @@ def forward( tensor_name="Embedding output", dtype=self._residual_dtype, ) - # image_embeddings = kwargs.pop(VisionEncoderKwargs.patch_embeddings) position_ids = kwargs.get(LanguageModelKwargs.position_ids) image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes) image_positions = kwargs.get(VisionEncoderKwargs.image_positions) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3180b6cb8..e88f64a30 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -191,7 +191,7 @@ def _get_meta( ) @property - def query_dims(self): + def _query_dims(self): return ( self._transformer_dim_names.batch, self._transformer_dim_names.sequence_q, @@ -200,7 +200,7 @@ def query_dims(self): ) @property - def kv_dims(self): + def _kv_dims(self): return ( self._transformer_dim_names.batch, self._transformer_dim_names.sequence_q, @@ -209,7 +209,7 @@ def kv_dims(self): ) @property - def context_dims(self): + def _context_dims(self): return ( self._transformer_dim_names.batch, self._transformer_dim_names.sequence_q, @@ -346,11 +346,11 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._config.rotary.enabled: if self._debug_transformer: - self._debug_log(query, "query_rotary_input", self.query_dims, kwargs) + self._debug_log(query, "query_rotary_input", self._query_dims, kwargs) self._debug_log( key, "key_rotary_input", - self.kv_dims, + self._kv_dims, kwargs, ) rotary_fn = triton_rotary_autograd_ if self._config.rotary.triton else apply_rotary_embeddings @@ -402,20 +402,20 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ ) if self._debug_transformer: - self._debug_log(query, "query", self.query_dims, kwargs) + self._debug_log(query, "query", self._query_dims, kwargs) self._debug_log( key, "key", - self.kv_dims, + self._kv_dims, kwargs, ) self._debug_log( value, "value", - self.kv_dims, + self._kv_dims, kwargs, ) - self._debug_log(input_, "context", self.context_dims, kwargs) + self._debug_log(input_, "context", self._context_dims, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 49babb06b..b8d153672 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -110,29 +110,6 @@ class VisionTransformerKwargs(BaseTransformerKwargs, prefix="image_encoder"): patch_position_ids = "patch_position_ids" -# class TransformerKwargs: -# rotary_freq_q = "rotary_freq_q" -# rotary_freq_k = "rotary_freq_k" -# attention_mask = "attention_mask" -# attention_mask_value = "attention_mask_value" -# sequence_lengths = "sequence_lengths" -# cu_seqlens_q = "cu_seqlens_q" -# cu_seqlens_k = "cu_seqlens_k" -# max_seqlen_q = "max_seqlen_q" -# max_seqlen_k = "max_seqlen_k" -# # TODO: Review these -# presents = "presents" -# past_key_values = "past_key_values" -# sequence_first = "sequence_first" -# hidden_dims = "hidden_dims" -# sequence_q_dim = "sequence_q_dim" -# sequence_k_dim = "sequence_k_dim" -# sequence_length = "sequence_length" -# micro_batch_size = "micro_batch_size" -# # TODO: Move -# grad_output = "grad_output" - - class TransformerLossNames: load_balancing_loss = "load_balancing_loss" router_z_loss = "router_z_loss" diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 2c79883b3..392ebb889 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -149,11 +149,3 @@ def __init__( def _create_mixer(self): self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) - - # @property - # def _transformer_kwargs(self) -> TransformerKwargs: - # return TransformerKwargs - - # @property - # def _transformer_dim_names(self) -> TransformerDimNames: - # return TransformerDimNames diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py index c2cfe9f23..7c1be0d16 100644 --- a/fast_llm/layers/transformer/vision_transformer.py +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -9,14 +9,6 @@ class VisionTransformerLayer(TransformerLayer): _name: str = "Vision transformer layer" - # @property - # def _transformer_kwargs(self) -> VisionTransformerKwargs: - # return VisionTransformerKwargs - - # @property - # def _transformer_dim_names(self) -> VisionTransformerDimNames: - # return VisionTransformerDimNames - def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[VisionTransformerKwargs.hidden_dims] if self._return_input: diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py index bf5f3f1aa..41ea065d0 100644 --- a/fast_llm/layers/vision_encoder/adapter.py +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -20,7 +20,6 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) self._activation_type = config.adapter_activation_type - # TODO Soham: Make them OutputParallelLinear instead? How would this work with parallelism? self.layer_1 = Linear( input_dim, tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 6932c8fc0..f788b5149 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,11 +1,12 @@ import enum -from fast_llm.config import Config, Field, FieldHint, config_class +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationConfig from fast_llm.layers.transformer.config import VisionTransformerConfig +from fast_llm.utils import Assert class VisionEncoderDimNames: @@ -129,18 +130,24 @@ class VisionEncoderConfig(BaseModelConfig): desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, ) + adapter_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the adapter weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + conv_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the convolutional layer weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.adapter_size, self.adapter_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.patch_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.in_channels, 3)) - # TODO Soham: add a check for presence of kv channels parameter (head_dim) - tensor_space.add_tensor_dim( - TensorDim( - VisionEncoderDimNames.kv_channels, self.transformer.hidden_size // self.transformer.num_attention_heads - ) - ) self.transformer.setup_tensor_space(tensor_space) @property diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/patch_conv.py similarity index 95% rename from fast_llm/layers/vision_encoder/encoder.py rename to fast_llm/layers/vision_encoder/patch_conv.py index 20749af48..68f22200a 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/patch_conv.py @@ -43,6 +43,7 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self._tensor_space = tensor_space self._distributed_config = tensor_space.distributed_config self._sequence_parallel = self._distributed_config.sequence_tensor_parallel + self._lr_scale = config.adapter_lr_scale # TODO Soham: lr_scale self.weight = ParameterMeta.from_dims( ( @@ -52,10 +53,13 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size), ), init_method=init_normal_(), + lr_scale=self._lr_scale, ) if config.conv_bias: self.bias = ParameterMeta.from_dims( - (self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels),) + (self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels),), + init_method=init_normal_(), + lr_sclae=self._lr_scale, ) else: self.bias = None diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 5009123f0..d85442a3e 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -103,7 +103,6 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self._distributed_config = self._tensor_space.distributed_config def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - # kwargs[VisionEncoderDimNames] kwargs[VisionEncoderKwargs.image_patches_meta] = TensorMeta.from_dims( ( TensorDim( @@ -141,16 +140,12 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] for imgs in images ] - # position_ids = position_ids_in_meshgrid(image_sizes, im_height, patch_size) patches = [] patch_position_ids = [] cu_seqlens = [0] max_seqlen = -1 kwargs.get(TransformerKwargs.sequence_first) for imgs, sizes in zip(images, image_sizes): - # sum( - # get_num_patches(*size, patch_size) for size in sizes - # ) seq_patches = [] sample_cu_seqlen = 0 for image, size in zip(imgs, sizes): diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index b7f9f7733..95bbebde2 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -115,8 +115,7 @@ def import_weight( return (merged_weight.t().contiguous(),) -class TransformerWeightConverterMixin: - +class WeightAndBiasConverterMixin: def _get_weight_and_bias_converters( self, fast_llm_prefix: str | tuple[str, ...], @@ -145,6 +144,83 @@ def _get_weight_and_bias_converters( ) return converters + +class CommonHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): + _model: GPTModel + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + """ + Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) + """ + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_layers"),), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "hidden_size"),), + export_names=(("hidden_size",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_attention_heads"),), + export_names=(("num_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "head_groups"),), + export_names=(("num_key_value_heads",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "ffn_hidden_size"),), + export_names=(("intermediate_size",),), + ), + RenameParamConverter( + fast_llm_names=(("vocab_size",),), + export_names=(("vocab_size",),), + ), + RenameParamConverter( + fast_llm_names=(("tie_word_embeddings",),), + export_names=(("tie_word_embeddings",),), + ), + ] + + @abc.abstractmethod + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + pass + + def _create_weight_converters( + self, + hf_base_prefix: str = "", + offset: int = 0, + ) -> list[WeightConverter]: + converters = [] + num_layers = self._model.config.base_model.transformer.num_layers + + # Embeddings + converters.append( + WeightConverter(f"layers.{offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight") + ) + + converters += self._create_lm_head_converters(hf_base_prefix, offset=offset) + + for i in range(num_layers): + converters += self._create_transformer_layer_converters( + f"layers.{i+offset+1}", f"{hf_base_prefix}model.layers.{i}" + ) + + return converters + def _create_lm_head_converters(self, hf_base_prefix: str = "", offset: int = 0) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads @@ -250,83 +326,6 @@ def _create_transformer_layer_converters( return converters -class CommonHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): - _model: GPTModel - _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig - """ - Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) - """ - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), - RenameParamConverter( - fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) - ), - MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), - export_names=(("hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), - export_names=(("num_hidden_layers",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), - export_names=(("hidden_size",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "num_attention_heads"),), - export_names=(("num_attention_heads",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "head_groups"),), - export_names=(("num_key_value_heads",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "ffn_hidden_size"),), - export_names=(("intermediate_size",),), - ), - RenameParamConverter( - fast_llm_names=(("vocab_size",),), - export_names=(("vocab_size",),), - ), - RenameParamConverter( - fast_llm_names=(("tie_word_embeddings",),), - export_names=(("tie_word_embeddings",),), - ), - ] - - @abc.abstractmethod - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - pass - - def _create_weight_converters( - self, - hf_base_prefix: str = "", - offset: int = 0, - ) -> list[WeightConverter]: - converters = [] - num_layers = self._model.config.base_model.transformer.num_layers - - # Embeddings - converters.append( - WeightConverter(f"layers.{offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight") - ) - - converters += self._create_lm_head_converters(hf_base_prefix, offset=offset) - - for i in range(num_layers): - converters += self._create_transformer_layer_converters( - f"layers.{i+offset+1}", f"{hf_base_prefix}model.layers.{i}" - ) - - return converters - - class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = Starcoder2GPTHuggingfaceCheckpointFormat @@ -565,7 +564,7 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] -class PixtralHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): +class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @@ -770,7 +769,7 @@ def num_layers(self) -> int: return self._model.config.base_model.vision_encoder.transformer.num_layers + 2 -class LlavaHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): +class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index cbce66f2e..586b511ba 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -32,7 +32,7 @@ from fast_llm.layers.transformer.vision_transformer import VisionTransformerLayer from fast_llm.layers.vision_encoder.adapter import VisionAdapter from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs -from fast_llm.layers.vision_encoder.encoder import PatchConv +from fast_llm.layers.vision_encoder.patch_conv import PatchConv from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -84,11 +84,6 @@ def __init__( self._preprocessors.append( RotaryEmbeddingPreprocessor(self._config.vision_encoder.transformer.rotary, self._tensor_space) ) - # self._vision_preprocessor = VisionPreprocessor(self._config.vision_encoder, self._tensor_space) - # if self._config.vision_encoder.transformer.rotary.enabled: - # self._vision_rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor( - # self._config.vision_encoder.transformer.rotary, self._tensor_space - # ) def get_output_layers(self) -> list[Layer]: layers = [] @@ -178,14 +173,14 @@ def preprocess_meta( ] image_rescale_factor = self._config.vision_encoder.image_normalization.rescale_factor vision_kwargs = { - VisionEncoderKwargs.patch_size: batch_meta.patch_size, + VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, VisionEncoderKwargs.image_size: image_size, VisionEncoderKwargs.image_mean: image_mean, VisionEncoderKwargs.image_std: image_std, VisionEncoderKwargs.image_rescale_factor: image_rescale_factor, VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, VisionEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( - VisionEncoderDimNames.kv_channels + VisionTransformerDimNames.kv_channels ).size, VisionEncoderKwargs.out_channels: self._tensor_space.get_tensor_dim( VisionEncoderDimNames.out_channels diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 482fea02f..840b80926 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -30,10 +30,15 @@ def _get_sampling_parameters( "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, "cross_document_attention": self._config.batch.cross_document_attention, "extra_tokens": self._config.model.base_model.prediction_heads, - "patch_size": self._config.batch.patch_size, - "image_size": self._config.batch.image_size, } ) + if self._config.model.base_model.vision_encoder.enabled: + parameters.update( + { + "patch_size": self._config.model.base_model.vision_encoder.patch_size, + "image_size": self._config.batch.image_size, + } + ) return parameters if _return_dict else GPTSamplingParameters(**parameters) def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: diff --git a/fast_llm/tools/cli.py b/fast_llm/tools/cli.py index 4d218c3ff..0cc02f426 100644 --- a/fast_llm/tools/cli.py +++ b/fast_llm/tools/cli.py @@ -36,6 +36,7 @@ def fast_llm(args=None): if sys.gettrace(): raise logger.critical(traceback.format_exc()) + sys.exit(1) if __name__ == "__main__": diff --git a/fast_llm/utils.py b/fast_llm/utils.py index c5b7f07ae..51e0eee59 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -336,10 +336,3 @@ def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple def check_equal_nested(config_a, config_b): if errors := compare_nested(config_a, config_b): raise ValueError("\n".join(errors)) - - -def prefix_class_vars(cls, prefix: str, base_cls: type): - for attr, value in vars(base_cls).items(): - if not attr.startswith("__") and isinstance(value, str) and not hasattr(cls, attr): - setattr(cls, attr, prefix + value) - return cls From dc888c8fc6596b0ba7483b4eaf184ba7015e2063 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 22 May 2025 23:05:57 +0000 Subject: [PATCH 45/97] image break token in sampling --- fast_llm/data/dataset/gpt/config.py | 1 + fast_llm/data/dataset/gpt/sampled.py | 45 +++++++++++++++++-- fast_llm/layers/vision_encoder/config.py | 5 +++ .../layers/vision_encoder/preprocessing.py | 10 +++++ fast_llm/models/gpt/trainer.py | 1 + 5 files changed, 58 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 44d1f4cc9..004a062c2 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -76,6 +76,7 @@ class GPTSamplingParameters(SamplingParameters): cross_document_attention: bool = True patch_size: int | None = None image_size: int | None = None + image_break_token: int | None = None # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 780b18878..de8e1d75c 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -14,7 +14,7 @@ from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.layers.vision_encoder.preprocessing import get_num_patches, get_resize_dims +from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens, get_resize_dims from fast_llm.utils import Assert try: @@ -138,7 +138,7 @@ def _sample(self) -> None: for i, sizes in enumerate(image_sizes): image_token_sizes.append( sum( - get_num_patches( + get_num_image_tokens( *get_resize_dims( *size, self._parameters.image_size, @@ -146,6 +146,7 @@ def _sample(self) -> None: self._parameters.patch_size, ), self._parameters.patch_size, + break_token=self._parameters.image_break_token is not None, ) for size in sizes ) @@ -211,6 +212,7 @@ def _sample(self) -> None: "sequence_length": self._parameters.sequence_length, "patch_size": self._parameters.patch_size, "truncate_documents": self._truncate_documents, + "image_break_token": self._parameters.image_break_token, "config": self._config.to_dict(), } if self._truncate_documents: @@ -423,7 +425,7 @@ def __getitem__(self, index: int) -> typing.Any: text_size, image_lengths = self._indexed_dataset.get_document_size(document_index) image_sizes = [ - get_num_patches( + get_num_image_tokens( *get_resize_dims( *image_length, self._parameters.image_size, @@ -431,6 +433,7 @@ def __getitem__(self, index: int) -> typing.Any: self._parameters.patch_size, ), self._parameters.patch_size, + break_token=self._parameters.image_break_token is not None, ) for image_length in image_lengths ] @@ -473,7 +476,41 @@ def __getitem__(self, index: int) -> typing.Any: # Add placeholders for image tokens token_ids.append(sample.token_ids[start_pos:im_position]) text_tokens_added += len(token_ids[-1]) - token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) + # token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) + if self._parameters.image_break_token is not None: + # Calculate patch dimensions for the image + width, height = get_resize_dims( + image_lengths[idx][0], + image_lengths[idx][1], + self._parameters.image_size, + self._parameters.image_size, + self._parameters.patch_size, + ) + num_patches_w = math.ceil(width / self._parameters.patch_size) + num_patches_h = math.ceil(height / self._parameters.patch_size) + + # Calculate the token count considering break tokens + tokens_per_row = num_patches_w + total_tokens = num_patches_h * tokens_per_row + ( + num_patches_h - 1 + ) # Add break tokens after each row except last + + # Create image token placeholder array + image_token_array = np.full((total_tokens,), -100, dtype=np.int64) + + # Add break tokens after each row except the last row + for row in range(num_patches_h - 1): + position = (row + 1) * tokens_per_row + row + image_token_array[position] = self._parameters.image_break_token + + token_ids.append(image_token_array) + + # Update image_tokens_added to reflect actual number of tokens added + image_tokens_added += total_tokens + else: + # Just add placeholders for all image tokens without break tokens + token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) + image_tokens_added += image_sizes[idx] image_positions.append(im_position + len(token_ids) + image_tokens_added) image_tokens_added += image_tokens start_pos = im_position diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index f788b5149..5b972f128 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -130,6 +130,11 @@ class VisionEncoderConfig(BaseModelConfig): desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, ) + image_break_token: int | None = Field( + default=None, + desc="Token id to separate image rows. If None, no token id is applied is applied.", + hint=FieldHint.optional, + ) adapter_lr_scale: float | None = Field( default=None, desc="Custom learning rate scale for the adapter weights.", diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index d85442a3e..5cffbff58 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -19,6 +19,16 @@ def get_num_patches(height: int, width: int, patch_size: int) -> tuple[int, int] return div(height, patch_size) * div(width, patch_size) +def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool) -> int: + """ + Calculate the number of image tokens. + If image_break is True, we consider 1 additional token after every row of patches. + """ + height_patches = div(height, patch_size) + width_patches = div(width, patch_size) + return height_patches * (width_patches + image_break) + + def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: """ Calculate the new dimensions for resizing an image while maintaining the aspect ratio. diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 840b80926..d1b6d19e2 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -37,6 +37,7 @@ def _get_sampling_parameters( { "patch_size": self._config.model.base_model.vision_encoder.patch_size, "image_size": self._config.batch.image_size, + "image_break_token": self._config.model.base_model.vision_encoder.image_break_token, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) From af3e2dbcb19bec618d88dbf1bfb913fe8940caf7 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 23 May 2025 22:47:04 +0000 Subject: [PATCH 46/97] minor fixes --- fast_llm/data/dataset/gpt/memmap.py | 6 +- fast_llm/data/dataset/gpt/sampled.py | 13 ++-- fast_llm/layers/multi_modal/embedding.py | 64 ++++++++++++++----- .../layers/vision_encoder/preprocessing.py | 2 +- 4 files changed, 60 insertions(+), 25 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 1efc312e8..a202d2e1f 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -10,7 +10,7 @@ from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.layers.vision_encoder.preprocessing import get_num_patches, get_resize_dims +from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens, get_resize_dims from fast_llm.utils import Assert, div @@ -201,6 +201,7 @@ def get( use_loss_masking_spans: bool = False, patch_size: int | None = None, image_size: int | None = None, + image_break: bool = False, ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, @@ -239,9 +240,10 @@ def get( additional_tokens = 0 image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") while image_position >= span[0] and image_position <= span[1]: - image_tokens = get_num_patches( + image_tokens = get_num_image_tokens( get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), patch_size, + image_break=image_break, ) additional_tokens += image_tokens image_idx += 1 diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index de8e1d75c..f441d9b9e 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -146,7 +146,7 @@ def _sample(self) -> None: self._parameters.patch_size, ), self._parameters.patch_size, - break_token=self._parameters.image_break_token is not None, + image_break=self._parameters.image_break_token is not None, ) for size in sizes ) @@ -433,7 +433,7 @@ def __getitem__(self, index: int) -> typing.Any: self._parameters.patch_size, ), self._parameters.patch_size, - break_token=self._parameters.image_break_token is not None, + image_break=self._parameters.image_break_token is not None, ) for image_length in image_lengths ] @@ -476,6 +476,7 @@ def __getitem__(self, index: int) -> typing.Any: # Add placeholders for image tokens token_ids.append(sample.token_ids[start_pos:im_position]) text_tokens_added += len(token_ids[-1]) + image_positions.append(text_tokens_added + im_position + image_tokens_added) # token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) if self._parameters.image_break_token is not None: # Calculate patch dimensions for the image @@ -491,12 +492,12 @@ def __getitem__(self, index: int) -> typing.Any: # Calculate the token count considering break tokens tokens_per_row = num_patches_w - total_tokens = num_patches_h * tokens_per_row + ( + resized_image_tokens = num_patches_h * tokens_per_row + ( num_patches_h - 1 ) # Add break tokens after each row except last # Create image token placeholder array - image_token_array = np.full((total_tokens,), -100, dtype=np.int64) + image_token_array = np.full((resized_image_tokens,), -100, dtype=np.int64) # Add break tokens after each row except the last row for row in range(num_patches_h - 1): @@ -506,13 +507,11 @@ def __getitem__(self, index: int) -> typing.Any: token_ids.append(image_token_array) # Update image_tokens_added to reflect actual number of tokens added - image_tokens_added += total_tokens + image_tokens_added += resized_image_tokens else: # Just add placeholders for all image tokens without break tokens token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) image_tokens_added += image_sizes[idx] - image_positions.append(im_position + len(token_ids) + image_tokens_added) - image_tokens_added += image_tokens start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) text_tokens_added += len(token_ids[-1]) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 8c541e983..12b58a764 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -9,9 +9,9 @@ from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs -from fast_llm.layers.vision_encoder.preprocessing import get_num_patches +from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert +from fast_llm.utils import Assert, div class MultiModalEmbedding(LanguageModelEmbedding): @@ -60,15 +60,32 @@ def _forward( for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): - num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) - if self._sequence_parallel: - embeddings[position : position + num_image_tokens, sample_idx] = input_[ - image_embedding_offset : image_embedding_offset + num_image_tokens, sample_idx - ] - else: - embeddings[sample_idx, position : position + num_image_tokens] = input_[ - sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens - ] + num_image_tokens = get_num_image_tokens(*size, self._config.vision_encoder.patch_size) + # Calculate the patch dimensions + patch_width = div(size[0], self._config.vision_encoder.patch_size) + patch_height = div(size[1], self._config.vision_encoder.patch_size) + + # Process row by row for both sequence parallel and non-parallel cases + for row in range(patch_height): + # Calculate source and destination starting positions + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) + + # Always use full patch_width + tokens_in_row = patch_width + + if self._sequence_parallel: + # Copy with dimensions swapped for sequence parallel case + embeddings[row_start_dst : row_start_dst + tokens_in_row, sample_idx] = input_[ + row_start_src : row_start_src + tokens_in_row, sample_idx + ] + else: + # Copy with normal dimension ordering + embeddings[sample_idx, row_start_dst : row_start_dst + tokens_in_row] = input_[ + sample_idx, row_start_src : row_start_src + tokens_in_row + ] + + # Move to the next image in the input tensor image_embedding_offset += num_image_tokens if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) @@ -85,10 +102,27 @@ def _forward( for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): - num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) - embeddings[sample_idx, position : position + num_image_tokens] = input_[ - sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens - ] + num_image_tokens = get_num_image_tokens( + *size, + self._config.vision_encoder.patch_size, + image_break=self._config.vision_encoder.image_break_token is not None, + ) + # Calculate the patch dimensions + patch_width = div(size[0], self._config.vision_encoder.patch_size) + patch_height = div(size[1], self._config.vision_encoder.patch_size) + + # Process row by row + for row in range(patch_height): + # Calculate source and destination starting positions + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) + + # Copy row by row + embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ + sample_idx, row_start_src : row_start_src + patch_width + ] + + # Move to the next image in the input tensor image_embedding_offset += num_image_tokens if self._use_absolute_position_embeddings: diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 5cffbff58..c5c14a262 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -26,7 +26,7 @@ def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: """ height_patches = div(height, patch_size) width_patches = div(width, patch_size) - return height_patches * (width_patches + image_break) + return height_patches * (width_patches + image_break) - 1 def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: From 6d56be085309a4e0f74c24c5bad4aa8aea442708 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 24 May 2025 19:43:34 +0000 Subject: [PATCH 47/97] fix img break --- fast_llm/data/dataset/gpt/sampled.py | 6 +++--- fast_llm/layers/vision_encoder/preprocessing.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f441d9b9e..2c068742c 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -476,19 +476,19 @@ def __getitem__(self, index: int) -> typing.Any: # Add placeholders for image tokens token_ids.append(sample.token_ids[start_pos:im_position]) text_tokens_added += len(token_ids[-1]) - image_positions.append(text_tokens_added + im_position + image_tokens_added) + image_positions.append(text_tokens_added + image_tokens_added) # token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) if self._parameters.image_break_token is not None: # Calculate patch dimensions for the image - width, height = get_resize_dims( + height, width = get_resize_dims( image_lengths[idx][0], image_lengths[idx][1], self._parameters.image_size, self._parameters.image_size, self._parameters.patch_size, ) - num_patches_w = math.ceil(width / self._parameters.patch_size) num_patches_h = math.ceil(height / self._parameters.patch_size) + num_patches_w = math.ceil(width / self._parameters.patch_size) # Calculate the token count considering break tokens tokens_per_row = num_patches_w diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index c5c14a262..8404adae9 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -26,7 +26,7 @@ def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: """ height_patches = div(height, patch_size) width_patches = div(width, patch_size) - return height_patches * (width_patches + image_break) - 1 + return height_patches * width_patches + (height_patches - 1 if image_break else 0) def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: From ce9164647d3a582b8a13fd3646a66f3a019c8966 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 27 May 2025 23:34:57 +0000 Subject: [PATCH 48/97] fixes --- fast_llm/layers/language_model/embedding.py | 5 ++++- fast_llm/layers/multi_modal/embedding.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1d9406ed1..f51f40df7 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -99,7 +99,10 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> t input_ = split(input_, group=group, dim=0) if self._use_absolute_position_embeddings: position_ids = split(position_ids, group=group, dim=0) - embeddings = torch.embedding(self.word_embeddings_weight, input_) + # mask padded tokens + input_mask = input_ >= 0 + masked_input = input_ * input_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) with set_generator( diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 12b58a764..f40df3f09 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -60,7 +60,11 @@ def _forward( for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): - num_image_tokens = get_num_image_tokens(*size, self._config.vision_encoder.patch_size) + num_image_tokens = get_num_image_tokens( + *size, + self._config.vision_encoder.patch_size, + image_break=self._config.vision_encoder.image_break_token is not None, + ) # Calculate the patch dimensions patch_width = div(size[0], self._config.vision_encoder.patch_size) patch_height = div(size[1], self._config.vision_encoder.patch_size) @@ -97,7 +101,10 @@ def _forward( # TODO Soham: get image positions for current split. Maybe in preprocessing? # for positions in image_positions: # if positions > self._distributed_config.tensor_rank - embeddings = torch.embedding(self.word_embeddings_weight, tokens) + # mask padded tokens + token_mask = tokens >= 0 + masked_tokens = tokens * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) embeddings = embeddings.clone() for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 From 204b3e9f27e6d12168f72a4ae045fc7ab9dbe475 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 28 May 2025 06:04:47 +0000 Subject: [PATCH 49/97] fix image embeddings offset --- fast_llm/data/dataset/gpt/config.py | 1 + fast_llm/data/dataset/gpt/memmap.py | 2 + fast_llm/data/dataset/gpt/sampled.py | 68 +++++++------- fast_llm/layers/multi_modal/embedding.py | 89 ++++++++----------- fast_llm/layers/vision_encoder/config.py | 7 +- .../layers/vision_encoder/preprocessing.py | 31 ++++++- fast_llm/models/gpt/trainer.py | 1 + 7 files changed, 109 insertions(+), 90 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 004a062c2..bb3ff717a 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -77,6 +77,7 @@ class GPTSamplingParameters(SamplingParameters): patch_size: int | None = None image_size: int | None = None image_break_token: int | None = None + image_end_token: int | None = None # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index a202d2e1f..d83064b1e 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -202,6 +202,7 @@ def get( patch_size: int | None = None, image_size: int | None = None, image_break: bool = False, + image_end: bool = False, ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, @@ -244,6 +245,7 @@ def get( get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), patch_size, image_break=image_break, + image_end=image_end, ) additional_tokens += image_tokens image_idx += 1 diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 2c068742c..6c8e9fe71 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -15,7 +15,7 @@ from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens, get_resize_dims -from fast_llm.utils import Assert +from fast_llm.utils import Assert, div try: from fast_llm.csrc.data import build_padded_token_cumsum, build_sample_idx # noqa @@ -147,6 +147,7 @@ def _sample(self) -> None: ), self._parameters.patch_size, image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, ) for size in sizes ) @@ -213,6 +214,7 @@ def _sample(self) -> None: "patch_size": self._parameters.patch_size, "truncate_documents": self._truncate_documents, "image_break_token": self._parameters.image_break_token, + "image_end_token": self._parameters.image_end_token, "config": self._config.to_dict(), } if self._truncate_documents: @@ -424,18 +426,23 @@ def __getitem__(self, index: int) -> typing.Any: text_size, image_lengths = self._indexed_dataset.get_document_size(document_index) + resized_image_lengths = [ + get_resize_dims( + *image_length, + self._parameters.image_size, + self._parameters.image_size, + self._parameters.patch_size, + ) + for image_length in image_lengths + ] image_sizes = [ get_num_image_tokens( - *get_resize_dims( - *image_length, - self._parameters.image_size, - self._parameters.image_size, - self._parameters.patch_size, - ), + *image_length, self._parameters.patch_size, image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, ) - for image_length in image_lengths + for image_length in resized_image_lengths ] image_tokens = sum(image_sizes) document_size = text_size + image_tokens @@ -468,6 +475,8 @@ def __getitem__(self, index: int) -> typing.Any: offset=token_start_index_in_document, length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._parameters.use_loss_masking_spans, + # image_break=self._parameters.image_break_token is not None, + # image_end=self._parameters.image_end_token is not None, ) start_pos = 0 if sample.image_positions: @@ -477,41 +486,30 @@ def __getitem__(self, index: int) -> typing.Any: token_ids.append(sample.token_ids[start_pos:im_position]) text_tokens_added += len(token_ids[-1]) image_positions.append(text_tokens_added + image_tokens_added) - # token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) if self._parameters.image_break_token is not None: - # Calculate patch dimensions for the image - height, width = get_resize_dims( - image_lengths[idx][0], - image_lengths[idx][1], - self._parameters.image_size, - self._parameters.image_size, - self._parameters.patch_size, - ) - num_patches_h = math.ceil(height / self._parameters.patch_size) - num_patches_w = math.ceil(width / self._parameters.patch_size) - - # Calculate the token count considering break tokens - tokens_per_row = num_patches_w - resized_image_tokens = num_patches_h * tokens_per_row + ( - num_patches_h - 1 - ) # Add break tokens after each row except last + height, width = resized_image_lengths[idx] + num_patches_h = div(height, self._parameters.patch_size) + num_patches_w = div(width, self._parameters.patch_size) # Create image token placeholder array - image_token_array = np.full((resized_image_tokens,), -100, dtype=np.int64) + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) # Add break tokens after each row except the last row for row in range(num_patches_h - 1): - position = (row + 1) * tokens_per_row + row + position = (row + 1) * num_patches_w + row image_token_array[position] = self._parameters.image_break_token - - token_ids.append(image_token_array) - - # Update image_tokens_added to reflect actual number of tokens added - image_tokens_added += resized_image_tokens + # add end token if specified, else break token + last_row_position = num_patches_h * num_patches_w + num_patches_h - 1 + if self._parameters.image_end_token is not None: + image_token_array[last_row_position] = self._parameters.image_end_token + else: + image_token_array[last_row_position] = self._parameters.image_break_token else: - # Just add placeholders for all image tokens without break tokens - token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) - image_tokens_added += image_sizes[idx] + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + if self._parameters.image_end_token is not None: + image_token_array[-1] = self._parameters.image_end_token + token_ids.append(image_token_array) + image_tokens_added += image_sizes[idx] start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) text_tokens_added += len(token_ids[-1]) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index f40df3f09..4dd4a46eb 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -9,7 +9,7 @@ from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs -from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens +from fast_llm.layers.vision_encoder.preprocessing import get_num_patches from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div @@ -60,37 +60,30 @@ def _forward( for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): - num_image_tokens = get_num_image_tokens( - *size, - self._config.vision_encoder.patch_size, - image_break=self._config.vision_encoder.image_break_token is not None, - ) - # Calculate the patch dimensions - patch_width = div(size[0], self._config.vision_encoder.patch_size) - patch_height = div(size[1], self._config.vision_encoder.patch_size) - - # Process row by row for both sequence parallel and non-parallel cases - for row in range(patch_height): - # Calculate source and destination starting positions - row_start_src = image_embedding_offset + row * patch_width - row_start_dst = position + row * (patch_width + 1) - - # Always use full patch_width - tokens_in_row = patch_width - - if self._sequence_parallel: - # Copy with dimensions swapped for sequence parallel case - embeddings[row_start_dst : row_start_dst + tokens_in_row, sample_idx] = input_[ - row_start_src : row_start_src + tokens_in_row, sample_idx - ] - else: - # Copy with normal dimension ordering - embeddings[sample_idx, row_start_dst : row_start_dst + tokens_in_row] = input_[ - sample_idx, row_start_src : row_start_src + tokens_in_row - ] - - # Move to the next image in the input tensor - image_embedding_offset += num_image_tokens + num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) + if self._config.vision_encoder.image_break_token is not None: + patch_width = div(size[0], self._config.vision_encoder.patch_size) + patch_height = div(size[1], self._config.vision_encoder.patch_size) + + for row in range(patch_height): + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) + + if self._sequence_parallel: + # Copy with dimensions swapped for sequence parallel case + embeddings[row_start_dst : row_start_dst + patch_width, sample_idx] = input_[ + row_start_src : row_start_src + patch_width, sample_idx + ] + else: + # Copy with normal dimension ordering + embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ + sample_idx, row_start_src : row_start_src + patch_width + ] + else: + embeddings[sample_idx, position : position + num_patches] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_patches + ] + image_embedding_offset += num_patches if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: @@ -109,28 +102,24 @@ def _forward( for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): - num_image_tokens = get_num_image_tokens( - *size, - self._config.vision_encoder.patch_size, - image_break=self._config.vision_encoder.image_break_token is not None, - ) - # Calculate the patch dimensions - patch_width = div(size[0], self._config.vision_encoder.patch_size) - patch_height = div(size[1], self._config.vision_encoder.patch_size) + num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) + if self._config.vision_encoder.image_break_token is not None: + patch_width = div(size[0], self._config.vision_encoder.patch_size) + patch_height = div(size[1], self._config.vision_encoder.patch_size) - # Process row by row - for row in range(patch_height): - # Calculate source and destination starting positions - row_start_src = image_embedding_offset + row * patch_width - row_start_dst = position + row * (patch_width + 1) + for row in range(patch_height): + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) - # Copy row by row - embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ - sample_idx, row_start_src : row_start_src + patch_width + embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ + sample_idx, row_start_src : row_start_src + patch_width + ] + else: + embeddings[sample_idx, position : position + num_patches] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_patches ] - # Move to the next image in the input tensor - image_embedding_offset += num_image_tokens + image_embedding_offset += num_patches if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 5b972f128..267941741 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -132,7 +132,12 @@ class VisionEncoderConfig(BaseModelConfig): ) image_break_token: int | None = Field( default=None, - desc="Token id to separate image rows. If None, no token id is applied is applied.", + desc="Token id to separate image rows. If None, no token id is applied.", + hint=FieldHint.optional, + ) + image_end_token: int | None = Field( + default=None, + desc="Token id to indicate the end of an image. If None, no token id is applied.", hint=FieldHint.optional, ) adapter_lr_scale: float | None = Field( diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 8404adae9..41da4fb6f 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -6,6 +6,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import TensorMeta @@ -19,14 +20,19 @@ def get_num_patches(height: int, width: int, patch_size: int) -> tuple[int, int] return div(height, patch_size) * div(width, patch_size) -def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool) -> int: +def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool, image_end: bool) -> int: """ Calculate the number of image tokens. If image_break is True, we consider 1 additional token after every row of patches. """ height_patches = div(height, patch_size) width_patches = div(width, patch_size) - return height_patches * width_patches + (height_patches - 1 if image_break else 0) + num_tokens = height_patches * width_patches + if image_break: + num_tokens += height_patches + elif image_end: + num_tokens += 1 + return num_tokens def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: @@ -150,16 +156,32 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] for imgs in images ] + image_positions = kwargs.get(VisionEncoderKwargs.image_positions) + + labels = kwargs[LanguageModelKwargs.labels] + if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): + # If image break or end token is present, we need to replace image token ids to -100 in labels + # TODO: avoid double cloning labels in case of loss masking spans? + labels = labels.clone() + patches = [] patch_position_ids = [] cu_seqlens = [0] max_seqlen = -1 kwargs.get(TransformerKwargs.sequence_first) - for imgs, sizes in zip(images, image_sizes): + for idx, (imgs, sizes, positions) in enumerate(zip(images, image_sizes, image_positions)): seq_patches = [] sample_cu_seqlen = 0 - for image, size in zip(imgs, sizes): + for image, size, position in zip(imgs, sizes, positions): seqlen = get_num_patches(*size, patch_size) + num_tokens = get_num_image_tokens( + *size, + patch_size=patch_size, + image_break=self._config.image_break_token is not None, + image_end=self._config.image_end_token is not None, + ) + # set labels for image patches to -100 + labels[idx, max(position - 1, 0) : position + num_tokens - 1] = -100 if seqlen > max_seqlen: max_seqlen = seqlen cu_seqlens.append(cu_seqlens[-1] + seqlen) @@ -204,6 +226,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: # TODO Soham: remove assert patches[-1].size(0) == kwargs[TransformerKwargs.sequence_length] patches = torch.cat(patches) + kwargs[LanguageModelKwargs.labels] = labels patch_position_ids = torch.cat(patch_position_ids) kwargs[VisionEncoderKwargs.image_patches] = patches kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index d1b6d19e2..a4f0b0b42 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -38,6 +38,7 @@ def _get_sampling_parameters( "patch_size": self._config.model.base_model.vision_encoder.patch_size, "image_size": self._config.batch.image_size, "image_break_token": self._config.model.base_model.vision_encoder.image_break_token, + "image_end_token": self._config.model.base_model.vision_encoder.image_end_token, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) From fd08eac092f508b50219d4314f22a54af8efe768 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 29 May 2025 00:10:38 +0000 Subject: [PATCH 50/97] heterogeneous data fixes --- fast_llm/engine/multi_stage/stage.py | 2 +- fast_llm/functional/cross_entropy.py | 2 +- fast_llm/functional/triton/mlp.py | 4 ++-- .../layers/vision_encoder/preprocessing.py | 21 +++++++++++++++---- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 675e878b3..b1c7df819 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -121,7 +121,7 @@ def backward( assert self._mode.support_backward input_, output = grad_context output.backward(output_grad) - return input_.grad + return input_.grad if input_.grad is not None else torch.zeros_like(input_) def restore_parameters(self) -> None: assert self._is_setup diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 513510ec7..53b5979ed 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -145,7 +145,7 @@ def _fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: - per_sample_loss = per_sample_loss * loss_mask + per_sample_loss = per_sample_loss[loss_mask] loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index ee3ba304c..0fb71bd56 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -50,7 +50,7 @@ def triton_mlp_activation_forward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) - if activation_type == _TritonActivationType.gelu: + if activation_type == _TritonActivationType.gelu_pytorch_tanh: tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) out = input_ * 0.5 * (1.0 + tanh) @@ -100,7 +100,7 @@ def triton_mlp_activation_backward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) output_grad = tl.load(grad_output_ptr + output_offsets, mask=mask).to(tl.float32) - if activation_type == _TritonActivationType.gelu: + if activation_type == _TritonActivationType.gelu_pytorch_tanh: tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) grad = 0.5 * input_ * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * input_ * input_)) + 0.5 * (1 + tanh) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 41da4fb6f..8fad35722 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -170,7 +170,13 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: max_seqlen = -1 kwargs.get(TransformerKwargs.sequence_first) for idx, (imgs, sizes, positions) in enumerate(zip(images, image_sizes, image_positions)): - seq_patches = [] + # add an empty tensor for clean concatenation in case of no images + seq_patches = [ + torch.tensor([]).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) + ] sample_cu_seqlen = 0 for image, size, position in zip(imgs, sizes, positions): seqlen = get_num_patches(*size, patch_size) @@ -211,9 +217,16 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] ) ) - position_ids = torch.cat( - [position_ids_in_meshgrid(*size, im_height // patch_size, patch_size) for size in sizes] - ).to(device=self._tensor_space.distributed.device) + if sizes: + position_ids = torch.cat( + [position_ids_in_meshgrid(*size, im_height // patch_size, patch_size) for size in sizes] + ).to(device=self._tensor_space.distributed.device) + else: + position_ids = torch.tensor( + [], + dtype=torch.int64, + device=self._tensor_space.distributed.device, + ) # We pad at the end instead of padding at the position in meshgrid because flash attention does not support custom attention masks patch_position_ids.append( torch.cat( From 1e3652aeae78f930fdd1c58d09b45681adec2047 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 29 May 2025 15:25:49 +0000 Subject: [PATCH 51/97] convert to rgb --- fast_llm/data/dataset/gpt/memmap.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index d83064b1e..703809417 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -325,10 +325,11 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP for image in document.images: # assume 3 channels (RGB) for all images with PIL.Image.open(io.BytesIO(image["bytes"])) as img: - if img.mode == "L": - # Convert grayscale to RGB + if img.mode != "RGB": + # Convert all images to RGB img = img.convert("RGB") pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW + assert pixels.dtype == np.uint8, f"Expected uint8 pixels, got {pixels.dtype}." image_lengths.append(np.array(pixels.shape[1:])) bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size From 2aabf353752eeb9290f470cd76e44da8482c0456 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 30 May 2025 20:48:27 +0000 Subject: [PATCH 52/97] fix sequence parallel image patches --- fast_llm/layers/multi_modal/embedding.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 4dd4a46eb..9e11df3f3 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -48,11 +48,17 @@ def _forward( """ Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) group = self._tensor_space.distributed.tensor_group + if self._sequence_parallel: + micro_seqlen = input_.size(0) + patch_start_offset = self._distributed_config.tensor_rank * micro_seqlen + patch_end_offset = (self._distributed_config.tensor_rank + 1) * micro_seqlen + else: + patch_start_offset = 0 + patch_end_offset = input_.size(0) if self._parallel_embeddings: token_mask = (tokens >= self._vocab_start_index) * (tokens < self._vocab_end_index) masked_tokens = (tokens - self._vocab_start_index) * token_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa - embeddings = reduce_forward(embeddings, group) if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) embeddings = embeddings.clone() @@ -61,13 +67,18 @@ def _forward( image_embedding_offset = 0 for position, size in zip(positions, sizes): num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) + if image_embedding_offset + num_patches < patch_start_offset: + continue if self._config.vision_encoder.image_break_token is not None: patch_width = div(size[0], self._config.vision_encoder.patch_size) patch_height = div(size[1], self._config.vision_encoder.patch_size) - for row in range(patch_height): row_start_src = image_embedding_offset + row * patch_width row_start_dst = position + row * (patch_width + 1) + if row_start_src > patch_end_offset: + break + if row_start_dst < patch_start_offset: + continue if self._sequence_parallel: # Copy with dimensions swapped for sequence parallel case @@ -84,6 +95,9 @@ def _forward( sample_idx, image_embedding_offset : image_embedding_offset + num_patches ] image_embedding_offset += num_patches + if image_embedding_offset > patch_end_offset: + break + embeddings = reduce_forward(embeddings, group) if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: From b6d48589ad500034efdecb3727a5d163702f60e2 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 31 May 2025 01:50:12 +0000 Subject: [PATCH 53/97] fixes --- fast_llm/layers/multi_modal/embedding.py | 46 +++++++++++++------ .../layers/vision_encoder/preprocessing.py | 2 +- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 9e11df3f3..76060a004 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -59,8 +59,6 @@ def _forward( token_mask = (tokens >= self._vocab_start_index) * (tokens < self._vocab_end_index) masked_tokens = (tokens - self._vocab_start_index) * token_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa - if self._use_absolute_position_embeddings: - embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) embeddings = embeddings.clone() input_ = gather(input_, group, dim=0) for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): @@ -70,34 +68,56 @@ def _forward( if image_embedding_offset + num_patches < patch_start_offset: continue if self._config.vision_encoder.image_break_token is not None: - patch_width = div(size[0], self._config.vision_encoder.patch_size) - patch_height = div(size[1], self._config.vision_encoder.patch_size) + patch_height = div(size[0], self._config.vision_encoder.patch_size) + patch_width = div(size[1], self._config.vision_encoder.patch_size) for row in range(patch_height): row_start_src = image_embedding_offset + row * patch_width row_start_dst = position + row * (patch_width + 1) if row_start_src > patch_end_offset: break - if row_start_dst < patch_start_offset: + if row_start_src + patch_width <= patch_start_offset: continue + input_start_index = max(row_start_src, patch_start_offset) - patch_start_offset + input_end_index = min(row_start_src + patch_width, patch_end_offset) - patch_start_offset + embeddings_start_index = row_start_dst - max(patch_start_offset - row_start_src, 0) + embeddings_end_index = ( + row_start_dst + patch_width - max(row_start_src + patch_width - patch_end_offset, 0) + ) + # row_end_src = min(row_start_src + patch_width, patch_end_offset) if self._sequence_parallel: # Copy with dimensions swapped for sequence parallel case - embeddings[row_start_dst : row_start_dst + patch_width, sample_idx] = input_[ - row_start_src : row_start_src + patch_width, sample_idx + embeddings[embeddings_start_index:embeddings_end_index, sample_idx] = input_[ + input_start_index:input_end_index, sample_idx ] + tokens[embeddings_start_index:embeddings_end_index, sample_idx] = 10 else: # Copy with normal dimension ordering - embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ - sample_idx, row_start_src : row_start_src + patch_width + embeddings[sample_idx, embeddings_start_index:embeddings_end_index] = input_[ + sample_idx, input_start_index:input_end_index ] + tokens[embeddings_start_index:embeddings_end_index, sample_idx] = 10 else: - embeddings[sample_idx, position : position + num_patches] = input_[ - sample_idx, image_embedding_offset : image_embedding_offset + num_patches + input_start_index = max(image_embedding_offset, patch_start_offset) - patch_start_offset + input_end_index = ( + min(image_embedding_offset + num_patches, patch_end_offset) - patch_start_offset + ) + embedding_start_index = position - max(patch_start_offset - image_embedding_offset, 0) + embedding_end_index = ( + position + num_patches - max(image_embedding_offset + num_patches - patch_end_offset, 0) + ) + embeddings[sample_idx, embedding_start_index:embedding_end_index] = input_[ + input_start_index:input_end_index, sample_idx ] + # embeddings[sample_idx, position : position + num_patches] = input_[ + # sample_idx, image_embedding_offset : image_embedding_offset + num_patches + # ] image_embedding_offset += num_patches if image_embedding_offset > patch_end_offset: break embeddings = reduce_forward(embeddings, group) + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: @@ -118,8 +138,8 @@ def _forward( for position, size in zip(positions, sizes): num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) if self._config.vision_encoder.image_break_token is not None: - patch_width = div(size[0], self._config.vision_encoder.patch_size) - patch_height = div(size[1], self._config.vision_encoder.patch_size) + patch_height = div(size[0], self._config.vision_encoder.patch_size) + patch_width = div(size[1], self._config.vision_encoder.patch_size) for row in range(patch_height): row_start_src = image_embedding_offset + row * patch_width diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 8fad35722..ab0d23787 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -205,7 +205,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: padding_size = kwargs[TransformerKwargs.sequence_length] - sample_cu_seqlen if padding_size > max_seqlen: max_seqlen = padding_size - cu_seqlens.append(kwargs[TransformerKwargs.sequence_length]) + cu_seqlens.append(kwargs[TransformerKwargs.sequence_length] * (idx + 1)) patches.append( torch.cat( [ From 25a650bf588e8a20b02a4b6f6b991aa42993808b Mon Sep 17 00:00:00 2001 From: root Date: Sat, 31 May 2025 17:10:16 +0000 Subject: [PATCH 54/97] no compile for embeddings --- fast_llm/layers/multi_modal/embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 76060a004..7f09347bf 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -26,7 +26,7 @@ def __init__( ): super().__init__(config, tensor_space) - @torch.compile + # @torch.compile def _forward( self, input_: torch.Tensor, From c904da5def23c6db1abb775971f6790a4bec8272 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 1 Jun 2025 17:48:59 +0000 Subject: [PATCH 55/97] fix sampling --- fast_llm/data/dataset/gpt/sampled.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 6c8e9fe71..8d216b3d4 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -453,7 +453,7 @@ def __getitem__(self, index: int) -> typing.Any: document_sampling_index += 1 continue tokens_in_sample = token_count % (self._parameters.sequence_length + 1) - if document_size + tokens_in_sample > self._parameters.sequence_length + 1: + if document_size + tokens_in_sample >= self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample if token_count > token_start: @@ -464,6 +464,7 @@ def __getitem__(self, index: int) -> typing.Any: else: # Move on to the next sample. token_count += padding_size + continue # Determine if the document belongs to the requested sample. if token_count + document_size >= token_start: From 7a4701c522431eb94a873f59a220e13691c007b9 Mon Sep 17 00:00:00 2001 From: sohamparikh Date: Mon, 2 Jun 2025 00:15:54 -0700 Subject: [PATCH 56/97] sampling and preprocessing bugs --- fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/layers/vision_encoder/preprocessing.py | 6 +++--- fast_llm/models/gpt/model.py | 13 ++++++++++--- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 8d216b3d4..f58b009a1 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -456,7 +456,7 @@ def __getitem__(self, index: int) -> typing.Any: if document_size + tokens_in_sample >= self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample - if token_count > token_start: + if token_count >= token_start: # Add padding tokens to current sample token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) Assert.eq(token_count + padding_size, token_end) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index ab0d23787..76b0aa284 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -137,6 +137,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: im_height = kwargs.get(VisionEncoderKwargs.image_size) im_width = kwargs.get(VisionEncoderKwargs.image_size) patch_size = kwargs[VisionEncoderKwargs.patch_size] + image_positions = kwargs.get(VisionEncoderKwargs.image_positions) image_sizes = [ [get_resize_dims(im.size(1), im.size(2), im_height, im_width, patch_size=patch_size) for im in ims] for ims in images @@ -156,7 +157,6 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] for imgs in images ] - image_positions = kwargs.get(VisionEncoderKwargs.image_positions) labels = kwargs[LanguageModelKwargs.labels] if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): @@ -239,9 +239,9 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: # TODO Soham: remove assert patches[-1].size(0) == kwargs[TransformerKwargs.sequence_length] patches = torch.cat(patches) - kwargs[LanguageModelKwargs.labels] = labels patch_position_ids = torch.cat(patch_position_ids) kwargs[VisionEncoderKwargs.image_patches] = patches + kwargs[VisionTransformerKwargs.patch_position_ids] = patch_position_ids kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( kwargs[VisionEncoderKwargs.rope_theta], kwargs[VisionEncoderKwargs.kv_channels], @@ -249,7 +249,6 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: patch_size, ).to(device=self._tensor_space.distributed.device) kwargs[VisionEncoderKwargs.max_image_tokens] = div(im_height * im_width, patch_size**2) - kwargs[VisionTransformerKwargs.patch_position_ids] = patch_position_ids # TODO Soham: handle sequence data parallel kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 @@ -259,3 +258,4 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ) kwargs[VisionTransformerKwargs.max_seqlen_q] = max_seqlen kwargs[VisionTransformerKwargs.max_seqlen_k] = max_seqlen + kwargs[LanguageModelKwargs.labels] = labels diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 586b511ba..45cf4a4fe 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -407,15 +407,22 @@ def preprocess( kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) - if batch.images is not None: + if self._config.vision_encoder.enabled: + batch_images = ( + batch.images if batch.images is not None else [[]] * kwargs[TransformerKwargs.micro_batch_size] + ) kwargs[VisionEncoderKwargs.images] = [ [ img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) for img in images ] - for images in batch.images + for images in batch_images ] - kwargs[VisionEncoderKwargs.image_positions] = batch.image_positions + kwargs[VisionEncoderKwargs.image_positions] = ( + batch.image_positions + if batch.image_positions is not None + else [[]] * kwargs[TransformerKwargs.micro_batch_size] + ) kwargs[LanguageModelKwargs.tokens] = tokens for preprocessor in self._preprocessors: From 067f901bc8bc0b51148f2531d1f929f74b90081a Mon Sep 17 00:00:00 2001 From: root Date: Mon, 2 Jun 2025 18:35:24 +0000 Subject: [PATCH 57/97] speed up sampling --- fast_llm/data/dataset/gpt/sampled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f58b009a1..2972632cb 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -166,7 +166,7 @@ def _sample(self) -> None: " Please make sure Fast-LLM is installed correctly." ) long_docs_filter = document_sizes + image_token_sizes > self._parameters.sequence_length + 1 - ignored_documents = sum(long_docs_filter) + ignored_documents = long_docs_filter.sum() if ignored_documents: log_main_rank( f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", From f24325eaf768de0dac5a3e4c7f879a3bf0d5f3cc Mon Sep 17 00:00:00 2001 From: root Date: Mon, 2 Jun 2025 22:15:57 +0000 Subject: [PATCH 58/97] cap image size reduction --- fast_llm/layers/vision_encoder/preprocessing.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 76b0aa284..a9115c97c 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -50,9 +50,22 @@ def get_resize_dims(height: int, width: int, max_height: int, max_width: int, pa def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: - resize_dims = get_resize_dims(image.size(1), image.size(2), max_height, max_width, patch_size=patch_size) + target_height, target_width = get_resize_dims( + image.size(1), image.size(2), max_height, max_width, patch_size=patch_size + ) + height, width = image.size(1), image.size(2) + while height > 2 * target_height or width > 2 * target_width: + # cap the resizing to half of the current size as a workaround for large images + # See pytorch issue: https://github.com/pytorch/pytorch/issues/103589 + intermediate_max_width = max(target_width, width // 2) + intermediate_max_height = max(target_height, height // 2) + height, width = get_resize_dims( + height, width, intermediate_max_height, intermediate_max_width, patch_size=patch_size + ) + image = F.resize(image, size=(height, width), interpolation=F.InterpolationMode.BICUBIC) + # TODO: options for interpolation mode? - return F.resize(image, size=resize_dims, interpolation=F.InterpolationMode.BICUBIC) + return F.resize(image, size=(target_height, target_width), interpolation=F.InterpolationMode.BICUBIC) def normalize(image: torch.Tensor, mean: list[float], std: list[float]) -> torch.Tensor: From 0f376643df53c2831c2a164436fd7aba92cb4f80 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 2 Jun 2025 23:15:50 +0000 Subject: [PATCH 59/97] fix span offset with images --- fast_llm/data/dataset/gpt/memmap.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 703809417..6f5a963f4 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -240,6 +240,18 @@ def get( for span in sample_spans: additional_tokens = 0 image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + while image_position < span[0]: + image_tokens = get_num_image_tokens( + get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), + patch_size, + image_break=image_break, + image_end=image_end, + ) + additional_tokens += image_tokens + image_idx += 1 + image_position = ( + image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + ) while image_position >= span[0] and image_position <= span[1]: image_tokens = get_num_image_tokens( get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), From ff8fecc46e37c76a84ba8036b3781e3c1c9c447e Mon Sep 17 00:00:00 2001 From: root Date: Mon, 2 Jun 2025 23:35:47 +0000 Subject: [PATCH 60/97] fix span offset with images --- fast_llm/data/dataset/gpt/memmap.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 6f5a963f4..a3f2f9019 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -237,8 +237,9 @@ def get( sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset if images: image_idx = 0 + prev_image_tokens = 0 for span in sample_spans: - additional_tokens = 0 + span_image_tokens = 0 image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") while image_position < span[0]: image_tokens = get_num_image_tokens( @@ -247,11 +248,12 @@ def get( image_break=image_break, image_end=image_end, ) - additional_tokens += image_tokens + span_image_tokens += image_tokens image_idx += 1 image_position = ( image_positions[image_idx] if image_idx < len(image_positions) else float("inf") ) + prev_image_tokens += image_tokens while image_position >= span[0] and image_position <= span[1]: image_tokens = get_num_image_tokens( get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), @@ -259,12 +261,14 @@ def get( image_break=image_break, image_end=image_end, ) - additional_tokens += image_tokens + span_image_tokens += image_tokens image_idx += 1 image_position = ( image_positions[image_idx] if image_idx < len(image_positions) else float("inf") ) - span[1] += additional_tokens + span[0] += prev_image_tokens + span[1] += prev_image_tokens + span_image_tokens + prev_image_tokens += span_image_tokens return GPTSample( token_ids=token_ids, images=images, From c663cbb69b6334e2802e04113c9aa33ca7e4f8c3 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 3 Jun 2025 00:06:02 +0000 Subject: [PATCH 61/97] move image logic to sampled --- fast_llm/data/dataset/gpt/memmap.py | 35 ---------------------------- fast_llm/data/dataset/gpt/sampled.py | 31 +++++++++++++++++++++--- 2 files changed, 28 insertions(+), 38 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index a3f2f9019..ce24f3b97 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -10,7 +10,6 @@ from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens, get_resize_dims from fast_llm.utils import Assert, div @@ -235,40 +234,6 @@ def get( ] sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - if images: - image_idx = 0 - prev_image_tokens = 0 - for span in sample_spans: - span_image_tokens = 0 - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - while image_position < span[0]: - image_tokens = get_num_image_tokens( - get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), - patch_size, - image_break=image_break, - image_end=image_end, - ) - span_image_tokens += image_tokens - image_idx += 1 - image_position = ( - image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - ) - prev_image_tokens += image_tokens - while image_position >= span[0] and image_position <= span[1]: - image_tokens = get_num_image_tokens( - get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), - patch_size, - image_break=image_break, - image_end=image_end, - ) - span_image_tokens += image_tokens - image_idx += 1 - image_position = ( - image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - ) - span[0] += prev_image_tokens - span[1] += prev_image_tokens + span_image_tokens - prev_image_tokens += span_image_tokens return GPTSample( token_ids=token_ids, images=images, diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 2972632cb..d0a867510 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -476,8 +476,6 @@ def __getitem__(self, index: int) -> typing.Any: offset=token_start_index_in_document, length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._parameters.use_loss_masking_spans, - # image_break=self._parameters.image_break_token is not None, - # image_end=self._parameters.image_end_token is not None, ) start_pos = 0 if sample.image_positions: @@ -520,12 +518,39 @@ def __getitem__(self, index: int) -> typing.Any: images.append([]) if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: + prev_image_tokens = 0 + image_idx = 0 + image_position = ( + sample.image_positions[image_idx] + if image_idx < len(sample.image_positions) + else float("inf") + ) + while image_position < loss_masking_span[0]: + prev_image_tokens += image_sizes[image_idx] + image_idx += 1 + image_position = ( + sample.image_positions[image_idx] + if image_idx < len(sample.image_positions) + else float("inf") + ) + span_image_tokens = 0 + while image_position <= loss_masking_span[1]: + span_image_tokens += image_sizes[image_idx] + image_idx += 1 + image_position = ( + sample.image_positions[image_idx] + if image_idx < len(sample.image_positions) + else float("inf") + ) + loss_masking_span[0] += prev_image_tokens + loss_masking_span[1] += prev_image_tokens + span_image_tokens + prev_image_tokens += span_image_tokens span = np.clip( loss_masking_span + token_count - token_start, 0, self._parameters.sequence_length + self._parameters.extra_tokens, ) - if span[1] > span[0]: + if span[1] >= span[0]: loss_masking_spans.append(span) # Go to the next document. From f52f02bf71ac17abd64cca4f7ecae4de9eea4cb2 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 3 Jun 2025 00:06:32 +0000 Subject: [PATCH 62/97] cleanup --- fast_llm/data/dataset/gpt/memmap.py | 34 ----------------------------- 1 file changed, 34 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index ce24f3b97..21c096b38 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -162,46 +162,12 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap - # def get( - # self, - # idx: int, - # offset: int = 0, - # image_offset: int = 0, - # length: int | None = None, - # use_loss_masking_spans: bool = False, - # ): - # token_ids = np.frombuffer( - # self._bin_buffer, - # dtype=self._dtype, - # count=self._document_sizes[idx] - offset if length is None else length, - # offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, - # ) - # if self._has_images: - # image_positions = self._image_positions[idx] - # pixels = np.frombuffer( - # self._bin_buffer, - # dtype=np.dtype(np.uint8), - # count=self._image_lengths[idx].prod(initial=3), - # offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, - # ) - # images = [] - # start = 0 - # for image_length in self._image_lengths[idx]: - # n_pixels = image_length.prod(initial=3) - # images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) - # start += n_pixels - # return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) - def get( self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False, - patch_size: int | None = None, - image_size: int | None = None, - image_break: bool = False, - image_end: bool = False, ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, From 02f6d8fa114dfd25c0ffbf52868131ba52011f20 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Jun 2025 01:50:33 +0000 Subject: [PATCH 63/97] cleanup --- fast_llm/data/dataset/gpt/memmap.py | 68 +++++++------- fast_llm/data/dataset/gpt/sampled.py | 9 +- fast_llm/layers/language_model/config.py | 1 - fast_llm/layers/multi_modal/embedding.py | 9 +- fast_llm/layers/transformer/config.py | 89 ++++++------------- fast_llm/layers/transformer/preprocessing.py | 9 +- fast_llm/layers/vision_encoder/config.py | 16 ++-- fast_llm/layers/vision_encoder/patch_conv.py | 1 - .../layers/vision_encoder/preprocessing.py | 4 +- fast_llm/models/gpt/conversion.py | 3 +- 10 files changed, 92 insertions(+), 117 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 76637565b..372415249 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -114,6 +114,34 @@ def _init( + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize + sum([x.nbytes for x in self._spans]) ) + # read preference spans + self._chosen_spans = None + self._rejected_spans = None + if self._has_preference_spans and self._version >= 3: + self._chosen_spans = [] + self._rejected_spans = [] + for idx in range(self._num_documents): + self._chosen_spans.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=2, + offset=offset + idx * 2 * np.dtype(np.int32).itemsize, + ) + ) + + rejected_span_offset = offset + np.array(self._chosen_spans).nbytes + for idx in range(self._num_documents): + self._rejected_spans.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=2, + offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize, + ) + ) + offset += np.array(self._chosen_spans).nbytes + np.array(self._rejected_spans).nbytes + self._num_pixels = 0 self._image_lengths = None self._image_positions = None @@ -147,36 +175,6 @@ def _init( ) images_seen += n_images - # read preference spans - self._chosen_spans = None - self._rejected_spans = None - if self._has_preference_spans and self._version >= 3: - self._chosen_spans = [] - self._rejected_spans = [] - chosen_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes - for idx in range(self._num_documents): - self._chosen_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=chosen_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) - - rejected_span_offset = ( - offset + self._document_sizes.nbytes + self._pointers.nbytes + np.array(self._chosen_spans).nbytes - ) - for idx in range(self._num_documents): - self._rejected_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) - self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) @@ -215,7 +213,9 @@ def get( offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) images = None + image_positions = None if self._has_images: + image_positions = self._image_positions[idx] # Truncations with images are not yet supported, so we get all images from the document pixels = np.frombuffer( self._bin_buffer, @@ -275,6 +275,8 @@ def get( return GPTSample( token_ids=token_ids, + images=images, + image_positions=image_positions, loss_masking_spans=sample_spans, chosen_span=chosen_span, rejected_span=rejected_span, @@ -384,10 +386,12 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if total_images: n_images = np.array(n_images, dtype=np.int32) + image_lengths = np.stack(image_lengths, dtype=np.int32) + im_positions = np.array(im_positions, dtype=np.int32) else: n_images = np.array([]) - image_lengths = np.stack(image_lengths, dtype=np.int32) - im_positions = np.array(im_positions, dtype=np.int32) + image_lengths = np.array([]) + im_positions = np.array([]) # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index ddd45539c..092a1c1c9 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -519,10 +519,10 @@ def __getitem__(self, index: int) -> typing.Any: document_sampling_index += 1 continue tokens_in_sample = token_count % (self._parameters.sequence_length + 1) - if document_size + tokens_in_sample >= self._parameters.sequence_length + 1: + if document_size + tokens_in_sample > self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample - if token_count >= token_start: + if token_count > token_start: # Add padding tokens to current sample token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) Assert.eq(token_count + padding_size, token_end) @@ -531,6 +531,11 @@ def __getitem__(self, index: int) -> typing.Any: # Move on to the next sample. token_count += padding_size continue + elif document_size + tokens_in_sample == self._parameters.sequence_length + 1: + if token_count + document_size == token_start: + token_count += document_size + document_sampling_index += 1 + continue # Determine if the document belongs to the requested sample. if token_count + document_size >= token_start: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index c50a26ab9..ff4d5ec97 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -48,7 +48,6 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.architecture, ) vision_encoder: VisionEncoderConfig = Field( - default_factory=VisionEncoderConfig, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, ) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 7f09347bf..fa5c0356b 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -59,8 +59,12 @@ def _forward( token_mask = (tokens >= self._vocab_start_index) * (tokens < self._vocab_end_index) masked_tokens = (tokens - self._vocab_start_index) * token_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa + # Cloning since we will modify the embeddings in-place embeddings = embeddings.clone() input_ = gather(input_, group, dim=0) + # the embeddings tensor are full-sized, but we might get a split of the patch embeddings + # We need to determine the offset in the embeddings tensor for each sample + # and also account for the special image tokens if applicable for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): @@ -86,13 +90,11 @@ def _forward( ) # row_end_src = min(row_start_src + patch_width, patch_end_offset) if self._sequence_parallel: - # Copy with dimensions swapped for sequence parallel case embeddings[embeddings_start_index:embeddings_end_index, sample_idx] = input_[ input_start_index:input_end_index, sample_idx ] tokens[embeddings_start_index:embeddings_end_index, sample_idx] = 10 else: - # Copy with normal dimension ordering embeddings[sample_idx, embeddings_start_index:embeddings_end_index] = input_[ sample_idx, input_start_index:input_end_index ] @@ -125,9 +127,6 @@ def _forward( tokens = split(tokens, group=group, dim=0) if self._use_absolute_position_embeddings: position_ids = split(position_ids, group=group, dim=0) - # TODO Soham: get image positions for current split. Maybe in preprocessing? - # for positions in image_positions: - # if positions > self._distributed_config.tensor_rank # mask padded tokens token_mask = tokens >= 0 masked_tokens = tokens * token_mask diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 1052e01ea..3bb302dd6 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -5,7 +5,7 @@ import typing import warnings -from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace @@ -120,13 +120,7 @@ class RotaryEmbeddingType(str, enum.Enum): default = "default" llama3 = "llama3" yarn = "yarn" - # TODO Soham: generic name? - pixtral = "pixtral" - - -class TransformerType(str, enum.Enum): - lm_decoder = "lm_decoder" - image_encoder = "image_encoder" + rope_2d = "rope_2d" @config_class(registry=True) @@ -193,28 +187,17 @@ def _validate(self) -> None: @property def _transformer_dim_names(self) -> TransformerDimNames: - return TransformerDimNames + if self.type == RotaryEmbeddingType.rope_2d: + return VisionTransformerDimNames + else: + return TransformerDimNames @property def _transformer_kwargs(self) -> TransformerKwargs: - return TransformerKwargs - - -@config_class() -class VisionRotaryConfig(RotaryConfig): - type: RotaryEmbeddingType = Field( - default=RotaryEmbeddingType.pixtral, - desc="The type of rotary embedding to use. Choices: none, default, llama3, yarn, pixtral.", - hint=FieldHint.feature, - ) - - @property - def _transformer_dim_names(self) -> VisionTransformerDimNames: - return VisionTransformerDimNames - - @property - def _transformer_kwargs(self) -> VisionTransformerKwargs: - return VisionTransformerKwargs + if self.type == RotaryEmbeddingType.rope_2d: + return VisionTransformerKwargs + else: + return TransformerKwargs for name in RotaryEmbeddingType: @@ -315,10 +298,15 @@ def _validate(self) -> None: TransformerPeftConfig.register_subclass(name.value, TransformerPeftConfig) -@config_class() +class TransformerType(str, enum.Enum): + lm_decoder = "lm_decoder" + image_encoder = "image_encoder" + + +@config_class(registry=True) class TransformerConfig(BaseModelConfig): _abstract = False - transformer_type: TransformerType = Field( + type: TransformerType = Field( default=TransformerType.lm_decoder, desc="Type of the transformer. Choices: lm_decoder, image_encoder.", hint=FieldHint.architecture, @@ -803,39 +791,20 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: @property def _transformer_kwargs(self) -> TransformerKwargs: - return TransformerKwargs + if self.type == TransformerType.image_encoder: + return VisionTransformerKwargs + else: + return TransformerKwargs @property def _transformer_dim_names(self) -> TransformerDimNames: - return TransformerDimNames - - -@config_class() -class VisionTransformerConfig(TransformerConfig): - """ - Configuration for the Vision Transformer (ViT) model. - """ - - transformer_type: TransformerType = FieldUpdate( - default=TransformerType.image_encoder, - desc="Type of the transformer. Choices: lm_decoder, image_encoder.", - hint=FieldHint.architecture, - ) - causal: bool = FieldUpdate( - default=False, - desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", - hint=FieldHint.feature, - ) - rotary: VisionRotaryConfig = FieldUpdate( - default_factory=VisionRotaryConfig, - desc="Configuration for the rotary positional embeddings.", - hint=FieldHint.feature, - ) + if self.type == TransformerType.image_encoder: + return VisionTransformerDimNames + else: + return TransformerDimNames - @property - def _transformer_kwargs(self) -> VisionTransformerKwargs: - return VisionTransformerKwargs - @property - def _transformer_dim_names(self) -> VisionTransformerDimNames: - return VisionTransformerDimNames +for name in TransformerType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + TransformerConfig.register_subclass(name.value, TransformerConfig) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index e5cb5fb89..ae74724c4 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -137,7 +137,6 @@ def get_2d_rotary_frequencies( height_positions = torch.arange(height, device=device, dtype=torch.float64) width_positions = torch.arange(width, device=device, dtype=torch.float64) frequencies = config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) - # TODO Soham: apply scaling angles_h = torch.outer(height_positions, frequencies[::2]) angles_w = torch.outer(width_positions, frequencies[1::2]) angles = torch.cat( @@ -182,7 +181,7 @@ def _create_tensors(self, sequence_length: int, num_patches: None | int = None) return self._tensor_cache_max_sequence_length = sequence_length - if self._config.type == RotaryEmbeddingType.pixtral: + if self._config.type == RotaryEmbeddingType.rope_2d: self._rotary_embedding_frequencies = get_2d_rotary_frequencies( self._config, num_patches, @@ -199,16 +198,16 @@ def _create_tensors(self, sequence_length: int, num_patches: None | int = None) ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - if self._config.type == RotaryEmbeddingType.pixtral: + if self._config.type == RotaryEmbeddingType.rope_2d: max_num_patches = kwargs[VisionEncoderKwargs.image_size] // kwargs[VisionEncoderKwargs.patch_size] self._create_tensors(kwargs[TransformerKwargs.sequence_length], max_num_patches) else: self._create_tensors(kwargs[TransformerKwargs.sequence_length]) sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - if self._config.type == RotaryEmbeddingType.pixtral: + if self._config.type == RotaryEmbeddingType.rope_2d: position_ids = kwargs[self._transformer_kwargs.patch_position_ids] - # TODO Soham: use position_ids_q and position_ids_k for sequence_data_parallelism + # sequence data parallelism is not yet supported with images, so we can safely assume that sequence_q == sequence_k kwargs[self._transformer_kwargs.rotary_freq_q] = self._rotary_embedding_frequencies[:, position_ids] kwargs[self._transformer_kwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, position_ids] else: diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 267941741..c5b790fe4 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -5,7 +5,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationConfig -from fast_llm.layers.transformer.config import VisionTransformerConfig +from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.utils import Assert @@ -78,10 +78,11 @@ class ImageNormalizationConfig(Config): class VisionEncoderType(str, enum.Enum): none = "none" + # TODO: better name? normalization, patch size, adapter can change based on implementation, no standard way currently. pixtral = "pixtral" -@config_class() +@config_class(registry=True) class VisionEncoderConfig(BaseModelConfig): _abstract = False @@ -90,8 +91,7 @@ class VisionEncoderConfig(BaseModelConfig): desc="Type of the vision encoder. Choices: none, pixtral.", hint=FieldHint.architecture, ) - transformer: VisionTransformerConfig = Field( - default_factory=VisionTransformerConfig, + transformer: TransformerConfig = Field( desc="Configuration for the vision transformer architecture.", hint=FieldHint.core, ) @@ -106,7 +106,6 @@ class VisionEncoderConfig(BaseModelConfig): hint=FieldHint.optional, ) patch_norm: NormalizationConfig = Field( - default_factory=NormalizationConfig, desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, ) @@ -126,7 +125,6 @@ class VisionEncoderConfig(BaseModelConfig): hint=FieldHint.optional, ) image_normalization: ImageNormalizationConfig = Field( - default_factory=ImageNormalizationConfig, desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, ) @@ -163,3 +161,9 @@ def setup_tensor_space(self, tensor_space: TensorSpace): @property def enabled(self) -> bool: return self.type != VisionEncoderType.none + + +for name in VisionEncoderType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + VisionEncoderConfig.register_subclass(name.value, VisionEncoderConfig) diff --git a/fast_llm/layers/vision_encoder/patch_conv.py b/fast_llm/layers/vision_encoder/patch_conv.py index 68f22200a..559ecc22d 100644 --- a/fast_llm/layers/vision_encoder/patch_conv.py +++ b/fast_llm/layers/vision_encoder/patch_conv.py @@ -44,7 +44,6 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self._distributed_config = tensor_space.distributed_config self._sequence_parallel = self._distributed_config.sequence_tensor_parallel self._lr_scale = config.adapter_lr_scale - # TODO Soham: lr_scale self.weight = ParameterMeta.from_dims( ( self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels), diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index a9115c97c..12dc68db6 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -214,7 +214,6 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] ) ) - # TODO Soham: should this be micro_sequence_length? padding_size = kwargs[TransformerKwargs.sequence_length] - sample_cu_seqlen if padding_size > max_seqlen: max_seqlen = padding_size @@ -249,7 +248,6 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] ) ) - # TODO Soham: remove assert patches[-1].size(0) == kwargs[TransformerKwargs.sequence_length] patches = torch.cat(patches) patch_position_ids = torch.cat(patch_position_ids) @@ -262,7 +260,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: patch_size, ).to(device=self._tensor_space.distributed.device) kwargs[VisionEncoderKwargs.max_image_tokens] = div(im_height * im_width, patch_size**2) - # TODO Soham: handle sequence data parallel + # sequence data parallel is not yet supported for images, so we use the same cu_seqlens for q and k kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 ) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 95bbebde2..661f5e516 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -241,12 +241,11 @@ def _create_lm_head_converters(self, hf_base_prefix: str = "", offset: int = 0) ) # MTP-heads > 0 are thrown away - # TODO Soham: handle offset with MTP for i in range(1, prediction_heads): logger.warning( f"The model weights for the multi-token prediction head {i} are discarded during conversion." ) - mtp_transformer_layer_index = num_layers - 1 + 2 * i + mtp_transformer_layer_index = num_layers + offset - 1 + 2 * i # MTP transformer layer converters += self._create_transformer_layer_converters( f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True From 68431293beed488b50ed963474df1a29d05222ec Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Jun 2025 02:10:54 +0000 Subject: [PATCH 64/97] jpeg dependency --- Dockerfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index b7e42d4dc..be579bccb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,8 +3,7 @@ FROM nvcr.io/nvidia/pytorch:24.11-py3 # Install dependencies. RUN apt-get update \ - && apt-get install --no-install-recommends -y acl git-lfs \ - # && apt-get install --no-install-recommends -y acl git-lfs libtiff5-dev \ + && apt-get install --no-install-recommends -y acl git-lfs libjpeg-dev \ && rm -rf /var/lib/apt/lists/* \ && git lfs install From b94b1eefda1d4447678a231e32ccdee9745c2184 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Jun 2025 02:14:56 +0000 Subject: [PATCH 65/97] install libjpeg-dev in gh actions --- .github/workflows/ci.yaml | 1 + Dockerfile | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 912ddaf5e..05ce16216 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -27,6 +27,7 @@ jobs: - name: Install dependencies run: | + sudo apt install libjpeg-dev pip install "torch>=2.2.2" pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]" diff --git a/Dockerfile b/Dockerfile index be579bccb..dda7b6535 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,8 @@ FROM nvcr.io/nvidia/pytorch:24.11-py3 # Install dependencies. RUN apt-get update \ - && apt-get install --no-install-recommends -y acl git-lfs libjpeg-dev \ + # && apt-get install --no-install-recommends -y acl git-lfs libjpeg-dev \ + && apt-get install --no-install-recommends -y acl git-lfs \ && rm -rf /var/lib/apt/lists/* \ && git lfs install From 9e4f14fe19d3b389a257eefb6245bc15869639b5 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Jun 2025 02:35:16 +0000 Subject: [PATCH 66/97] fix sampling test --- .github/workflows/docs.yaml | 2 ++ fast_llm/data/dataset/gpt/indexed.py | 8 ++++++++ tests/data/test_sampling.py | 3 +++ 3 files changed, 13 insertions(+) diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 93191972e..e8cb56d85 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -29,6 +29,7 @@ jobs: restore-keys: | mkdocs-material- - run: | + sudo apt install libjpeg-dev pip install "torch>=2.2.2" pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ @@ -56,6 +57,7 @@ jobs: restore-keys: | mkdocs-material- - run: | + sudo apt install libjpeg-dev pip install "torch>=2.2.2" pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 56c4c8927..2c7aefc80 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -34,6 +34,14 @@ def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": else GPTSampledIndexedDataset(self, sampling) ) + @property + @abc.abstractmethod + def has_images(self) -> bool: + """ + Whether the dataset contains images. + This is used to determine whether to use image-related fields in the sampled data. + """ + class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset): """ diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 386795826..a0aff3a72 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -106,6 +106,9 @@ def get_document_size(self, index: int) -> int: def name(self) -> str: return "dataset" + def has_images(self) -> bool: + return False + TEST_DATASET = SimpleGPTIndexedDataset( [ From d1c804ff558e0b34f7dc47822a281cbf9c2c796c Mon Sep 17 00:00:00 2001 From: root Date: Fri, 6 Jun 2025 05:46:22 +0000 Subject: [PATCH 67/97] fix --- fast_llm/data/dataset/gpt/memmap.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 372415249..ba2aa5800 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -56,9 +56,6 @@ def _init( if self._version >= 3: self._has_preference_spans = struct.unpack("= 3: - self._has_preference_spans = struct.unpack("= 4: self._has_images = struct.unpack(" Date: Mon, 9 Jun 2025 17:40:10 +0000 Subject: [PATCH 68/97] fix data cache reloading --- fast_llm/data/dataset/gpt/sampled.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 092a1c1c9..b4648af40 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -236,9 +236,10 @@ def _sample(self) -> None: if self._yaml_path is not None and self._yaml_path.is_file(): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) + yaml_data["unshuffled_tokens"] = loaded_yaml_data.get("unshuffled_tokens", 0) self._load_yaml_data(yaml_data) - if not self._truncate_documents and not self._parameters.use_preference_loss_spans: - del loaded_yaml_data["unshuffled_tokens"] + # if not self._truncate_documents and not self._parameters.use_preference_loss_spans: + # del loaded_yaml_data["unshuffled_tokens"] if loaded_yaml_data != yaml_data: raise RuntimeError( From cba6986a5d9665f7dc26ff50ebc6875667af43e5 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 9 Jun 2025 17:43:20 +0000 Subject: [PATCH 69/97] fix tokenization --- fast_llm/data/tokenizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 5988769f2..24eb77bd3 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -69,7 +69,7 @@ def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[li char_pos = image_position image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") if char_pos < start: - self._tokenize(text[char_pos:start], begin=beginning_of_text, end=False) + tokenized_text = self._tokenize(text[char_pos:start], begin=beginning_of_text, end=False) beginning_of_text = False token_ids.extend(tokenized_text) char_pos = start From 275fefa1dbb0dcd4a85417a0488e115a6efe647c Mon Sep 17 00:00:00 2001 From: shruthan Date: Wed, 11 Jun 2025 12:05:26 -0700 Subject: [PATCH 70/97] pixtral SFT (#296) Co-authored-by: sohamparikh --- fast_llm/data/dataset/gpt/memmap.py | 8 +++++--- fast_llm/data/dataset/gpt/sampled.py | 9 +++++---- fast_llm/data/preparator/gpt_memmap/prepare.py | 4 ++-- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index ba2aa5800..acc7914f1 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -58,6 +58,8 @@ def _init( if self._version >= 4: self._has_images = struct.unpack(" typing.Any: use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) start_pos = 0 - if sample.image_positions: + has_images = sample.image_positions is not None + if has_image_positions: for idx, im_position in enumerate(sample.image_positions): # image_positions.append(im_positions + len(token_ids) + image_tokens_added) # Add placeholders for image tokens @@ -594,7 +595,7 @@ def __getitem__(self, index: int) -> typing.Any: image_idx = 0 image_position = ( sample.image_positions[image_idx] - if image_idx < len(sample.image_positions) + if has_images and image_idx < len(sample.image_positions) else float("inf") ) while image_position < loss_masking_span[0]: @@ -602,7 +603,7 @@ def __getitem__(self, index: int) -> typing.Any: image_idx += 1 image_position = ( sample.image_positions[image_idx] - if image_idx < len(sample.image_positions) + if has_images and image_idx < len(sample.image_positions) else float("inf") ) span_image_tokens = 0 @@ -611,7 +612,7 @@ def __getitem__(self, index: int) -> typing.Any: image_idx += 1 image_position = ( sample.image_positions[image_idx] - if image_idx < len(sample.image_positions) + if has_images and image_idx < len(sample.image_positions) else float("inf") ) loss_masking_span[0] += prev_image_tokens diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index ad3dd4496..0b6803100 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -158,13 +158,13 @@ def _document_generator(): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( np.array(item["input_ids"], dtype=self._data_type.numpy), + item["images"] if self._config.dataset.images else None, + item["image_positions"] if self._config.dataset.image_positions else None, ( np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2) if self._config.dataset.loss_masking_spans else None ), - item["images"] if self._config.dataset.images else None, - item["image_positions"] if self._config.dataset.image_positions else None, item.get("chosen_token_spans", None), item.get("rejected_token_spans", None), ) From 605cc7ffac53047bb82871233413ac13cef35cac Mon Sep 17 00:00:00 2001 From: root Date: Wed, 11 Jun 2025 21:12:51 +0000 Subject: [PATCH 71/97] review comments --- Dockerfile | 1 - fast_llm/data/dataset/gpt/memmap.py | 45 ++++++++----------- fast_llm/data/tokenizer.py | 28 ++++++------ fast_llm/layers/transformer/transformer.py | 4 ++ .../layers/transformer/vision_transformer.py | 16 ------- fast_llm/layers/vision_encoder/patch_conv.py | 27 ----------- .../layers/vision_encoder/preprocessing.py | 12 ----- fast_llm/models/gpt/model.py | 3 +- 8 files changed, 39 insertions(+), 97 deletions(-) delete mode 100644 fast_llm/layers/transformer/vision_transformer.py diff --git a/Dockerfile b/Dockerfile index dda7b6535..8c2efa85e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,6 @@ FROM nvcr.io/nvidia/pytorch:24.11-py3 # Install dependencies. RUN apt-get update \ - # && apt-get install --no-install-recommends -y acl git-lfs libjpeg-dev \ && apt-get install --no-install-recommends -y acl git-lfs \ && rm -rf /var/lib/apt/lists/* \ && git lfs install diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index acc7914f1..642cd9800 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -58,8 +58,6 @@ def _init( if self._version >= 4: self._has_images = struct.unpack("= 4: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) - self._image_lengths = [] + self._image_sizes = [] self._image_positions = [] images_seen = 0 for n_images in self._n_images: - self._image_lengths.append( + self._image_sizes.append( np.frombuffer( self._index_bin_buffer, dtype=np.int32, @@ -159,7 +154,7 @@ def _init( offset=offset + self._n_images.nbytes + 2 * images_seen * np.dtype(np.int32).itemsize, ).reshape(-1, 2) ) - self._num_pixels += self._image_lengths[-1].prod(axis=1, initial=3).sum() + self._num_pixels += self._image_sizes[-1].prod(axis=1, initial=3).sum() self._image_positions.append( np.frombuffer( self._index_bin_buffer, @@ -214,19 +209,19 @@ def get( image_positions = None if self._has_images: image_positions = self._image_positions[idx] - + # Truncations with images are not yet supported, so we get all images from the document pixels = np.frombuffer( self._bin_buffer, dtype=np.dtype(np.uint8), - count=self._image_lengths[idx].prod(initial=3, axis=1).sum(), + count=self._image_sizes[idx].prod(initial=3, axis=1).sum(), offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, ) images = [] start = 0 - for image_length in self._image_lengths[idx]: - n_pixels = image_length.prod(initial=3) - images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) + for image_size in self._image_sizes[idx]: + n_pixels = image_size.prod(initial=3) + images.append(pixels[start : start + n_pixels].reshape(3, image_size[0], image_size[1])) start += n_pixels sample_spans = None if use_loss_masking_spans and self._spans is not None: @@ -302,10 +297,10 @@ def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes, self._image_lengths + return self._document_sizes, self._image_sizes def get_document_size(self, index: int) -> int: - return self._document_sizes[index].item(), self._image_lengths[index] if self._has_images else [] + return self._document_sizes[index].item(), self._image_sizes[index] if self._has_images else [] @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): @@ -314,7 +309,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP num_documents = 0 doc_lengths = [] n_images = [] - image_lengths = [] + image_sizes = [] im_positions = [] total_images = 0 pointers = [] @@ -353,7 +348,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP img = img.convert("RGB") pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW assert pixels.dtype == np.uint8, f"Expected uint8 pixels, got {pixels.dtype}." - image_lengths.append(np.array(pixels.shape[1:])) + image_sizes.append(np.array(pixels.shape[1:])) bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size im_positions.extend(document.image_positions) @@ -385,11 +380,11 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if total_images: n_images = np.array(n_images, dtype=np.int32) - image_lengths = np.stack(image_lengths, dtype=np.int32) + image_sizes = np.stack(image_sizes, dtype=np.int32) im_positions = np.array(im_positions, dtype=np.int32) else: n_images = np.array([]) - image_lengths = np.array([]) + image_sizes = np.array([]) im_positions = np.array([]) # Write the index file (.idx) @@ -402,12 +397,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP idx_stream.write(struct.pack(" 0 else 0)) - # Placeholder flag for preference spans - idx_stream.write(struct.pack(" 0 else 0)) # Flag to indicate whether preference loss-masking spans are present idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) + # Flag to indicate whether images are present + idx_stream.write(struct.pack(" 0 else 0)) # Data type idx_stream.write(struct.pack(" tuple[li image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") for start, end in char_spans: + # Tokenize all text before the span, with image positions in mind (i.e., break text at image positions). while image_position <= start: tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) beginning_of_text = False @@ -76,6 +77,7 @@ def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[li len(token_ids) span_length = 0 token_start = len(token_ids) + # Tokenize all text before the end of the span while image_position <= end: tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) beginning_of_text = False @@ -85,21 +87,21 @@ def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[li char_pos = image_position image_idx += 1 image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - if char_pos < end: - if end >= len(text) - 1: - tokenized_text = self._tokenize(text[char_pos : end + 1], begin=beginning_of_text, end=True) - beginning_of_text = False - token_ids.extend(tokenized_text) - span_length += len(tokenized_text) - char_pos = end + 1 - else: - tokenized_text = self._tokenize(text[char_pos : end + 1], begin=beginning_of_text, end=False) - beginning_of_text = False - token_ids.extend(tokenized_text) - span_length += len(tokenized_text) - char_pos = end + 1 + # Tokenize the last part of the span, since there are no more images + if char_pos < end + 1: + # end of span is end of text + tokenized_text = self._tokenize( + text[char_pos : end + 1], + begin=beginning_of_text, + end=(end >= len(text) - 1), + ) + beginning_of_text = False + token_ids.extend(tokenized_text) + span_length += len(tokenized_text) + char_pos = end + 1 token_spans.append((token_start, token_start + span_length - 1)) + # Tokenize text remaining after the last span while image_position <= len(text): image_position = image_positions[image_idx] tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 392ebb889..784a0f051 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -149,3 +149,7 @@ def __init__( def _create_mixer(self): self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + + +class VisionTransformerLayer(TransformerLayer): + _name: str = "Vision transformer layer" diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py deleted file mode 100644 index 7c1be0d16..000000000 --- a/fast_llm/layers/transformer/vision_transformer.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch - -from fast_llm.engine.config_utils.tensor_space import TensorDim -from fast_llm.layers.transformer.config import VisionTransformerKwargs -from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.tensor import TensorMeta - - -class VisionTransformerLayer(TransformerLayer): - _name: str = "Vision transformer layer" - - def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): - dims = kwargs[VisionTransformerKwargs.hidden_dims] - if self._return_input: - dims = (TensorDim("stacked_input_output", 2),) + dims - return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) diff --git a/fast_llm/layers/vision_encoder/patch_conv.py b/fast_llm/layers/vision_encoder/patch_conv.py index 559ecc22d..3d1845dd8 100644 --- a/fast_llm/layers/vision_encoder/patch_conv.py +++ b/fast_llm/layers/vision_encoder/patch_conv.py @@ -10,33 +10,6 @@ from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ -def position_ids_in_meshgrid(patch_embeddings_list, max_size): - positions = [] - for patch in patch_embeddings_list: - height, width = patch.shape[-2:] - mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") - h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) - ids = h_grid * max_size + v_grid - positions.append(ids[:, 0]) - return torch.cat(positions) - - -def generate_block_attention_mask(patch_embeds_list, tensor): - dtype = tensor.dtype - device = tensor.device - seq_len = tensor.shape[1] - d_min = torch.finfo(dtype).min - causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device) - - block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1) - block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1) - for start, end in zip(block_start_idx, block_end_idx): - causal_mask[start:end, start:end] = 0 - - causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1) - return causal_mask - - class PatchConv(Layer): def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 12dc68db6..77220a063 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -104,18 +104,6 @@ def create_inv_freqs(rope_theta: int, kv_channels: int, image_size: int, patch_s return torch.cat((inv_freq, inv_freq), dim=-1) -def position_ids_in_meshgrid(image_sizes: list[torch.Tensor], max_size: int, patch_size: int) -> torch.Tensor: - positions = [] - for h, w in image_sizes: - patch_height = h // patch_size - patch_width = w // patch_size - mesh = torch.meshgrid(torch.arange(patch_height), torch.arange(patch_width), indexing="ij") - h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) - ids = h_grid * max_size + v_grid - positions.append(ids[:, 0]) - return positions - - def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tensor: patch_height = height // patch_size patch_width = width // patch_size diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 6b16938fb..bf3778cc7 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -113,13 +113,12 @@ def get_output_layers(self) -> list[Layer]: return layers def get_vision_layers(self) -> list[Layer]: - patch_conv = PatchConv(self._config.vision_encoder, self._tensor_space) vit_layers = [ VisionTransformerLayer(self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1) for idx in range(self._config.vision_encoder.transformer.num_layers) ] return [ - patch_conv, + PatchConv(self._config.vision_encoder, self._tensor_space), *vit_layers, VisionAdapter(self._config.vision_encoder, self._tensor_space), MultiModalEmbedding(self._config, self._tensor_space), From 06aa7401119302d1ced30c54012e7b5d19e88ea9 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 12 Jun 2025 17:40:56 +0000 Subject: [PATCH 72/97] simplified tokenization with spans --- fast_llm/data/tokenizer.py | 97 ++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 57 deletions(-) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index d8b0ff87b..284ae21f7 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -42,77 +42,60 @@ def _tokenize(self, text: str, begin=True, end=True) -> list[int]: + ([self.eod_id] if end else []) ) - def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[list[int], list[tuple[int, int]]]: + def tokenize( + self, text: str, char_spans=None, image_positions=None + ) -> tuple[list[int], list[tuple[int, int]], list[int]]: """ - Tokenize the input text and return the tokenized input_ids and if provided, token spans and image positions. + Tokenize the input text and return the tokenized input_ids, token spans, and image token positions. + This version simplifies logic by merging all relevant positions, sorting, and tokenizing between them. """ if not image_positions: image_positions = [] if not char_spans: char_spans = [] - image_idx = 0 - char_pos = 0 + # Collect all positions with their type + positions = [] + for idx, pos in enumerate(image_positions): + positions.append((pos, "image")) + for idx, (start, end) in enumerate(char_spans): + positions.append((start, "span_start")) + positions.append((end + 1, "span_end")) + # Sort positions by character index. We assume that image and span positions are individually sorted and spans do not overlap + positions = sorted(positions, key=lambda x: x[0]) + token_ids = [] - image_token_positions = [] token_spans = [] - beginning_of_text = True - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + image_token_positions = [] + char_pos = 0 + current_span_start = None - for start, end in char_spans: - # Tokenize all text before the span, with image positions in mind (i.e., break text at image positions). - while image_position <= start: - tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) - beginning_of_text = False - token_ids.extend(tokenized_text) - image_token_positions.append(len(token_ids)) - image_idx += 1 - char_pos = image_position - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - if char_pos < start: - tokenized_text = self._tokenize(text[char_pos:start], begin=beginning_of_text, end=False) - beginning_of_text = False - token_ids.extend(tokenized_text) - char_pos = start - len(token_ids) - span_length = 0 - token_start = len(token_ids) - # Tokenize all text before the end of the span - while image_position <= end: - tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) - beginning_of_text = False - token_ids.extend(tokenized_text) - image_token_positions.append(len(token_ids)) - span_length += len(tokenized_text) - char_pos = image_position - image_idx += 1 - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - # Tokenize the last part of the span, since there are no more images - if char_pos < end + 1: - # end of span is end of text + for position in positions: + if char_pos < position[0]: tokenized_text = self._tokenize( - text[char_pos : end + 1], - begin=beginning_of_text, - end=(end >= len(text) - 1), + text[char_pos : position[0]], begin=(char_pos == 0), end=position[0] > len(text) - 1 ) - beginning_of_text = False token_ids.extend(tokenized_text) - span_length += len(tokenized_text) - char_pos = end + 1 - token_spans.append((token_start, token_start + span_length - 1)) - - # Tokenize text remaining after the last span - while image_position <= len(text): - image_position = image_positions[image_idx] - tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) - beginning_of_text = False + char_pos = position[0] + # beginning_of_text = False + if position[1] == "image": + image_token_positions.append(len(token_ids)) + elif position[1] == "span_start": + assert ( + current_span_start is None + ), "Starting a new span before current has ended, please check for overlapping spans" + current_span_start = len(token_ids) + elif position[1] == "span_end": + assert ( + current_span_start is not None + ), "Closing a span that has not started, please check for overlapping spans" + # spans are inclusive, so we take the index of the last token in the span + token_spans.append((current_span_start, len(token_ids) - 1)) + current_span_start = None + # Handle any remaining text after the last position and add EOS token + if char_pos < len(text): + tokenized_text = self._tokenize(text[char_pos:], begin=(char_pos == 0), end=True) token_ids.extend(tokenized_text) - image_token_positions.append(len(token_ids)) - char_pos = image_position - image_idx += 1 - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - tokenized_text = self._tokenize(text[char_pos:], begin=beginning_of_text, end=True) - token_ids.extend(tokenized_text) return token_ids, token_spans, image_token_positions From 30e3d34acca8a1cb89149ad69ea0720fa0d327ca Mon Sep 17 00:00:00 2001 From: sohamparikh Date: Thu, 12 Jun 2025 10:42:12 -0700 Subject: [PATCH 73/97] Update fast_llm/data/preparator/gpt_memmap/prepare.py Co-authored-by: RaymondLi0 --- fast_llm/data/preparator/gpt_memmap/prepare.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 0b6803100..43849857b 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -329,6 +329,7 @@ def run(self) -> None: if self._config.dataset.images else 0 ) + # Add the token-equivalent bytes of pixels to determine shard size total_tokens += total_pixels // np.dtype(self._data_type.numpy).itemsize # Split dataset into shards based on number of tokens From c1aa7094924cd6931d27db5e02384fb79aaa1b36 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 12 Jun 2025 17:47:17 +0000 Subject: [PATCH 74/97] rename --- fast_llm/data/dataset/gpt/config.py | 2 +- fast_llm/data/dataset/gpt/sampled.py | 8 ++++---- fast_llm/engine/schedule/config.py | 2 +- fast_llm/layers/transformer/preprocessing.py | 2 +- fast_llm/layers/vision_encoder/config.py | 2 +- .../layers/vision_encoder/preprocessing.py | 18 +++++++++--------- fast_llm/models/gpt/model.py | 4 ++-- fast_llm/models/gpt/trainer.py | 2 +- 8 files changed, 20 insertions(+), 20 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 9a5aa2007..250bfcb09 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -76,7 +76,7 @@ class GPTSamplingParameters(SamplingParameters): use_preference_loss_spans: bool = False cross_document_attention: bool = True patch_size: int | None = None - image_size: int | None = None + max_image_size: int | None = None image_break_token: int | None = None image_end_token: int | None = None # How many extra tokens to add to the sequence length. diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 255a30963..d4bcacddd 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -151,8 +151,8 @@ def _sample(self) -> None: get_num_image_tokens( *get_resize_dims( *size, - self._parameters.image_size, - self._parameters.image_size, + self._parameters.max_image_size, + self._parameters.max_image_size, self._parameters.patch_size, ), self._parameters.patch_size, @@ -496,8 +496,8 @@ def __getitem__(self, index: int) -> typing.Any: resized_image_lengths = [ get_resize_dims( *image_length, - self._parameters.image_size, - self._parameters.image_size, + self._parameters.max_image_size, + self._parameters.max_image_size, self._parameters.patch_size, ) for image_length in image_lengths diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 204abdf1c..f5c1bc133 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -50,7 +50,7 @@ class BatchConfig(Config): hint=FieldHint.setup, ) # Image inputs - image_size: int | None = Field( + max_image_size: int | None = Field( default=None, desc="Maximum image height and width", hint=FieldHint.optional, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index ae74724c4..9b79aa1b3 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -199,7 +199,7 @@ def _create_tensors(self, sequence_length: int, num_patches: None | int = None) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: if self._config.type == RotaryEmbeddingType.rope_2d: - max_num_patches = kwargs[VisionEncoderKwargs.image_size] // kwargs[VisionEncoderKwargs.patch_size] + max_num_patches = kwargs[VisionEncoderKwargs.max_image_size] // kwargs[VisionEncoderKwargs.patch_size] self._create_tensors(kwargs[TransformerKwargs.sequence_length], max_num_patches) else: self._create_tensors(kwargs[TransformerKwargs.sequence_length]) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index c5b790fe4..2ea7f6114 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -22,7 +22,7 @@ class VisionEncoderKwargs: images = "images" image_patches = "image_patches" image_positions = "image_positions" - image_size = "image_size" + max_image_size = "max_image_size" image_sizes = "image_sizes" image_mean = "image_normalization_mean" image_std = "image_normalization_std" diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 77220a063..ebd41b3d7 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -84,9 +84,9 @@ def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: return F.pad(image, (0, 0, depth_padding, width_padding), 0) -def create_inv_freqs(rope_theta: int, kv_channels: int, image_size: int, patch_size: int) -> torch.Tensor: +def create_inv_freqs(rope_theta: int, kv_channels: int, max_image_size: int, patch_size: int) -> torch.Tensor: freqs = 1.0 / (rope_theta ** (torch.arange(0, kv_channels, 2).float() / kv_channels)) - max_patches_per_side = image_size // patch_size + max_patches_per_side = max_image_size // patch_size h = torch.arange(max_patches_per_side) w = torch.arange(max_patches_per_side) @@ -135,19 +135,19 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: images = kwargs.get(VisionEncoderKwargs.images) - im_height = kwargs.get(VisionEncoderKwargs.image_size) - im_width = kwargs.get(VisionEncoderKwargs.image_size) + max_image_size = kwargs.get(VisionEncoderKwargs.max_image_size) + im_width = kwargs.get(VisionEncoderKwargs.max_image_size) patch_size = kwargs[VisionEncoderKwargs.patch_size] image_positions = kwargs.get(VisionEncoderKwargs.image_positions) image_sizes = [ - [get_resize_dims(im.size(1), im.size(2), im_height, im_width, patch_size=patch_size) for im in ims] + [get_resize_dims(im.size(1), im.size(2), max_image_size, im_width, patch_size=patch_size) for im in ims] for ims in images ] kwargs[VisionEncoderKwargs.image_sizes] = image_sizes images = [ [ normalize( - resize(image, im_height, im_width, patch_size).to( + resize(image, max_image_size, im_width, patch_size).to( dtype=self._tensor_space.distributed_config.training_dtype.torch ) / kwargs[VisionEncoderKwargs.image_rescale_factor], @@ -219,7 +219,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ) if sizes: position_ids = torch.cat( - [position_ids_in_meshgrid(*size, im_height // patch_size, patch_size) for size in sizes] + [position_ids_in_meshgrid(*size, max_image_size // patch_size, patch_size) for size in sizes] ).to(device=self._tensor_space.distributed.device) else: position_ids = torch.tensor( @@ -244,10 +244,10 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( kwargs[VisionEncoderKwargs.rope_theta], kwargs[VisionEncoderKwargs.kv_channels], - im_height, + max_image_size, patch_size, ).to(device=self._tensor_space.distributed.device) - kwargs[VisionEncoderKwargs.max_image_tokens] = div(im_height * im_width, patch_size**2) + kwargs[VisionEncoderKwargs.max_image_tokens] = div(max_image_size * im_width, patch_size**2) # sequence data parallel is not yet supported for images, so we use the same cu_seqlens for q and k kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index bf3778cc7..a1479a34d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -162,7 +162,7 @@ def preprocess_meta( micro_sequence_length = sequence_length if self._config.vision_encoder.enabled: - image_size = batch_meta.image_size + max_image_size = batch_meta.max_image_size image_mean = [ self._config.vision_encoder.image_normalization.mean_r, self._config.vision_encoder.image_normalization.mean_g, @@ -176,7 +176,7 @@ def preprocess_meta( image_rescale_factor = self._config.vision_encoder.image_normalization.rescale_factor vision_kwargs = { VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, - VisionEncoderKwargs.image_size: image_size, + VisionEncoderKwargs.max_image_size: max_image_size, VisionEncoderKwargs.image_mean: image_mean, VisionEncoderKwargs.image_std: image_std, VisionEncoderKwargs.image_rescale_factor: image_rescale_factor, diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index b2736b447..92cb20554 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -37,7 +37,7 @@ def _get_sampling_parameters( parameters.update( { "patch_size": self._config.model.base_model.vision_encoder.patch_size, - "image_size": self._config.batch.image_size, + "max_image_size": self._config.batch.max_image_size, "image_break_token": self._config.model.base_model.vision_encoder.image_break_token, "image_end_token": self._config.model.base_model.vision_encoder.image_end_token, } From 8e106f74041a4c4181ecf414968f91d975000c5f Mon Sep 17 00:00:00 2001 From: root Date: Thu, 12 Jun 2025 23:18:13 +0000 Subject: [PATCH 75/97] fix conversion --- fast_llm/models/gpt/conversion.py | 8 +++++++- fast_llm/models/gpt/model.py | 3 +-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 661f5e516..080b8b3ae 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -28,7 +28,7 @@ from fast_llm.functional.config import ActivationType from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.layers.common.config import NormalizationType -from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig +from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig, TransformerType from fast_llm.layers.vision_encoder.config import VisionEncoderType from fast_llm.models.gpt.config import ( GPTBaseModelConfig, @@ -576,6 +576,9 @@ def _create_config_converters(cls) -> list[ParamConverter]: ConstantImportParamConverter( fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "type"),), fast_llm_value=TransformerType.image_encoder + ), ConstantExportParamConverter(export_names=(("architectures",),), export_value=["PixtralVisionModel"]), ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value=VisionEncoderType.pixtral), ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), @@ -639,6 +642,9 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), export_names=(("head_dim",),), ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.rope_2d + ), RenameParamConverter( fast_llm_names=( ( diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a1479a34d..8fc8c830b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -28,8 +28,7 @@ FlashAttnVarlenPreprocessor, RotaryEmbeddingPreprocessor, ) -from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.transformer.vision_transformer import VisionTransformerLayer +from fast_llm.layers.transformer.transformer import TransformerLayer, VisionTransformerLayer from fast_llm.layers.vision_encoder.adapter import VisionAdapter from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.layers.vision_encoder.patch_conv import PatchConv From 080dcb58e8a8c364513eb7f915950c7c7ffb5c28 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 16 Jun 2025 21:15:27 +0000 Subject: [PATCH 76/97] fix sequence lengths, parallel conv --- fast_llm/data/dataset/gpt/sampled.py | 33 ++++++++++++++---------- fast_llm/layers/multi_modal/embedding.py | 5 +--- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index d4bcacddd..2f1575f7d 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -551,26 +551,23 @@ def __getitem__(self, index: int) -> typing.Any: ) start_pos = 0 has_images = sample.image_positions is not None - if has_image_positions: + if has_images: + sample_token_ids = [] for idx, im_position in enumerate(sample.image_positions): - # image_positions.append(im_positions + len(token_ids) + image_tokens_added) - # Add placeholders for image tokens - token_ids.append(sample.token_ids[start_pos:im_position]) - text_tokens_added += len(token_ids[-1]) - image_positions.append(text_tokens_added + image_tokens_added) + # add placeholder masked tokens for images + # if image_break_token is set, it is appended after every row + # if image_end_token is set, it is appended at the end of the image instead of image_break_token + text_part = sample.token_ids[start_pos:im_position] if self._parameters.image_break_token is not None: height, width = resized_image_lengths[idx] num_patches_h = div(height, self._parameters.patch_size) num_patches_w = div(width, self._parameters.patch_size) - - # Create image token placeholder array image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) - - # Add break tokens after each row except the last row + # account for break tokens after each row for row in range(num_patches_h - 1): position = (row + 1) * num_patches_w + row image_token_array[position] = self._parameters.image_break_token - # add end token if specified, else break token + # handle the last row separately last_row_position = num_patches_h * num_patches_w + num_patches_h - 1 if self._parameters.image_end_token is not None: image_token_array[last_row_position] = self._parameters.image_end_token @@ -580,11 +577,19 @@ def __getitem__(self, index: int) -> typing.Any: image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) if self._parameters.image_end_token is not None: image_token_array[-1] = self._parameters.image_end_token - token_ids.append(image_token_array) + segment = np.concatenate([text_part, image_token_array], dtype=np.int64) + sample_token_ids.append(segment) + text_tokens_added += len(text_part) + image_positions.append(text_tokens_added + image_tokens_added) image_tokens_added += image_sizes[idx] start_pos = im_position - token_ids.append(sample.token_ids[start_pos:]) - text_tokens_added += len(token_ids[-1]) + # Add the last text segment after the last image + sample_token_ids.append(sample.token_ids[start_pos:]) + text_tokens_added += len(sample_token_ids[-1]) + token_ids.append(np.concatenate(sample_token_ids)) + else: + token_ids.append(sample.token_ids[start_pos:]) + text_tokens_added += len(token_ids[-1]) if sample.images: images.append(sample.images) else: diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index fa5c0356b..948b2acf9 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -3,7 +3,7 @@ import torch from fast_llm.core.distributed import set_generator -from fast_llm.core.ops import gather, reduce_forward, split +from fast_llm.core.ops import reduce_forward, split from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.layers.language_model.embedding import LanguageModelEmbedding @@ -61,7 +61,6 @@ def _forward( embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa # Cloning since we will modify the embeddings in-place embeddings = embeddings.clone() - input_ = gather(input_, group, dim=0) # the embeddings tensor are full-sized, but we might get a split of the patch embeddings # We need to determine the offset in the embeddings tensor for each sample # and also account for the special image tokens if applicable @@ -93,12 +92,10 @@ def _forward( embeddings[embeddings_start_index:embeddings_end_index, sample_idx] = input_[ input_start_index:input_end_index, sample_idx ] - tokens[embeddings_start_index:embeddings_end_index, sample_idx] = 10 else: embeddings[sample_idx, embeddings_start_index:embeddings_end_index] = input_[ sample_idx, input_start_index:input_end_index ] - tokens[embeddings_start_index:embeddings_end_index, sample_idx] = 10 else: input_start_index = max(image_embedding_offset, patch_start_offset) - patch_start_offset input_end_index = ( From f1868687f2a230a98613ab936de87d6a145b17d1 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 16 Jun 2025 21:19:50 +0000 Subject: [PATCH 77/97] minor --- fast_llm/data/dataset/gpt/sampled.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 2f1575f7d..8641ee707 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -577,8 +577,7 @@ def __getitem__(self, index: int) -> typing.Any: image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) if self._parameters.image_end_token is not None: image_token_array[-1] = self._parameters.image_end_token - segment = np.concatenate([text_part, image_token_array], dtype=np.int64) - sample_token_ids.append(segment) + sample_token_ids.append(np.concatenate([text_part, image_token_array], dtype=np.int64)) text_tokens_added += len(text_part) image_positions.append(text_tokens_added + image_tokens_added) image_tokens_added += image_sizes[idx] From 6b9ea2e1b22e83fa936aae66c95d66218adfa0b3 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 16 Jun 2025 21:37:51 +0000 Subject: [PATCH 78/97] fix image at beginning --- fast_llm/data/tokenizer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 284ae21f7..7268ba3ce 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -56,9 +56,9 @@ def tokenize( # Collect all positions with their type positions = [] - for idx, pos in enumerate(image_positions): + for pos in image_positions: positions.append((pos, "image")) - for idx, (start, end) in enumerate(char_spans): + for start, end in char_spans: positions.append((start, "span_start")) positions.append((end + 1, "span_end")) # Sort positions by character index. We assume that image and span positions are individually sorted and spans do not overlap @@ -71,6 +71,7 @@ def tokenize( current_span_start = None for position in positions: + # We only tokenize if there is at least one character, else we might potentially add begin/end multiple times if char_pos < position[0]: tokenized_text = self._tokenize( text[char_pos : position[0]], begin=(char_pos == 0), end=position[0] > len(text) - 1 @@ -79,7 +80,11 @@ def tokenize( char_pos = position[0] # beginning_of_text = False if position[1] == "image": - image_token_positions.append(len(token_ids)) + if position[0] == 0: + # image should be after the bos token + image_token_positions.append(1) + else: + image_token_positions.append(len(token_ids)) elif position[1] == "span_start": assert ( current_span_start is None From ad18ea1903e8ceacd58ab2817759c2f963ccb511 Mon Sep 17 00:00:00 2001 From: RaymondLi0 Date: Fri, 20 Jun 2025 15:14:22 -0400 Subject: [PATCH 79/97] pixtral fix conversion (#315) --- fast_llm/models/gpt/conversion.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 080b8b3ae..a7e624ffe 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -563,6 +563,26 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] +class PixtralNumHeadsConverter(ParamConverter): + """ + Pixtral encoder uses Multi-Head Attention. + Map `num_attention_heads` and `head_groups` to a single `num_heads` parameter. + """ + + def __post_init__(self): + Assert.eq(len(self.fast_llm_names), 2) + Assert.eq(len(self.export_names), 1) + + def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (num_heads, head_groups) = fast_llm_values + assert head_groups == num_heads, "Pixtral encoder expects num_heads == head_groups (MHA)" + return (num_heads,) + + def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (num_heads,) = export_values + return (num_heads, num_heads) + + class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @@ -600,23 +620,18 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), export_names=(("hidden_size",),), ), - RenameParamConverter( + PixtralNumHeadsConverter( fast_llm_names=( ( "transformer", "num_attention_heads", ), - ), - export_names=(("num_attention_heads",),), - ), - RenameParamConverter( - fast_llm_names=( ( "transformer", "head_groups", ), ), - export_names=(("num_key_value_heads",),), + export_names=(("num_attention_heads",),), ), RenameParamConverter( fast_llm_names=( From 29e66d944034f58ae454b4cfd028be21efb8f848 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 25 Jun 2025 21:11:56 +0000 Subject: [PATCH 80/97] handle no image samples --- fast_llm/data/dataset/gpt/memmap.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 642cd9800..c7a99f10f 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -145,6 +145,7 @@ def _init( self._image_sizes = [] self._image_positions = [] images_seen = 0 + num_total_images = self._n_images.sum() for n_images in self._n_images: self._image_sizes.append( np.frombuffer( @@ -162,8 +163,8 @@ def _init( count=n_images, offset=offset + self._n_images.nbytes - + 2 * self._n_images.sum() * np.dtype(np.int32).itemsize - + images_seen * np.dtype(np.int32).itemsize, + + 2 * num_total_images * np.dtype(np.int32).itemsize + + +images_seen * np.dtype(np.int32).itemsize, ) ) images_seen += n_images @@ -352,6 +353,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size im_positions.extend(document.image_positions) + else: + n_images.append(0) # Update metadata doc_length = len(document.token_ids) From 06a0910bd06de25f84115b89dd8ddb566780a24a Mon Sep 17 00:00:00 2001 From: root Date: Thu, 26 Jun 2025 18:23:38 +0000 Subject: [PATCH 81/97] mask special image tokens --- fast_llm/models/gpt/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 8fc8c830b..9bef7ae5c 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -410,6 +410,12 @@ def preprocess( if self._config.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) + if self._config.vision_encoder.enabled: + labels = labels.clone() + if self._config.vision_encoder.image_break_token is not None: + labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) + if self._config.vision_encoder.image_end_token is not None: + labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) From bbd71dfb2706d19022eccb5e55019e526ded2007 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 27 Jun 2025 16:34:27 +0000 Subject: [PATCH 82/97] avoid multiple labels cloning --- fast_llm/models/gpt/model.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 9bef7ae5c..23bb3d067 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -387,9 +387,11 @@ def preprocess( labels = batch.token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config + labels_cloned = False if batch.loss_masking_spans is not None: # avoid changing input tokens labels = labels.clone() + labels_cloned = True for i, spans in enumerate(batch.loss_masking_spans): if not spans.numel(): continue @@ -411,10 +413,15 @@ def preprocess( kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) if self._config.vision_encoder.enabled: - labels = labels.clone() if self._config.vision_encoder.image_break_token is not None: + if not labels_cloned: + labels = labels.clone() + labels_cloned = True labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) if self._config.vision_encoder.image_end_token is not None: + if not labels_cloned: + labels = labels.clone() + labels_cloned = True labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) From 96a5fd82f5200712274ba217323793b91038615b Mon Sep 17 00:00:00 2001 From: root Date: Tue, 8 Jul 2025 19:11:33 +0000 Subject: [PATCH 83/97] fix training --- Dockerfile | 2 +- fast_llm/functional/triton/mlp.py | 4 +- fast_llm/layers/transformer/rotary/config.py | 8 ++ fast_llm/layers/transformer/rotary/rotary.py | 72 +++++++++- fast_llm/models/gpt/conversion.py | 134 +++++++++---------- fast_llm/models/gpt/model.py | 5 +- 6 files changed, 150 insertions(+), 75 deletions(-) diff --git a/Dockerfile b/Dockerfile index e98223de8..6c013c14d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,7 +37,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV]" triton==3.1.0 +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,DEV]" triton==3.1.0 # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index a34af4f5e..f3d9d7d0c 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -47,7 +47,7 @@ def triton_mlp_activation_forward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) - if activation_type in ["gelu_pytorch_tanh", "gelu"]: + if activation_type == "gelu" or activation_type == "gelu_pytorch_tanh": tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) out = input_ * 0.5 * (1.0 + tanh) @@ -97,7 +97,7 @@ def triton_mlp_activation_backward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) output_grad = tl.load(grad_output_ptr + output_offsets, mask=mask).to(tl.float32) - if activation_type in ["gelu_pytorch_tanh", "gelu"]: + if activation_type == "gelu" or activation_type == "gelu_pytorch_tanh": tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) grad = 0.5 * input_ * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * input_ * input_)) + 0.5 * (1 + tanh) diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/transformer/rotary/config.py index ce7af88d5..d7285714f 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/transformer/rotary/config.py @@ -136,3 +136,11 @@ def _get_configurable_class(self) -> "type[YarnRotary]": from fast_llm.layers.transformer.rotary.rotary import YarnRotary return YarnRotary + + +@config_class(dynamic_type={RotaryConfig: "rope_2d"}) +class Rotary2DConfig(DefaultRotaryConfig): + def _get_configurable_class(self) -> "type[Rotary2D]": + from fast_llm.layers.transformer.rotary.rotary import Rotary2D + + return Rotary2D diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 056b9aa4c..b2c69dd8d 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -8,14 +8,16 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs, VisionTransformerKwargs from fast_llm.layers.transformer.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, NoRotaryConfig, + Rotary2DConfig, RotaryConfig, YarnRotaryConfig, ) +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -212,3 +214,71 @@ def _get_correction(self, beta: float, dim: int) -> float: * math.log(self._config.original_context_length / (beta * 2 * math.pi)) / (2 * math.log(self._config.theta)) ) + + +class Rotary2D[ConfigType: DefaultRotaryConfig](DefaultRotary[Rotary2DConfig]): + _rotary_embedding_frequencies: torch.Tensor + _tensor_cache_max_num_patches: int = -1 + + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + assert self._tensor_space is not None + max_num_patches = kwargs[VisionEncoderKwargs.max_image_size] // kwargs[VisionEncoderKwargs.patch_size] + self._create_tensors(max_num_patches) + position_ids = kwargs[VisionTransformerKwargs.patch_position_ids] + kwargs[VisionTransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[:, position_ids] + kwargs[VisionTransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, position_ids] + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + assert self._tensor_space is not None + kwargs[VisionTransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + ( + self._scalar_dim, + kwargs[TransformerKwargs.sequence_q_dim], + self._scalar_dim, + self._kv_channels_dim, + ), + tensor_name=VisionTransformerKwargs.rotary_freq_q, + ) + kwargs[VisionTransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + ( + self._scalar_dim, + kwargs[TransformerKwargs.sequence_k_dim], + self._scalar_dim, + self._kv_channels_dim, + ), + tensor_name=VisionTransformerKwargs.rotary_freq_k, + ) + + def _create_tensors(self, max_num_patches: int) -> None: + if max_num_patches <= self._tensor_cache_max_num_patches: + return + self._tensor_cache_max_num_patches = max_num_patches + + self._rotary_embedding_frequencies = self._get_frequencies( + max_num_patches, + self._kv_channels_dim.global_size, + device=self._tensor_space.distributed.device, + ) + + def _get_frequencies(self, max_num_patches: int, kv_channels: int, device="cuda") -> torch.Tensor: + # Calculate complex frequencies by using alternating channels for width and height + height_positions = torch.arange(max_num_patches, device=device, dtype=torch.float64) + width_positions = torch.arange(max_num_patches, device=device, dtype=torch.float64) + frequencies = self._config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) + angles_h = torch.outer(height_positions, frequencies[::2]) + angles_w = torch.outer(width_positions, frequencies[1::2]) + angles = torch.cat( + [ + angles_h[:, None, :].repeat(1, max_num_patches, 1), + angles_w[None, :, :].repeat(max_num_patches, 1, 1), + ], + dim=-1, + ).reshape(-1, kv_channels // 2) + + frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) + if not self._config.complex_format: + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), kv_channels, 3 + ).contiguous() + + return frequencies diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 319a495d7..e01aaf702 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -26,11 +26,15 @@ from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import LayerNormalizationConfig, NormalizationType -from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig, TransformerType -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig +from fast_llm.layers.common.config import LayerNormalizationConfig +from fast_llm.layers.transformer.config import RoutingType, TransformerConfig +from fast_llm.layers.transformer.rotary.config import ( + DefaultRotaryConfig, + Llama3RotaryConfig, + Rotary2DConfig, + YarnRotaryConfig, +) from fast_llm.layers.transformer.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex -from fast_llm.layers.vision_encoder.config import VisionEncoderType from fast_llm.models.gpt.config import ( DiffusionDreamGPTHuggingfaceCheckpointFormat, DiffusionLlamaGPTHuggingfaceCheckpointFormat, @@ -161,6 +165,7 @@ class CommonHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, Huggingfac @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantImportParamConverter(fast_llm_names=(("transformer", "type"),), fast_llm_value="lm_decoder"), ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]), ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), RenameParamConverter( @@ -228,42 +233,6 @@ def _create_weight_converters( return converters - def _create_lm_head_converters(self, hf_base_prefix: str = "", offset: int = 0) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers - prediction_heads = self._model.config.base_model.prediction_heads - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm - converters = [] - - # Next-token prediction head - # Final norm - converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + offset + 1}.final_norm", f"{hf_base_prefix}model.norm", norm_bias - ) - # Output weights - if self._model.config.base_model.tie_word_embeddings: - converters.append(IgnoreImportWeightConverter((), f"{hf_base_prefix}lm_head.weight")) - else: - converters.append( - WeightConverter(f"layers.{num_layers + offset + 1}.output_weights", f"{hf_base_prefix}lm_head.weight") - ) - - # MTP-heads > 0 are thrown away - for i in range(1, prediction_heads): - logger.warning( - f"The model weights for the multi-token prediction head {i} are discarded during conversion." - ) - mtp_transformer_layer_index = num_layers + offset - 1 + 2 * i - # MTP transformer layer - converters += self._create_transformer_layer_converters( - f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True - ) - # MTP output norm - converters += self._get_weight_and_bias_converters( - f"layers.{mtp_transformer_layer_index + 2}.final_norm", (), norm_bias, IgnoreExportWeightConverter - ) - - return converters - def _create_transformer_layer_converters( self, fast_llm_layer_name: str, hf_layer_name: str, ignore_export: bool = False ) -> list[WeightConverter]: @@ -331,7 +300,7 @@ def _create_transformer_layer_converters( converters += self._get_mlp_converters(f"{fast_llm_layer_name}", f"{hf_layer_name}") return converters - def _create_lm_head_converters(self) -> list[WeightConverter]: + def _create_lm_head_converters(self, hf_base_prefix: str = "", offset: int = 0) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) @@ -340,20 +309,22 @@ def _create_lm_head_converters(self) -> list[WeightConverter]: # Next-token prediction head # Final norm converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias + f"layers.{num_layers + offset + 1}.final_norm", f"{hf_base_prefix}model.norm", norm_bias ) # Output weights if self._model.config.base_model.tie_word_embeddings: - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) + converters.append(IgnoreImportWeightConverter((), f"{hf_base_prefix}lm_head.weight")) else: - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + converters.append( + WeightConverter(f"layers.{num_layers + offset + 1}.output_weights", f"{hf_base_prefix}lm_head.weight") + ) # MTP-heads > 0 are thrown away for i in range(1, prediction_heads): logger.warning( f"The model weights for the multi-token prediction head {i} are discarded during conversion." ) - mtp_transformer_layer_index = num_layers - 1 + 2 * i + mtp_transformer_layer_index = num_layers + offset - 1 + 2 * i # MTP transformer layer converters += self._create_transformer_layer_converters( f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True @@ -466,7 +437,7 @@ def __post_init__(self): def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: (rotary_config,) = fast_llm_values - if type(rotary_config) is DefaultRotaryConfig: + if type(rotary_config) is DefaultRotaryConfig or rotary_config is Rotary2DConfig: rotary_scaling = { "rope_type": "default", } @@ -663,6 +634,34 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A return (num_heads, num_heads) +class PixtralRotaryParamConverter(ParamConverter): + """ + Pixtral encoder uses 2D Rotary Embeddings. + Map `rope_theta` to a single `rotary` parameter. `rotary_scaling` is not needed. + """ + + def __init__(self, fast_llm_names, export_names): + Assert.eq(len(fast_llm_names), 1) + Assert.eq(len(export_names), 1) + self.fast_llm_names = fast_llm_names + self.export_names = export_names + + def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (rotary_config,) = fast_llm_values + if type(rotary_config) is Rotary2DConfig: + return (rotary_config.theta,) + else: + raise ValueError(f"Unsupported rotary type: {type(rotary_config).__name__}") + + def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + (rotary_theta,) = export_values + rotary_config = { + "type": "rope_2d", + "theta": rotary_theta, + } + return (rotary_config,) + + class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @@ -670,17 +669,13 @@ class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, Huggingfa @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value="pixtral"), + ConstantImportParamConverter(fast_llm_names=(("patch_norm", "type"),), fast_llm_value="rms_norm"), ConstantImportParamConverter( - fast_llm_names=(("patch_norm", "type"),), fast_llm_value=NormalizationType.rms_norm - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "type"),), fast_llm_value=TransformerType.image_encoder + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value="rms_norm" ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "type"),), fast_llm_value="image_encoder"), ConstantExportParamConverter(export_names=(("architectures",),), export_value=["PixtralVisionModel"]), - ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value=VisionEncoderType.pixtral), ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), RenameParamConverter( fast_llm_names=( @@ -737,17 +732,21 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), export_names=(("head_dim",),), ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.rope_2d - ), - RenameParamConverter( - fast_llm_names=( - ( - "transformer", - "rotary", - "theta", - ), - ), + # ConstantImportParamConverter( + # fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.rope_2d + # ), + # RenameParamConverter( + # fast_llm_names=( + # ( + # "transformer", + # "rotary", + # "theta", + # ), + # ), + # export_names=(("rope_theta",),), + # ), + PixtralRotaryParamConverter( + fast_llm_names=(("transformer", "rotary"),), export_names=(("rope_theta",),), ), RenameParamConverter(fast_llm_names=(("patch_size",),), export_names=(("patch_size",),)), @@ -773,7 +772,7 @@ def _create_vision_transformer_layer_converters( ) -> list[WeightConverter]: # Vision transformer layer transformer_config = self._model.config.base_model.vision_encoder.transformer - norm_bias: bool = transformer_config.normalization.type == NormalizationType.layer_norm + norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) name_bias_cls = [ # Self-attn ( @@ -828,11 +827,12 @@ def _create_vision_transformer_layer_converters( def _create_weight_converters(self, offset: int = 0, hf_base_prefix: str = "") -> list[WeightConverter]: converters = [] + norm_bias = isinstance(self._model.config.base_model.vision_encoder.patch_norm, LayerNormalizationConfig) converters.append(WeightConverter(f"layers.{offset}.weight", f"{hf_base_prefix}patch_conv.weight")) if self._model.config.base_model.vision_encoder.conv_bias: converters.append(WeightConverter(f"layers.{offset}.bias", f"{hf_base_prefix}patch_conv.bias")) converters.append(WeightConverter(f"layers.{offset}.norm.weight", f"{hf_base_prefix}ln_pre.weight")) - if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: + if norm_bias: converters.append(WeightConverter(f"layers.{offset}.norm.bias", f"{hf_base_prefix}ln_pre.bias")) num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 38639fc5f..436b4a60f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -72,10 +72,7 @@ def __init__( if self._config.vision_encoder.enabled: self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) - if self._config.vision_encoder.transformer.rotary.enabled: - self._preprocessors.append( - RotaryEmbeddingPreprocessor(self._config.vision_encoder.transformer.rotary, self._tensor_space) - ) + self._preprocessors.append(self._config.vision_encoder.transformer.rotary.build(self._tensor_space)) def get_output_layers(self) -> list[Layer]: layers = [] From 8f93a276e7406d6195edaacda4fbdd774d193356 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 8 Jul 2025 19:34:41 +0000 Subject: [PATCH 84/97] fix prepare config --- fast_llm/data/preparator/gpt_memmap/config.py | 18 ++++++++++++------ fast_llm/data/preparator/gpt_memmap/prepare.py | 18 +++++++++++------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 9f25cba4c..da353793d 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -42,6 +42,18 @@ class TextColumnConfig(SourceSchemaConfig): ) +@config_class(dynamic_type={SourceSchemaConfig: "text_image_column"}) +class TextImageColumnConfig(TextColumnConfig): + images_column: str = Field( + default="images", + desc="Field containing images relevant to a document.", + ) + image_positions_column: None | str = Field( + default="image_positions", + desc="Field containing image positions within a document.", + ) + + @config_class() class GPTHuggingfaceDatasetConfig(Config): path: str = Field( @@ -79,12 +91,6 @@ class GPTHuggingfaceDatasetConfig(Config): rejected_text: None | str = Field( default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional ) - image_positions: None | str = Field( - default=None, desc="Field containing image positions within a document", hint=FieldHint.optional - ) - images: None | str = Field( - default=None, desc="Field containing images relevant to a document", hint=FieldHint.optional - ) data_type: DataType | None = Field( default=None, desc="Data type of the dataset field." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index dee1b37bf..b100ce400 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -27,7 +27,11 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.config import DatasetPreparator -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig +from fast_llm.data.preparator.gpt_memmap.config import ( + GPTMemmapDatasetPreparatorConfig, + TextColumnConfig, + TextImageColumnConfig, +) from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -60,9 +64,9 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ im_char_positions, ) for text, loss_mask_spans, im_char_positions in zip( - batch[self._config.dataset.field], - batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), - batch.get(self._config.dataset.image_positions, itertools.repeat(None)), + batch[self._text_column], + batch.get(self._loss_masking_spans_column, itertools.repeat(None)), + batch.get(self._image_positions_column, itertools.repeat(None)), ) ] ] @@ -160,8 +164,8 @@ def _document_generator(): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( np.array(item["input_ids"], dtype=self._data_type.numpy), - item["images"] if self._config.dataset.images else None, - item["image_positions"] if self._config.dataset.image_positions else None, + item["images"] if self._images_column else None, + item["image_positions"] if self._image_positions_column else None, ( np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2) if self._loss_masking_spans_column @@ -344,7 +348,7 @@ def run(self) -> None: total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens")) total_pixels = ( sum(tqdm.tqdm(tokenized_dataset["num_pixels"], desc="Counting pixels", unit="pixels")) - if self._config.dataset.images + if self._images_column else 0 ) # Add the token-equivalent bytes of pixels to determine shard size From c3eda1c1c792a144d353e99f6d4532ec644a837a Mon Sep 17 00:00:00 2001 From: root Date: Tue, 8 Jul 2025 19:44:18 +0000 Subject: [PATCH 85/97] fix imports --- .github/workflows/ci.yaml | 2 +- .github/workflows/docs.yaml | 2 +- fast_llm/layers/vision_encoder/preprocessing.py | 14 +++++++++----- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 03353a79b..cb5260dca 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -32,7 +32,7 @@ jobs: pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV,DOCS]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,DEV,DOCS]" - name: Run tests run: pytest . diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index b509b2702..75ba3bb31 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -34,7 +34,7 @@ jobs: pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,DEV,DOCS]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,DEV,DOCS]" - name: Build the documentation run: mkdocs build diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index ebd41b3d7..c81e7c646 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -2,7 +2,7 @@ import typing import torch -import torchvision.transforms.v2.functional as F +import torchvision from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace @@ -62,17 +62,21 @@ def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int height, width = get_resize_dims( height, width, intermediate_max_height, intermediate_max_width, patch_size=patch_size ) - image = F.resize(image, size=(height, width), interpolation=F.InterpolationMode.BICUBIC) + image = torchvision.transforms.v2.functional.resize( + image, size=(height, width), interpolation=torchvision.transforms.InterpolationMode.BICUBIC + ) # TODO: options for interpolation mode? - return F.resize(image, size=(target_height, target_width), interpolation=F.InterpolationMode.BICUBIC) + return torchvision.transforms.v2.functional.resize( + image, size=(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.BICUBIC + ) def normalize(image: torch.Tensor, mean: list[float], std: list[float]) -> torch.Tensor: """ Normalize the image using the specified mean and standard deviation. """ - return F.normalize(image, mean=mean, std=std) + return torchvision.transforms.v2.functional.normalize(image, mean=mean, std=std) def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: @@ -81,7 +85,7 @@ def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: """ width_padding = max(0, max_height - image.size(1)) depth_padding = max(0, max_width - image.size(2)) - return F.pad(image, (0, 0, depth_padding, width_padding), 0) + return torchvision.transforms.v2.functional.pad(image, (0, 0, depth_padding, width_padding), 0) def create_inv_freqs(rope_theta: int, kv_channels: int, max_image_size: int, patch_size: int) -> torch.Tensor: From 1cf0ea0285bdfb8754ce6030ff03be6582d3c7c5 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 8 Jul 2025 21:46:26 +0000 Subject: [PATCH 86/97] fix tests --- fast_llm/data/dataset/gpt/fim.py | 6 +++--- fast_llm/data/dataset/gpt/indexed.py | 13 +++++++++++-- fast_llm/data/preparator/gpt_memmap/prepare.py | 2 +- tests/data/common.py | 4 ++-- tests/data/test_sampling.py | 1 + 5 files changed, 18 insertions(+), 8 deletions(-) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 2b2c8b3be..b05b79b24 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -158,9 +158,9 @@ def _fim_permute_sequence( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=np.int64) - middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=np.int64) - suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=np.int64) + prefix = np.array([*self._tokenizer._tokenize(prefix, end=False)], dtype=np.int64) + middle = np.array([*self._tokenizer._tokenize(middle, begin=False, end=False)], dtype=np.int64) + suffix = np.array([*self._tokenizer._tokenize(suffix, begin=False)], dtype=np.int64) # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 2c7aefc80..8a4440ae4 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -53,7 +53,7 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. doc_sizes, im_sizes = self._dataset.get_document_sizes() - return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else [] + return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else np.array([]) def get_document_size(self, index: int) -> int: return self._dataset.get_document_size(self._begin + index) @@ -70,8 +70,17 @@ class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. - return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) + # return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) + sizes = [dataset.get_document_sizes() for dataset in self._datasets] + return ( + np.concatenate([size[0] for size in sizes]), + np.concatenate([size[1] for size in sizes]) if sizes[0][1] is not None else np.array([]), + ) def get_document_size(self, index: int) -> int: dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item()) + + @property + def has_images(self) -> bool: + return any(dataset.has_images for dataset in self._datasets) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index b100ce400..c6a0528f1 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -425,7 +425,7 @@ def _split_and_blend_dataset_configs( dataset_configs: list[GPTMemmapDatasetConfig], splits: dict[str, int | float], output_path: pathlib.Path, - image_patch_size: int, + image_patch_size: None | int = None, ) -> dict[str, GPTSampledDatasetConfig]: split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] diff --git a/tests/data/common.py b/tests/data/common.py index 2bb90a6b4..858380816 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -127,10 +127,10 @@ def compare_indexed_dataset( loss_masking_spans: dict[int, list[int]] | None = None, ) -> None: Assert.eq(len(dataset), length) - sizes = dataset.get_document_sizes() + text_sizes, image_sizes = dataset.get_document_sizes() # Assert.eq(sizes.sum(), num_tokens) Assert.all_equal( - [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)] + [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], text_sizes[: min(len(dataset), 100)] ) for i, expected_sample in expected_samples.items(): Assert.all_equal(dataset.get(i).token_ids, np.array(expected_sample, dtype=np.uint16)) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 123e5e955..b8e7a92ff 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -106,6 +106,7 @@ def get_document_size(self, index: int) -> int: def name(self) -> str: return "dataset" + @property def has_images(self) -> bool: return False From 77d294c76157210caf9856b23dc9a52ca1d4f44c Mon Sep 17 00:00:00 2001 From: root Date: Wed, 9 Jul 2025 17:57:48 +0000 Subject: [PATCH 87/97] fix tests --- fast_llm/data/dataset/gpt/memmap.py | 3 ++- fast_llm/data/dataset/gpt/sampled.py | 2 +- .../data/preparator/gpt_memmap/prepare.py | 4 ++-- tests/data/common.py | 7 +++++- tests/data/test_sampling.py | 10 ++++++-- tests/test_config.py | 24 +++++++++++++++++++ 6 files changed, 43 insertions(+), 7 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index c7a99f10f..2a1986b63 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -136,7 +136,7 @@ def _init( offset += np.array(self._chosen_spans).nbytes + np.array(self._rejected_spans).nbytes self._num_pixels = 0 - self._image_sizes = None + self._image_sizes = [] self._image_positions = None if self._has_images and self._version >= 4: self._n_images = np.frombuffer( @@ -177,6 +177,7 @@ def _init( assert self._num_pixels == num_pixels if num_tokens is not None: assert self._num_tokens == num_tokens + self._image_sizes = np.array(self._image_sizes, dtype=np.int32) def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 29a784b77..42062a58c 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -143,7 +143,7 @@ def _sample(self) -> None: # Get the document sizes, the main information needed for sampling. document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) - if image_sizes: + if image_sizes.any(): image_token_sizes = [] for i, sizes in enumerate(image_sizes): image_token_sizes.append( diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index c6a0528f1..fce0f022c 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -458,7 +458,7 @@ def _split_and_blend_dataset_configs( text_sizes, image_sizes = dataset.get_document_sizes() tokens_cumsum = text_sizes.cumsum() Assert.eq(tokens_cumsum[-1], dataset_config.num_tokens) - if image_sizes: + if image_sizes.any(): num_pixels_cumsum = np.cumsum([x.prod(axis=1).sum() for x in image_sizes]) # We use the patch sizes only for the purposes of even splitting and blending weights. # We can always use a different patch size for training without any significant impact @@ -466,7 +466,7 @@ def _split_and_blend_dataset_configs( image_tokens_cumsum = num_pixels_cumsum // (image_patch_size**2) tokens_cumsum += image_tokens_cumsum num_pixels_cumsum = num_pixels_cumsum * 3 - Assert.eq(num_pixels_cumsum[-1], dataset_config.num_pixels) + Assert.eq(num_pixels_cumsum[-1], dataset_config.num_pixels) begin_index = _get_nearest_split(tokens_cumsum, split_begin_in_dataset * tokens_cumsum[-1]) end_index = _get_nearest_split(tokens_cumsum, split_end_in_dataset * tokens_cumsum[-1]) if end_index > begin_index: diff --git a/tests/data/common.py b/tests/data/common.py index 858380816..23ed9d76b 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -224,10 +224,15 @@ def __len__(self) -> int: return self._config.num_documents def get_document_sizes(self) -> np.ndarray: - return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64) + return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64), np.array( + [], dtype=np.int64 + ) def get_document_size(self, index: int) -> int: return self._config.num_tokens_per_document def get(self, index: int, *args, **kwargs) -> typing.Any: raise NotImplementedError() + + def has_images(self) -> bool: + return False diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index b8e7a92ff..296102f7d 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -98,10 +98,16 @@ def __len__(self) -> int: return len(self._samples) def get_document_sizes(self) -> np.ndarray: - return np.array([self.get_document_size(index) for index in range(len(self))], dtype=np.int64) + doc_sizes = [] + im_sizes = [] + for index in range(len(self)): + doc_size, im_size = self.get_document_size(index) + doc_sizes.append(doc_size) + im_sizes.append(im_size) + return np.array(doc_sizes, dtype=np.int64), np.array(im_sizes, dtype=np.int64) def get_document_size(self, index: int) -> int: - return len(self._samples[index]) + return len(self._samples[index]), [] def name(self) -> str: return "dataset" diff --git a/tests/test_config.py b/tests/test_config.py index b6a9a9854..c12ef9f03 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -88,6 +88,14 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "multi_stage": {"zero_stage": 3}, "distributed": {"training_dtype": "bfloat16"}, + # "vision_encoder": { + # "type": "none", + # "transformer": { + # "normalization": { + # "type": "rms_norm", + # } + # } + # } } ) with NoAutoValidate(): @@ -137,6 +145,14 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "tie_word_embeddings": False, "vocab_size": 1000, + "vision_encoder": { + "transformer": { + "normalization": {"type": "layer_norm"}, + "rotary": {"type": "none"}, + "peft": {"type": "none"}, + }, + "patch_norm": {"type": "layer_norm"}, + }, } else: base_model_update["transformer"]["peft"] = { @@ -146,6 +162,14 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): } base_model_update["transformer"]["normalization"]["type"] = "layer_norm" base_model_update["transformer"]["rotary"] = {"type": "none"} + base_model_update["vision_encoder"] = { + "transformer": { + "normalization": {"type": "layer_norm"}, + "rotary": {"type": "none"}, + "peft": {"type": "none"}, + }, + "patch_norm": {"type": "layer_norm"}, + } expected_config["base_model"] = base_model_update check_equal_nested(serialized_config, expected_config) From 8434b20e7f2dc3e4885ab87772bea3694fddaff2 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 9 Jul 2025 18:00:58 +0000 Subject: [PATCH 88/97] cleanup --- tests/test_config.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index c12ef9f03..52c00f0a1 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -88,14 +88,6 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "multi_stage": {"zero_stage": 3}, "distributed": {"training_dtype": "bfloat16"}, - # "vision_encoder": { - # "type": "none", - # "transformer": { - # "normalization": { - # "type": "rms_norm", - # } - # } - # } } ) with NoAutoValidate(): From ef982c9cd1d73ed4ddf4d117603e07d9a7eae697 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 17 Jul 2025 15:59:39 +0000 Subject: [PATCH 89/97] fix torchvision import --- fast_llm/layers/vision_encoder/preprocessing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index c81e7c646..3b857ba26 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -2,7 +2,7 @@ import typing import torch -import torchvision +import torchvision.transforms.v2 as torchvision_transforms from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace @@ -62,13 +62,13 @@ def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int height, width = get_resize_dims( height, width, intermediate_max_height, intermediate_max_width, patch_size=patch_size ) - image = torchvision.transforms.v2.functional.resize( - image, size=(height, width), interpolation=torchvision.transforms.InterpolationMode.BICUBIC + image = torchvision_transforms.functional.resize( + image, size=(height, width), interpolation=torchvision_transforms.InterpolationMode.BICUBIC ) # TODO: options for interpolation mode? - return torchvision.transforms.v2.functional.resize( - image, size=(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.BICUBIC + return torchvision_transforms.functional.resize( + image, size=(target_height, target_width), interpolation=torchvision_transforms.InterpolationMode.BICUBIC ) @@ -76,7 +76,7 @@ def normalize(image: torch.Tensor, mean: list[float], std: list[float]) -> torch """ Normalize the image using the specified mean and standard deviation. """ - return torchvision.transforms.v2.functional.normalize(image, mean=mean, std=std) + return torchvision_transforms.functional.normalize(image, mean=mean, std=std) def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: @@ -85,7 +85,7 @@ def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: """ width_padding = max(0, max_height - image.size(1)) depth_padding = max(0, max_width - image.size(2)) - return torchvision.transforms.v2.functional.pad(image, (0, 0, depth_padding, width_padding), 0) + return torchvision_transforms.functional.pad(image, (0, 0, depth_padding, width_padding), 0) def create_inv_freqs(rope_theta: int, kv_channels: int, max_image_size: int, patch_size: int) -> torch.Tensor: From ca68072491dbb292bfc35397c214896bac3f27de Mon Sep 17 00:00:00 2001 From: root Date: Tue, 29 Jul 2025 18:55:38 +0000 Subject: [PATCH 90/97] cosmetic changes --- fast_llm/data/dataset/gpt/fim.py | 6 +++--- fast_llm/data/tokenizer.py | 8 +++++--- fast_llm/layers/transformer/config.py | 4 ++-- fast_llm/layers/vision_encoder/adapter.py | 2 +- fast_llm/layers/vision_encoder/config.py | 14 ++++++------- fast_llm/layers/vision_encoder/patch_conv.py | 20 ++++++++++++------- .../layers/vision_encoder/preprocessing.py | 18 +---------------- fast_llm/models/gpt/conversion.py | 6 ++++-- fast_llm/models/gpt/model.py | 16 +++++++-------- tests/test_config.py | 4 ++-- 10 files changed, 46 insertions(+), 52 deletions(-) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index b05b79b24..843f6735d 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -158,9 +158,9 @@ def _fim_permute_sequence( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*self._tokenizer._tokenize(prefix, end=False)], dtype=np.int64) - middle = np.array([*self._tokenizer._tokenize(middle, begin=False, end=False)], dtype=np.int64) - suffix = np.array([*self._tokenizer._tokenize(suffix, begin=False)], dtype=np.int64) + prefix = np.array([*self._tokenizer.tokenize(prefix, add_eos=False)], dtype=np.int64) + middle = np.array([*self._tokenizer.tokenize(middle, add_bos=False, add_eos=False)], dtype=np.int64) + suffix = np.array([*self._tokenizer.tokenize(suffix, add_bos=False)], dtype=np.int64) # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index d46e38935..93fa9b81b 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -49,7 +49,7 @@ def _tokenize(self, text: str, begin=True, end=True) -> list[int]: ) def tokenize( - self, text: str, char_spans=None, image_positions=None + self, text: str, add_bos=True, add_eos=True, char_spans=None, image_positions=None ) -> tuple[list[int], list[tuple[int, int]], list[int]]: """ Tokenize the input text and return the tokenized input_ids, token spans, and image token positions. @@ -81,7 +81,9 @@ def tokenize( # We only tokenize if there is at least one character, else we might potentially add begin/end multiple times if char_pos < position[0]: tokenized_text = self._tokenize( - text[char_pos : position[0]], begin=(char_pos == 0), end=position[0] > len(text) - 1 + text[char_pos : position[0]], + begin=(char_pos == 0) and add_bos, + end=position[0] > len(text) - 1 and add_eos, ) token_ids.extend(tokenized_text) char_pos = position[0] @@ -106,7 +108,7 @@ def tokenize( current_span_start = None # Handle any remaining text after the last position and add EOS token if char_pos < len(text): - tokenized_text = self._tokenize(text[char_pos:], begin=(char_pos == 0), end=True) + tokenized_text = self._tokenize(text[char_pos:], begin=(char_pos == 0) and add_bos, end=add_eos) token_ids.extend(tokenized_text) return token_ids, token_spans, image_token_positions diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 4d83215da..316a60c65 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -234,7 +234,7 @@ def _validate(self) -> None: class TransformerType(str, enum.Enum): - lm_decoder = "lm_decoder" + language_model_decoder = "language_model_decoder" image_encoder = "image_encoder" @@ -242,7 +242,7 @@ class TransformerType(str, enum.Enum): class TransformerConfig(LLMBlockConfig): _abstract = False type: TransformerType = Field( - default=TransformerType.lm_decoder, + default=TransformerType.language_model_decoder, desc="Type of the transformer. Choices: lm_decoder, image_encoder.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py index 41ea065d0..03f8a54b4 100644 --- a/fast_llm/layers/vision_encoder/adapter.py +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -13,7 +13,7 @@ class VisionAdapter(Layer): """ - Vision adapter layer for the LLM. + Vision adapter layer that projects vision encoder features into the language model token embeddings. """ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 2ea7f6114..7e8d75f36 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -39,32 +39,32 @@ class VisionEncoderKwargs: @config_class() class ImageNormalizationConfig(Config): - mean_r: float = Field( + mean_red: float = Field( default=0.48145466, desc="Mean value for the red channel in the image normalization process.", hint=FieldHint.optional, ) - mean_g: float = Field( + mean_green: float = Field( default=0.4578275, desc="Mean value for the green channel in the image normalization process.", hint=FieldHint.optional, ) - mean_b: float = Field( + mean_blue: float = Field( default=0.40821073, desc="Mean value for the blue channel in the image normalization process.", hint=FieldHint.optional, ) - std_r: float = Field( + std_red: float = Field( default=0.26862954, desc="Standard deviation value for the red channel in the image normalization process.", hint=FieldHint.optional, ) - std_g: float = Field( + std_green: float = Field( default=0.26130258, desc="Standard deviation value for the green channel in the image normalization process.", hint=FieldHint.optional, ) - std_b: float = Field( + std_blue: float = Field( default=0.27577711, desc="Standard deviation value for the blue channel in the image normalization process.", hint=FieldHint.optional, @@ -105,7 +105,7 @@ class VisionEncoderConfig(BaseModelConfig): desc="Whether to use bias in the convolutional layer.", hint=FieldHint.optional, ) - patch_norm: NormalizationConfig = Field( + patch_normalization: NormalizationConfig = Field( desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, ) diff --git a/fast_llm/layers/vision_encoder/patch_conv.py b/fast_llm/layers/vision_encoder/patch_conv.py index 3d1845dd8..71e1b40dc 100644 --- a/fast_llm/layers/vision_encoder/patch_conv.py +++ b/fast_llm/layers/vision_encoder/patch_conv.py @@ -10,12 +10,16 @@ from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ -class PatchConv(Layer): +class PatchConvolution(Layer): + """ + A convolution layer applied to image patches to create embeddings for each patch. These embeddings are fed into the vision transformer. + """ + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() self._tensor_space = tensor_space self._distributed_config = tensor_space.distributed_config - self._sequence_parallel = self._distributed_config.sequence_tensor_parallel + self._sequence_tensor_parallel = self._distributed_config.sequence_tensor_parallel self._lr_scale = config.adapter_lr_scale self.weight = ParameterMeta.from_dims( ( @@ -35,8 +39,10 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): ) else: self.bias = None - self.norm = config.patch_norm.get_layer(tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels)) - self.stride = config.patch_size + self.normalization = config.patch_normalization.get_layer( + tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + ) + self._stride = config.patch_size def forward( self, @@ -53,10 +59,10 @@ def forward( out_channels = kwargs[VisionEncoderKwargs.out_channels] reshape_dims = (micro_batch_size, sequence_length, out_channels) group = self._tensor_space.distributed.tensor_group - input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) - patch_embeddings = self.norm(input_.flatten(1)) + input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self._stride) + patch_embeddings = self.normalization(input_.flatten(1)) patch_embeddings = patch_embeddings.view(reshape_dims) - if self._sequence_parallel: + if self._sequence_tensor_parallel: patch_embeddings = patch_embeddings.permute(1, 0, 2).contiguous() patch_embeddings = split(patch_embeddings, group=group, dim=0) return patch_embeddings diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 3b857ba26..9a01931b1 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -72,22 +72,6 @@ def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int ) -def normalize(image: torch.Tensor, mean: list[float], std: list[float]) -> torch.Tensor: - """ - Normalize the image using the specified mean and standard deviation. - """ - return torchvision_transforms.functional.normalize(image, mean=mean, std=std) - - -def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: - """ - Pad images on the right and bottom with 0s untitl max_height and max_width - """ - width_padding = max(0, max_height - image.size(1)) - depth_padding = max(0, max_width - image.size(2)) - return torchvision_transforms.functional.pad(image, (0, 0, depth_padding, width_padding), 0) - - def create_inv_freqs(rope_theta: int, kv_channels: int, max_image_size: int, patch_size: int) -> torch.Tensor: freqs = 1.0 / (rope_theta ** (torch.arange(0, kv_channels, 2).float() / kv_channels)) max_patches_per_side = max_image_size // patch_size @@ -150,7 +134,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: kwargs[VisionEncoderKwargs.image_sizes] = image_sizes images = [ [ - normalize( + torchvision_transforms.functional.normalize( resize(image, max_image_size, im_width, patch_size).to( dtype=self._tensor_space.distributed_config.training_dtype.torch ) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index e01aaf702..f6004e40f 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -670,7 +670,7 @@ class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, Huggingfa def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value="pixtral"), - ConstantImportParamConverter(fast_llm_names=(("patch_norm", "type"),), fast_llm_value="rms_norm"), + ConstantImportParamConverter(fast_llm_names=(("patch_normalization", "type"),), fast_llm_value="rms_norm"), ConstantImportParamConverter( fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value="rms_norm" ), @@ -827,7 +827,9 @@ def _create_vision_transformer_layer_converters( def _create_weight_converters(self, offset: int = 0, hf_base_prefix: str = "") -> list[WeightConverter]: converters = [] - norm_bias = isinstance(self._model.config.base_model.vision_encoder.patch_norm, LayerNormalizationConfig) + norm_bias = isinstance( + self._model.config.base_model.vision_encoder.patch_normalization, LayerNormalizationConfig + ) converters.append(WeightConverter(f"layers.{offset}.weight", f"{hf_base_prefix}patch_conv.weight")) if self._model.config.base_model.vision_encoder.conv_bias: converters.append(WeightConverter(f"layers.{offset}.bias", f"{hf_base_prefix}patch_conv.bias")) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 6356cf23d..f7c6b35f4 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -27,7 +27,7 @@ from fast_llm.layers.transformer.transformer import TransformerLayer, VisionTransformerLayer from fast_llm.layers.vision_encoder.adapter import VisionAdapter from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs -from fast_llm.layers.vision_encoder.patch_conv import PatchConv +from fast_llm.layers.vision_encoder.patch_conv import PatchConvolution from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -104,7 +104,7 @@ def get_vision_layers(self) -> list[Layer]: for idx in range(self._config.vision_encoder.transformer.num_layers) ] return [ - PatchConv(self._config.vision_encoder, self._tensor_space), + PatchConvolution(self._config.vision_encoder, self._tensor_space), *vit_layers, VisionAdapter(self._config.vision_encoder, self._tensor_space), MultiModalEmbedding(self._config, self._tensor_space), @@ -150,14 +150,14 @@ def preprocess_meta( if self._config.vision_encoder.enabled: max_image_size = batch_meta.max_image_size image_mean = [ - self._config.vision_encoder.image_normalization.mean_r, - self._config.vision_encoder.image_normalization.mean_g, - self._config.vision_encoder.image_normalization.mean_b, + self._config.vision_encoder.image_normalization.mean_red, + self._config.vision_encoder.image_normalization.mean_green, + self._config.vision_encoder.image_normalization.mean_blue, ] image_std = [ - self._config.vision_encoder.image_normalization.std_r, - self._config.vision_encoder.image_normalization.std_g, - self._config.vision_encoder.image_normalization.std_b, + self._config.vision_encoder.image_normalization.std_red, + self._config.vision_encoder.image_normalization.std_green, + self._config.vision_encoder.image_normalization.std_blue, ] image_rescale_factor = self._config.vision_encoder.image_normalization.rescale_factor vision_kwargs = { diff --git a/tests/test_config.py b/tests/test_config.py index 52c00f0a1..30646e660 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -143,7 +143,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "rotary": {"type": "none"}, "peft": {"type": "none"}, }, - "patch_norm": {"type": "layer_norm"}, + "patch_normalization": {"type": "layer_norm"}, }, } else: @@ -160,7 +160,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "rotary": {"type": "none"}, "peft": {"type": "none"}, }, - "patch_norm": {"type": "layer_norm"}, + "patch_normalization": {"type": "layer_norm"}, } expected_config["base_model"] = base_model_update From 55a3706437800e85cebdf331b628e0c0d937c653 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 6 Aug 2025 20:49:52 +0000 Subject: [PATCH 91/97] fixes --- fast_llm/data/dataset/gpt/memmap.py | 3 +-- fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/layers/multi_modal/embedding.py | 3 ++- fast_llm/layers/transformer/attention.py | 4 ++-- fast_llm/layers/transformer/config.py | 2 +- fast_llm/layers/transformer/transformer.py | 16 +++++++++++++--- fast_llm/models/gpt/conversion.py | 8 +++++--- fast_llm/models/gpt/model.py | 12 ++++++++++-- 8 files changed, 35 insertions(+), 15 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 2a1986b63..c7a99f10f 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -136,7 +136,7 @@ def _init( offset += np.array(self._chosen_spans).nbytes + np.array(self._rejected_spans).nbytes self._num_pixels = 0 - self._image_sizes = [] + self._image_sizes = None self._image_positions = None if self._has_images and self._version >= 4: self._n_images = np.frombuffer( @@ -177,7 +177,6 @@ def _init( assert self._num_pixels == num_pixels if num_tokens is not None: assert self._num_tokens == num_tokens - self._image_sizes = np.array(self._image_sizes, dtype=np.int32) def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 42062a58c..29a784b77 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -143,7 +143,7 @@ def _sample(self) -> None: # Get the document sizes, the main information needed for sampling. document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) - if image_sizes.any(): + if image_sizes: image_token_sizes = [] for i, sizes in enumerate(image_sizes): image_token_sizes.append( diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 948b2acf9..a5a789f9e 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -69,6 +69,7 @@ def _forward( for position, size in zip(positions, sizes): num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) if image_embedding_offset + num_patches < patch_start_offset: + image_embedding_offset += num_patches continue if self._config.vision_encoder.image_break_token is not None: patch_height = div(size[0], self._config.vision_encoder.patch_size) @@ -83,7 +84,7 @@ def _forward( input_start_index = max(row_start_src, patch_start_offset) - patch_start_offset input_end_index = min(row_start_src + patch_width, patch_end_offset) - patch_start_offset - embeddings_start_index = row_start_dst - max(patch_start_offset - row_start_src, 0) + embeddings_start_index = row_start_dst + max(patch_start_offset - row_start_src, 0) embeddings_end_index = ( row_start_dst + patch_width - max(row_start_src + patch_width - patch_end_offset, 0) ) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index fbd6dd0c4..04f789d57 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -55,14 +55,14 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, layer_index, + layer_offset: int = 1, ): super().__init__() self._transformer_dim_names = config._transformer_dim_names self._transformer_kwargs = config._transformer_kwargs self._config = config self._tensor_space = tensor_space - # TODO Soham: fix assert - # Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) + Assert.in_range_incl(layer_index, layer_offset, max(self._config.num_layers + layer_offset, layer_offset)) self._layer_index = layer_index self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel self._debug_transformer = self._config.debug_transformer diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 316a60c65..0059718f5 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -243,7 +243,7 @@ class TransformerConfig(LLMBlockConfig): _abstract = False type: TransformerType = Field( default=TransformerType.language_model_decoder, - desc="Type of the transformer. Choices: lm_decoder, image_encoder.", + desc="Type of the transformer. Choices: language_model_decoder, image_encoder.", hint=FieldHint.architecture, ) normalization: NormalizationConfig = Field( diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 38a80beff..73819b8ad 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -26,7 +26,11 @@ class BaseBlock(Layer, abc.ABC): _mixer_module_name = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, + config: TransformerConfig, + tensor_space: TensorSpace, + layer_index: int, + return_input: bool = False, ): super().__init__() self._transformer_dim_names = config._transformer_dim_names @@ -145,12 +149,18 @@ class TransformerLayer(BaseBlock): _mixer_module_name = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, + config: TransformerConfig, + tensor_space: TensorSpace, + layer_index: int, + return_input: bool = False, + layer_offset: int = 1, ): + self._layer_offset = layer_offset super().__init__(config, tensor_space, layer_index, return_input) def _create_mixer(self): - self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + self.self_attn = Attention(self._config, self._tensor_space, self._layer_index, self._layer_offset) class VisionTransformerLayer(TransformerLayer): diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index f6004e40f..a15a237f9 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -165,7 +165,9 @@ class CommonHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, Huggingfac @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ - ConstantImportParamConverter(fast_llm_names=(("transformer", "type"),), fast_llm_value="lm_decoder"), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "type"),), fast_llm_value="language_model_decoder" + ), ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]), ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), RenameParamConverter( @@ -833,9 +835,9 @@ def _create_weight_converters(self, offset: int = 0, hf_base_prefix: str = "") - converters.append(WeightConverter(f"layers.{offset}.weight", f"{hf_base_prefix}patch_conv.weight")) if self._model.config.base_model.vision_encoder.conv_bias: converters.append(WeightConverter(f"layers.{offset}.bias", f"{hf_base_prefix}patch_conv.bias")) - converters.append(WeightConverter(f"layers.{offset}.norm.weight", f"{hf_base_prefix}ln_pre.weight")) + converters.append(WeightConverter(f"layers.{offset}.normalization.weight", f"{hf_base_prefix}ln_pre.weight")) if norm_bias: - converters.append(WeightConverter(f"layers.{offset}.norm.bias", f"{hf_base_prefix}ln_pre.bias")) + converters.append(WeightConverter(f"layers.{offset}.normalization.bias", f"{hf_base_prefix}ln_pre.bias")) num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers for i in range(num_layers): diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index f7c6b35f4..a62171b3a 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -100,7 +100,9 @@ def get_output_layers(self) -> list[Layer]: def get_vision_layers(self) -> list[Layer]: vit_layers = [ - VisionTransformerLayer(self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1) + VisionTransformerLayer( + self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1, layer_offset=1 + ) for idx in range(self._config.vision_encoder.transformer.num_layers) ] return [ @@ -111,6 +113,9 @@ def get_vision_layers(self) -> list[Layer]: ] def get_layers(self) -> list[Layer]: + lm_layer_offset = ( + self._config.vision_encoder.transformer.num_layers + 3 if self._config.vision_encoder.enabled else 1 + ) return [ *( [LanguageModelEmbedding(self._config, self._tensor_space)] @@ -121,10 +126,13 @@ def get_layers(self) -> list[Layer]: TransformerLayer( self._config.transformer, self._tensor_space, - layer_index=i + 1, + layer_index=i + 1 + lm_layer_offset, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, + # optionally account for patch convolution, vision transformer, vision adapter + # by default we only have the embedding layer + layer_offset=lm_layer_offset, ) for i in range(self._config.transformer.num_layers) ], From 273cc555da2a431d6eb38a8d35947b3372cdf181 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 15 Aug 2025 18:59:04 -0400 Subject: [PATCH 92/97] Dataset --- Megatron-LM | 2 +- fast_llm/data/data/gpt/data.py | 28 +- fast_llm/data/dataset/gpt/config.py | 14 +- fast_llm/data/dataset/gpt/indexed.py | 39 ++- fast_llm/data/dataset/gpt/memmap.py | 423 +++++++++++++-------------- fast_llm/data/dataset/gpt/sampled.py | 5 +- 6 files changed, 265 insertions(+), 246 deletions(-) diff --git a/Megatron-LM b/Megatron-LM index 511e8f5cb..f02b413f7 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 511e8f5cbe3ab8291953ac64e5beceb727a1b814 +Subproject commit f02b413f793af05ade3893bccd8aef6d644d3edf diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 9df9b9b86..2c728ed4b 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -51,28 +51,24 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch] if not sampling_parameters.cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] - has_images = False - batch_images = [] - for sample in batch: - if sample.images is not None: - batch_images.append([torch.from_numpy(image) for image in sample.images]) - has_images = True - else: - batch_images.append([]) - batch_image_positions = [] - for sample in batch: - if sample.image_positions is not None: - batch_image_positions.append(torch.from_numpy(sample.image_positions)) - else: - batch_image_positions.append([]) + has_images = any(sample.images is not None for sample in batch) + if has_images: + images = [ + [] if sample.images is None else [torch.from_numpy(image) for image in sample.images] for sample in batch + ] + image_positions = [ + [] if sample.image_positions is None else torch.from_numpy(sample.image_positions) for sample in batch + ] + else: + images, image_positions = None, None return GPTBatch( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, chosen_spans=stacked_chosen_spans, rejected_spans=stacked_rejected_spans, - images=batch_images if has_images else None, - image_positions=batch_image_positions if has_images else None, + images=images, + image_positions=image_positions, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 692776a24..7bfdc8515 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -65,7 +65,15 @@ class GPTSamplingConfig(SamplingConfig): @dataclasses.dataclass(kw_only=True) -class GPTSamplingParameters(SamplingParameters): +class ImageSamplingParameters: + patch_size: int | None = None + max_image_size: int | None = None + image_break_token: int | None = None + image_end_token: int | None = None + + +@dataclasses.dataclass(kw_only=True) +class GPTSamplingParameters(SamplingParameters, ImageSamplingParameters): """ Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. """ @@ -76,10 +84,6 @@ class GPTSamplingParameters(SamplingParameters): use_preference_loss_spans: bool = False cross_document_attention: bool = True truncate_documents: bool = True - patch_size: int | None = None - max_image_size: int | None = None - image_break_token: int | None = None - image_end_token: int | None = None # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 8a4440ae4..669b2d9e9 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -19,12 +19,26 @@ def get_document_sizes(self) -> np.ndarray: and derived classes should try to avoid holding the whole array im memory. """ + @abc.abstractmethod + def get_image_sizes(self) -> list[np.ndarray]: + """ + The size of each image in the dataset. + The resulting array could be very large, so this method should be called cautiously, + and derived classes should try to avoid holding the whole array im memory. + """ + @abc.abstractmethod def get_document_size(self, index: int) -> int: """ The size of a document in the dataset. """ + @abc.abstractmethod + def get_image_size(self, index: int) -> np.ndarray: + """ + The size of an image in the dataset. + """ + def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset, LegacyGPTSampledIndexedDataset @@ -52,12 +66,18 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. - doc_sizes, im_sizes = self._dataset.get_document_sizes() - return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else np.array([]) + return self._dataset.get_document_sizes()[self._begin : self._end] + + def get_image_sizes(self) -> list[np.ndarray]: + # TODO: This can be really big. + return self._dataset.get_image_sizes()[self._begin : self._end] def get_document_size(self, index: int) -> int: return self._dataset.get_document_size(self._begin + index) + def get_image_size(self, index: int) -> np.ndarray: + return self._dataset.get_image_size(self._begin + index) + @property def has_images(self) -> bool: return self._dataset.has_images @@ -70,17 +90,20 @@ class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. - # return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) - sizes = [dataset.get_document_sizes() for dataset in self._datasets] - return ( - np.concatenate([size[0] for size in sizes]), - np.concatenate([size[1] for size in sizes]) if sizes[0][1] is not None else np.array([]), - ) + return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) + + def get_image_sizes(self) -> list[np.ndarray]: + # TODO: This can be really big. + return sum([dataset.get_image_sizes() for dataset in self._datasets], []) def get_document_size(self, index: int) -> int: dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item()) + def get_image_size(self, index: int) -> np.ndarray: + dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") + return self._datasets[dataset].get_image_size(index - self._dataset_splits[dataset].item()) + @property def has_images(self) -> bool: return any(dataset.has_images for dataset in self._datasets) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index c7a99f10f..e0473b7e1 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -43,22 +43,14 @@ def _init( super().__init__() self._name = name self._prefix = pathlib.Path(prefix) - self._has_spans = 0 - self._has_images = 0 - self._has_preference_spans = False with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: - self._has_spans = struct.unpack("= 3: - self._has_preference_spans = struct.unpack("= 4: - self._has_images = struct.unpack("= 2 else False + self._has_preference_spans = bool(struct.unpack("= 3 else False + self._has_images = bool(struct.unpack("= 4 else False self._dtype = MEMMAP_DTYPES[struct.unpack("= 2: - self._spans = [] - self._num_spans = np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._num_documents, - offset=offset, - ) - self._num_spans_cumsum = np.r_[0, np.cumsum(self._num_spans[:-1], dtype=np.int64)] - for idx in range(self._num_documents): - self._spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._num_spans[idx] * 2, - offset=offset - + self._num_spans.nbytes - + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, - ).reshape(-1, 2) - ) - offset += self._num_spans.nbytes + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize - # read preference spans - self._chosen_spans = None - self._rejected_spans = None - if self._has_preference_spans and self._version >= 3: - self._chosen_spans = [] - self._rejected_spans = [] - for idx in range(self._num_documents): - self._chosen_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) - - rejected_span_offset = offset + np.array(self._chosen_spans).nbytes - for idx in range(self._num_documents): - self._rejected_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) - offset += np.array(self._chosen_spans).nbytes + np.array(self._rejected_spans).nbytes - - self._num_pixels = 0 - self._image_sizes = None - self._image_positions = None - if self._has_images and self._version >= 4: - self._n_images = np.frombuffer( - self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset - ) - self._image_sizes = [] - self._image_positions = [] - images_seen = 0 - num_total_images = self._n_images.sum() - for n_images in self._n_images: - self._image_sizes.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=n_images * 2, - offset=offset + self._n_images.nbytes + 2 * images_seen * np.dtype(np.int32).itemsize, - ).reshape(-1, 2) - ) - self._num_pixels += self._image_sizes[-1].prod(axis=1, initial=3).sum() - self._image_positions.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=n_images, - offset=offset - + self._n_images.nbytes - + 2 * num_total_images * np.dtype(np.int32).itemsize - + +images_seen * np.dtype(np.int32).itemsize, - ) - ) - images_seen += n_images + if self._has_spans: + offset = self._init_spans(offset) + + if self._has_preference_spans: + offset = self._init_preference_spans(offset) + + total_pixels, _ = self._init_images(offset) if self._has_images else (0, offset) + if num_pixels is not None: + assert total_pixels == num_pixels self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) - self._num_tokens = div(self._bin_buffer_mmap.size - self._num_pixels, np.dtype(self._dtype).itemsize) - if num_pixels is not None: - assert self._num_pixels == num_pixels + self._num_tokens = div(self._bin_buffer_mmap.size - total_pixels, np.dtype(self._dtype).itemsize) if num_tokens is not None: assert self._num_tokens == num_tokens + def _init_spans(self, offset: int) -> int: + num_spans = np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=self._num_documents, + offset=offset, + ) + num_spans_cumsum = np.r_[0, np.cumsum(num_spans[:-1], dtype=np.int64)] + self._spans = [ + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=num_spans[idx] * 2, + offset=offset + num_spans.nbytes + num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, + ).reshape(-1, 2) + for idx in range(self._num_documents) + ] + return offset + num_spans.nbytes + num_spans.sum() * 2 * np.dtype(np.int32).itemsize + + def _init_preference_spans(self, offset: int) -> int: + item_size = np.dtype(np.int32).itemsize + self._chosen_spans = [ + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=2, + offset=offset + 2 * idx * item_size, + ) + for idx in range(self._num_documents) + ] + offset += 2 * item_size * self._num_documents + self._rejected_spans = [ + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=2, + offset=offset + 2 * idx * item_size, + ) + for idx in range(self._num_documents) + ] + return offset + 2 * item_size * self._num_documents + + def _init_images(self, offset: int) -> tuple[int, int]: + total_pixels = 0 + image_counts = np.frombuffer(self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset) + offset += image_counts.nbytes + + self._image_sizes = [] + self._image_positions = [] + item_size = np.dtype(np.int32).itemsize + + for image_count in image_counts: + self._image_sizes.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=image_count * 2, + offset=offset, + ).reshape(-1, 2) + ) + total_pixels += self._image_sizes[-1].prod(axis=1, initial=3).sum() + offset += 2 * image_count * item_size + + for image_count in image_counts: + self._image_positions.append( + np.frombuffer(self._index_bin_buffer, dtype=np.int32, count=image_count, offset=offset) + ) + offset += image_count * item_size + return total_pixels, offset + def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) @@ -206,77 +192,77 @@ def get( count=self._document_sizes[idx] - offset if length is None else length, offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) - images = None - image_positions = None - if self._has_images: - image_positions = self._image_positions[idx] - - # Truncations with images are not yet supported, so we get all images from the document - pixels = np.frombuffer( - self._bin_buffer, - dtype=np.dtype(np.uint8), - count=self._image_sizes[idx].prod(initial=3, axis=1).sum(), - offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, - ) - images = [] - start = 0 - for image_size in self._image_sizes[idx]: - n_pixels = image_size.prod(initial=3) - images.append(pixels[start : start + n_pixels].reshape(3, image_size[0], image_size[1])) - start += n_pixels - sample_spans = None - if use_loss_masking_spans and self._spans is not None: - sample_spans = self._spans[idx] - - # filter spans that are outside the range of the selected tokens in the document - sample_spans = sample_spans[ - (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) - ] - - # subtract by offset to normalize span boundaries - sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset # offset - sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - - chosen_span = None - rejected_span = None - - if use_preference_loss_spans: - if not self._has_preference_spans: - raise ValueError("No preference spans found in memmap dataset.") - elif self._has_preference_spans and self._chosen_spans is None: - raise ValueError("Failed to read chosen spans from memmap dataset.") - elif self._has_preference_spans and self._rejected_spans is None: - raise ValueError("Failed to read rejected spans from memmap dataset.") - else: - chosen_span = self._chosen_spans[idx] - - # filter spans that are outside the range of the selected tokens in the document - chosen_span = chosen_span[(chosen_span[0] < offset + len(token_ids)) & (chosen_span[1] >= offset)][0] - - # subtract by offset to normalize span boundaries - chosen_span[0] = np.maximum(chosen_span[0], offset) - offset # offset - chosen_span[1] = np.minimum(chosen_span[1], offset + len(token_ids) - 1) - offset - - rejected_span = self._rejected_spans[idx] - - # filter spans that are outside the range of the selected tokens in the document - rejected_span = rejected_span[ - (rejected_span[0] < offset + len(token_ids)) & (rejected_span[1] >= offset) - ][0] - - # subtract by offset to normalize span boundaries - rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset - rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset + + loss_masking_spans = self._get_loss_masking_spans(idx, offset, token_ids) + chosen_span, rejected_span = ( + self._get_preference_spans(idx, offset, token_ids) if use_preference_loss_spans else (None, None) + ) + images, image_positions = self._get_images(idx) return GPTSample( token_ids=token_ids, images=images, image_positions=image_positions, - loss_masking_spans=sample_spans, + loss_masking_spans=loss_masking_spans, chosen_span=chosen_span, rejected_span=rejected_span, ) + def _get_loss_masking_spans(self, idx: int, offset: int, token_ids: np.ndarray) -> np.ndarray | None: + if not self._has_spans: + return None + loss_masking_spans = self._spans[idx] + + # filter spans that are outside the range of the selected tokens in the document + loss_masking_spans = loss_masking_spans[ + (loss_masking_spans[:, 0] < offset + len(token_ids)) & (loss_masking_spans[:, 1] >= offset) + ] + + # subtract by offset to normalize span boundaries + loss_masking_spans[:, 0] = np.maximum(loss_masking_spans[:, 0], offset) - offset # offset + loss_masking_spans[:, 1] = np.minimum(loss_masking_spans[:, 1], offset + len(token_ids) - 1) - offset + return loss_masking_spans + + def _get_preference_spans(self, idx: int, offset: int, token_ids: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + if not self._has_preference_spans: + raise ValueError(f"Dataset {self.name} doesn't have preference spans.") + chosen_span = self._chosen_spans[idx] + + # filter spans that are outside the range of the selected tokens in the document + chosen_span = chosen_span[(chosen_span[0] < offset + len(token_ids)) & (chosen_span[1] >= offset)][0] + + # subtract by offset to normalize span boundaries + chosen_span[0] = np.maximum(chosen_span[0], offset) - offset # offset + chosen_span[1] = np.minimum(chosen_span[1], offset + len(token_ids) - 1) - offset + + rejected_span = self._rejected_spans[idx] + + # filter spans that are outside the range of the selected tokens in the document + rejected_span = rejected_span[(rejected_span[0] < offset + len(token_ids)) & (rejected_span[1] >= offset)][0] + + # subtract by offset to normalize span boundaries + rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset + rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset + return chosen_span, rejected_span + + def _get_images(self, idx: int) -> tuple[list[np.ndarray] | None, np.ndarray | None]: + if not self._has_images: + return None, None + # Truncations with images are not yet supported, so we get all images from the document + pixels = np.frombuffer( + self._bin_buffer, + dtype=np.dtype(np.uint8), + count=self._image_sizes[idx].prod(initial=3, axis=1).sum(), + offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, + ) + images = [] + start = 0 + for image_size in self._image_sizes[idx]: + n_pixels = image_size.prod(initial=3) + images.append(pixels[start : start + n_pixels].reshape(3, image_size[0], image_size[1])) + start += n_pixels + return images, self._image_positions[idx] + @property def name(self) -> str: return self._name @@ -292,16 +278,22 @@ def num_tokens(self) -> int: def has_images(self) -> bool: return self._has_images - def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: + def get_document_sizes(self) -> np.ndarray: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes, self._image_sizes + return self._document_sizes + + def get_image_sizes(self) -> list[np.ndarray]: + return self._image_sizes if self._has_images else [np.array([])] * self._num_documents def get_document_size(self, index: int) -> int: - return self._document_sizes[index].item(), self._image_sizes[index] if self._has_images else [] + return self._document_sizes[index].item() + + def get_image_size(self, index: int) -> np.ndarray: + return self._image_sizes[index] if self._has_images else [] @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): @@ -311,8 +303,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP doc_lengths = [] n_images = [] image_sizes = [] - im_positions = [] - total_images = 0 + image_positions = [] + has_images = False pointers = [] offset = 0 # number of spans for each document @@ -331,16 +323,27 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if dtype is None: dtype = document.token_ids.dtype assert dtype is not None, "Document dtype could not be inferred from the data." - # Ensure all documents have the same dtype assert document.token_ids.dtype == dtype, f"Expected dtype {dtype}, got {document.token_ids.dtype}." + pointers.append(offset) + doc_lengths.append(doc_length := len(document.token_ids)) + # Write document to binary file bin_stream.write(document.token_ids.tobytes(order="C")) - total_im_size = 0 - if document.images: + offset += doc_length * np.dtype(dtype).itemsize + + if document.loss_masking_spans is not None: + num_spans.append(len(document.loss_masking_spans)) + spans.append(document.loss_masking_spans) + if document.chosen_span is not None: + chosen_spans.append(document.chosen_span) + if document.rejected_span is not None: + rejected_spans.append(document.rejected_span) + + if document.images is not None: n_images.append(len(document.images)) - total_images += len(document.images) + has_images = True for image in document.images: # assume 3 channels (RGB) for all images with PIL.Image.open(io.BytesIO(image["bytes"])) as img: @@ -351,59 +354,48 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP assert pixels.dtype == np.uint8, f"Expected uint8 pixels, got {pixels.dtype}." image_sizes.append(np.array(pixels.shape[1:])) bin_stream.write(pixels.tobytes(order="C")) - total_im_size += pixels.size - im_positions.extend(document.image_positions) + offset += pixels.size * np.dtype(np.uint8).itemsize + image_positions.extend(document.image_positions) else: n_images.append(0) - # Update metadata - doc_length = len(document.token_ids) - doc_lengths.append(doc_length) - pointers.append(offset) - if document.loss_masking_spans is not None: - num_spans.append(len(document.loss_masking_spans)) - spans.append(document.loss_masking_spans) - if document.chosen_span is not None: - chosen_spans.append(document.chosen_span) - if document.rejected_span is not None: - rejected_spans.append(document.rejected_span) - offset += doc_length * np.dtype(dtype).itemsize + total_im_size * np.dtype(np.uint8).itemsize num_documents += 1 # Finalize metadata arrays doc_lengths = np.array(doc_lengths, dtype=np.int32) pointers = np.array(pointers, dtype=np.int64) - num_spans = np.array(num_spans, dtype=np.int32) - if len(spans) > 0: + + assert len(spans) == len(num_spans) + if has_loss_masking_spans := len(spans) > 0: + assert len(spans) == num_documents + num_spans = np.array(num_spans, dtype=np.int32) spans = np.vstack(spans, dtype=np.int32) - else: - spans = np.array(spans, dtype=np.int32) - chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) - rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) - if total_images: + assert len(chosen_spans) == len(rejected_spans) + if has_preference_spans := len(chosen_spans) > 0: + assert len(chosen_spans) == num_documents + chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) + rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) + + if has_images: n_images = np.array(n_images, dtype=np.int32) image_sizes = np.stack(image_sizes, dtype=np.int32) - im_positions = np.array(im_positions, dtype=np.int32) - else: - n_images = np.array([]) - image_sizes = np.array([]) - im_positions = np.array([]) + image_positions = np.array(image_positions, dtype=np.int32) # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: idx_stream.write(MEMMAP_INDEX_HEADER) # Indicates the version - # Version 2 onwards optionally add loss-masking spans - # Version 3 optionally adds chosen/rejected spans - # Version 4 onwards optionally add images + # Version 2 onwards supports loss-masking spans + # Version 3 onwards supports preference spans + # Version 4 onwards supports images idx_stream.write(struct.pack(" 0 else 0)) + idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) + idx_stream.write(struct.pack(" 0 else 0)) + idx_stream.write(struct.pack(" typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - text_size, image_lengths = self._indexed_dataset.get_document_size(document_index) + (text_size,) = self._indexed_dataset.get_document_size(document_index) + image_lengths = self._indexed_dataset.get_image_size(document_index) resized_image_lengths = [ get_resize_dims( From 31d4857ee9d627727e87e7594b609ad62e2b27f6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 18 Aug 2025 15:23:20 -0400 Subject: [PATCH 93/97] misc --- fast_llm/data/dataset/gpt/sampled.py | 502 ++++++++++++++------------- 1 file changed, 262 insertions(+), 240 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 0b86d4355..7b42e0bf6 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -95,10 +95,11 @@ def __init__( self._truncate_documents = sampling.parameters.truncate_documents self._device = torch.device("cuda" if self._config.gpu else "cpu") - if self._indexed_dataset.has_images and self._truncate_documents: - raise RuntimeError( - "Truncating documents with images is not yet supported. Please turn off truncation to use images." + if self._indexed_dataset.has_images: + assert not self._truncate_documents, ( + "Truncating documents with images is not yet supported." " Please turn off truncation to use images." ) + assert not self._parameters.use_preference_loss_spans, "Preference loss spans not supported with images." if sampling.cache_directory is None: self._document_shuffling = MemmapArray() @@ -119,16 +120,19 @@ def __init__( ) # TODO: Names are confusing self._document_shuffling = MemmapArray(base_path.with_name(base_path.name + "_shuffling.npy")) - self._token_cumsum_shuffled = MemmapArray(base_path.with_name(base_path.name + "_shuffled_cumsum.npy")) - self._token_cumsum_unshuffled = MemmapArray(base_path.with_name(base_path.name + "_unshuffled_cumsum.npy")) self._yaml_path = base_path.with_suffix(".yaml") # keep document sizes and len filtered docs for preference loss masking if self._parameters.use_preference_loss_spans: self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy")) - self._doc_length_filtered_indicies = MemmapArray( + self._doc_length_filtered_indices = MemmapArray( base_path.with_name(base_path.name + "_doc_length_filtered_indices.npy") ) + else: + self._token_cumsum_shuffled = MemmapArray(base_path.with_name(base_path.name + "_shuffled_cumsum.npy")) + self._token_cumsum_unshuffled = MemmapArray( + base_path.with_name(base_path.name + "_unshuffled_cumsum.npy") + ) # Sample or validate the dataset of a given rank. if sampling.distributed.config.rank == sampling.get_next_rank(): @@ -140,10 +144,87 @@ def _sample(self) -> None: """ Create a `GPTSampledDataset` with the requested parameters. """ - # Get the document sizes, the main information needed for sampling. - document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() - document_sizes = torch.from_numpy(document_sizes).to(self._device) - if image_sizes: + # Get the size each document, the main information needed for sampling. + # Note: "document" may refer to more than just text. + document_sizes = self._get_document_sizes() + + documents_per_epoch, tokens_per_epoch, long_docs_filter = self._get_epoch_size(document_sizes) + num_epochs, shuffled_epochs = self._get_epoch_count(documents_per_epoch, tokens_per_epoch) + + shuffled_documents = documents_per_epoch * shuffled_epochs + unshuffled_epochs = num_epochs - shuffled_epochs + + yaml_data, cached = self._get_and_compare_yaml_data(documents_per_epoch, tokens_per_epoch, unshuffled_epochs) + if cached: + return + + if shuffled_documents > 1e8: + warnings.warn( + f"Shuffling {shuffled_documents:.2e} documents for dataset {self._indexed_dataset.name}." + f" This may take a while and/or use an excessive amount of memory." + ) + elif documents_per_epoch > 1e8: + # TODO: Most of the damage is already done in `get_document_sizes`. Find a way to warn earlier? + warnings.warn( + f"The dataset {self._indexed_dataset.name} contains {documents_per_epoch:.2e} documents." + f" Sampling may take a while and/or use an excessive amount of memory." + ) + + document_shuffling = self._get_document_shuffling(documents_per_epoch, shuffled_documents, shuffled_epochs) + + if self._parameters.use_preference_loss_spans: + # index of all documents less than seq length long + self._doc_length_filtered_indices.save(torch.nonzero(long_docs_filter, as_tuple=True)[0].numpy(force=True)) + self._document_sizes.save(document_sizes.numpy(force=True)) + if shuffled_epochs > 0: + self._document_shuffling.save(document_shuffling[: self._parameters.num_samples].numpy(force=True)) + unshuffled_tokens = 0 + + else: + + # To get a sample on the fly we need to know where it begins, + # and this is a non-trivial information because the documents have variable length. + # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. + # `document_sizes[all_document_index][:document[idx]].sum() + token[idx] == idx * sequence_length`. + # This can be computed quickly provided we know a (partial) sum close to `(idx * sequence_length)`. + # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`. + # Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation. + # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` + + # TODO: Allowing for max 100% extra tokens for padding, is that enough? + cumsum_dtype = get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs) + if unshuffled_epochs > 0: + token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum(document_sizes, 0, cumsum_dtype) + self._token_cumsum_unshuffled.save(token_cumsum_unshuffled) + else: + unshuffled_tokens = 0 + + if shuffled_epochs > 0: + token_cumsum_shuffled, _ = self._get_token_cumsum( + document_sizes[ + # Torch indexing only works with int32 or int64 + document_shuffling.to( + dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 + ) + ], + self._unshuffled_tokens, + cumsum_dtype, + ) + self._token_cumsum_shuffled.save(token_cumsum_shuffled) + self._document_shuffling.save( + document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy(force=True) + ) + + yaml_data["unshuffled_tokens"] = unshuffled_tokens + self._load_yaml_data(yaml_data) + if self._yaml_path is not None: + self._yaml_path.parent.mkdir(parents=True, exist_ok=True) + yaml.safe_dump(yaml_data, self._yaml_path.open("w")) + + def _get_document_sizes(self) -> torch.Tensor: + document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) + if self._indexed_dataset.has_images: + image_sizes = self._indexed_dataset.get_image_sizes() image_token_sizes = [] for i, sizes in enumerate(image_sizes): image_token_sizes.append( @@ -162,37 +243,43 @@ def _sample(self) -> None: for size in sizes ) ) - image_token_sizes = torch.tensor(image_token_sizes).to(self._device) - else: - image_token_sizes = torch.zeros_like(document_sizes) + document_sizes += torch.tensor(image_token_sizes).to(self._device) + return document_sizes + def _get_epoch_size(self, document_sizes: torch.Tensor) -> tuple[int, int, torch.Tensor | None]: documents_per_epoch = document_sizes.numel() - tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() - - # Calculate basic stats. - if not self._truncate_documents: + if self._truncate_documents: + tokens_per_epoch = document_sizes.sum().item() + long_docs_filter = None + else: assert _extension_available, ( "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - long_docs_filter = document_sizes + image_token_sizes > self._parameters.sequence_length + 1 - ignored_documents = long_docs_filter.sum().item() - if ignored_documents: + long_docs_filter = document_sizes <= self._parameters.sequence_length + 1 + documents_per_epoch_filtered = long_docs_filter.sum().item() + if ignored_documents := documents_per_epoch_filtered - documents_per_epoch: log_main_rank( - f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", + f" > {ignored_documents}/{documents_per_epoch} documents" + f" are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", log_fn=logger.warning, ) - tokens_per_epoch = (document_sizes[~long_docs_filter] + image_token_sizes[~long_docs_filter]).sum().item() + # TODO: WHY?!?!?!? + if self._parameters.use_preference_loss_spans: + documents_per_epoch = documents_per_epoch_filtered + tokens_per_epoch = document_sizes[long_docs_filter].sum().item() if tokens_per_epoch == 0: raise RuntimeError( - f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." + f" > No documents shorter than {self._parameters.sequence_length+1}" + f" tokens found in dataset {self._indexed_dataset.name}." ) + return documents_per_epoch, tokens_per_epoch, long_docs_filter + def _get_epoch_count(self, documents_per_epoch: int, tokens_per_epoch: int) -> tuple[int, int]: # We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads, # but in case of truncations we also include those last labels in the following sample, # so we need `sequence_length * num_samples + extra_tokens` tokens in total. if self._parameters.use_preference_loss_spans: - documents_per_epoch = (~long_docs_filter).sum().item() num_epochs = math.ceil(self._parameters.num_samples / documents_per_epoch) elif self._truncate_documents: num_epochs = math.ceil( @@ -206,35 +293,34 @@ def _sample(self) -> None: ) # Prepare for shuffling. - generator = torch.Generator(device=self._device) if self._config.shuffle == ShufflingType.skip_first_epoch: shuffled_epochs = num_epochs - 1 elif self._config.shuffle == ShufflingType.disabled: shuffled_epochs = 0 else: shuffled_epochs = num_epochs - shuffled_documents = documents_per_epoch * shuffled_epochs - unshuffled_epochs = num_epochs - shuffled_epochs + return num_epochs, shuffled_epochs + def _get_and_compare_yaml_data( + self, + documents_per_epoch: int, + tokens_per_epoch: int, + unshuffled_epochs: int, + ) -> tuple[dict[str, typing.Any], bool]: yaml_data = { "dataset": { "name": self._indexed_dataset.name, "documents_per_epoch": documents_per_epoch, "tokens_per_epoch": tokens_per_epoch, }, - "num_samples": self._parameters.num_samples, + "sampling": self._parameters.__dict__, "unshuffled_epochs": unshuffled_epochs, - "sequence_length": self._parameters.sequence_length, - "patch_size": self._parameters.patch_size, - "truncate_documents": self._truncate_documents, - "image_break_token": self._parameters.image_break_token, - "image_end_token": self._parameters.image_end_token, "config": self._config.to_dict(), } if self._truncate_documents: yaml_data["unshuffled_tokens"] = tokens_per_epoch * unshuffled_epochs - if self._yaml_path is not None and self._yaml_path.is_file(): + if cached := (self._yaml_path is not None and self._yaml_path.is_file()): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) # Hack to make sure unshuffled tokens are loaded if not self._truncate_documents: @@ -251,123 +337,8 @@ def _sample(self) -> None: ) # Dataset is already sampled, skip. logger.info(f"Using existing sampling for dataset {self.name}") - return - - if shuffled_documents > 1e8: - warnings.warn( - f"Shuffling {shuffled_documents:.2e} documents for dataset {self._indexed_dataset.name}." - f" This may take a while and/or use an excessive amount of memory." - ) - elif documents_per_epoch > 1e8: - # TODO: Most of the damage is already done in `get_document_sizes`. Find a way to warn earlier? - warnings.warn( - f"The dataset {self._indexed_dataset.name} contains {documents_per_epoch:.2e} documents." - f" Sampling may take a while and/or use an excessive amount of memory." - ) - - # Use the smallest possible data type to save memory and disk usage. - document_shuffling_dtype = get_unsigned_integer_type(documents_per_epoch).torch - # Shuffle the dataset (documents) - # This generates a document shuffling index `all_document_index`, the unshuffled part is trivial - # so we only evaluate and store the shuffled part `document_shuffling`. - if self._config.shuffle == ShufflingType.full: - generator.manual_seed(self._config.seed) - # Equivalent to `shuffle(range(documents_per_epoch * num_epochs)) % documents_per_epoch` - document_shuffling = ( - torch.randperm( - shuffled_documents, - generator=generator, - dtype=get_unsigned_integer_type(shuffled_documents).torch, - device=self._device, - ) - .remainder_(documents_per_epoch) - .to(dtype=document_shuffling_dtype) - ) - elif self._config.shuffle in (ShufflingType.skip_first_epoch, ShufflingType.epoch): - document_shuffling = torch.empty( - shuffled_documents, - dtype=document_shuffling_dtype, - device=self._device, - ) - for i in range(shuffled_epochs): - generator.manual_seed(self._config.seed + i * 571) - torch.randperm( - documents_per_epoch, - generator=generator, - out=document_shuffling[i * documents_per_epoch : (i + 1) * documents_per_epoch], - ) - elif self._config.shuffle == ShufflingType.disabled: - document_shuffling = None - else: - raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") - if self._parameters.use_preference_loss_spans: - yaml_data["unshuffled_tokens"] = 0 # not used, ignore - - # index of all documents less than seq length long - doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] - self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu)) - - # apply shuffling on doc_length_filtered_indicies - if shuffled_epochs > 0: - self._document_shuffling.save( - document_shuffling[: self._parameters.num_samples].numpy(force=self._config.gpu) - ) - self._document_sizes.save(document_sizes.numpy(force=self._config.gpu)) - if self._yaml_path is not None: - self._yaml_path.parent.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(yaml_data, self._yaml_path.open("w")) - return - - # To get a sample on the fly we need to know where it begins, - # and this is a non-trivial information because the documents have variable length. - # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. - # `document_sizes[all_document_index][:document[idx]].sum() + token[idx] == idx * sequence_length`. - # This can be computed quickly provided we know a (partial) sum close to `(idx * sequence_length)`. - # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`. - # Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation. - # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` - if unshuffled_epochs > 0: - token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( - document_sizes + image_token_sizes, - offset=0, - # TODO: Allowing for max 100% extra tokens for padding, is that enough? - dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), - ) - self._token_cumsum_unshuffled.save(token_cumsum_unshuffled) - else: - unshuffled_tokens = 0 - - if not self._truncate_documents: - yaml_data["unshuffled_tokens"] = unshuffled_tokens - self._load_yaml_data(yaml_data) - if self._yaml_path is not None: - self._yaml_path.parent.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(yaml_data, self._yaml_path.open("w")) - - if shuffled_epochs > 0: - token_cumsum_shuffled, _ = self._get_token_cumsum( - document_sizes[ - # Torch indexing only works with int32 or int64 - document_shuffling.to( - dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 - ) - ] - + image_token_sizes[ - document_shuffling.to(torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32) - ], - offset=self._unshuffled_tokens, - # TODO: Allowing for max 100% extra tokens for padding, is that enough? - dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), - ) - self._token_cumsum_shuffled.save(token_cumsum_shuffled) - self._document_shuffling.save( - document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy( - force=self._config.gpu - ) - ) - # Free memory - del document_shuffling + return yaml_data, cached def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) -> tuple[np.ndarray, int | None]: if self._truncate_documents: @@ -410,6 +381,50 @@ def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) - ] return out, num_tokens + def _get_document_shuffling( + self, + documents_per_epoch: int, + shuffled_documents: int, + shuffled_epochs: int, + ) -> torch.Tensor | None: + generator = torch.Generator(device=self._device) + # Use the smallest possible data type to save memory and disk usage. + document_shuffling_dtype = get_unsigned_integer_type(documents_per_epoch).torch + # Shuffle the dataset (documents) + # This generates a document shuffling index `all_document_index`, the unshuffled part is trivial + # so we only evaluate and store the shuffled part `document_shuffling`. + if self._config.shuffle == ShufflingType.full: + generator.manual_seed(self._config.seed) + # Equivalent to `shuffle(range(documents_per_epoch * num_epochs)) % documents_per_epoch` + document_shuffling = ( + torch.randperm( + shuffled_documents, + generator=generator, + dtype=get_unsigned_integer_type(shuffled_documents).torch, + device=self._device, + ) + .remainder_(documents_per_epoch) + .to(dtype=document_shuffling_dtype) + ) + elif self._config.shuffle in (ShufflingType.skip_first_epoch, ShufflingType.epoch): + document_shuffling = torch.empty( + shuffled_documents, + dtype=document_shuffling_dtype, + device=self._device, + ) + for i in range(shuffled_epochs): + generator.manual_seed(self._config.seed + i * 571) + torch.randperm( + documents_per_epoch, + generator=generator, + out=document_shuffling[i * documents_per_epoch : (i + 1) * documents_per_epoch], + ) + elif self._config.shuffle == ShufflingType.disabled: + document_shuffling = None + else: + raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") + return document_shuffling + def __len__(self) -> int: return self._parameters.num_samples @@ -422,37 +437,7 @@ def __getitem__(self, index: int) -> typing.Any: self._lazy_load() if self._parameters.use_preference_loss_spans: - if index < self._unshuffled_documents: - document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch] - else: - document_index = self._doc_length_filtered_indicies[ - self._document_shuffling[index - self._unshuffled_documents].item() - ] - - sample = self._indexed_dataset.get( - document_index, - offset=0, - length=self._document_sizes[document_index], - use_loss_masking_spans=self._parameters.use_loss_masking_spans, - use_preference_loss_spans=self._parameters.use_preference_loss_spans, - ) - - chosen_span_end = sample.chosen_span[1] + 1 - sequence_lengths = [ - chosen_span_end, - len(sample.token_ids) - chosen_span_end, - ] - - # compute padding size - padding = np.full((self._parameters.sequence_length + 1,), 0) - padding[: len(sample.token_ids)] = sample.token_ids - sequence_lengths.append(self._parameters.sequence_length - len(sample.token_ids)) - sample.token_ids = padding - - if not self._parameters.cross_document_attention: - sample.sequence_lengths = np.array(sequence_lengths) - - return sample + return self._get_preference_loss_span_sample(index) # tokens at the boundary are included in only one sample when we pack without truncations # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample @@ -479,10 +464,12 @@ def __getitem__(self, index: int) -> typing.Any: token_count = token_start_array[token_start_cumsum_index] token_ids = [] - loss_masking_spans = [] - images = [] - image_positions = [] - image_tokens_added = 0 + if self._parameters.use_loss_masking_spans: + loss_masking_spans = [] + if self._indexed_dataset.has_images: + images = [] + image_positions = [] + image_tokens_added = 0 text_tokens_added = 0 while token_count < token_end: # Find the document index in the dataset. @@ -491,29 +478,32 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - (text_size,) = self._indexed_dataset.get_document_size(document_index) - image_lengths = self._indexed_dataset.get_image_size(document_index) + text_size = self._indexed_dataset.get_document_size(document_index) + if self._indexed_dataset.has_images: + image_lengths = self._indexed_dataset.get_image_size(document_index) - resized_image_lengths = [ - get_resize_dims( - *image_length, - self._parameters.max_image_size, - self._parameters.max_image_size, - self._parameters.patch_size, - ) - for image_length in image_lengths - ] - image_sizes = [ - get_num_image_tokens( - *image_length, - self._parameters.patch_size, - image_break=self._parameters.image_break_token is not None, - image_end=self._parameters.image_end_token is not None, - ) - for image_length in resized_image_lengths - ] - image_tokens = sum(image_sizes) - document_size = text_size + image_tokens + resized_image_lengths = [ + get_resize_dims( + *image_length, + self._parameters.max_image_size, + self._parameters.max_image_size, + self._parameters.patch_size, + ) + for image_length in image_lengths + ] + image_sizes = [ + get_num_image_tokens( + *image_length, + self._parameters.patch_size, + image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, + ) + for image_length in resized_image_lengths + ] + image_tokens = sum(image_sizes) + document_size = text_size + image_tokens + else: + document_size = text_size if not self._truncate_documents: if document_size > self._parameters.sequence_length + 1: @@ -550,9 +540,8 @@ def __getitem__(self, index: int) -> typing.Any: length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) - start_pos = 0 - has_images = sample.image_positions is not None - if has_images: + if self._indexed_dataset.has_images: + start_pos = 0 sample_token_ids = [] for idx, im_position in enumerate(sample.image_positions): # add placeholder masked tokens for images @@ -587,42 +576,42 @@ def __getitem__(self, index: int) -> typing.Any: sample_token_ids.append(sample.token_ids[start_pos:]) text_tokens_added += len(sample_token_ids[-1]) token_ids.append(np.concatenate(sample_token_ids)) - else: - token_ids.append(sample.token_ids[start_pos:]) - text_tokens_added += len(token_ids[-1]) - if sample.images: images.append(sample.images) else: - images.append([]) + token_ids.append(sample.token_ids) + text_tokens_added += len(token_ids[-1]) if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: - prev_image_tokens = 0 - image_idx = 0 - image_position = ( - sample.image_positions[image_idx] - if has_images and image_idx < len(sample.image_positions) - else float("inf") - ) - while image_position < loss_masking_span[0]: - prev_image_tokens += image_sizes[image_idx] - image_idx += 1 + if self._indexed_dataset.has_images: + prev_image_tokens = 0 + image_idx = 0 image_position = ( sample.image_positions[image_idx] - if has_images and image_idx < len(sample.image_positions) + if image_idx < len(sample.image_positions) else float("inf") ) - span_image_tokens = 0 - while image_position <= loss_masking_span[1]: - span_image_tokens += image_sizes[image_idx] - image_idx += 1 - image_position = ( - sample.image_positions[image_idx] - if has_images and image_idx < len(sample.image_positions) - else float("inf") - ) - loss_masking_span[0] += prev_image_tokens - loss_masking_span[1] += prev_image_tokens + span_image_tokens - prev_image_tokens += span_image_tokens + while image_position < loss_masking_span[0]: + prev_image_tokens += image_sizes[image_idx] + image_idx += 1 + image_position = ( + sample.image_positions[image_idx] + if image_idx < len(sample.image_positions) + else float("inf") + ) + span_image_tokens = 0 + while image_position <= loss_masking_span[1]: + span_image_tokens += image_sizes[image_idx] + image_idx += 1 + image_position = ( + sample.image_positions[image_idx] + if image_idx < len(sample.image_positions) + else float("inf") + ) + loss_masking_span[0] += prev_image_tokens + loss_masking_span[1] += prev_image_tokens + span_image_tokens + # TODO: Unused, meant to be inside loop? What about 2 lines above? + prev_image_tokens += span_image_tokens + span = np.clip( loss_masking_span + token_count - token_start, 0, @@ -646,8 +635,8 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.use_loss_masking_spans else None ) - images = [im for img_list in images for im in img_list] if images else None - image_positions = np.array(image_positions) if image_positions else None + images = [im for img_list in images for im in img_list] if self._indexed_dataset.has_images else None + image_positions = np.array(image_positions) if self._indexed_dataset.has_images else None Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) return GPTSample( @@ -658,6 +647,39 @@ def __getitem__(self, index: int) -> typing.Any: image_positions=image_positions, ) + def _get_preference_loss_span_sample(self, index: int): + if index < self._unshuffled_documents: + document_index = self._doc_length_filtered_indices[index % self._documents_per_epoch] + else: + document_index = self._doc_length_filtered_indices[ + self._document_shuffling[index - self._unshuffled_documents].item() + ] + + sample = self._indexed_dataset.get( + document_index, + offset=0, + length=self._document_sizes[document_index], + use_loss_masking_spans=self._parameters.use_loss_masking_spans, + use_preference_loss_spans=self._parameters.use_preference_loss_spans, + ) + + chosen_span_end = sample.chosen_span[1] + 1 + sequence_lengths = [ + chosen_span_end, + len(sample.token_ids) - chosen_span_end, + ] + + # compute padding size + padding = np.full((self._parameters.sequence_length + 1,), 0) + padding[: len(sample.token_ids)] = sample.token_ids + sequence_lengths.append(self._parameters.sequence_length - len(sample.token_ids)) + sample.token_ids = padding + + if not self._parameters.cross_document_attention: + sample.sequence_lengths = np.array(sequence_lengths) + + return sample + @property def name(self) -> str: return self._indexed_dataset.name From 4534ae939e9b252d1e6803314b60634a6151419f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 18 Aug 2025 16:14:16 -0400 Subject: [PATCH 94/97] Reduce diff --- fast_llm/functional/config.py | 6 +- fast_llm/functional/cross_entropy.py | 2 +- fast_llm/functional/triton/mlp.py | 6 +- fast_llm/layers/transformer/attention.py | 102 +++++++------- fast_llm/layers/transformer/config.py | 135 ++++++++----------- fast_llm/layers/transformer/mlp.py | 11 +- fast_llm/layers/transformer/preprocessing.py | 30 ++--- fast_llm/layers/transformer/transformer.py | 25 +--- 8 files changed, 134 insertions(+), 183 deletions(-) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index c56b63065..6ab82bfea 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -68,8 +68,7 @@ def _set_activation_fn_map() -> None: global _ACTIVATION_FN_MAP _ACTIVATION_FN_MAP = { - ActivationType.gelu: torch.nn.functional.gelu, - ActivationType.gelu_pytorch_tanh: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.gelu: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), ActivationType.silu: torch.nn.functional.silu, ActivationType.relu: torch.nn.functional.relu, ActivationType.squared_relu: lambda x: torch.pow(torch.nn.functional.relu(x), 2), @@ -80,8 +79,7 @@ def _set_activation_fn_map() -> None: _ACTIVATION_FN_MAP: dict[ActivationType, typing.Callable[["torch.Tensor"], "torch.Tensor"]] = {} _ACTIVATION_HF_NAMES = { - ActivationType.gelu: "gelu", - ActivationType.gelu_pytorch_tanh: "gelu_pytorch_tanh", + ActivationType.gelu: "gelu_pytorch_tanh", ActivationType.silu: "silu", ActivationType.relu: "relu", ActivationType.squared_relu: "relu2", diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 7a289b579..d56dce98d 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -147,7 +147,7 @@ def _fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: - per_sample_loss = per_sample_loss[loss_mask] + per_sample_loss = per_sample_loss * loss_mask loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index f3d9d7d0c..ab408368f 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -47,7 +47,8 @@ def triton_mlp_activation_forward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) - if activation_type == "gelu" or activation_type == "gelu_pytorch_tanh": + # Triton doesn't like enums, so we use str instead of ActivationType. + if activation_type == "gelu": tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) out = input_ * 0.5 * (1.0 + tanh) @@ -97,7 +98,8 @@ def triton_mlp_activation_backward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) output_grad = tl.load(grad_output_ptr + output_offsets, mask=mask).to(tl.float32) - if activation_type == "gelu" or activation_type == "gelu_pytorch_tanh": + # Triton doesn't like enums, so we use str instead of ActivationType. + if activation_type == "gelu": tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) grad = 0.5 * input_ * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * input_ * input_)) + 0.5 * (1 + tanh) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 04f789d57..c9d548d42 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -7,7 +7,12 @@ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs, TransformerSubLayerName +from fast_llm.layers.transformer.config import ( + TransformerConfig, + TransformerDimNames, + TransformerKwargs, + TransformerSubLayerName, +) from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert, get_lr_scale @@ -50,6 +55,24 @@ class Attention(torch.nn.Module): A self-attention layer. """ + _QUERY_DIMS = ( + TransformerDimNames.batch, + TransformerDimNames.sequence_q, + TransformerDimNames.composite_heads, + TransformerDimNames.kv_channels, + ) + _KV_DIMS = ( + TransformerDimNames.batch, + TransformerDimNames.sequence_q, + TransformerDimNames.group_heads, + TransformerDimNames.kv_channels, + ) + _CONTEXT_DIMS = ( + TransformerDimNames.batch, + TransformerDimNames.sequence_q, + TransformerDimNames.composite_dense, + ) + def __init__( self, config: TransformerConfig, @@ -58,15 +81,12 @@ def __init__( layer_offset: int = 1, ): super().__init__() - self._transformer_dim_names = config._transformer_dim_names - self._transformer_kwargs = config._transformer_kwargs self._config = config self._tensor_space = tensor_space - Assert.in_range_incl(layer_index, layer_offset, max(self._config.num_layers + layer_offset, layer_offset)) + # Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) self._layer_index = layer_index self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel self._debug_transformer = self._config.debug_transformer - self._causal = self._config.causal self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -80,14 +100,14 @@ def __init__( max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space.get_tensor_dim(self._transformer_dim_names.kv_channels).size - self._head_groups = self._tensor_space.get_tensor_dim(self._transformer_dim_names.head_groups).global_size - self._local_head_groups = self._tensor_space.get_tensor_dim(self._transformer_dim_names.head_groups).size - self._local_heads_per_group = self._tensor_space.get_tensor_dim(self._transformer_dim_names.group_heads).size + self._kv_channels = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels).size + self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size + self._local_head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).size + self._local_heads_per_group = self._tensor_space.get_tensor_dim(TransformerDimNames.group_heads).size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) + hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) @@ -95,7 +115,7 @@ def __init__( # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_query), + self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query), bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -104,7 +124,7 @@ def __init__( ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_key_value), + self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value), bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -118,7 +138,7 @@ def __init__( # Output. self.dense = InputParallelLinear( - self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_dense), + self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense), hidden_dim, bias=self._config.add_attn_dense_bias, weight_init_method=init_method_std_attn_proj, @@ -184,7 +204,7 @@ def _attn_fused( def _get_meta( self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] ) -> TensorMeta: - hidden_dims = {dim.name: dim for dim in kwargs[self._transformer_kwargs.hidden_dims]} + hidden_dims = {dim.name: dim for dim in kwargs[TransformerKwargs.hidden_dims]} return TensorMeta.from_dims( tuple( hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) @@ -194,32 +214,6 @@ def _get_meta( dtype=input_.dtype, ) - @property - def _query_dims(self): - return ( - self._transformer_dim_names.batch, - self._transformer_dim_names.sequence_q, - self._transformer_dim_names.composite_heads, - self._transformer_dim_names.kv_channels, - ) - - @property - def _kv_dims(self): - return ( - self._transformer_dim_names.batch, - self._transformer_dim_names.sequence_q, - self._transformer_dim_names.group_heads, - self._transformer_dim_names.kv_channels, - ) - - @property - def _context_dims(self): - return ( - self._transformer_dim_names.batch, - self._transformer_dim_names.sequence_q, - self._transformer_dim_names.composite_dense, - ) - def _debug_log( self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] ) -> None: @@ -318,12 +312,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # TODO: Move the rest to function. - if (past_key_values := kwargs.get(self._transformer_kwargs.past_key_values)) is not None: + if (past_key_values := kwargs.get(TransformerKwargs.past_key_values)) is not None: assert sequence_first # Clear the lists so tensors can be de-allocated key_value = torch.cat((past_key_values.pop(0), key_value), dim=0) - if (presents := kwargs.get(self._transformer_kwargs.presents)) is not None: + if (presents := kwargs.get(TransformerKwargs.presents)) is not None: # Return the presents as a leaf tensors so the gradients from later micro-sequences # don't propagate to this one. presents.append(present := key_value.detach().requires_grad_()) @@ -363,7 +357,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._use_flash_attention: assert _flash_available with set_generator(self._tensor_space.distributed.tp_generator): - if (cu_seqlens_q := kwargs.get(self._transformer_kwargs.cu_seqlens_q, None)) is not None: + if (cu_seqlens_q := kwargs.get(TransformerKwargs.cu_seqlens_q, None)) is not None: out_dims = query.size() query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) @@ -373,12 +367,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key, value, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=kwargs.get(self._transformer_kwargs.cu_seqlens_k), - max_seqlen_q=kwargs.get(self._transformer_kwargs.max_seqlen_q), - max_seqlen_k=kwargs.get(self._transformer_kwargs.max_seqlen_k), + cu_seqlens_k=kwargs.get(TransformerKwargs.cu_seqlens_k), + max_seqlen_q=kwargs.get(TransformerKwargs.max_seqlen_q), + max_seqlen_k=kwargs.get(TransformerKwargs.max_seqlen_k), dropout_p=self._config.attention_dropout if self.training else 0.0, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), - causal=self._causal, + causal=self._config.causal, softmax_scale=self._softmax_scale, ).view(*out_dims) else: @@ -388,7 +382,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ value, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), dropout_p=self._config.attention_dropout if self.training else 0.0, - causal=self._causal, + causal=self._config.causal, softmax_scale=self._softmax_scale, ) input_ = input_.flatten(-2) @@ -398,25 +392,25 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ query.flatten(-2), key.flatten(-2), value.flatten(-2), - kwargs[self._transformer_kwargs.attention_mask], - kwargs[self._transformer_kwargs.attention_mask_value], + kwargs[TransformerKwargs.attention_mask], + kwargs[TransformerKwargs.attention_mask_value], ) if self._debug_transformer: - self._debug_log(query, "query", self._query_dims, kwargs) + self._debug_log(query, "query", self._QUERY_DIMS, kwargs) self._debug_log( key, "key", - self._kv_dims, + self._KV_DIMS, kwargs, ) self._debug_log( value, "value", - self._kv_dims, + self._KV_DIMS, kwargs, ) - self._debug_log(input_, "context", self._context_dims, kwargs) + self._debug_log(input_, "context", self._CONTEXT_DIMS, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 0059718f5..857ae3c6b 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -29,85 +29,62 @@ class RoutingType(str, enum.Enum): sinkhorn = "sinkhorn" -class BaseTransformerDimNames: - _kwargs_attributes = { - "batch": "batch", - "sequence_q": "sequence_q", - "sequence_q_tp": "sequence_q_tp", - "sequence_k": "sequence_k", - "hidden": "hidden", - "head_groups": "head_groups", - "group_heads": "group_heads", - "key_and_value": "key_value", - "kv_channels": "kv_channels", - "composite_heads": "composite_heads", - "composite_query": "composite_query", - "composite_key_value": "composite_key_value", - "composite_dense": "composite_dense", - "mlp": "mlp", - "gate_and_up": "gate_and_up", - "composite_gated_mlp": "composite_gated_mlp", - "experts": "experts", - "top_experts": "top_experts", - "shared_experts": "shared_experts", - "unshared_experts": "unshared_experts", - "composite_expert_mlp": "composite_expert_mlp", - "composite_gated_expert_mlp": "composite_gated_expert_mlp", - "composite_shared_expert_mlp": "composite_shared_expert_mlp", - "composite_gated_shared_expert_mlp": "composite_gated_shared_expert_mlp", - } - - def __init_subclass__(cls, prefix="", **kwargs): - super().__init_subclass__(**kwargs) - cls._prefix = prefix - for attr, value in BaseTransformerDimNames._kwargs_attributes.items(): - setattr(cls, attr, f"{cls._prefix}_{value}") - - -class TransformerDimNames(BaseTransformerDimNames, prefix=""): - pass - - -class VisionTransformerDimNames(BaseTransformerDimNames, prefix="image_encoder"): - pass - - -class BaseTransformerKwargs: - _kwargs_attributes = { - "rotary_freq_q": "rotary_freq_q", - "rotary_freq_k": "rotary_freq_k", - "attention_mask": "attention_mask", - "attention_mask_value": "attention_mask_value", - "sequence_lengths": "sequence_lengths", - "cu_seqlens_q": "cu_seqlens_q", - "cu_seqlens_k": "cu_seqlens_k", - "max_seqlen_q": "max_seqlen_q", - "max_seqlen_k": "max_seqlen_k", - "presents": "presents", - "past_key_values": "past_key_values", - "sequence_first": "sequence_first", - "hidden_dims": "hidden_dims", - "sequence_q_dim": "sequence_q_dim", - "sequence_k_dim": "sequence_k_dim", - "sequence_length": "sequence_length", - "micro_batch_size": "micro_batch_size", - "grad_output": "grad_output", - } - - _prefix = "" - - def __init_subclass__(cls, prefix="", **kwargs): - super().__init_subclass__(**kwargs) - cls._prefix = prefix - for attr, value in BaseTransformerKwargs._kwargs_attributes.items(): - setattr(cls, value, f"{cls._prefix}_{value}" if cls._prefix else value) - - -class TransformerKwargs(BaseTransformerKwargs, prefix=""): - pass - - -class VisionTransformerKwargs(BaseTransformerKwargs, prefix="image_encoder"): +class TransformerDimNames: + # A set of common tensor dim names packed into a namespace. + # Input dimensions (variable) + # TODO: Does batch belong here? + batch = "batch" + # TODO: Distinguish micro-sequence? + sequence_q = "sequence_q" + sequence_q_tp = "sequence_q_tp" + sequence_k = "sequence_k" + hidden = "hidden" + # Self-attention dimensions + head_groups = "head_groups" + group_heads = "group_heads" + key_and_value = "key_value" + kv_channels = "kv_channels" + composite_heads = "composite_heads" + composite_query = "composite_query" + composite_key_value = "composite_key_value" + composite_dense = "composite_dense" + # MLP dimensions + mlp = "mlp" + gate_and_up = "gate_and_up" + composite_gated_mlp = "composite_gated_mlp" + experts = "experts" + top_experts = "top_experts" + shared_experts = "shared_experts" + unshared_experts = "unshared_experts" + composite_expert_mlp = "composite_expert_mlp" + composite_gated_expert_mlp = "composite_gated_expert_mlp" + composite_shared_expert_mlp = "composite_shared_expert_mlp" + composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" + + +class TransformerKwargs: + rotary_freq_q = "rotary_freq_q" + rotary_freq_k = "rotary_freq_k" + attention_mask = "attention_mask" + attention_mask_value = "attention_mask_value" + sequence_lengths = "sequence_lengths" + cu_seqlens_q = "cu_seqlens_q" + cu_seqlens_k = "cu_seqlens_k" + max_seqlen_q = "max_seqlen_q" + max_seqlen_k = "max_seqlen_k" + # TODO: Review these + presents = "presents" + past_key_values = "past_key_values" + sequence_first = "sequence_first" + hidden_dims = "hidden_dims" + sequence_q_dim = "sequence_q_dim" + sequence_k_dim = "sequence_k_dim" + sequence_length = "sequence_length" + # TODO: Move + grad_output = "grad_output" + + +class VisionKwargs: patch_position_ids = "patch_position_ids" diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index 83f1110c1..b01eb2aa5 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -8,7 +8,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import TransformerConfig, TransformerSubLayerName +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerName from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert, get_lr_scale @@ -19,9 +19,6 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._name = name self._layer_index = layer_index - self._transformer_dim_names = config._transformer_dim_names - self._transformer_kwargs = config._transformer_kwargs - init_method_1 = init_normal_( std=config.init_method_std_mlp_1, min_val=config.init_method_min_mlp_1, @@ -33,8 +30,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) - self._intermediate_dim = tensor_space.get_tensor_dim(self._transformer_dim_names.composite_expert_mlp) + hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) + self._intermediate_dim = tensor_space.get_tensor_dim(TransformerDimNames.composite_expert_mlp) self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._recompute_level = config.mlp_recompute_level @@ -49,7 +46,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space.get_tensor_dim(self._transformer_dim_names.composite_gated_expert_mlp), + tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index cb64ccf06..dc3ddeb52 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -24,8 +24,6 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, ): - self._transformer_dim_names = config._transformer_dim_names - self._transformer_kwargs = config._transformer_kwargs self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config @@ -56,10 +54,10 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: self._create_tensors(kwargs[TransformerKwargs.sequence_length]) sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - kwargs[self._transformer_kwargs.attention_mask] = self._mask[ + kwargs[TransformerKwargs.attention_mask] = self._mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] - if (sequence_lengths := kwargs.get(self._transformer_kwargs.sequence_lengths, None)) is not None: + if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths, None)) is not None: seq_ids = torch.stack( [ torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) @@ -67,14 +65,14 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: ] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) - kwargs[self._transformer_kwargs.attention_mask] = ( - kwargs[self._transformer_kwargs.attention_mask] + kwargs[TransformerKwargs.attention_mask] = ( + kwargs[TransformerKwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] ) - kwargs[self._transformer_kwargs.attention_mask_value] = self._mask_value + kwargs[TransformerKwargs.attention_mask_value] = self._mask_value def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[self._transformer_kwargs.attention_mask] = TensorMeta.from_dims( + kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( ( self._scalar_dim, self._scalar_dim, @@ -82,12 +80,12 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: self._scalar_dim, kwargs[TransformerKwargs.sequence_k_dim], ), - tensor_name=self._transformer_kwargs.attention_mask, + tensor_name=TransformerKwargs.attention_mask, dtype=torch.bool, ) - kwargs[self._transformer_kwargs.attention_mask_value] = TensorMeta.from_dims( + kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( (self._scalar_dim,), - tensor_name=self._transformer_kwargs.attention_mask_value, + tensor_name=TransformerKwargs.attention_mask_value, dtype=self._tensor_space.distributed_config.training_dtype.torch, ) @@ -98,8 +96,6 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert self._config.do_use_flash_attention(self._distributed_config) - self._transformer_dim_names = config._transformer_dim_names - self._transformer_kwargs = config._transformer_kwargs def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: """ @@ -150,17 +146,17 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: else: seqlens_q = torch.cat(sequence_lengths) seqlens_k = torch.cat(sequence_lengths) - kwargs[self._transformer_kwargs.cu_seqlens_q] = torch.cat( + kwargs[TransformerKwargs.cu_seqlens_q] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[self._transformer_kwargs.cu_seqlens_k] = torch.cat( + kwargs[TransformerKwargs.cu_seqlens_k] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[self._transformer_kwargs.max_seqlen_q] = seqlens_q.max() - kwargs[self._transformer_kwargs.max_seqlen_k] = seqlens_k.max() + kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max() + kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 73819b8ad..761399e5d 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -9,7 +9,7 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage @@ -26,15 +26,9 @@ class BaseBlock(Layer, abc.ABC): _mixer_module_name = "self_attn" def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index: int, - return_input: bool = False, + self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): super().__init__() - self._transformer_dim_names = config._transformer_dim_names - self._transformer_kwargs = config._transformer_kwargs self._config: TransformerConfig = config self._tensor_space: TensorSpace = tensor_space self._dropout_p: float = self._config.hidden_dropout @@ -43,8 +37,7 @@ def __init__( self._layer_index = layer_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - - hidden_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) + hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) # Note, layer_lr_scale does not impact the norms # TODO: add a seperate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) @@ -77,7 +70,7 @@ def name(self) -> str: return f"{self._name} {self._layer_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): - dims = kwargs[self._transformer_kwargs.hidden_dims] + dims = kwargs[TransformerKwargs.hidden_dims] if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) @@ -149,18 +142,12 @@ class TransformerLayer(BaseBlock): _mixer_module_name = "self_attn" def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index: int, - return_input: bool = False, - layer_offset: int = 1, + self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): - self._layer_offset = layer_offset super().__init__(config, tensor_space, layer_index, return_input) def _create_mixer(self): - self.self_attn = Attention(self._config, self._tensor_space, self._layer_index, self._layer_offset) + self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) class VisionTransformerLayer(TransformerLayer): From 14aa78893775dcda33b63f285f6a777a1ba20669 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 18 Aug 2025 16:26:08 -0400 Subject: [PATCH 95/97] Reduce diff --- fast_llm/data/dataset/gpt/indexed.py | 4 +- fast_llm/layers/transformer/attention.py | 1 - fast_llm/layers/transformer/config.py | 64 ++++++-------------- fast_llm/layers/transformer/rotary/config.py | 9 ++- fast_llm/layers/transformer/rotary/rotary.py | 16 ++--- tests/data/common.py | 8 +-- tests/data/test_sampling.py | 4 +- 7 files changed, 43 insertions(+), 63 deletions(-) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 669b2d9e9..59e701a63 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -19,13 +19,13 @@ def get_document_sizes(self) -> np.ndarray: and derived classes should try to avoid holding the whole array im memory. """ - @abc.abstractmethod def get_image_sizes(self) -> list[np.ndarray]: """ The size of each image in the dataset. The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ + raise NotImplementedError() @abc.abstractmethod def get_document_size(self, index: int) -> int: @@ -33,11 +33,11 @@ def get_document_size(self, index: int) -> int: The size of a document in the dataset. """ - @abc.abstractmethod def get_image_size(self, index: int) -> np.ndarray: """ The size of an image in the dataset. """ + raise NotImplementedError() def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset, LegacyGPTSampledIndexedDataset diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index c9d548d42..e72171ec1 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -78,7 +78,6 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, layer_index, - layer_offset: int = 1, ): super().__init__() self._config = config diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 857ae3c6b..20a750f69 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -623,89 +623,65 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.hidden, self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.hidden, self.hidden_size)) # Self-attention dimensions tensor_space.add_tensor_dim( head_groups := TensorDim( - self._transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - self._transformer_dim_names.group_heads, + TransformerDimNames.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(self._transformer_dim_names.key_and_value, 2)) + tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) tensor_space.add_tensor_dim( - kv_channels := TensorDim(self._transformer_dim_names.kv_channels, self.kv_channels) + CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self._transformer_dim_names.composite_heads, (head_groups, group_heads)) + CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self._transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim( - self._transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels) - ) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(self._transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) ) # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(self._transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) - tensor_space.add_tensor_dim( - gate_and_up := TensorDim(self._transformer_dim_names.gate_and_up, 2 if self.gated else 1) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(self._transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp)) - ) - tensor_space.add_tensor_dim(experts := TensorDim(self._transformer_dim_names.experts, self.num_experts)) + tensor_space.add_tensor_dim(mlp := TensorDim(TransformerDimNames.mlp, self.ffn_hidden_size, tensor)) + tensor_space.add_tensor_dim(gate_and_up := TensorDim(TransformerDimNames.gate_and_up, 2 if self.gated else 1)) + tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_gated_mlp, (gate_and_up, mlp))) + tensor_space.add_tensor_dim(experts := TensorDim(TransformerDimNames.experts, self.num_experts)) + tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_expert_mlp, (experts, mlp))) tensor_space.add_tensor_dim( - CompositeTensorDim(self._transformer_dim_names.composite_expert_mlp, (experts, mlp)) + CompositeTensorDim(TransformerDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) ) - tensor_space.add_tensor_dim( - CompositeTensorDim(self._transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) - ) - tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.unshared_experts, self.num_unshared_experts)) + tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.top_experts, self.num_experts_per_token)) + tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.unshared_experts, self.num_unshared_experts)) # shared_experts if self.num_shared_experts: tensor_space.add_tensor_dim( - shared_experts := TensorDim(self._transformer_dim_names.shared_experts, self.num_shared_experts) + shared_experts := TensorDim(TransformerDimNames.shared_experts, self.num_shared_experts) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self._transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) + CompositeTensorDim(TransformerDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - self._transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) + TransformerDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) ) ) def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) - @property - def _transformer_kwargs(self) -> TransformerKwargs: - if self.type == TransformerType.image_encoder: - return VisionTransformerKwargs - else: - return TransformerKwargs - - @property - def _transformer_dim_names(self) -> TransformerDimNames: - if self.type == TransformerType.image_encoder: - return VisionTransformerDimNames - else: - return TransformerDimNames - for name in TransformerType: # We need this because we are using the reserved field name `type`. diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/transformer/rotary/config.py index eb739e5c4..ba598e385 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/transformer/rotary/config.py @@ -10,7 +10,14 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.transformer.rotary.rotary import DefaultRotary, Llama3Rotary, NoRotary, Rotary, YarnRotary + from fast_llm.layers.transformer.rotary.rotary import ( + DefaultRotary, + Llama3Rotary, + NoRotary, + Rotary, + Rotary2D, + YarnRotary, + ) @config_class(registry=True) diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index b2c69dd8d..46ea2a8b2 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -8,7 +8,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs, VisionTransformerKwargs +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs, VisionKwargs from fast_llm.layers.transformer.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, @@ -224,29 +224,29 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None max_num_patches = kwargs[VisionEncoderKwargs.max_image_size] // kwargs[VisionEncoderKwargs.patch_size] self._create_tensors(max_num_patches) - position_ids = kwargs[VisionTransformerKwargs.patch_position_ids] - kwargs[VisionTransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[:, position_ids] - kwargs[VisionTransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, position_ids] + position_ids = kwargs[VisionKwargs.patch_position_ids] + kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[:, position_ids] + kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, position_ids] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None - kwargs[VisionTransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( ( self._scalar_dim, kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=VisionTransformerKwargs.rotary_freq_q, + tensor_name=TransformerKwargs.rotary_freq_q, ) - kwargs[VisionTransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( ( self._scalar_dim, kwargs[TransformerKwargs.sequence_k_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=VisionTransformerKwargs.rotary_freq_k, + tensor_name=TransformerKwargs.rotary_freq_k, ) def _create_tensors(self, max_num_patches: int) -> None: diff --git a/tests/data/common.py b/tests/data/common.py index 23ed9d76b..6bd6b2126 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -127,10 +127,10 @@ def compare_indexed_dataset( loss_masking_spans: dict[int, list[int]] | None = None, ) -> None: Assert.eq(len(dataset), length) - text_sizes, image_sizes = dataset.get_document_sizes() + sizes = dataset.get_document_sizes() # Assert.eq(sizes.sum(), num_tokens) Assert.all_equal( - [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], text_sizes[: min(len(dataset), 100)] + [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)] ) for i, expected_sample in expected_samples.items(): Assert.all_equal(dataset.get(i).token_ids, np.array(expected_sample, dtype=np.uint16)) @@ -224,9 +224,7 @@ def __len__(self) -> int: return self._config.num_documents def get_document_sizes(self) -> np.ndarray: - return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64), np.array( - [], dtype=np.int64 - ) + return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64) def get_document_size(self, index: int) -> int: return self._config.num_tokens_per_document diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 296102f7d..34dc714e1 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -104,10 +104,10 @@ def get_document_sizes(self) -> np.ndarray: doc_size, im_size = self.get_document_size(index) doc_sizes.append(doc_size) im_sizes.append(im_size) - return np.array(doc_sizes, dtype=np.int64), np.array(im_sizes, dtype=np.int64) + return np.array(doc_sizes, dtype=np.int64) def get_document_size(self, index: int) -> int: - return len(self._samples[index]), [] + return len(self._samples[index]) def name(self) -> str: return "dataset" From b6b642305b809c631221820be3e83c400ddb4ca1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 19 Aug 2025 08:55:52 -0400 Subject: [PATCH 96/97] misc --- fast_llm/functional/config.py | 2 - fast_llm/layers/language_model/config.py | 5 +- fast_llm/layers/vision_encoder/adapter.py | 4 +- fast_llm/layers/vision_encoder/config.py | 59 +++++++++---------- fast_llm/layers/vision_encoder/patch_conv.py | 12 ++-- .../layers/vision_encoder/preprocessing.py | 22 ++++--- fast_llm/models/gpt/model.py | 22 +++---- tests/data/test_sampling.py | 8 +-- 8 files changed, 64 insertions(+), 70 deletions(-) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 6ab82bfea..2c553d906 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -79,7 +79,6 @@ def _set_activation_fn_map() -> None: _ACTIVATION_FN_MAP: dict[ActivationType, typing.Callable[["torch.Tensor"], "torch.Tensor"]] = {} _ACTIVATION_HF_NAMES = { - ActivationType.gelu: "gelu_pytorch_tanh", ActivationType.silu: "silu", ActivationType.relu: "relu", ActivationType.squared_relu: "relu2", @@ -87,7 +86,6 @@ def _set_activation_fn_map() -> None: } _ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()} - MAX_DROPLESS_BLOCK_SIZE_ROW = 128 diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index b0bb6ec6f..9d8a65929 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -250,8 +250,11 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: # TODO: Need both? tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab, self.vocab_size)) tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) + if self.vision_encoder.enabled: - self.vision_encoder.setup_tensor_space(tensor_space) + # TODO: Remove tensor spaces so we don't need this hack. + tensor_space.vision = TensorSpace(tensor_space.distributed_config) + self.vision_encoder.setup_tensor_space(tensor_space.vision) @property def num_absolute_position_embeddings(self) -> int: diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py index 03f8a54b4..fecc6d086 100644 --- a/fast_llm/layers/vision_encoder/adapter.py +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -7,7 +7,7 @@ from fast_llm.functional.triton.mlp import torch_mlp_activation from fast_llm.layers.common.linear import Linear from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs -from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames +from fast_llm.layers.vision_encoder.config import PixtralVisionEncoderConfig, VisionEncoderDimNames from fast_llm.tensor import TensorMeta, init_normal_ @@ -16,7 +16,7 @@ class VisionAdapter(Layer): Vision adapter layer that projects vision encoder features into the language model token embeddings. """ - def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): + def __init__(self, config: PixtralVisionEncoderConfig, tensor_space: TensorSpace): super().__init__() input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) self._activation_type = config.adapter_activation_type diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 7e8d75f36..59255e5eb 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,4 +1,4 @@ -import enum +import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig @@ -76,25 +76,42 @@ class ImageNormalizationConfig(Config): ) -class VisionEncoderType(str, enum.Enum): - none = "none" - # TODO: better name? normalization, patch size, adapter can change based on implementation, no standard way currently. - pixtral = "pixtral" - - @config_class(registry=True) class VisionEncoderConfig(BaseModelConfig): + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is VisionEncoderConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return NoVisionEncoderConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class(dynamic_type={VisionEncoderConfig: "none"}) +class NoVisionEncoderConfig(BaseModelConfig): + _abstract = False + + +@config_class(dynamic_type={VisionEncoderConfig: "pixtral"}) +class PixtralVisionEncoderConfig(BaseModelConfig): _abstract = False - type: VisionEncoderType = Field( - default=VisionEncoderType.none, - desc="Type of the vision encoder. Choices: none, pixtral.", - hint=FieldHint.architecture, - ) transformer: TransformerConfig = Field( desc="Configuration for the vision transformer architecture.", hint=FieldHint.core, ) + patch_normalization: NormalizationConfig = Field( + desc="Configuration for the normalization layers applied to the image patches.", + hint=FieldHint.optional, + ) + image_normalization: ImageNormalizationConfig = Field( + desc="Configuration for the normalization layers applied to the image patches.", + hint=FieldHint.optional, + ) patch_size: int = Field( default=16, desc="Patch size for the image encoder.", @@ -105,10 +122,6 @@ class VisionEncoderConfig(BaseModelConfig): desc="Whether to use bias in the convolutional layer.", hint=FieldHint.optional, ) - patch_normalization: NormalizationConfig = Field( - desc="Configuration for the normalization layers applied to the image patches.", - hint=FieldHint.optional, - ) adapter_size: int = Field( default=5120, desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", @@ -124,10 +137,6 @@ class VisionEncoderConfig(BaseModelConfig): desc="Whether to use bias in the adapter linear layer.", hint=FieldHint.optional, ) - image_normalization: ImageNormalizationConfig = Field( - desc="Configuration for the normalization layers applied to the image patches.", - hint=FieldHint.optional, - ) image_break_token: int | None = Field( default=None, desc="Token id to separate image rows. If None, no token id is applied.", @@ -157,13 +166,3 @@ def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.patch_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.in_channels, 3)) self.transformer.setup_tensor_space(tensor_space) - - @property - def enabled(self) -> bool: - return self.type != VisionEncoderType.none - - -for name in VisionEncoderType: - # We need this because we are using the reserved field name `type`. - # TODO: Implement proper dynamic typing. - VisionEncoderConfig.register_subclass(name.value, VisionEncoderConfig) diff --git a/fast_llm/layers/vision_encoder/patch_conv.py b/fast_llm/layers/vision_encoder/patch_conv.py index 71e1b40dc..f1c5f9c21 100644 --- a/fast_llm/layers/vision_encoder/patch_conv.py +++ b/fast_llm/layers/vision_encoder/patch_conv.py @@ -5,8 +5,12 @@ from fast_llm.core.ops import split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerKwargs -from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.vision_encoder.config import ( + PixtralVisionEncoderConfig, + VisionEncoderDimNames, + VisionEncoderKwargs, +) from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ @@ -15,7 +19,7 @@ class PatchConvolution(Layer): A convolution layer applied to image patches to create embeddings for each patch. These embeddings are fed into the vision transformer. """ - def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): + def __init__(self, config: PixtralVisionEncoderConfig, tensor_space: TensorSpace): super().__init__() self._tensor_space = tensor_space self._distributed_config = tensor_space.distributed_config @@ -51,7 +55,7 @@ def forward( losses: dict[str, typing.Any] | None = None, metrics: dict | None = None, ) -> torch.Tensor: - hidden_dims = kwargs[VisionTransformerKwargs.hidden_dims] + hidden_dims = kwargs[TransformerKwargs.hidden_dims] if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) micro_batch_size = kwargs[TransformerKwargs.micro_batch_size] diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 9a01931b1..fb2482d08 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -7,8 +7,12 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerDimNames, VisionTransformerKwargs -from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs, VisionKwargs +from fast_llm.layers.vision_encoder.config import ( + PixtralVisionEncoderConfig, + VisionEncoderDimNames, + VisionEncoderKwargs, +) from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -102,7 +106,7 @@ def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tenso class VisionPreprocessor(Preprocessor): - def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): + def __init__(self, config: PixtralVisionEncoderConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config @@ -111,7 +115,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: kwargs[VisionEncoderKwargs.image_patches_meta] = TensorMeta.from_dims( ( TensorDim( - VisionTransformerDimNames.batch, + TransformerDimNames.batch, kwargs[TransformerKwargs.micro_batch_size] * kwargs[TransformerKwargs.sequence_q_dim].size, ), TensorDim(VisionEncoderDimNames.in_channels, 3), @@ -228,7 +232,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: patches = torch.cat(patches) patch_position_ids = torch.cat(patch_position_ids) kwargs[VisionEncoderKwargs.image_patches] = patches - kwargs[VisionTransformerKwargs.patch_position_ids] = patch_position_ids + kwargs[VisionKwargs.patch_position_ids] = patch_position_ids kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( kwargs[VisionEncoderKwargs.rope_theta], kwargs[VisionEncoderKwargs.kv_channels], @@ -237,12 +241,12 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ).to(device=self._tensor_space.distributed.device) kwargs[VisionEncoderKwargs.max_image_tokens] = div(max_image_size * im_width, patch_size**2) # sequence data parallel is not yet supported for images, so we use the same cu_seqlens for q and k - kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( + kwargs[TransformerKwargs.cu_seqlens_q] = torch.tensor( cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 ) - kwargs[VisionTransformerKwargs.cu_seqlens_k] = torch.tensor( + kwargs[TransformerKwargs.cu_seqlens_k] = torch.tensor( cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 ) - kwargs[VisionTransformerKwargs.max_seqlen_q] = max_seqlen - kwargs[VisionTransformerKwargs.max_seqlen_k] = max_seqlen + kwargs[TransformerKwargs.max_seqlen_q] = max_seqlen + kwargs[TransformerKwargs.max_seqlen_k] = max_seqlen kwargs[LanguageModelKwargs.labels] = labels diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a62171b3a..9650510a9 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -100,9 +100,7 @@ def get_output_layers(self) -> list[Layer]: def get_vision_layers(self) -> list[Layer]: vit_layers = [ - VisionTransformerLayer( - self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1, layer_offset=1 - ) + VisionTransformerLayer(self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1) for idx in range(self._config.vision_encoder.transformer.num_layers) ] return [ @@ -113,26 +111,20 @@ def get_vision_layers(self) -> list[Layer]: ] def get_layers(self) -> list[Layer]: - lm_layer_offset = ( - self._config.vision_encoder.transformer.num_layers + 3 if self._config.vision_encoder.enabled else 1 - ) return [ *( - [LanguageModelEmbedding(self._config, self._tensor_space)] + self.get_vision_layers() if not self._config.vision_encoder.enabled - else self.get_vision_layers() + else [LanguageModelEmbedding(self._config, self._tensor_space)] ), *[ TransformerLayer( self._config.transformer, self._tensor_space, - layer_index=i + 1 + lm_layer_offset, + layer_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, - # optionally account for patch convolution, vision transformer, vision adapter - # by default we only have the embedding layer - layer_offset=lm_layer_offset, ) for i in range(self._config.transformer.num_layers) ], @@ -387,7 +379,7 @@ def preprocess( # avoid changing input tokens labels = labels.clone() labels_cloned = True - for i, spans in enumerate(batch.loss_masking_spans): + for idx, spans in enumerate(batch.loss_masking_spans): if not spans.numel(): continue valid_spans = spans[ @@ -401,9 +393,9 @@ def preprocess( loss_mask = torch.ones_like(labels, dtype=torch.bool) for start, end in valid_spans: if sequence_first: - loss_mask[start : end + 1, i] = False + loss_mask[start : end + 1, idx] = False else: - loss_mask[i, start : end + 1] = False + loss_mask[idx, start : end + 1] = False if self._config.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 34dc714e1..b8e7a92ff 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -98,13 +98,7 @@ def __len__(self) -> int: return len(self._samples) def get_document_sizes(self) -> np.ndarray: - doc_sizes = [] - im_sizes = [] - for index in range(len(self)): - doc_size, im_size = self.get_document_size(index) - doc_sizes.append(doc_size) - im_sizes.append(im_size) - return np.array(doc_sizes, dtype=np.int64) + return np.array([self.get_document_size(index) for index in range(len(self))], dtype=np.int64) def get_document_size(self, index: int) -> int: return len(self._samples[index]) From 684ee62af97a1d008de178506c1c69554d8d06f2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 19 Aug 2025 16:10:10 -0400 Subject: [PATCH 97/97] attempt --- fast_llm/layers/transformer/config.py | 2 + fast_llm/layers/vision_encoder/config.py | 15 ++-- fast_llm/layers/vision_encoder/patch_conv.py | 32 +++---- .../layers/vision_encoder/preprocessing.py | 43 ++++----- fast_llm/models/gpt/model.py | 87 ++++++++----------- tests/test_config.py | 18 +--- 6 files changed, 82 insertions(+), 115 deletions(-) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 20a750f69..4d63b927f 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -80,6 +80,8 @@ class TransformerKwargs: sequence_q_dim = "sequence_q_dim" sequence_k_dim = "sequence_k_dim" sequence_length = "sequence_length" + batch_dim = "batch_dim" + micro_batch_size = (micro_batch_size,) # TODO: Move grad_output = "grad_output" diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 59255e5eb..14ff578dc 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,3 +1,4 @@ +import functools import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none @@ -24,17 +25,11 @@ class VisionEncoderKwargs: image_positions = "image_positions" max_image_size = "max_image_size" image_sizes = "image_sizes" - image_mean = "image_normalization_mean" - image_std = "image_normalization_std" image_rescale_factor = "image_rescale_factor" - rope_theta = "vit_rope_theta" - rotary_inv_freq = "vit_rotary_inv_freq" kv_channels = "vit_kv_channels" max_image_tokens = "max_image_tokens" - patch_embeddings = "patch_embeddings" hidden_dims = "vit_hidden_dims" image_patches_meta = "vit_image_patches_meta" - out_channels = "vit_out_channels" @config_class() @@ -75,6 +70,14 @@ class ImageNormalizationConfig(Config): hint=FieldHint.optional, ) + @functools.cached_property + def mean(self) -> list[float]: + return [self.mean_red, self.mean_green, self.mean_blue] + + @functools.cached_property + def std(self) -> list[float]: + return [self.std_red, self.std_green, self.std_blue] + @config_class(registry=True) class VisionEncoderConfig(BaseModelConfig): diff --git a/fast_llm/layers/vision_encoder/patch_conv.py b/fast_llm/layers/vision_encoder/patch_conv.py index f1c5f9c21..d91fed163 100644 --- a/fast_llm/layers/vision_encoder/patch_conv.py +++ b/fast_llm/layers/vision_encoder/patch_conv.py @@ -6,11 +6,7 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.layers.vision_encoder.config import ( - PixtralVisionEncoderConfig, - VisionEncoderDimNames, - VisionEncoderKwargs, -) +from fast_llm.layers.vision_encoder.config import PixtralVisionEncoderConfig, VisionEncoderDimNames from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ @@ -21,10 +17,10 @@ class PatchConvolution(Layer): def __init__(self, config: PixtralVisionEncoderConfig, tensor_space: TensorSpace): super().__init__() + self._config = config self._tensor_space = tensor_space self._distributed_config = tensor_space.distributed_config self._sequence_tensor_parallel = self._distributed_config.sequence_tensor_parallel - self._lr_scale = config.adapter_lr_scale self.weight = ParameterMeta.from_dims( ( self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels), @@ -33,20 +29,19 @@ def __init__(self, config: PixtralVisionEncoderConfig, tensor_space: TensorSpace self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size), ), init_method=init_normal_(), - lr_scale=self._lr_scale, + lr_scale=self._config.adapter_lr_scale, ) if config.conv_bias: self.bias = ParameterMeta.from_dims( (self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels),), init_method=init_normal_(), - lr_sclae=self._lr_scale, + lr_sclae=self._config.adapter_lr_scale, ) else: self.bias = None self.normalization = config.patch_normalization.get_layer( tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) ) - self._stride = config.patch_size def forward( self, @@ -58,15 +53,14 @@ def forward( hidden_dims = kwargs[TransformerKwargs.hidden_dims] if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) - micro_batch_size = kwargs[TransformerKwargs.micro_batch_size] - sequence_length = kwargs[TransformerKwargs.sequence_length] - out_channels = kwargs[VisionEncoderKwargs.out_channels] - reshape_dims = (micro_batch_size, sequence_length, out_channels) - group = self._tensor_space.distributed.tensor_group - input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self._stride) - patch_embeddings = self.normalization(input_.flatten(1)) - patch_embeddings = patch_embeddings.view(reshape_dims) - if self._sequence_tensor_parallel: + input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self._config.patch_size) + patch_embeddings = self.normalization(input_.flatten(1)).view( + kwargs[TransformerKwargs.batch_dim].size, + kwargs[TransformerKwargs.sequence_q_dim].size, + self._config.transformer.hidden_size, + ) + if kwargs[TransformerKwargs.sequence_first]: patch_embeddings = patch_embeddings.permute(1, 0, 2).contiguous() - patch_embeddings = split(patch_embeddings, group=group, dim=0) + if self._sequence_tensor_parallel: + patch_embeddings = split(patch_embeddings, group=self._tensor_space.distributed.tensor_group, dim=0) return patch_embeddings diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index fb2482d08..8f9be8012 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -119,8 +119,8 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: kwargs[TransformerKwargs.micro_batch_size] * kwargs[TransformerKwargs.sequence_q_dim].size, ), TensorDim(VisionEncoderDimNames.in_channels, 3), - TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), - TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + TensorDim(VisionEncoderDimNames.patch_size, self._config.patch_size), + TensorDim(VisionEncoderDimNames.patch_size, self._config.patch_size), ), dtype=self._distributed_config.training_dtype.torch, ) @@ -129,22 +129,24 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: images = kwargs.get(VisionEncoderKwargs.images) max_image_size = kwargs.get(VisionEncoderKwargs.max_image_size) im_width = kwargs.get(VisionEncoderKwargs.max_image_size) - patch_size = kwargs[VisionEncoderKwargs.patch_size] image_positions = kwargs.get(VisionEncoderKwargs.image_positions) image_sizes = [ - [get_resize_dims(im.size(1), im.size(2), max_image_size, im_width, patch_size=patch_size) for im in ims] + [ + get_resize_dims(im.size(1), im.size(2), max_image_size, im_width, patch_size=self._config.patch_size) + for im in ims + ] for ims in images ] kwargs[VisionEncoderKwargs.image_sizes] = image_sizes images = [ [ torchvision_transforms.functional.normalize( - resize(image, max_image_size, im_width, patch_size).to( + resize(image, max_image_size, im_width, self._config.patch_size).to( dtype=self._tensor_space.distributed_config.training_dtype.torch ) / kwargs[VisionEncoderKwargs.image_rescale_factor], - mean=kwargs[VisionEncoderKwargs.image_mean], - std=kwargs[VisionEncoderKwargs.image_std], + mean=self._config.image_normalization.mean, + std=self._config.image_normalization.std, ) for image in imgs ] @@ -172,10 +174,10 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] sample_cu_seqlen = 0 for image, size, position in zip(imgs, sizes, positions): - seqlen = get_num_patches(*size, patch_size) + seqlen = get_num_patches(*size, self._config.patch_size) num_tokens = get_num_image_tokens( *size, - patch_size=patch_size, + patch_size=self._config.patch_size, image_break=self._config.image_break_token is not None, image_end=self._config.image_end_token is not None, ) @@ -188,9 +190,9 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: seq_patches.append( torch.cat( [ - torch.nn.functional.unfold(image, kernel_size=patch_size, stride=patch_size).T.reshape( - -1, 3, patch_size, patch_size - ), + torch.nn.functional.unfold( + image, kernel_size=self._config.patch_size, stride=self._config.patch_size + ).T.reshape(-1, 3, self._config.patch_size, self._config.patch_size), ] ) ) @@ -202,7 +204,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: torch.cat( [ *seq_patches, - torch.zeros(padding_size, 3, patch_size, patch_size).to( + torch.zeros(padding_size, 3, self._config.patch_size, self._config.patch_size).to( dtype=self._tensor_space.distributed_config.training_dtype.torch, device=self._tensor_space.distributed.device, ), @@ -211,7 +213,12 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ) if sizes: position_ids = torch.cat( - [position_ids_in_meshgrid(*size, max_image_size // patch_size, patch_size) for size in sizes] + [ + position_ids_in_meshgrid( + *size, max_image_size // self._config.patch_size, self._config.patch_size + ) + for size in sizes + ] ).to(device=self._tensor_space.distributed.device) else: position_ids = torch.tensor( @@ -233,13 +240,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: patch_position_ids = torch.cat(patch_position_ids) kwargs[VisionEncoderKwargs.image_patches] = patches kwargs[VisionKwargs.patch_position_ids] = patch_position_ids - kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( - kwargs[VisionEncoderKwargs.rope_theta], - kwargs[VisionEncoderKwargs.kv_channels], - max_image_size, - patch_size, - ).to(device=self._tensor_space.distributed.device) - kwargs[VisionEncoderKwargs.max_image_tokens] = div(max_image_size * im_width, patch_size**2) + kwargs[VisionEncoderKwargs.max_image_tokens] = div(max_image_size * im_width, self._config.patch_size**2) # sequence data parallel is not yet supported for images, so we use the same cu_seqlens for q and k kwargs[TransformerKwargs.cu_seqlens_q] = torch.tensor( cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 9650510a9..754a0235e 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -20,13 +20,11 @@ TransformerDimNames, TransformerKwargs, TransformerLossNames, - VisionTransformerDimNames, - VisionTransformerKwargs, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.layers.transformer.transformer import TransformerLayer, VisionTransformerLayer from fast_llm.layers.vision_encoder.adapter import VisionAdapter -from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs from fast_llm.layers.vision_encoder.patch_conv import PatchConvolution from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig @@ -147,36 +145,6 @@ def preprocess_meta( sequence_length -= self._config.prediction_heads micro_sequence_length = sequence_length - if self._config.vision_encoder.enabled: - max_image_size = batch_meta.max_image_size - image_mean = [ - self._config.vision_encoder.image_normalization.mean_red, - self._config.vision_encoder.image_normalization.mean_green, - self._config.vision_encoder.image_normalization.mean_blue, - ] - image_std = [ - self._config.vision_encoder.image_normalization.std_red, - self._config.vision_encoder.image_normalization.std_green, - self._config.vision_encoder.image_normalization.std_blue, - ] - image_rescale_factor = self._config.vision_encoder.image_normalization.rescale_factor - vision_kwargs = { - VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, - VisionEncoderKwargs.max_image_size: max_image_size, - VisionEncoderKwargs.image_mean: image_mean, - VisionEncoderKwargs.image_std: image_std, - VisionEncoderKwargs.image_rescale_factor: image_rescale_factor, - VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, - VisionEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( - VisionTransformerDimNames.kv_channels - ).size, - VisionEncoderKwargs.out_channels: self._tensor_space.get_tensor_dim( - VisionEncoderDimNames.out_channels - ).size, - } - else: - vision_kwargs = {} - batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) @@ -219,8 +187,38 @@ def preprocess_meta( if sequence_first else (batch_dim, hidden_sequence_q_dim, hidden_dim) ) + + common_kwargs = { + LanguageModelKwargs.phase: phase, + TransformerKwargs.sequence_first: sequence_first, + TransformerKwargs.hidden_dims: hidden_dims, + TransformerKwargs.sequence_length: sequence_length, + TransformerKwargs.sequence_q_dim: sequence_q_dim, + TransformerKwargs.batch_dim: batch_dim, + } + if self._config.vision_encoder.enabled: - vision_hidden_dim = self._tensor_space.get_tensor_dim(VisionTransformerDimNames.hidden) + max_image_size = batch_meta.max_image_size + image_mean = [ + self._config.vision_encoder.image_normalization.mean_red, + self._config.vision_encoder.image_normalization.mean_green, + self._config.vision_encoder.image_normalization.mean_blue, + ] + image_std = [ + self._config.vision_encoder.image_normalization.std_red, + self._config.vision_encoder.image_normalization.std_green, + self._config.vision_encoder.image_normalization.std_blue, + ] + image_rescale_factor = self._config.vision_encoder.image_normalization.rescale_factor + vision_kwargs = { + VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, + VisionEncoderKwargs.max_image_size: max_image_size, + VisionEncoderKwargs.image_rescale_factor: image_rescale_factor, + VisionEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( + TransformerDimNames.kv_channels + ).size, + } + vision_hidden_dim = self._tensor_space.vision.get_tensor_dim(TransformerDimNames.hidden) vision_hidden_dims = ( (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) if sequence_first @@ -228,19 +226,10 @@ def preprocess_meta( ) vision_kwargs.update( { - VisionTransformerKwargs.hidden_dims: vision_hidden_dims, + TransformerKwargs.hidden_dims: vision_hidden_dims, } ) - - common_kwargs = { - LanguageModelKwargs.phase: phase, - TransformerKwargs.sequence_first: sequence_first, - TransformerKwargs.hidden_dims: hidden_dims, - TransformerKwargs.sequence_length: sequence_length, - TransformerKwargs.sequence_q_dim: sequence_q_dim, - TransformerKwargs.micro_batch_size: micro_batch_size, - } - common_kwargs.update(vision_kwargs) + common_kwargs["vision"] = vision_kwargs sequence_k_pasts = range( sequence_q_dim.size * self._tensor_space.distributed_config.sequence_data_rank, @@ -374,11 +363,9 @@ def preprocess( labels = batch.token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config - labels_cloned = False if batch.loss_masking_spans is not None: # avoid changing input tokens labels = labels.clone() - labels_cloned = True for idx, spans in enumerate(batch.loss_masking_spans): if not spans.numel(): continue @@ -401,14 +388,8 @@ def preprocess( labels = torch.where(loss_mask, labels, -100) if self._config.vision_encoder.enabled: if self._config.vision_encoder.image_break_token is not None: - if not labels_cloned: - labels = labels.clone() - labels_cloned = True labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) if self._config.vision_encoder.image_end_token is not None: - if not labels_cloned: - labels = labels.clone() - labels_cloned = True labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) diff --git a/tests/test_config.py b/tests/test_config.py index 30646e660..9e09e8041 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -137,14 +137,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "tie_word_embeddings": False, "vocab_size": 1000, - "vision_encoder": { - "transformer": { - "normalization": {"type": "layer_norm"}, - "rotary": {"type": "none"}, - "peft": {"type": "none"}, - }, - "patch_normalization": {"type": "layer_norm"}, - }, + "vision_encoder": {"type": "none"}, } else: base_model_update["transformer"]["peft"] = { @@ -154,14 +147,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): } base_model_update["transformer"]["normalization"]["type"] = "layer_norm" base_model_update["transformer"]["rotary"] = {"type": "none"} - base_model_update["vision_encoder"] = { - "transformer": { - "normalization": {"type": "layer_norm"}, - "rotary": {"type": "none"}, - "peft": {"type": "none"}, - }, - "patch_normalization": {"type": "layer_norm"}, - } + base_model_update["vision_encoder"] = {"type": "none"} expected_config["base_model"] = base_model_update check_equal_nested(serialized_config, expected_config)