diff --git a/caduceus/configuration_caduceus.py b/caduceus/configuration_caduceus.py index dfccd5b..519eea2 100644 --- a/caduceus/configuration_caduceus.py +++ b/caduceus/configuration_caduceus.py @@ -34,6 +34,7 @@ def __init__( bidirectional_strategy: Union[str, None] = "add", bidirectional_weight_tie: bool = True, rcps: bool = False, + gradient_checkpointing_stride: int = 1, complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead **kwargs, ): @@ -52,4 +53,5 @@ def __init__( self.bidirectional_strategy = bidirectional_strategy self.bidirectional_weight_tie = bidirectional_weight_tie self.rcps = rcps + self.gradient_checkpointing_stride = gradient_checkpointing_stride self.complement_map = complement_map diff --git a/caduceus/modeling_caduceus.py b/caduceus/modeling_caduceus.py index 9a1206b..12f76a1 100644 --- a/caduceus/modeling_caduceus.py +++ b/caduceus/modeling_caduceus.py @@ -5,7 +5,7 @@ import inspect import math from functools import partial -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Any, Protocol import torch from mamba_ssm.modules.mamba_simple import Mamba @@ -163,7 +163,43 @@ def forward(self, input_ids): return self.word_embeddings(input_ids) -class CaduceusMixerModel(nn.Module): +class HFGCProtocol(Protocol): + """Protocol for modules that support gradient checkpointing with Hugging Face Transformers.""" + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs: dict[str, Any]) -> None: + """Enable gradient checkpointing. + + Args: + gradient_checkpointing_kwargs: Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. + + See Also: + - [Transformers Documentation - enable gradient checkpointing](https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.gradient_checkpointing_enable) + - [Transformers Source - enable implementation](https://github.com/huggingface/transformers/blob/6bc0fbcfa7acb6ac4937e7456a76c2f7975fefec/src/transformers/modeling_utils.py#L2521) + """ + ... + + def gradient_checkpointing_disable(self) -> None: + """Disable gradient checkpointing. + + See Also: + - [Transformers Documentation - disable gradient checkpointing](https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.gradient_checkpointing_disable) + - [Transformers Source - disable implementation](https://github.com/huggingface/transformers/blob/6bc0fbcfa7acb6ac4937e7456a76c2f7975fefec/src/transformers/modeling_utils.py#L2585) + """ + ... + + +class MCGCProtocol(Protocol): + """Protocol for modules that support gradient checkpointing with MosaicML Composer.""" + + def activation_checkpointing_fn(self, module: nn.Module) -> bool: + """Determine if module should be checkpointed. + + See Also: + - [Composer Documentation - FSDP auto wrap policy](https://github.com/mosaicml/composer/blob/7fa03545cc2025f256d914abc111a068d239d632/docs/source/notes/distributed_training.rst#composers-fsdp-auto-wrap-policy) + - [MosaicML Examples - GPT implementation](https://github.com/mosaicml/examples/blob/6972fe3000d5a5480d8757ff710965514155e8db/llm/llm/gpt.py#L173-L175) + """ + +class CaduceusMixerModel(nn.Module, HFGCProtocol, MCGCProtocol): def __init__( self, config: CaduceusConfig, @@ -173,10 +209,10 @@ def __init__( super().__init__() factory_kwargs = {"device": device, "dtype": dtype} + self.config = config self.fused_add_norm = config.fused_add_norm self.rcps = config.rcps self.residual_in_fp32 = config.residual_in_fp32 - self.embeddings = CaduceusEmbeddings(config, **factory_kwargs) # Mamba changes the order of residual and layer norm: @@ -188,6 +224,13 @@ def __init__( if layer_norm_fn is None or rms_norm_fn is None: raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") + self.gradient_checkpointing = False + if not (1 <= config.gradient_checkpointing_stride <= config.n_layer): + raise ValueError( + f"`gradient_checkpointing_stride` must be between 1 and {config.n_layer}; " + f"got {config.gradient_checkpointing_stride}." + ) + self.layers = nn.ModuleList( [ create_block( @@ -213,6 +256,27 @@ def __init__( ) self.norm_f = norm_f if (config.fused_add_norm or not config.rcps) else RCPSAddNormWrapper(norm_f) + + def _gradient_checkpointing_indexes(self) -> list[int]: + return [ + i for i in range(len(self.layers)) + if i % self.config.gradient_checkpointing_stride == 0 + ] + + def activation_checkpointing_fn(self, module: nn.Module) -> bool: + for index in self._gradient_checkpointing_indexes(): + if self.layers[index] is module: + return True + return False + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs: dict[str, Any]) -> None: + self.gradient_checkpointing = True + self.gradient_checkpointing_kwargs = gradient_checkpointing_kwargs + + def gradient_checkpointing_disable(self) -> None: + self.gradient_checkpointing = False + self.gradient_checkpointing_kwargs = None + def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False): """Mixer forward.""" all_hidden_states = [] @@ -220,14 +284,21 @@ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False): hidden_states = inputs_embeds else: hidden_states = self.embeddings(input_ids) - + residual = None - for layer in self.layers: + checkpoint_indexes = set(self._gradient_checkpointing_indexes()) + for index, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states.append(hidden_states) - # TODO: Add support for gradient checkpointing - hidden_states, residual = layer( - hidden_states, residual, inference_params=None + layer_fn = layer + if self.gradient_checkpointing and index in checkpoint_indexes: + layer_fn = partial( + torch.utils.checkpoint.checkpoint, + layer, **self.gradient_checkpointing_kwargs + ) + hidden_states, residual = layer_fn( + # Only positional args can be used for `torch.utils.checkpoint.checkpoint` + hidden_states, residual, None # inference_params=None ) if not self.fused_add_norm: @@ -298,7 +369,7 @@ class CaduceusPreTrainedModel(PreTrainedModel): """PreTrainedModel wrapper for Caduceus backbone.""" config_class = CaduceusConfig base_model_prefix = "caduceus" - supports_gradient_checkpointing = False + supports_gradient_checkpointing = True _no_split_modules = ["BiMambaWrapper"] def _init_weights( @@ -340,8 +411,7 @@ def _init_weights( with torch.no_grad(): p /= math.sqrt(n_residuals_per_layer * n_layer) - -class Caduceus(CaduceusPreTrainedModel): +class Caduceus(CaduceusPreTrainedModel, HFGCProtocol): """Caduceus model that can be instantiated using HF patterns.""" def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs): super().__init__(config) @@ -360,6 +430,12 @@ def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs): factory_kwargs = {"device": device, "dtype": dtype} self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs) + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs: dict[str, Any]) -> None: + self.backbone.gradient_checkpointing_enable(gradient_checkpointing_kwargs) + + def gradient_checkpointing_disable(self) -> None: + self.backbone.gradient_checkpointing_disable() + def forward( self, input_ids: torch.LongTensor = None, @@ -389,7 +465,7 @@ def forward( return hidden_states -class CaduceusForMaskedLM(CaduceusPreTrainedModel): +class CaduceusForMaskedLM(CaduceusPreTrainedModel, HFGCProtocol): """HF-compatible Caduceus model for masked language modeling.""" def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs): @@ -414,6 +490,12 @@ def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs): # Initialize weights and apply final processing self.post_init() + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs: dict[str, Any]) -> None: + self.caduceus.gradient_checkpointing_enable(gradient_checkpointing_kwargs) + + def gradient_checkpointing_disable(self) -> None: + self.caduceus.gradient_checkpointing_disable() + def get_input_embeddings(self): return self.caduceus.backbone.embeddings.word_embeddings @@ -492,7 +574,7 @@ def forward( ) -class CaduceusForSequenceClassification(CaduceusPreTrainedModel): +class CaduceusForSequenceClassification(CaduceusPreTrainedModel, HFGCProtocol): def __init__( self, config: CaduceusConfig, @@ -638,3 +720,11 @@ def forward( logits=logits, hidden_states=transformer_outputs.hidden_states, ) + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs: dict[str, Any]) -> None: + self.caduceus.gradient_checkpointing_enable(gradient_checkpointing_kwargs) + + def gradient_checkpointing_disable(self) -> None: + self.caduceus.gradient_checkpointing_disable() + + diff --git a/caduceus/tests/test_modeling.py b/caduceus/tests/test_modeling.py new file mode 100644 index 0000000..78d191e --- /dev/null +++ b/caduceus/tests/test_modeling.py @@ -0,0 +1,269 @@ +import pytest +import torch +from collections import defaultdict +from typing import Optional, Dict, Any, Literal +from torch import nn +from torch.utils.data import Dataset + +from caduceus.configuration_caduceus import CaduceusConfig +from caduceus.modeling_caduceus import CaduceusForMaskedLM +from transformers import TrainingArguments, Trainer +from composer import Trainer as ComposerTrainer +from composer.models import HuggingFaceModel +from composer.optim import DecoupledAdamW + + +# Reentrant gradient checkpointing is required to force recomputation +# on backward passes, which is necessary for tests that verify checkpointing +# usage based on observing those recomputations +USE_REENTRANT_CHECKPOINTS = True + +def create_test_model( + config_overrides: Optional[Dict[str, Any]] = None, + device: torch.device = torch.device("cuda"), + seed: int = 0 +) -> CaduceusForMaskedLM: + """Create a CaduceusForMaskedLM model with test configuration.""" + torch.random.manual_seed(seed) + + # Default test configuration + config_params: Dict[str, Any] = { + 'd_model': 128, + 'n_layer': 4, + 'vocab_size': 10, + 'gradient_checkpointing_stride': 1, + 'pad_token_id': -100 + } + + # Update with any overrides + if config_overrides: + config_params.update(config_overrides) + + config = CaduceusConfig(**config_params) + return CaduceusForMaskedLM(config).to(device) + + +def create_test_inputs( + model: CaduceusForMaskedLM, + batch_size: Optional[int] = 2, + seq_len: int = 128, + device: Optional[torch.device] = None, +) -> Dict[str, torch.Tensor]: + """Create random input tensors for testing.""" + shape = (batch_size, seq_len) if batch_size is not None else (seq_len,) + return { + 'input_ids': torch.randint(0, model.config.vocab_size, shape, device=device), + 'labels': torch.randint(0, model.config.vocab_size, shape, device=device) + } + + +def test_caduceus_masked_lm(): + """Test basic CaduceusForMaskedLM functionality with default settings.""" + # Create model with default config + model = create_test_model() + + # Generate random input + batch_size, seq_len = 3, 128 + inputs = create_test_inputs(model, batch_size=batch_size, seq_len=seq_len, device=model.device) + + # Run forward pass + outputs = model(**inputs) + + # Check output shapes + assert outputs.logits.shape == (batch_size, seq_len, model.config.vocab_size), "Unexpected logits shape" + + # Check loss is computed and backpropagates + assert outputs.loss is not None, "Loss should be computed when labels are provided" + outputs.loss.backward() + + # Check all parameters received gradients + for name, param in model.named_parameters(): + assert param.grad is not None, f"Parameter {name} has no gradient" + assert not torch.isnan(param.grad).any(), f"Parameter {name} has NaN gradients" + assert not torch.isinf(param.grad).any(), f"Parameter {name} has Inf gradients" + +def run_direct_pass( + model: CaduceusForMaskedLM, + inputs: Dict[str, torch.Tensor], + gradient_checkpointing: bool +) -> Dict[nn.Module, int]: + """Run single forward and backward pass explicitly and return forward pass counts.""" + forward_counts = defaultdict(int) + def count_forwards(module, input, output): + forward_counts[module] += 1 + + # Register hooks on each layer + for layer in model.caduceus.backbone.layers: + layer.register_forward_hook(count_forwards) + + if gradient_checkpointing: + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": USE_REENTRANT_CHECKPOINTS}) + else: + model.gradient_checkpointing_disable() + + outputs = model(**inputs) + outputs.loss.backward() + + return dict(forward_counts) + + +def run_hf_training( + model: CaduceusForMaskedLM, + inputs: Dict[str, torch.Tensor], + gradient_checkpointing: bool, + output_dir: str +) -> Dict[nn.Module, int]: + """Run training using HF Trainer and return forward pass counts. + + See Also: + - [Transformers Documentation - Trainer](https://huggingface.co/docs/transformers/v4.48.0/en/main_classes/trainer#transformers.Trainer) + - [Transformers Source - Trainer implementation](https://github.com/huggingface/transformers/blob/v4.48.0/src/transformers/trainer.py#L312) + """ + forward_counts = defaultdict(int) + def count_forwards(module, input, output): + forward_counts[module] += 1 + + # Register hooks on each layer + for layer in model.caduceus.backbone.layers: + layer.register_forward_hook(count_forwards) + + training_args = TrainingArguments( + output_dir=output_dir, + per_device_train_batch_size=1, + disable_tqdm=True, + gradient_checkpointing=gradient_checkpointing, + gradient_checkpointing_kwargs={"use_reentrant": USE_REENTRANT_CHECKPOINTS}, + max_steps=1, + logging_strategy="no", + report_to="none", + # Ensure that serialization is tested as well + save_strategy="steps", + save_steps=1, + # Use safetensors instead of native serialization to avoid errors about + # saving shared tensors in different modules (as `BiMambaWrapper` does), e.g.: + # RuntimeError: The weights trying to be saved contained shared tensors [] that are mismatching the transformers base configuration. + save_safetensors=False, + # Overwrite must be enabled as pytest tmp_dir is named systematically, e.g.: + # /tmp/pytest-of-ubuntu/pytest-3/test_activation_checkpointing_1 + overwrite_output_dir=True, + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=SimpleDataset(inputs), + ) + + trainer.train() + return dict(forward_counts) + +def run_mosaic_training( + model: CaduceusForMaskedLM, + inputs: Dict[str, torch.Tensor], + gradient_checkpointing: bool, + output_dir: str +) -> Dict[nn.Module, int]: + """Run training using MosaicML Composer Trainer and return forward pass counts. + + See Also: + - [Composer Documentation - Trainer](https://docs.mosaicml.com/projects/composer/en/latest/api_reference/generated/composer.Trainer.html) + - [Composer Source - Trainer implementation](https://docs.mosaicml.com/projects/composer/en/latest/_modules/composer/trainer/trainer.html#Trainer) + """ + forward_counts = defaultdict(int) + def count_forwards(module, input, output): + forward_counts[module] += 1 + + # Register hooks on each layer + for layer in model.caduceus.backbone.layers: + layer.register_forward_hook(count_forwards) + + if gradient_checkpointing: + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": USE_REENTRANT_CHECKPOINTS}) + else: + model.gradient_checkpointing_disable() + + # Wrap the model for Composer + composer_model = HuggingFaceModel(model) + + optimizer = DecoupledAdamW(model.parameters()) + + trainer = ComposerTrainer( + optimizers=optimizer, + model=composer_model, + log_to_console=False, + progress_bar=False, + train_dataloader=SimpleDataset(inputs), + max_duration='1ba', + device_train_microbatch_size=1, + # Ensure that serialization is tested as well + save_folder=output_dir, + save_interval='1ep', + # Overwrite must be enabled as pytest tmp_dir is named systematically, e.g.: + # /tmp/pytest-of-ubuntu/pytest-3/test_activation_checkpointing_1 + save_overwrite=True, + ) + + trainer.fit() + return dict(forward_counts) + +class SimpleDataset(Dataset): + """Simple dataset wrapper for a single input (or single batch of inputs).""" + + def __init__(self, inputs: Dict[str, torch.Tensor]): + self.inputs = inputs + + def __len__(self) -> int: + return 1 + + def __getitem__(self, _: int) -> Dict[str, torch.Tensor]: + return self.inputs + +@pytest.mark.parametrize("gradient_checkpointing_stride", [1, 2]) +@pytest.mark.parametrize("mode", ["direct", "huggingface", "mosaic"]) +def test_activation_checkpointing_recomputation( + gradient_checkpointing_stride: int, + mode: Literal["direct", "huggingface", "mosaic"], + tmp_path, +): + """Test that activation checkpointing causes expected recomputation.""" + # Create model with specified stride + model = create_test_model({ + 'gradient_checkpointing_stride': gradient_checkpointing_stride + }) + + # Generate random input + # Trainer APIs handle device placement and batching so we only those it in `direct` mode + batch_size = 2 if mode == "direct" or mode == "mosaic" else None + device = model.device if mode == "direct" or mode == "mosaic" else None + inputs = create_test_inputs(model, batch_size=batch_size, device=device) + + # Run training with and without checkpointing + run_fn = (run_hf_training if mode == "huggingface" + else run_mosaic_training if mode == "mosaic" + else run_direct_pass) + kwargs = {"output_dir": str(tmp_path)} if mode in ["huggingface", "mosaic"] else {} + + forward_counts_no_checkpoint = run_fn(model, inputs, False, **kwargs) + forward_counts_checkpoint = run_fn(model, inputs, True, **kwargs) + + # Verify counts + for layer in model.caduceus.backbone.layers: + layer_idx = layer.layer_idx + if layer_idx % gradient_checkpointing_stride == 0: + # Checkpointed layers should be computed twice + assert forward_counts_checkpoint[layer] == 2 * forward_counts_no_checkpoint[layer] + else: + # Non-checkpointed layers should be computed the same number of times + assert forward_counts_checkpoint[layer] == forward_counts_no_checkpoint[layer] + +def test_invalid_gradient_checkpointing_stride(): + """Test that invalid gradient_checkpointing_stride raises ValueError.""" + # Test stride = 0 + with pytest.raises(ValueError, match=r".*must be between 1 and \d+.*"): + create_test_model({'gradient_checkpointing_stride': 0}) + + # Test stride > n_layer + with pytest.raises(ValueError, match=r".*must be between 1 and \d+.*"): + create_test_model({'gradient_checkpointing_stride': 5}) + +