Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
fd4145a
Add content-based selection and set as fusion default
RolandBERTINJOHANNET Nov 25, 2025
2039528
Add helper to attach learned attention, keep fusion default random
RolandBERTINJOHANNET Nov 25, 2025
6930d1d
Move learned attention helper to fusion class
RolandBERTINJOHANNET Nov 25, 2025
a8004b1
Rebalance broadcast loss coefs and aggregates
RolandBERTINJOHANNET Nov 25, 2025
e646158
Fix indentation in selection helper causing test import failure
RolandBERTINJOHANNET Nov 25, 2025
4f09afc
Run ruff cleanups in selection
RolandBERTINJOHANNET Nov 25, 2025
b31049a
Format global_workspace with ruff
RolandBERTINJOHANNET Nov 25, 2025
95d054f
Align selection signatures and move learned attention helper
RolandBERTINJOHANNET Nov 25, 2025
7bca399
Apply ruff formatting to global workspace
RolandBERTINJOHANNET Nov 25, 2025
95b3b62
Update broadcasts docstrings
RolandBERTINJOHANNET Nov 28, 2025
cc9dbc3
Remove fused loss handler
RolandBERTINJOHANNET Nov 28, 2025
1b2fd83
Rename learned attention module
RolandBERTINJOHANNET Dec 8, 2025
c16fd1d
Add LearnedAttention coverage
RolandBERTINJOHANNET Dec 8, 2025
4db5e55
Fix LearnedAttention test types for mypy
RolandBERTINJOHANNET Dec 8, 2025
bc53230
Split cycle loss from broadcast path
RolandBERTINJOHANNET Dec 10, 2025
931df04
Format code with ruff
RolandBERTINJOHANNET Dec 10, 2025
9f15e82
Move cycle reconstruction into cycle loss
RolandBERTINJOHANNET Dec 10, 2025
3dca8a4
Ensure LearnedAttention only builds selected key projection
RolandBERTINJOHANNET Dec 10, 2025
c192aac
Annotate LearnedAttention key layers as optional for mypy
RolandBERTINJOHANNET Dec 10, 2025
bf3f75c
Fix mypy issues in LearnedAttention and ckpt migration CLI
RolandBERTINJOHANNET Dec 10, 2025
664dc4f
Revert ckpt migration typing tweaks
RolandBERTINJOHANNET Dec 10, 2025
6b66406
Fix mypy for ckpt migration CLI by using string paths
RolandBERTINJOHANNET Dec 10, 2025
101f7f1
Add domain-latent key option to LearnedAttention
RolandBERTINJOHANNET Dec 11, 2025
d1badae
Format LearnedAttention per ruff
RolandBERTINJOHANNET Dec 11, 2025
02167d1
Pass domain key options through init_learned_attention
RolandBERTINJOHANNET Dec 11, 2025
8388448
Drop duplicate missing-domain-dims check in attention tests
RolandBERTINJOHANNET Dec 11, 2025
553ab6f
Warn before initializing LearnedAttention
RolandBERTINJOHANNET Dec 11, 2025
887cbcf
Refactor broadcast naming and stop logging metrics
RolandBERTINJOHANNET Dec 17, 2025
17b467a
Warn when loss coef missing
RolandBERTINJOHANNET Dec 17, 2025
c6506ef
Add warning test for missing loss coef
RolandBERTINJOHANNET Dec 17, 2025
4dc4e7c
Clarify broadcast docstring
RolandBERTINJOHANNET Dec 17, 2025
8601c92
Run ruff format
RolandBERTINJOHANNET Dec 17, 2025
3213179
Restore broadcast metrics logging (minus aggregate)
RolandBERTINJOHANNET Dec 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,4 @@ cython_debug/
.rgignore

