Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions bergson/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__version__ = "0.4.6"

from .collection import collect_gradients
from .collector.gradient_collectors import GradientCollector
from .config import (
AttentionConfig,
DataConfig,
Expand All @@ -11,6 +12,7 @@
)
from .data import load_gradient_dataset, load_gradients
from .gradients import GradientProcessor
from .normalizer.fit_normalizers import fit_normalizers
from .query.attributor import Attributor
from .query.faiss_index import FaissConfig
from .score.scorer import Scorer
Expand All @@ -20,10 +22,12 @@
"collect_gradients",
"load_gradients",
"load_gradient_dataset",
"fit_normalizers",
"Attributor",
"FaissConfig",
"FiniteDiff",
"GradientProcessor",
"GradientCollector",
"IndexConfig",
"DataConfig",
"AttentionConfig",
Expand Down
10 changes: 2 additions & 8 deletions bergson/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,10 @@ def execute(self):
self.command.execute()


def get_parser():
"""Get the argument parser. Used for documentation generation."""
parser = ArgumentParser(conflict_resolution=ConflictResolution.EXPLICIT)
parser.add_arguments(Main, dest="prog")
return parser


def main(args: Optional[list[str]] = None):
"""Parse CLI arguments and dispatch to the selected subcommand."""
parser = get_parser()
parser = ArgumentParser(conflict_resolution=ConflictResolution.EXPLICIT)
parser.add_arguments(Main, dest="prog")
prog: Main = parser.parse_args(args=args).prog
prog.execute()

Expand Down
2 changes: 1 addition & 1 deletion bergson/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def build_worker(
)

model, target_modules = setup_model_and_peft(cfg, rank)
processor = create_processor(cfg, rank)
processor = create_processor(model, ds, cfg, rank, target_modules)

attention_cfgs = {module: cfg.attention for module in cfg.split_attention_modules}

