diff --git a/.gitignore b/.gitignore index cf96243f..9cf85808 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ cython_debug/ .rgignore .ruff_cache/ +.poetry_cache/ diff --git a/shimmer/__init__.py b/shimmer/__init__.py index 53862080..5e8355d4 100644 --- a/shimmer/__init__.py +++ b/shimmer/__init__.py @@ -40,6 +40,7 @@ combine_loss, ) from shimmer.modules.selection import ( + LearnedAttention, RandomSelection, SelectionBase, SingleDomainSelection, @@ -103,6 +104,7 @@ "RandomSelection", "SelectionBase", "SingleDomainSelection", + "LearnedAttention", "DomainDesc", "RepeatedDataset", "ShimmerDataset", diff --git a/shimmer/cli/ckpt_migration.py b/shimmer/cli/ckpt_migration.py index bbc3303f..c573be66 100644 --- a/shimmer/cli/ckpt_migration.py +++ b/shimmer/cli/ckpt_migration.py @@ -10,9 +10,9 @@ @click.argument( "paths", nargs=-1, - type=click.Path(exists=True, path_type=Path, file_okay=True, dir_okay=False), + type=click.Path(exists=True, file_okay=True, dir_okay=False), ) -def migrate_ckpt(paths: Sequence[Path]): +def migrate_ckpt(paths: Sequence[str]): """ Script to migrate a list of checkpoints. This can be called with: @@ -24,4 +24,4 @@ def migrate_ckpt(paths: Sequence[Path]): Internally, this calls `shimmer.utils.migrate_model` for each of the given paths. """ for path in paths: - migrate_model(path) + migrate_model(Path(path)) diff --git a/shimmer/modules/domain.py b/shimmer/modules/domain.py index c909f010..4d23d14a 100644 --- a/shimmer/modules/domain.py +++ b/shimmer/modules/domain.py @@ -185,25 +185,6 @@ def compute_tr_loss( """ return self.compute_loss(pred, target, raw_target) - def compute_fused_loss( - self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any - ) -> LossOutput | None: - """ - Computes the loss for fused (fusion). Override if the fused loss is - different that the generic loss. - - Args: - pred (`torch.Tensor`): prediction of the model - target (`torch.Tensor`): target tensor - raw_target (`Any`): raw data from the input - Results: - `LossOutput | None`: LossOuput with training loss and additional metrics. - If `None` is returned, this loss will be ignored and will not - participate in the total loss; it can be used to deactivate - fused loss for this domain. - """ - return self.compute_loss(pred, target, raw_target) - def compute_domain_loss(self, domain: Any) -> LossOutput | None: """ Compute the unimodal domain loss. diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index e2b6b9f8..07f84887 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -1,3 +1,4 @@ +import warnings from collections.abc import Callable, Iterable, Mapping from enum import Enum, auto from pathlib import Path @@ -29,6 +30,7 @@ LossCoefs, ) from shimmer.modules.selection import ( + LearnedAttention, RandomSelection, SelectionBase, SingleDomainSelection, @@ -65,7 +67,7 @@ class GWPredictionsBase(TypedDict): broadcasts: dict[frozenset[str], dict[str, torch.Tensor]] """ broadcasts predictions of the model for each domain. It contains demi-cycles, - translations, and fused. + translations. """ cycles: dict[frozenset[str], dict[str, torch.Tensor]] @@ -706,7 +708,7 @@ def __init__( ) -class GlobalWorkspaceFusion(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]): +class GlobalWorkspaceFusion(GlobalWorkspaceBase[GWModule, SelectionBase, GWLosses]): """The fusion (with broadcast loss) flavor of GlobalWorkspaceBase. This is used to simplify a Global Workspace instanciation and only overrides the @@ -721,6 +723,7 @@ def __init__( workspace_dim: int, loss_coefs: BroadcastLossCoefs | Mapping[str, float], selection_temperature: float = 0.2, + selection_mod: SelectionBase | None = None, optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, @@ -748,7 +751,9 @@ def __init__( loss_coefs (`BroadcastLossCoefs | Mapping[str, float]`): loss coefs for the losses. selection_temperature (`float`): temperature value for the RandomSelection - module. + module (default selection). + selection_mod (`SelectionBase | None`): optional custom selection module. + If None (default), uses `RandomSelection`. optim_lr (`float`): learning rate optim_weight_decay (`float`): weight decay scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments @@ -772,7 +777,8 @@ def __init__( torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale ) - selection_mod = RandomSelection(selection_temperature) + if selection_mod is None: + selection_mod = RandomSelection(selection_temperature) loss_mod = GWLosses( gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss ) @@ -787,6 +793,60 @@ def __init__( scheduler, ) + def init_learned_attention( + self, + head_size: int = 64, + per_domain_keys: bool = False, + stopgrad: bool = True, + key_on_prefusion: bool = True, + domain_dims: Mapping[str, int] | None = None, + ) -> LearnedAttention: + """ + Initialize and attach a learned content-based attention module. + + This replaces `self.selection_mod` with a `LearnedAttention` configured for + the current workspace (uses `workspace_dim` and domain names from + `domain_mods`), ensuring its parameters are tracked by Lightning/torch. + """ + warnings.warn( + ( + "LearnedAttention is best used after pretraining the global workspace " + "with a simpler selection (e.g., random or single-domain). " + "This path is minimally validated; use at your own risk." + ), + UserWarning, + stacklevel=2, + ) + if not key_on_prefusion and not per_domain_keys: + raise ValueError( + "key_on_prefusion=False requires per_domain_keys=True because " + "domain latent dimensions can differ." + ) + + final_domain_dims = domain_dims + if not key_on_prefusion: + if final_domain_dims is None: + final_domain_dims = { + name: mod.latent_dim for name, mod in self.domain_mods.items() + } + missing = [d for d in self.domain_mods if d not in final_domain_dims] + if missing: + raise ValueError( + f"Missing domain_dims for: {', '.join(sorted(missing))}" + ) + + selection = LearnedAttention( + gw_dim=self.workspace_dim, + domain_names=self.domain_mods.keys(), + head_size=head_size, + per_domain_keys=per_domain_keys, + stopgrad=stopgrad, + key_on_prefusion=key_on_prefusion, + domain_dims=final_domain_dims, + ) + self.selection_mod = selection + return selection + def pretrained_global_workspace( checkpoint_path: str | Path, diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index a7c039bb..c21a3a59 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -217,7 +217,7 @@ class GWModulePrediction(TypedDict): broadcasts: dict[str, torch.Tensor] """ broadcasts predictions of the model for each domain. It contains demi-cycles, - translations, and fused. + translations. """ cycles: dict[str, torch.Tensor] diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index d934307f..202040bc 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -1,3 +1,4 @@ +import warnings from abc import ABC, abstractmethod from collections.abc import Generator, Mapping from itertools import product @@ -287,9 +288,9 @@ class LossCoefs(TypedDict, total=False): """ Dict of loss coefficients used in the GWLosses. - If one is not provided, the coefficient is assumed to be 0 and will not be logged. - If the loss is excplicitely set to 0, it will be logged, but not take part in - the total loss. + If one is not provided, the coefficient is assumed to be 0 and will not be logged + (a warning is emitted). If the loss is explicitly set to 0, it will be logged, but + not take part in the total loss. """ demi_cycles: float @@ -309,19 +310,16 @@ class BroadcastLossCoefs(TypedDict, total=False): """ Dict of loss coefficients used in the GWLossesFusion. - If one is not provided, the coefficient is assumed to be 0 and will not be logged. - If the loss is excplicitely set to 0, it will be logged, but not take part in - the total loss. + If one is not provided, the coefficient is assumed to be 0 and will not be logged + (a warning is emitted). If the loss is explicitly set to 0, it will be logged, but + not take part in the total loss. """ contrastives: float """Contrastive loss coefficient.""" - fused: float - """fused loss coefficient (encode multiple domains and decode to one of them).""" - demi_cycles: float - """demi_cycles loss coefficient. Demi-cycles are always one-to-one""" + """demi_cycles loss coefficient. Demi-cycles aggregate fused cases too.""" cycles: float """cycles loss coefficient. Cycles can be many-to-one""" @@ -349,6 +347,18 @@ def combine_loss( Returns: `torch.Tensor`: the combined loss. """ + missing = { + name for name in _EXPECTED_COEF_KEYS if name in metrics and name not in coefs + } + for name in sorted(missing): + if name not in _MISSING_COEFS_WARNED: + warnings.warn( + f"Loss coefficient '{name}' not provided; defaulting to 0.", + UserWarning, + stacklevel=2, + ) + _MISSING_COEFS_WARNED.add(name) + loss = torch.stack( [ metrics[name] * coef @@ -360,6 +370,32 @@ def combine_loss( return loss +_EXPECTED_COEF_KEYS = {"contrastives", "demi_cycles", "cycles", "translations"} +_MISSING_COEFS_WARNED: set[str] = set() + + +class CycleCase(TypedDict): + """Container for precomputed cycle inputs to avoid recomputation.""" + + group_name: str + selected_group_label: str + selected_latents: Mapping[str, torch.Tensor] + decoded_latents: Mapping[str, torch.Tensor] + raw_group: Mapping[str, object] + + +class BroadcastResult(TypedDict): + """ + Broadcast loss output without cycle computation. + + `metrics` contains demi-cycle/translation metrics and per-example losses. + `cycle_cases` holds precomputed inputs for later cycle loss computation. + """ + + metrics: dict[str, torch.Tensor] + cycle_cases: list[CycleCase] + + class GWLosses2Domains(GWLossesBase): """ Implementation of `GWLossesBase` used for `GWModule`. @@ -516,27 +552,26 @@ def generate_partitions(n: int) -> Generator[tuple[int, ...], None, None]: yield perm -def broadcast_loss( +def broadcast( gw_mod: GWModuleBase, selection_mod: SelectionBase, domain_mods: Mapping[str, DomainModule], latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT, -) -> dict[str, torch.Tensor]: +) -> BroadcastResult: """ - Computes broadcast loss including demi-cycle, cycle, and translation losses. + Computes broadcast demi-cycle (with fused) and translation losses, and prepares + precomputed artifacts for cycle losses. - This return multiple metrics: + This returns multiple metrics: * `demi_cycles` - * `cycles` * `translations` - * `fused` * `from_{start_group}_to_{domain}_loss` where `{start_group}` is of the form "{domain1,domain2,domainN}" sorted in alphabetical order - (e.g. "from_{t,v}_to_t_loss"). - * `from_{start_group}_to_{domain}_{metric}` with - additional metrics provided by the domain_mod's - `compute_broadcast_loss` output + (e.g. "from_{t,v}_to_t_loss"). Note: fused cases are aggregated into + `demi_cycles`. + * `from_{start_group}_to_{domain}_{metric}` with additional metrics provided by + the domain module's loss outputs * `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_loss` where `{start_group}`, `{target_group}` and `{case_group}` is of the form "{domain1,domain2,domainN}" sorted in alphabetical order @@ -544,8 +579,7 @@ def broadcast_loss( domains, `{target_group}` the target domains used for the cycle and `{case_group}` all available domains participating to the loss. * `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_{metric}` - additional metrics provided by the domain_mod's `compute_broadcast_loss` - output + additional metrics provided by the domain module's loss outputs Args: gw_mod (`shimmer.modules.gw_module.GWModuleBase`): The GWModule to use @@ -555,15 +589,14 @@ def broadcast_loss( raw_data (`RawDomainGroupsT`): raw input data Returns: - A dictionary with the total loss and additional metrics. + `BroadcastResult`: demi/translation metrics plus precomputed cycle data. """ # noqa: E501 losses: dict[str, torch.Tensor] = {} metrics: dict[str, torch.Tensor] = {} demi_cycle_losses: list[str] = [] - cycle_losses: list[str] = [] translation_losses: list[str] = [] - fused_losses: list[str] = [] + cycle_cases: list[CycleCase] = [] for group_domains, latents in latent_domains.items(): encoded_latents = gw_mod.encode(latents) @@ -598,7 +631,7 @@ def broadcast_loss( elif domain not in selected_latents: loss_fn = domain_mods[domain].compute_tr_loss else: - loss_fn = domain_mods[domain].compute_fused_loss + loss_fn = domain_mods[domain].compute_dcy_loss loss_output = loss_fn( pred, ground_truth, raw_data[group_domains][domain] @@ -616,70 +649,96 @@ def broadcast_loss( demi_cycle_losses.append(loss_label + "_loss") elif domain not in selected_latents: translation_losses.append(loss_label + "_loss") - else: # fused loss - fused_losses.append(loss_label + "_loss") + else: # fused loss counts toward demi_cycles aggregate + demi_cycle_losses.append(loss_label + "_loss") if num_active_domains < num_total_domains: - inverse_selected_latents = { - domain: decoded_latents[domain] - for domain in decoded_latents - if domain not in selected_latents - } - - inverse_selected_group_label = ( - "{" + ",".join(sorted(inverse_selected_latents)) + "}" - ) - - re_encoded_latents = gw_mod.encode(inverse_selected_latents) - re_selection_scores = selection_mod( - inverse_selected_latents, re_encoded_latents - ) - re_fused_latents = gw_mod.fuse(re_encoded_latents, re_selection_scores) - re_decoded_latents = gw_mod.decode( - re_fused_latents, domains=selected_latents.keys() - ) - - for domain in selected_latents: - re_ground_truth = latents[domain] - re_loss_output = domain_mods[domain].compute_cy_loss( - re_decoded_latents[domain], - re_ground_truth, - raw_data[group_domains][domain], + cycle_cases.append( + CycleCase( + group_name=group_name, + selected_group_label=selected_group_label, + selected_latents=selected_latents, + decoded_latents=decoded_latents, + raw_group=raw_data[group_domains], ) - if re_loss_output is None: - continue - loss_label = ( - f"from_{selected_group_label}_" - f"through_{inverse_selected_group_label}_to_{domain}_" - f"case_{group_name}" - ) - losses[loss_label + "_loss"] = re_loss_output.loss - metrics.update( - { - f"{loss_label}_{k}": v - for k, v in re_loss_output.metrics.items() - } - ) - cycle_losses.append(loss_label + "_loss") + ) if demi_cycle_losses: metrics["demi_cycles"] = torch.mean( torch.stack([losses[loss_name] for loss_name in demi_cycle_losses]) ) - if cycle_losses: - metrics["cycles"] = torch.mean( - torch.stack([losses[loss_name] for loss_name in cycle_losses]) - ) if translation_losses: metrics["translations"] = torch.mean( torch.stack([losses[loss_name] for loss_name in translation_losses]) ) - if fused_losses: - metrics["fused"] = torch.mean( - torch.stack([losses[loss_name] for loss_name in fused_losses]) - ) metrics.update(losses) + return BroadcastResult(metrics=metrics, cycle_cases=cycle_cases) + + +def cycle_loss_from_broadcast( + gw_mod: GWModuleBase, + selection_mod: SelectionBase, + domain_mods: Mapping[str, DomainModule], + cycle_cases: list[CycleCase], +) -> dict[str, torch.Tensor]: + """ + Computes cycle losses from precomputed broadcast artifacts. + + Args: + gw_mod: GW module used for encoding/decoding. + selection_mod: selection module used during fusion. + domain_mods: domain modules used to compute the losses. + cycle_cases: precomputed cycle data produced by `broadcast`. + + Returns: + Metrics dict containing per-case losses/metrics and aggregate `cycles`. + """ + metrics: dict[str, torch.Tensor] = {} + cycle_losses: list[str] = [] + + for case in cycle_cases: + inverse_selected_latents = { + domain: case["decoded_latents"][domain] + for domain in case["decoded_latents"] + if domain not in case["selected_latents"] + } + inverse_selected_group_label = ( + "{" + ",".join(sorted(inverse_selected_latents)) + "}" + ) + + re_encoded_latents = gw_mod.encode(inverse_selected_latents) + re_selection_scores = selection_mod( + inverse_selected_latents, re_encoded_latents + ) + re_fused_latents = gw_mod.fuse(re_encoded_latents, re_selection_scores) + re_decoded_latents = gw_mod.decode( + re_fused_latents, domains=case["selected_latents"].keys() + ) + + for domain, target in case["selected_latents"].items(): + loss_name = ( + f"from_{case['selected_group_label']}_" + f"through_{inverse_selected_group_label}_to_{domain}_" + f"case_{case['group_name']}" + ) + loss_output = domain_mods[domain].compute_cy_loss( + re_decoded_latents[domain], target, case["raw_group"][domain] + ) + if loss_output is None: + continue + + metrics[loss_name + "_loss"] = loss_output.loss + metrics.update( + {f"{loss_name}_{k}": v for k, v in loss_output.metrics.items()} + ) + cycle_losses.append(loss_name + "_loss") + + if cycle_losses: + metrics["cycles"] = torch.mean( + torch.stack([metrics[loss_name] for loss_name in cycle_losses]) + ) + return metrics @@ -728,10 +787,10 @@ def contrastive_loss( return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn) - def broadcast_loss( + def broadcast( self, latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT - ) -> dict[str, torch.Tensor]: - return broadcast_loss( + ) -> BroadcastResult: + return broadcast( self.gw_mod, self.selection_mod, self.domain_mods, latent_domains, raw_data ) @@ -756,17 +815,20 @@ def step( metrics: dict[str, torch.Tensor] = {} metrics.update(self.contrastive_loss(domain_latents)) - metrics.update(self.broadcast_loss(domain_latents, raw_data)) + broadcast_result = self.broadcast(domain_latents, raw_data) + metrics.update(broadcast_result["metrics"]) + metrics.update( + cycle_loss_from_broadcast( + self.gw_mod, + self.selection_mod, + self.domain_mods, + broadcast_result["cycle_cases"], + ) + ) loss = combine_loss(metrics, self.loss_coefs) - metrics["broadcast_loss"] = torch.stack( - [ - metrics[name] - for name, coef in self.loss_coefs.items() - if isinstance(coef, float) and name != "contrastives" - ], - dim=0, - ).mean() + # Do not expose the deprecated broadcast_loss aggregate. + metrics.pop("broadcast_loss", None) return LossOutput(loss, metrics) diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index ac03bdd1..062649ba 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from collections.abc import Iterable +from collections.abc import Iterable, Mapping +from typing import cast import torch import torch.nn as nn @@ -152,6 +153,192 @@ def _calculate_attention_dict( return attention_dict +class LearnedAttention(SelectionBase): + """ + Content-based single-step attention over GW latents with configurable toggles. + + Design: + - Query is the mean of available GW latents (content-q0 seed) + - Single-step dot-product attention over domains (no refinement loop) + - Optional per-domain keys + + Toggles: + - per_domain_keys: use per-domain key projections instead of a shared one + - stopgrad: detach GW latents before computing keys/query + - key_on_prefusion: compute keys on pre-fusion GW latents (True) or raw domains + - domain_dims: required when key_on_prefusion=False to size per-domain key layers + """ + + def __init__( + self, + gw_dim: int, + domain_names: Iterable[str], + head_size: int = 64, + per_domain_keys: bool = False, + stopgrad: bool = True, + key_on_prefusion: bool = True, + domain_dims: Mapping[str, int] | None = None, + ): + super().__init__() + self.gw_dim = int(gw_dim) + self.head_size = int(head_size) + self.domain_names = list(domain_names) + + # Toggles + self.per_domain_keys = bool(per_domain_keys) + self.stopgrad = bool(stopgrad) + self.key_on_prefusion = bool(key_on_prefusion) + self.domain_dims = dict(domain_dims) if domain_dims is not None else None + + # Projections + self.query_layer = nn.Linear(self.gw_dim, self.head_size) + self.per_key_layers: nn.ModuleDict | None + self.shared_key_layer: nn.Linear | None + if self.key_on_prefusion: + if self.per_domain_keys: + self.per_key_layers = nn.ModuleDict( + { + d: nn.Linear(self.gw_dim, self.head_size) + for d in self.domain_names + } + ) + self.shared_key_layer = None + else: + self.shared_key_layer = nn.Linear(self.gw_dim, self.head_size) + self.per_key_layers = None + else: + if not self.per_domain_keys: + raise ValueError( + "key_on_prefusion=False requires per_domain_keys=True because " + "domain latent dimensions can differ." + ) + if self.domain_dims is None: + raise ValueError( + "key_on_prefusion=False requires domain_dims for key projections." + ) + missing_dims = [d for d in self.domain_names if d not in self.domain_dims] + if missing_dims: + raise ValueError( + f"Missing domain_dims for: {', '.join(sorted(missing_dims))}" + ) + self.per_key_layers = nn.ModuleDict( + { + d: nn.Linear(self.domain_dims[d], self.head_size) + for d in self.domain_names + } + ) + self.shared_key_layer = None + + @staticmethod + def _calc_attention( + keys: dict[str, torch.Tensor], + query: torch.Tensor, + order: Iterable[str], + ) -> dict[str, torch.Tensor]: + """ + Compute attention over domains. + + Args: + keys: mapping of domain -> key tensor (B, H) + query: query tensor (B, H) + order: iterable of domain names to fix output ordering + + Returns: + dict[str, torch.Tensor]: per-domain attention scores that sum to 1. + """ + names = [d for d in order if d in keys] + if not names: + raise ValueError("LearnedAttention: no keys provided.") + + logits = torch.stack( + [(keys[d] * query).sum(dim=1) for d in names], dim=1 + ) # (B, D) + + probs = torch.softmax(logits, dim=1) + + return {d: probs[:, i] for i, d in enumerate(names)} + + def forward( + self, + domains: LatentsDomainGroupT, + encodings_pre_fusion: LatentsDomainGroupT | None = None, + ) -> dict[str, torch.Tensor]: + """ + Args: + domains: mapping from domain name to GW latent (B, gw_dim) + encodings_pre_fusion: pre-fusion encodings (used when key_on_prefusion) + + Returns: + dict[str, torch.Tensor]: per-domain attention weights. + """ + domain_latents: Mapping[str, torch.Tensor] = domains + + present = [d for d in self.domain_names if d in domain_latents] + if not present: + raise ValueError( + "LearnedAttention: no known domains present in gw_latents." + ) + + if self.key_on_prefusion: + if encodings_pre_fusion is None: + raise ValueError( + "key_on_prefusion=True requires encodings_pre_fusion inputs." + ) + key_source = encodings_pre_fusion + else: + key_source = domain_latents + + missing_keys = [d for d in present if d not in key_source] + if missing_keys: + raise ValueError( + f"Missing key latents for: {', '.join(sorted(missing_keys))}" + ) + + if encodings_pre_fusion is None: + query_source = domain_latents + else: + query_source = encodings_pre_fusion + + missing_query = [d for d in present if d not in query_source] + if missing_query: + raise ValueError( + f"Missing query latents for: {', '.join(sorted(missing_query))}" + ) + + if self.stopgrad: + key_latents = {d: key_source[d].detach() for d in present} + query_latents = {d: query_source[d].detach() for d in present} + else: + key_latents = {d: key_source[d] for d in present} + query_latents = {d: query_source[d] for d in present} + + if self.per_domain_keys: + if self.per_key_layers is None: + raise RuntimeError( + "per_domain_keys=True but per-domain key layers are missing." + ) + keys = { + d: cast(nn.Linear, self.per_key_layers[d])(key_latents[d]) + for d in present + } + else: + if self.shared_key_layer is None: + raise RuntimeError( + "per_domain_keys=False but shared key layer is missing." + ) + proj = self.shared_key_layer + keys = {d: proj(key_latents[d]) for d in present} + + stacked = torch.stack([query_latents[d] for d in present], dim=0) # (D, B, F) + query = self.query_layer(stacked.mean(0)) # (B, H) + + return self._calc_attention( + keys=keys, + query=query, + order=self.domain_names, + ) + + class RandomSelection(SelectionBase): """ Modified random attention to only utilize uniform-softmax scores across modalities. diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index 012e169e..7900e861 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -27,7 +27,7 @@ def compute_loss( return LossOutput(loss=loss) # Constructing LossOutput with the loss -def test_broadcast_loss(): +def test_broadcast(): domain_mods: dict[str, DomainModule] = { "domain1": DummyDomainModule(latent_dim=10), "domain2": DummyDomainModule(latent_dim=10), @@ -36,7 +36,6 @@ def test_broadcast_loss(): gw_decoders = {"domain1": nn.Linear(10, 10), "domain2": nn.Linear(10, 10)} workspace_dim = 10 loss_coefs: BroadcastLossCoefs = { - "fused": 1.0, "cycles": 1.0, "demi_cycles": 1.0, "translations": 1.0, @@ -56,7 +55,7 @@ def test_broadcast_loss(): learn_logit_scale=False, ) - # Adjusting the dummy data to fit the expected input structure for broadcast_loss + # Adjusting the dummy data to fit the expected input structure for broadcast # Now using a frozenset for the keys to match LatentsDomainGroupsT latent_domains = { frozenset(["domain1", "domain2"]): { @@ -65,18 +64,22 @@ def test_broadcast_loss(): } } - # Test broadcast_loss with the corrected structure - output = gw_fusion.loss_mod.broadcast_loss(latent_domains, latent_domains) + # Test broadcast with the corrected structure + result = gw_fusion.loss_mod.broadcast(latent_domains, latent_domains) + assert "metrics" in result and "cycle_cases" in result + # Cycle metrics are computed in step(), not within broadcast + metrics = result["metrics"] + assert all(metric in metrics for metric in ["demi_cycles", "translations"]) - er_msg = "Demi-cycle, cycle, fused and translation metrics should be in the output." - assert all( - metric in output - for metric in ["demi_cycles", "cycles", "translations", "fused"] - ), er_msg + # Broadcast metrics should be logged from step (but not the deprecated aggregate) + step_output = gw_fusion.loss_mod.step(latent_domains, latent_domains, mode="train") + assert "broadcast_loss" not in step_output.metrics + for metric in ["demi_cycles", "translations", "cycles"]: + assert metric in step_output.metrics er_msg = "Losses should be scalar tensors or 1D tensor with size equal to one." assert all( (loss.dim() == 0 or (loss.dim() == 1 and loss.size(0) == 1)) - for key, loss in output.items() + for key, loss in metrics.items() if key.endswith("_loss") ), er_msg diff --git a/tests/test_learned_attention.py b/tests/test_learned_attention.py new file mode 100644 index 00000000..b361c706 --- /dev/null +++ b/tests/test_learned_attention.py @@ -0,0 +1,160 @@ +import pytest +import torch +import torch.nn as nn + +from shimmer.modules.domain import DomainModule +from shimmer.modules.global_workspace import ( + GlobalWorkspaceFusion, + freeze_domain_modules, +) +from shimmer.modules.selection import LearnedAttention + + +def _make_latents(batch_size: int, dim: int) -> dict[str, torch.Tensor]: + return { + "a": torch.randn(batch_size, dim, requires_grad=True), + "b": torch.randn(batch_size, dim, requires_grad=True), + } + + +def test_learned_attention_probs_sum_to_one() -> None: + selector = LearnedAttention(gw_dim=4, domain_names=["a", "b"], head_size=2) + latents = _make_latents(batch_size=8, dim=4) + + weights = selector(latents, encodings_pre_fusion=latents) + + for domain in ["a", "b"]: + assert weights[domain].shape == (8,) + + stacked = torch.stack([weights["a"], weights["b"]], dim=1) + assert torch.allclose(stacked.sum(dim=1), torch.ones(8)) + + +def test_learned_attention_stopgrad_toggle() -> None: + base_latents = _make_latents(batch_size=4, dim=6) + + frozen_latents = { + k: v.detach().clone().requires_grad_(True) for k, v in base_latents.items() + } + frozen_selector = LearnedAttention( + gw_dim=6, domain_names=["a", "b"], head_size=3, stopgrad=True + ) + frozen_weights = frozen_selector( + frozen_latents, encodings_pre_fusion=frozen_latents + ) + torch.stack(list(frozen_weights.values())).sum().backward() + assert frozen_latents["a"].grad is None + assert frozen_latents["b"].grad is None + + train_latents = { + k: v.detach().clone().requires_grad_(True) for k, v in base_latents.items() + } + trainable_selector = LearnedAttention( + gw_dim=6, domain_names=["a", "b"], head_size=3, stopgrad=False + ) + trainable_weights = trainable_selector( + train_latents, encodings_pre_fusion=train_latents + ) + torch.stack(list(trainable_weights.values())).sum().backward() + assert train_latents["a"].grad is not None + assert train_latents["b"].grad is not None + + +def test_learned_attention_domain_key_path() -> None: + domain_dims = {"a": 3, "b": 5} + selector = LearnedAttention( + gw_dim=4, + domain_names=domain_dims.keys(), + head_size=3, + per_domain_keys=True, + stopgrad=False, + key_on_prefusion=False, + domain_dims=domain_dims, + ) + + domain_latents = { + "a": torch.randn(6, 3, requires_grad=True), + "b": torch.randn(6, 5, requires_grad=True), + } + prefusion_latents = { + "a": torch.randn(6, 4, requires_grad=True), + "b": torch.randn(6, 4, requires_grad=True), + } + + weights = selector(domain_latents, encodings_pre_fusion=prefusion_latents) + + stacked = torch.stack([weights["a"], weights["b"]], dim=1) + assert torch.allclose(stacked.sum(dim=1), torch.ones(6)) + + +def test_learned_attention_domain_key_shared_layer_error() -> None: + domain_dims = {"a": 3, "b": 5} + with pytest.raises(ValueError): + LearnedAttention( + gw_dim=4, + domain_names=domain_dims.keys(), + head_size=3, + per_domain_keys=False, + stopgrad=True, + key_on_prefusion=False, + domain_dims=domain_dims, + ) + + +class _DummyDomain(DomainModule): + def __init__(self, latent_dim: int): + super().__init__(latent_dim) + + def encode(self, x: torch.Tensor) -> torch.Tensor: # pragma: no cover - simple stub + return x + + def decode(self, z: torch.Tensor) -> torch.Tensor: # pragma: no cover - simple stub + return z + + +def test_global_workspace_init_learned_attention_domain_dims() -> None: + domain_mods = freeze_domain_modules({"a": _DummyDomain(3), "b": _DummyDomain(5)}) + gw_encoders = {"a": nn.Identity(), "b": nn.Identity()} + gw_decoders = {"a": nn.Identity(), "b": nn.Identity()} + + gw = GlobalWorkspaceFusion( + domain_mods=domain_mods, + gw_encoders=gw_encoders, + gw_decoders=gw_decoders, + workspace_dim=4, + loss_coefs={"contrastives": 0.0}, + ) + + selector = gw.init_learned_attention( + head_size=2, + per_domain_keys=True, + stopgrad=False, + key_on_prefusion=False, + ) + + assert selector.key_on_prefusion is False + assert selector.per_key_layers is not None + assert selector.per_key_layers["a"].weight.shape[1] == 3 + assert selector.per_key_layers["b"].weight.shape[1] == 5 + + +def test_global_workspace_init_learned_attention_shared_error() -> None: + domain_mods = freeze_domain_modules({"a": _DummyDomain(3), "b": _DummyDomain(5)}) + gw_encoders = {"a": nn.Identity(), "b": nn.Identity()} + gw_decoders = {"a": nn.Identity(), "b": nn.Identity()} + + gw = GlobalWorkspaceFusion( + domain_mods=domain_mods, + gw_encoders=gw_encoders, + gw_decoders=gw_decoders, + workspace_dim=4, + loss_coefs={"contrastives": 0.0}, + ) + + with pytest.raises(ValueError): + gw.init_learned_attention( + head_size=2, + per_domain_keys=False, + stopgrad=True, + key_on_prefusion=False, + ) diff --git a/tests/test_loss_coefs_warning.py b/tests/test_loss_coefs_warning.py new file mode 100644 index 00000000..01928c24 --- /dev/null +++ b/tests/test_loss_coefs_warning.py @@ -0,0 +1,20 @@ +import warnings + +import torch + +from shimmer.modules.losses import combine_loss + + +def test_missing_loss_coef_warns_and_defaults_to_zero() -> None: + metrics = { + "demi_cycles": torch.tensor(1.0), + "contrastives": torch.tensor(2.0), + } + coefs = {"contrastives": 1.0} + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + loss = combine_loss(metrics, coefs) + + assert any("demi_cycles" in str(w.message) for w in caught) + assert torch.isclose(loss, torch.tensor(2.0))