.ruff_cache/
.poetry_cache/
2 changes: 2 additions & 0 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
combine_loss,
)
from shimmer.modules.selection import (
LearnedAttention,
RandomSelection,
SelectionBase,
SingleDomainSelection,
Expand Down Expand Up @@ -103,6 +104,7 @@
"RandomSelection",
"SelectionBase",
"SingleDomainSelection",
"LearnedAttention",
"DomainDesc",
"RepeatedDataset",
"ShimmerDataset",
Expand Down
6 changes: 3 additions & 3 deletions shimmer/cli/ckpt_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
@click.argument(
"paths",
nargs=-1,
type=click.Path(exists=True, path_type=Path, file_okay=True, dir_okay=False),
type=click.Path(exists=True, file_okay=True, dir_okay=False),
)
def migrate_ckpt(paths: Sequence[Path]):
def migrate_ckpt(paths: Sequence[str]):
"""
Script to migrate a list of checkpoints.
This can be called with:
Expand All @@ -24,4 +24,4 @@ def migrate_ckpt(paths: Sequence[Path]):
Internally, this calls `shimmer.utils.migrate_model` for each of the given paths.
"""
for path in paths:
migrate_model(path)
migrate_model(Path(path))
19 changes: 0 additions & 19 deletions shimmer/modules/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,25 +185,6 @@ def compute_tr_loss(
"""
return self.compute_loss(pred, target, raw_target)

def compute_fused_loss(
self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
) -> LossOutput | None:
"""
Computes the loss for fused (fusion). Override if the fused loss is
different that the generic loss.

Args:
pred (`torch.Tensor`): prediction of the model
target (`torch.Tensor`): target tensor
raw_target (`Any`): raw data from the input
Results:
`LossOutput | None`: LossOuput with training loss and additional metrics.
If `None` is returned, this loss will be ignored and will not
participate in the total loss; it can be used to deactivate
fused loss for this domain.
"""
return self.compute_loss(pred, target, raw_target)

def compute_domain_loss(self, domain: Any) -> LossOutput | None:
"""
Compute the unimodal domain loss.
Expand Down
68 changes: 64 additions & 4 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections.abc import Callable, Iterable, Mapping
from enum import Enum, auto
from pathlib import Path
Expand Down Expand Up @@ -29,6 +30,7 @@
LossCoefs,
)
from shimmer.modules.selection import (
LearnedAttention,
RandomSelection,
SelectionBase,
SingleDomainSelection,
Expand Down Expand Up @@ -65,7 +67,7 @@ class GWPredictionsBase(TypedDict):
broadcasts: dict[frozenset[str], dict[str, torch.Tensor]]
"""
broadcasts predictions of the model for each domain. It contains demi-cycles,
translations, and fused.
translations.
"""

cycles: dict[frozenset[str], dict[str, torch.Tensor]]
Expand Down Expand Up @@ -706,7 +708,7 @@ def __init__(
)


class GlobalWorkspaceFusion(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]):
class GlobalWorkspaceFusion(GlobalWorkspaceBase[GWModule, SelectionBase, GWLosses]):
"""The fusion (with broadcast loss) flavor of GlobalWorkspaceBase.

This is used to simplify a Global Workspace instanciation and only overrides the
Expand All @@ -721,6 +723,7 @@ def __init__(
workspace_dim: int,
loss_coefs: BroadcastLossCoefs | Mapping[str, float],
selection_temperature: float = 0.2,
selection_mod: SelectionBase | None = None,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
Expand Down Expand Up @@ -748,7 +751,9 @@ def __init__(
loss_coefs (`BroadcastLossCoefs | Mapping[str, float]`): loss coefs for the
losses.
selection_temperature (`float`): temperature value for the RandomSelection
module.
module (default selection).
selection_mod (`SelectionBase | None`): optional custom selection module.
If None (default), uses `RandomSelection`.
optim_lr (`float`): learning rate
optim_weight_decay (`float`): weight decay
scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
Expand All @@ -772,7 +777,8 @@ def __init__(
torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
)

selection_mod = RandomSelection(selection_temperature)
if selection_mod is None:
selection_mod = RandomSelection(selection_temperature)
loss_mod = GWLosses(
gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss
)
Expand All @@ -787,6 +793,60 @@ def __init__(
scheduler,
)

def init_learned_attention(
self,
head_size: int = 64,
per_domain_keys: bool = False,
stopgrad: bool = True,
key_on_prefusion: bool = True,
domain_dims: Mapping[str, int] | None = None,
) -> LearnedAttention:
"""
Initialize and attach a learned content-based attention module.

This replaces `self.selection_mod` with a `LearnedAttention` configured for
the current workspace (uses `workspace_dim` and domain names from
`domain_mods`), ensuring its parameters are tracked by Lightning/torch.
"""
warnings.warn(
(
"LearnedAttention is best used after pretraining the global workspace "
"with a simpler selection (e.g., random or single-domain). "
"This path is minimally validated; use at your own risk."
),
UserWarning,
stacklevel=2,
)
if not key_on_prefusion and not per_domain_keys:
raise ValueError(
"key_on_prefusion=False requires per_domain_keys=True because "
"domain latent dimensions can differ."
)

final_domain_dims = domain_dims
if not key_on_prefusion:
if final_domain_dims is None:
final_domain_dims = {
name: mod.latent_dim for name, mod in self.domain_mods.items()
}
missing = [d for d in self.domain_mods if d not in final_domain_dims]
if missing:
raise ValueError(
f"Missing domain_dims for: {', '.join(sorted(missing))}"
)

selection = LearnedAttention(
gw_dim=self.workspace_dim,
domain_names=self.domain_mods.keys(),
head_size=head_size,
per_domain_keys=per_domain_keys,
stopgrad=stopgrad,
key_on_prefusion=key_on_prefusion,
domain_dims=final_domain_dims,
)
self.selection_mod = selection
return selection


def pretrained_global_workspace(
checkpoint_path: str | Path,
Expand Down
2 changes: 1 addition & 1 deletion shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ class GWModulePrediction(TypedDict):
broadcasts: dict[str, torch.Tensor]
"""
broadcasts predictions of the model for each domain. It contains demi-cycles,
translations, and fused.
translations.
"""

cycles: dict[str, torch.Tensor]
Expand Down
Loading