Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions caduceus/configuration_caduceus.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
bidirectional_strategy: Union[str, None] = "add",
bidirectional_weight_tie: bool = True,
rcps: bool = False,
gradient_checkpointing_stride: int = 1,
complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead
**kwargs,
):
Expand All @@ -52,4 +53,5 @@ def __init__(
self.bidirectional_strategy = bidirectional_strategy
self.bidirectional_weight_tie = bidirectional_weight_tie
self.rcps = rcps
self.gradient_checkpointing_stride = gradient_checkpointing_stride
self.complement_map = complement_map
116 changes: 103 additions & 13 deletions caduceus/modeling_caduceus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import inspect
import math
from functools import partial
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, Any, Protocol

import torch
from mamba_ssm.modules.mamba_simple import Mamba
Expand Down Expand Up @@ -163,7 +163,43 @@ def forward(self, input_ids):
return self.word_embeddings(input_ids)


class CaduceusMixerModel(nn.Module):
class HFGCProtocol(Protocol):
"""Protocol for modules that support gradient checkpointing with Hugging Face Transformers."""

def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs: dict[str, Any]) -> None:
"""Enable gradient checkpointing.

Args:
gradient_checkpointing_kwargs: Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.

See Also:
- [Transformers Documentation - enable gradient checkpointing](https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.gradient_checkpointing_enable)
- [Transformers Source - enable implementation](https://github.com/huggingface/transformers/blob/6bc0fbcfa7acb6ac4937e7456a76c2f7975fefec/src/transformers/modeling_utils.py#L2521)
"""
...

def gradient_checkpointing_disable(self) -> None:
"""Disable gradient checkpointing.

See Also:
- [Transformers Documentation - disable gradient checkpointing](https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.gradient_checkpointing_disable)
- [Transformers Source - disable implementation](https://github.com/huggingface/transformers/blob/6bc0fbcfa7acb6ac4937e7456a76c2f7975fefec/src/transformers/modeling_utils.py#L2585)
"""
...


class MCGCProtocol(Protocol):
"""Protocol for modules that support gradient checkpointing with MosaicML Composer."""

def activation_checkpointing_fn(self, module: nn.Module) -> bool:
"""Determine if module should be checkpointed.

See Also:
- [Composer Documentation - FSDP auto wrap policy](https://github.com/mosaicml/composer/blob/7fa03545cc2025f256d914abc111a068d239d632/docs/source/notes/distributed_training.rst#composers-fsdp-auto-wrap-policy)
- [MosaicML Examples - GPT implementation](https://github.com/mosaicml/examples/blob/6972fe3000d5a5480d8757ff710965514155e8db/llm/llm/gpt.py#L173-L175)
"""

