Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
7a06ca4
feat(molt): a staged version of molt
MerlinRaptor Jul 31, 2025
006280c
feat(molt): most of molt should be done right
MerlinRaptor Jul 31, 2025
5f50f3d
refactor(molt): refactor rank distribution logic and a fix a bug in d…
MerlinRaptor Jul 31, 2025
ab321bd
fix(trainer): adapt log_info for situations where l_s doesn't exist
MerlinRaptor Jul 31, 2025
aa71656
refactor(molt): refactor decode for a better effciency and vram usage
MerlinRaptor Jul 31, 2025
82fc5d0
refactor(molt): refactor decode and achieve a comparable efficiency t…
MerlinRaptor Jul 31, 2025
05f9176
refactor(molt): refactor decode into einsum oprations
MerlinRaptor Jul 31, 2025
836efeb
feat(molt): implement distributed molt
MerlinRaptor Aug 3, 2025
f614978
fix(molt): fix a init bug in dist molt
MerlinRaptor Aug 4, 2025
63b4020
fix(molt): fix a major bug in _decode_distributed and now dist molt i…
MerlinRaptor Aug 4, 2025
c10a296
feat(trainer): add per rank_group logging logic to molt
MerlinRaptor Aug 5, 2025
a2a94e3
refactor(molt): refactor rank assignments logic for disentangling dis…
MerlinRaptor Aug 8, 2025
b14ac14
misc(trainer): add more logging logic for molt
MerlinRaptor Aug 8, 2025
cba63d8
fix(train): support sae for data prallel
MerlinRaptor Aug 8, 2025
a042a49
misc: ruff fix
MerlinRaptor Aug 8, 2025
b5793b4
refactor(abstract_sae): prepare_input should now also return decoder_…
MerlinRaptor Aug 11, 2025
fb05597
fix(molt):fix a bug regarding rank assign introduced by a2a94e37ed821…
MerlinRaptor Aug 11, 2025
ad2541b
feat(analyze): support molt analysis
MerlinRaptor Aug 11, 2025
37ebc08
misc: format changes for type check
MerlinRaptor Aug 11, 2025
4a1e3ba
misc: ruff fix
MerlinRaptor Aug 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/lm_saes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
InitializerConfig,
LanguageModelConfig,
LLaDAConfig,
MOLTConfig,
MongoDBConfig,
SAEConfig,
TrainerConfig,
Expand All @@ -24,6 +25,7 @@
from .crosscoder import CrossCoder
from .database import MongoClient
from .evaluator import EvalConfig, Evaluator
from .molt import MixtureOfLinearTransform
from .resource_loaders import load_dataset, load_model
from .runners import (
AnalyzeCrossCoderSettings,
Expand All @@ -38,6 +40,7 @@
SweepSAESettings,
TrainCLTSettings,
TrainCrossCoderSettings,
TrainMOLTSettings,
TrainSAESettings,
analyze_crosscoder,
analyze_sae,
Expand All @@ -50,6 +53,7 @@
sweep_sae,
train_clt,
train_crosscoder,
train_molt,
train_sae,
)
from .sae import SparseAutoEncoder
Expand Down Expand Up @@ -109,4 +113,8 @@
"DirectLogitAttributeSettings",
"direct_logit_attribute",
"DirectLogitAttributorConfig",
"MOLTConfig",
"MixtureOfLinearTransform",
"train_molt",
"TrainMOLTSettings",
]
16 changes: 11 additions & 5 deletions src/lm_saes/abstract_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,11 +714,12 @@ def compute_loss(
"""Compute the loss for the autoencoder.
Ensure that the input activations are normalized by calling `normalize_activations` before calling this method.
"""
x, encoder_kwargs = self.prepare_input(batch)
x, encoder_kwargs, decoder_kwargs = self.prepare_input(batch)

label = self.prepare_label(batch, **kwargs)

feature_acts, hidden_pre = self.encode(x, return_hidden_pre=True, **encoder_kwargs)
reconstructed = self.decode(feature_acts, **kwargs)
reconstructed = self.decode(feature_acts, **decoder_kwargs)

with timer.time("loss_calculation"):
l_rec = (reconstructed - label).pow(2)
Expand Down Expand Up @@ -768,9 +769,14 @@ def compute_loss(
return loss

@abstractmethod
def prepare_input(self, batch: dict[str, torch.Tensor], **kwargs) -> tuple[torch.Tensor, dict[str, Any]]:
"""Prepare the input for the encoder.
Returns a tuple of (input, kwargs) where kwargs is a dictionary of additional arguments for the encoder computation.
def prepare_input(self, batch: dict[str, torch.Tensor], **kwargs) -> tuple[torch.Tensor, dict[str, Any], dict[str, Any]]:
"""Prepare the input for the encoder and decoder.

Returns:
tuple: (input_tensor, encoder_kwargs, decoder_kwargs)
- input_tensor: The input tensor for the encoder
- encoder_kwargs: Additional arguments for the encoder
- decoder_kwargs: Additional arguments for the decoder
"""
raise NotImplementedError("Subclasses must implement this method")

Expand Down
4 changes: 2 additions & 2 deletions src/lm_saes/analysis/feature_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ def analyze_chunk(
batch = sae.normalize_activations(batch)

# Get feature activations from SAE
x, kwargs = sae.prepare_input(batch)
feature_acts: torch.Tensor = sae.encode(x, **kwargs)
x, encoder_kwargs, _ = sae.prepare_input(batch)
feature_acts: torch.Tensor = sae.encode(x, **encoder_kwargs)
if isinstance(feature_acts, DTensor):
assert device_mesh is not None, "Device mesh is required for DTensor feature activations"
if device_mesh is not feature_acts.device_mesh:
Expand Down
6 changes: 4 additions & 2 deletions src/lm_saes/clt.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def init_encoder_with_decoder_transpose(self, factor: float = 1.0):
raise NotImplementedError("init_encoder_with_decoder_transpose does not make sense for CLT")

@override
def prepare_input(self, batch: "dict[str, torch.Tensor]", **kwargs) -> "tuple[torch.Tensor, dict[str, Any]]":
def prepare_input(self, batch: "dict[str, torch.Tensor]", **kwargs) -> "tuple[torch.Tensor, dict[str, Any], dict[str, Any]]":
"""Prepare input tensor from batch by stacking all layer activations from hook_points_in."""
x_layers = []
for hook_point in self.cfg.hook_points_in:
Expand All @@ -427,7 +427,9 @@ def prepare_input(self, batch: "dict[str, torch.Tensor]", **kwargs) -> "tuple[to
if isinstance(self.W_E, DTensor) and not isinstance(x, DTensor):
assert self.device_mesh is not None
x = DTensor.from_local(x, device_mesh=self.device_mesh, placements=[torch.distributed.tensor.Replicate()])
return x, {}
encoder_kwargs = {}
decoder_kwargs = {}
return x, encoder_kwargs, decoder_kwargs

@override
def prepare_label(self, batch: "dict[str, torch.Tensor]", **kwargs) -> torch.Tensor:
Expand Down
223 changes: 219 additions & 4 deletions src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
PlainSerializer,
WithJsonSchema,
)
from typing_extensions import override

from .utils.huggingface import parse_pretrained_name_or_path
from .utils.misc import (
Expand Down Expand Up @@ -51,7 +52,7 @@ class BaseSAEConfig(BaseModelConfig, ABC):
So this class should not be used directly but only as a base config class for other SAE variants like SAEConfig, CrossCoderConfig, etc.
"""

sae_type: Literal["sae", "crosscoder", "clt"]
sae_type: Literal["sae", "crosscoder", "clt", "molt"]
d_model: int
expansion_factor: int
use_decoder_bias: bool = True
Expand Down Expand Up @@ -114,7 +115,7 @@ def associated_hook_points(self) -> list[str]:


class SAEConfig(BaseSAEConfig):
sae_type: Literal["sae", "crosscoder", "clt"] = "sae"
sae_type: Literal["sae", "crosscoder", "clt", "molt"] = "sae"
hook_point_in: str
hook_point_out: str = Field(default_factory=lambda validated_model: validated_model["hook_point_in"])
use_glu_encoder: bool = False
Expand All @@ -131,7 +132,7 @@ class CLTConfig(BaseSAEConfig):
reads from the residual stream at that layer and can decode to layers L through L-1.
"""

sae_type: Literal["sae", "crosscoder", "clt"] = "clt"
sae_type: Literal["sae", "crosscoder", "clt", "molt"] = "clt"
hook_points_in: list[str]
"""List of hook points to capture input activations from, one for each layer."""
hook_points_out: list[str]
Expand Down Expand Up @@ -159,8 +160,222 @@ def model_post_init(self, __context):
)


class MOLTConfig(BaseSAEConfig):
"""Configuration for Mixture of Linear Transforms (MOLT).

MOLT is a more efficient alternative to transcoders that sparsely replaces
MLP computation in transformers. It converts dense MLP layers into sparse,
interpretable linear transforms.
"""

sae_type: Literal["sae", "crosscoder", "clt", "molt"] = "molt"
hook_point_in: str
"""Hook point to capture input activations from."""
hook_point_out: str
"""Hook point to output activations to."""
rank_distribution: dict[int, int] = Field(default_factory=lambda: {4: 1, 8: 2, 16: 4, 32: 8, 64: 16})
"""Dictionary mapping rank values to their integer ratios.
Keys are rank values, values are integer ratios that will be automatically normalized to proportions.
Example: {4: 1, 8: 2, 16: 4, 32: 8, 64: 16} means ratio 1:2:4:8:16 which means a proportion of 1/32, 2/32, 4/32, 8/32, 16/32,
which will be normalized to proportions automatically."""
model_parallel_size_training: int = 1
"""Number of model parallel devices for distributed training. Distinct from model_parallel_size_running which is the number of model parallel devices in both training and inference."""

def model_post_init(self, __context):
super().model_post_init(__context)
# Validate ratios
assert self.rank_distribution, "rank_distribution cannot be empty"

total_ratio = sum(self.rank_distribution.values())
assert total_ratio > 0, f"Total ratio must be positive, got {total_ratio}"

for rank, ratio in self.rank_distribution.items():
assert ratio > 0, f"Ratio for rank {rank} must be positive, got {ratio}"

# Store normalized proportions for internal use
self._normalized_proportions = {
rank: ratio / total_ratio
for rank, ratio in self.rank_distribution.items()
}

def generate_rank_assignments(self) -> list[int]:
"""Generate rank assignment for each of the d_sae linear transforms.

Returns:
List of rank assignments for each transform.
For example: [1, 1, 1, 1, 2, 2, 4].
For distributed case, this method ensures that each rank type is divisible by model_parallel_size_training.
"""
# Validate rank distribution
assert self.rank_distribution, "rank_distribution cannot be empty"

# Calculate base d_sae
base_d_sae = self.d_model * self.expansion_factor

# For distributed training, use special logic to ensure consistency
if self.model_parallel_size_training > 1:
return self._generate_distributed_rank_assignments(base_d_sae)
else:
return self._generate_rank_assignments_single_gpu(base_d_sae)

def _generate_rank_assignments_single_gpu(self, base_d_sae: int) -> list[int]:
"""Generate rank assignments for single GPU training.

Args:
base_d_sae: Target number of total transforms

Returns:
List of rank assignments for each transform.
"""
assignments = []

# Distribute transforms based on normalized proportions
for rank, proportion in sorted(self._normalized_proportions.items()):
count = int(base_d_sae * proportion)
assignments.extend([rank] * count)

# Handle any remaining transforms due to rounding
while len(assignments) < base_d_sae:
# Assign remaining to the most common rank (by original ratio)
most_common_rank = max(
self.rank_distribution.keys(),
key=lambda k: self.rank_distribution[k]
)
assignments.append(most_common_rank)

# Truncate if we have too many (shouldn't happen with proper proportions)
assignments = assignments[:base_d_sae]

# Verify we have exactly base_d_sae assignments
assert len(assignments) == base_d_sae, (
f"Expected {base_d_sae} assignments, got {len(assignments)}"
)

return assignments

def _generate_distributed_rank_assignments(self, base_d_sae: int) -> list[int]:
"""Generate rank assignments optimized for distributed training.

Ensures each rank type has count divisible by model_parallel_size_training.

Args:
base_d_sae: Target number of total transforms

Returns:
List of rank assignments for each transform.
"""
assignments = []
total_ratio = sum(self.rank_distribution.values())

# Ensure minimum requirement: each rank gets at least model_parallel_size_training
# transforms
min_total_needed = len(self.rank_distribution) * self.model_parallel_size_training
assert base_d_sae >= min_total_needed, (
f"base_d_sae ({base_d_sae}) must be >= min_total_needed "
f"({min_total_needed}) for distributed training with "
f"{len(self.rank_distribution)} rank types"
)

# Calculate proportional distribution
for rank in sorted(self.rank_distribution.keys()):
rank_ratio = self.rank_distribution[rank]
raw_count = int(base_d_sae * rank_ratio / total_ratio)

# Ensure count is divisible by model_parallel_size_training
count = max(
self.model_parallel_size_training, # minimum requirement
(raw_count // self.model_parallel_size_training) * self.model_parallel_size_training,
)
assignments.extend([rank] * count)

# Handle any remaining transforms due to rounding
while len(assignments) < base_d_sae:
# Assign remaining to the most common rank (by original ratio)
most_common_rank = max(
self.rank_distribution.keys(),
key=lambda k: self.rank_distribution[k]
)
# Add model_parallel_size_training transforms at a time for divisibility
remaining = base_d_sae - len(assignments)
to_add = min(self.model_parallel_size_training, remaining)
assignments.extend([most_common_rank] * to_add)

# Truncate if we have too many (shouldn't happen normally)
assignments = assignments[:base_d_sae]

# Verify divisibility constraint
rank_counts = {}
for rank in assignments:
rank_counts[rank] = rank_counts.get(rank, 0) + 1

for rank, count in rank_counts.items():
assert count % self.model_parallel_size_training == 0, (
f"Rank {rank} count {count} not divisible by "
f"model_parallel_size_training {self.model_parallel_size_training}"
)

return assignments

def get_local_rank_assignments(self, local_rank: int, model_parallel_size_running: int) -> list[int]:
"""Get rank assignments for a specific local device in distributed running (both training and inference).

Each device gets all rank groups, with each group evenly divided across devices.
This ensures consistent encoder/decoder sharding without feature_acts redistribution.

Args:
local_rank: The local rank of this process
model_parallel_size_running: Number of model parallel devices in running (training and inference)

Returns:
List of rank assignments for this local device
For example:
global_rank_assignments = [1, 1, 2, 2], model_parallel_size_running = 2 -> local_rank_assignments = [1, 2]
"""
global_rank_counts = {rank: self.generate_rank_assignments().count(rank) for rank in self.available_ranks}

# Each device gets count/model_parallel_size_running transforms of each rank type
local_assignments = []
for rank in sorted(self.rank_distribution.keys()):
global_count = global_rank_counts[rank]

# Verify even division (should be guaranteed by _generate_distributed_rank_assignments)
assert global_count % model_parallel_size_running == 0, (
f"Rank {rank} global count {global_count} not divisible by "
f"model_parallel_size_running {model_parallel_size_running}"
)

local_count = global_count // model_parallel_size_running

# Add local_count transforms of this rank type
local_assignments.extend([rank] * local_count)

return local_assignments

@property
@override
def d_sae(self) -> int:
"""Calculate d_sae based on rank assignments with padding for distributed training."""
# Generate rank assignments and return the length
rank_assignments = self.generate_rank_assignments()
return len(rank_assignments)

@property
def available_ranks(self) -> list[int]:
"""Get sorted list of available ranks."""
return sorted(self.rank_distribution.keys())

@property
def num_rank_types(self) -> int:
"""Number of different rank types."""
return len(self.rank_distribution)

@property
def associated_hook_points(self) -> list[str]:
return [self.hook_point_in, self.hook_point_out]


class CrossCoderConfig(BaseSAEConfig):
sae_type: Literal["sae", "crosscoder", "clt"] = "crosscoder"
sae_type: Literal["sae", "crosscoder", "clt", "molt"] = "crosscoder"
hook_points: list[str]

@property
Expand Down
10 changes: 7 additions & 3 deletions src/lm_saes/crosscoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def init_encoder_with_decoder_transpose(self, factor: float = 1.0):

@override
@timer.time("prepare_input")
def prepare_input(self, batch: dict[str, torch.Tensor], **kwargs) -> tuple[torch.Tensor, dict[str, Any]]:
def prepare_input(self, batch: dict[str, torch.Tensor], **kwargs) -> tuple[torch.Tensor, dict[str, Any], dict[str, Any]]:
def pad_to_d_model(x: torch.Tensor) -> torch.Tensor:
# TODO: Support padding for distributed setting
if x.shape[-1] > self.cfg.d_model:
Expand All @@ -462,7 +462,9 @@ def pad_to_d_model(x: torch.Tensor) -> torch.Tensor:

# The following code is to stack the activations per head to (batch, ..., n_heads, d_model)
if self.device_mesh is None or "head" not in cast(tuple[str, ...], self.device_mesh.mesh_dim_names):
return torch.stack([pad_to_d_model(batch[hook_point]) for hook_point in self.cfg.hook_points], dim=-2), {}
encoder_kwargs = {}
decoder_kwargs = {}
return torch.stack([pad_to_d_model(batch[hook_point]) for hook_point in self.cfg.hook_points], dim=-2), encoder_kwargs, decoder_kwargs
else:
# The following code stacks the activations in a distributed setting. It's a bit complicated so I'll try to explain it in detail.

Expand Down Expand Up @@ -507,11 +509,13 @@ def pad_to_d_model(x: torch.Tensor) -> torch.Tensor:
first_hook_point_activations.placements, first_hook_point_activations.device_mesh
)

encoder_kwargs = {}
decoder_kwargs = {}
return DTensor.from_local(
per_process_activations,
device_mesh=self.device_mesh,
placements=output_dim_map.placements(self.device_mesh),
), {}
), encoder_kwargs, decoder_kwargs

@override
@timer.time("prepare_label")
Expand Down
Loading
Loading