Expand Down
5 changes: 3 additions & 2 deletions bergson/collector/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,10 +477,10 @@ def run_with_collector_hooks(
total_processed = torch.tensor(0, device=self.model.device)
prof = self._setup_profiler()
step = 0

with prof:
for indices in tqdm(
self.batches, disable=self.rank != 0, desc=f"Computing {desc}"
self.batches,
desc=f"Computing {desc}",
):
batch = self.data[indices]

Expand All @@ -503,6 +503,7 @@ def run_with_collector_hooks(
step += 1

self.collector.process_batch(indices, losses=losses)
total_processed += len(indices)

self.collector.teardown()
if dist.is_initialized():
Expand Down
3 changes: 3 additions & 0 deletions bergson/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ class IndexConfig:
processor_path: str = ""
"""Path to a precomputed processor."""

normalizer: Literal["adafactor", "adam", "none"] = "none"
"""Type of normalizer to use for the gradients."""

skip_preconditioners: bool = False
"""Whether to skip computing preconditioners for the gradients."""

Expand Down
204 changes: 102 additions & 102 deletions bergson/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,108 +58,6 @@ def state_dict(self) -> dict[str, str | Tensor]:
}


@dataclass
class AdafactorNormalizer(Normalizer):
"""
Row and column sums of second moments of gradients for a matrix-valued parameter.
"""

row: Tensor # shape [O]
col: Tensor # shape [I]

def __post_init__(self):
assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D"
assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D"

@torch.compile
def normalize_(
self,
grad: Tensor,
eps: float = 1e-30,
) -> Tensor:
"""
Normalize the row and column sums by adding a small epsilon.

Note: Our `eps` corresponds to epsilon_1 in the original Adafactor paper. They
recommend 1e-30, but we use 1e-16 for extra numerical stability.
"""
# We follow the Adafactor implementation in the tensor2tensor repo, which is
# different from the paper and from the PyTorch implementation. First add eps
# to ensure these second moments are sufficiently far from zero. Then we don't
# need to worry about numerical stability anywhere else, and we don't need to
# materialize the outer product at any point.
r, c = self.row.add(eps), self.col.add(eps)

# This is the denominator for V, the rank-one matrix of second moment estimates:
# V = torch.outer(r, c) / denom
# V_ij = r_i * c_j / denom
# But we want to (implicitly) take the Hadamard product with the elementwise
# reciprocal square root of V:
# (V_ij)^{-1/2} = denom.sqrt() * r_i.rsqrt() * c_j.rsqrt()
denom = r.mean()

# Hadamard product with a rank-one matrix ab^T is the same as left-multiplying
# by diag(a) and right-multiplying by diag(b). In this case we can represent
# the elementwise reciprocal square root of V as ab^T where:
# a = denom.sqrt() * r.rsqrt() and b = c.rsqrt()
a = denom.sqrt() * r.rsqrt_() # shape [O]
b = c.rsqrt_()

# Implicitly do the Hadamard product
grad *= a[:, None] # [N, O] * [O] → [N, O]
grad *= b[None, :]
return grad

def to_adam(self) -> "AdamNormalizer":
"""
Convert this Adafactor normalizer to an Adam normalizer by materializing the
rank-one second moment matrix.
"""
# Compute the second moment matrix as a square matrix of shape [O, I]
# NOTE: We don't add the epsilon here, since the AdamNormalizer is going to
# add it outside the square root. This could cause infs though if there are
# any exactly zero rows or columns, so we should be careful.
avg_sq = torch.outer(self.row, self.col) / self.row.mean()
return AdamNormalizer(avg_sq=avg_sq)


@dataclass
class AdamNormalizer(Normalizer):
"""
Contains the second moments of the gradients.
"""

avg_sq: Tensor

@torch.compile
def normalize_(
self,
grad: Tensor,
eps: float = 1e-8,
) -> Tensor:
"""Normalize the gradients by the square root of the second moments."""
# Adam-style epsilon is added outside the square root
denom = self.avg_sq.sqrt()
return grad.div_(denom.add_(eps))

def to_adafactor(self) -> AdafactorNormalizer:
"""
Convert this Adam normalizer to an Adafactor normalizer, minimizing the
I-divergence (generalized Kullback-Leibler divergence) between the original
and the factored second moments.
"""
# We assume avg_sq is a square matrix of shape [O, I]
assert (
self.avg_sq.ndim == 2
), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D"

# Compute row and column means
return AdafactorNormalizer(
row=self.avg_sq.mean(dim=1), # shape [O]
col=self.avg_sq.mean(dim=0), # shape [I]
)


@dataclass
class GradientProcessor:
"""Configuration for processing and compressing gradients."""
Expand Down Expand Up @@ -317,3 +215,105 @@ def out_attr(layer: nn.Module) -> str:
return "out_channels"
case _:
raise ValueError(f"Unsupported layer type: {type(layer)}")


@dataclass
class AdafactorNormalizer(Normalizer):
"""
Row and column sums of second moments of gradients for a matrix-valued parameter.
"""

row: Tensor # shape [O]
col: Tensor # shape [I]

def __post_init__(self):
assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D"
assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D"

@torch.compile
def normalize_(
self,
grad: Tensor,
eps: float = 1e-30,
) -> Tensor:
"""
Normalize the row and column sums by adding a small epsilon.

Note: Our `eps` corresponds to epsilon_1 in the original Adafactor paper. They
recommend 1e-30, but we use 1e-16 for extra numerical stability.
"""
# We follow the Adafactor implementation in the tensor2tensor repo, which is
# different from the paper and from the PyTorch implementation. First add eps
# to ensure these second moments are sufficiently far from zero. Then we don't
# need to worry about numerical stability anywhere else, and we don't need to
# materialize the outer product at any point.
r, c = self.row.add(eps), self.col.add(eps)

# This is the denominator for V, the rank-one matrix of second moment estimates:
# V = torch.outer(r, c) / denom
# V_ij = r_i * c_j / denom
# But we want to (implicitly) take the Hadamard product with the elementwise
# reciprocal square root of V:
# (V_ij)^{-1/2} = denom.sqrt() * r_i.rsqrt() * c_j.rsqrt()
denom = r.mean()

# Hadamard product with a rank-one matrix ab^T is the same as left-multiplying
# by diag(a) and right-multiplying by diag(b). In this case we can represent
# the elementwise reciprocal square root of V as ab^T where:
# a = denom.sqrt() * r.rsqrt() and b = c.rsqrt()
a = denom.sqrt() * r.rsqrt_() # shape [O]
b = c.rsqrt_()

# Implicitly do the Hadamard product
grad *= a[:, None] # [N, O] * [O] → [N, O]
grad *= b[None, :]
return grad

def to_adam(self) -> "AdamNormalizer":
"""
Convert this Adafactor normalizer to an Adam normalizer by materializing the
rank-one second moment matrix.
"""
# Compute the second moment matrix as a square matrix of shape [O, I]
# NOTE: We don't add the epsilon here, since the AdamNormalizer is going to
# add it outside the square root. This could cause infs though if there are
# any exactly zero rows or columns, so we should be careful.
avg_sq = torch.outer(self.row, self.col) / self.row.mean()
return AdamNormalizer(avg_sq=avg_sq)


@dataclass
class AdamNormalizer(Normalizer):
"""
Contains the second moments of the gradients.
"""

avg_sq: Tensor

@torch.compile
def normalize_(
self,
grad: Tensor,
eps: float = 1e-8,
) -> Tensor:
"""Normalize the gradients by the square root of the second moments."""
# Adam-style epsilon is added outside the square root
denom = self.avg_sq.sqrt()
return grad.div_(denom.add_(eps))

def to_adafactor(self) -> AdafactorNormalizer:
"""
Convert this Adam normalizer to an Adafactor normalizer, minimizing the
I-divergence (generalized Kullback-Leibler divergence) between the original
and the factored second moments.
"""
# We assume avg_sq is a square matrix of shape [O, I]
assert (
self.avg_sq.ndim == 2
), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D"

# Compute row and column means
return AdafactorNormalizer(
row=self.avg_sq.mean(dim=1), # shape [O]
col=self.avg_sq.mean(dim=0), # shape [I]
)
Empty file added bergson/normalizer/__init__.py
Empty file.
Loading