class CaduceusMixerModel(nn.Module, HFGCProtocol, MCGCProtocol):
def __init__(
self,
config: CaduceusConfig,
Expand All @@ -173,10 +209,10 @@ def __init__(
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}

self.config = config
self.fused_add_norm = config.fused_add_norm
self.rcps = config.rcps
self.residual_in_fp32 = config.residual_in_fp32

self.embeddings = CaduceusEmbeddings(config, **factory_kwargs)

# Mamba changes the order of residual and layer norm:
Expand All @@ -188,6 +224,13 @@ def __init__(
if layer_norm_fn is None or rms_norm_fn is None:
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")

self.gradient_checkpointing = False
if not (1 <= config.gradient_checkpointing_stride <= config.n_layer):
raise ValueError(
f"`gradient_checkpointing_stride` must be between 1 and {config.n_layer}; "
f"got {config.gradient_checkpointing_stride}."
)

self.layers = nn.ModuleList(
[
create_block(
Expand All @@ -213,21 +256,49 @@ def __init__(
)
self.norm_f = norm_f if (config.fused_add_norm or not config.rcps) else RCPSAddNormWrapper(norm_f)


def _gradient_checkpointing_indexes(self) -> list[int]:
return [
i for i in range(len(self.layers))
if i % self.config.gradient_checkpointing_stride == 0
]

def activation_checkpointing_fn(self, module: nn.Module) -> bool:
for index in self._gradient_checkpointing_indexes():
if self.layers[index] is module:
return True
return False

def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs: dict[str, Any]) -> None:
self.gradient_checkpointing = True
self.gradient_checkpointing_kwargs = gradient_checkpointing_kwargs

def gradient_checkpointing_disable(self) -> None:
self.gradient_checkpointing = False
self.gradient_checkpointing_kwargs = None

def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
"""Mixer forward."""
all_hidden_states = []
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(input_ids)

residual = None
for layer in self.layers:
checkpoint_indexes = set(self._gradient_checkpointing_indexes())
for index, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states.append(hidden_states)
# TODO: Add support for gradient checkpointing
hidden_states, residual = layer(
hidden_states, residual, inference_params=None
layer_fn = layer
if self.gradient_checkpointing and index in checkpoint_indexes:
layer_fn = partial(
torch.utils.checkpoint.checkpoint,
layer, **self.gradient_checkpointing_kwargs
)
hidden_states, residual = layer_fn(
# Only positional args can be used for `torch.utils.checkpoint.checkpoint`
hidden_states, residual, None # inference_params=None
)

if not self.fused_add_norm:
Expand Down Expand Up @@ -298,7 +369,7 @@ class CaduceusPreTrainedModel(PreTrainedModel):
"""PreTrainedModel wrapper for Caduceus backbone."""
config_class = CaduceusConfig
base_model_prefix = "caduceus"
supports_gradient_checkpointing = False
supports_gradient_checkpointing = True
_no_split_modules = ["BiMambaWrapper"]

def _init_weights(
Expand Down Expand Up @@ -340,8 +411,7 @@ def _init_weights(
with torch.no_grad():
p /= math.sqrt(n_residuals_per_layer * n_layer)


class Caduceus(CaduceusPreTrainedModel):
class Caduceus(CaduceusPreTrainedModel, HFGCProtocol):
"""Caduceus model that can be instantiated using HF patterns."""
def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
super().__init__(config)
Expand All @@ -360,6 +430,12 @@ def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
factory_kwargs = {"device": device, "dtype": dtype}
self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)

def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs: dict[str, Any]) -> None:
self.backbone.gradient_checkpointing_enable(gradient_checkpointing_kwargs)

def gradient_checkpointing_disable(self) -> None:
self.backbone.gradient_checkpointing_disable()

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -389,7 +465,7 @@ def forward(
return hidden_states


class CaduceusForMaskedLM(CaduceusPreTrainedModel):
class CaduceusForMaskedLM(CaduceusPreTrainedModel, HFGCProtocol):
"""HF-compatible Caduceus model for masked language modeling."""

def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
Expand All @@ -414,6 +490,12 @@ def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
# Initialize weights and apply final processing
self.post_init()

def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs: dict[str, Any]) -> None:
self.caduceus.gradient_checkpointing_enable(gradient_checkpointing_kwargs)

def gradient_checkpointing_disable(self) -> None:
self.caduceus.gradient_checkpointing_disable()

def get_input_embeddings(self):
return self.caduceus.backbone.embeddings.word_embeddings

Expand Down Expand Up @@ -492,7 +574,7 @@ def forward(
)


class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
class CaduceusForSequenceClassification(CaduceusPreTrainedModel, HFGCProtocol):
def __init__(
self,
config: CaduceusConfig,
Expand Down Expand Up @@ -638,3 +720,11 @@ def forward(
logits=logits,
hidden_states=transformer_outputs.hidden_states,
)

def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs: dict[str, Any]) -> None:
self.caduceus.gradient_checkpointing_enable(gradient_checkpointing_kwargs)

def gradient_checkpointing_disable(self) -> None:
self.caduceus.gradient_checkpointing_disable()


Loading