From 68e5368a64bd343cb2f5eb17144d01f71e15cf8b Mon Sep 17 00:00:00 2001 From: Rufin VanRullen <40198228+rufinv@users.noreply.github.com> Date: Wed, 14 Jan 2026 14:06:50 +0100 Subject: [PATCH 1/8] Cleaning up losses.py Removed BroadcastLossCoeff as it is 100% identical to LossCoeff Removed an elif statement in broadcast function, to simplify the classification between demi-cycles and translation. --- shimmer/modules/losses.py | 43 ++++++++------------------------------- 1 file changed, 8 insertions(+), 35 deletions(-) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 202040b..a92657f 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -305,32 +305,9 @@ class LossCoefs(TypedDict, total=False): contrastives: float """Contrastive loss coefficient.""" - -class BroadcastLossCoefs(TypedDict, total=False): - """ - Dict of loss coefficients used in the GWLossesFusion. - - If one is not provided, the coefficient is assumed to be 0 and will not be logged - (a warning is emitted). If the loss is explicitly set to 0, it will be logged, but - not take part in the total loss. - """ - - contrastives: float - """Contrastive loss coefficient.""" - - demi_cycles: float - """demi_cycles loss coefficient. Demi-cycles aggregate fused cases too.""" - - cycles: float - """cycles loss coefficient. Cycles can be many-to-one""" - - translations: float - """translation loss coefficient. Translation, like cycles, can be many-to-one.""" - - def combine_loss( metrics: dict[str, torch.Tensor], - coefs: Mapping[str, float] | LossCoefs | BroadcastLossCoefs, + coefs: Mapping[str, float] | LossCoefs, ) -> torch.Tensor: """ Combines the metrics according to the ones selected in coefs @@ -626,12 +603,10 @@ def broadcast( continue ground_truth = latents[domain] - if num_active_domains == 1 and domain in selected_latents: + if domain in selected_latents: loss_fn = domain_mods[domain].compute_dcy_loss - elif domain not in selected_latents: - loss_fn = domain_mods[domain].compute_tr_loss else: - loss_fn = domain_mods[domain].compute_dcy_loss + loss_fn = domain_mods[domain].compute_tr_loss loss_output = loss_fn( pred, ground_truth, raw_data[group_domains][domain] @@ -645,13 +620,11 @@ def broadcast( {f"{loss_label}_{k}": v for k, v in loss_output.metrics.items()} ) - if num_active_domains == 1 and domain in selected_latents: + if domain in selected_latents: demi_cycle_losses.append(loss_label + "_loss") - elif domain not in selected_latents: + else: translation_losses.append(loss_label + "_loss") - else: # fused loss counts toward demi_cycles aggregate - demi_cycle_losses.append(loss_label + "_loss") - + if num_active_domains < num_total_domains: cycle_cases.append( CycleCase( @@ -752,7 +725,7 @@ def __init__( gw_mod: GWModule, selection_mod: SelectionBase, domain_mods: dict[str, DomainModule], - loss_coefs: BroadcastLossCoefs | Mapping[str, float], + loss_coefs: LossCoefs | Mapping[str, float], contrastive_fn: ContrastiveLossType, ): """ @@ -762,7 +735,7 @@ def __init__( gw_mod: The GWModule for the global workspace. selection_mod: The selection mechanism for the model. domain_mods: A mapping of domain names to their respective DomainModule. - loss_coefs (`BroadcastLossCoefs`): coefs for the losses + loss_coefs (`LossCoefs`): coefs for the losses contrastive_fn: The function used for computing contrastive loss. """ super().__init__() From e97d1a564267f375fba43ce193a8672fa7211209 Mon Sep 17 00:00:00 2001 From: Rufin VanRullen <40198228+rufinv@users.noreply.github.com> Date: Thu, 22 Jan 2026 16:15:58 +0100 Subject: [PATCH 2/8] Fix formatting issue in losses.py To avoid ruff error --- shimmer/modules/losses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index a92657f..8a62049 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -624,7 +624,7 @@ def broadcast( demi_cycle_losses.append(loss_label + "_loss") else: translation_losses.append(loss_label + "_loss") - + if num_active_domains < num_total_domains: cycle_cases.append( CycleCase( From baf58d7c77610523cfaa831e6ba3cae864e11389 Mon Sep 17 00:00:00 2001 From: Rufin VanRullen <40198228+rufinv@users.noreply.github.com> Date: Thu, 22 Jan 2026 15:25:25 +0000 Subject: [PATCH 3/8] reformatted losses.py with ruff --- shimmer/modules/losses.py | 1 + 1 file changed, 1 insertion(+) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 8a62049..e6ed67c 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -305,6 +305,7 @@ class LossCoefs(TypedDict, total=False): contrastives: float """Contrastive loss coefficient.""" + def combine_loss( metrics: dict[str, torch.Tensor], coefs: Mapping[str, float] | LossCoefs, From 0d35abe24835fcc916f606f90af7a2c9a3c9167d Mon Sep 17 00:00:00 2001 From: Rufin VanRullen <40198228+rufinv@users.noreply.github.com> Date: Thu, 22 Jan 2026 16:30:08 +0100 Subject: [PATCH 4/8] Remove BroadcastLossCoefs import from global_workspace.py --- shimmer/modules/global_workspace.py | 1 - 1 file changed, 1 deletion(-) diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 07f8488..43a946a 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -23,7 +23,6 @@ translation, ) from shimmer.modules.losses import ( - BroadcastLossCoefs, GWLosses, GWLosses2Domains, GWLossesBase, From fb8a03486404b9db77a1e3bef02a609f102be1b2 Mon Sep 17 00:00:00 2001 From: Rufin VanRullen <40198228+rufinv@users.noreply.github.com> Date: Thu, 22 Jan 2026 16:32:16 +0100 Subject: [PATCH 5/8] Change type hint for loss_coefs parameter --- shimmer/modules/global_workspace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 43a946a..4d0c251 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -720,7 +720,7 @@ def __init__( gw_encoders: Mapping[str, Module], gw_decoders: Mapping[str, Module], workspace_dim: int, - loss_coefs: BroadcastLossCoefs | Mapping[str, float], + loss_coefs: LossCoefs | Mapping[str, float], selection_temperature: float = 0.2, selection_mod: SelectionBase | None = None, optim_lr: float = 1e-3, From ea5790bc77e08ffb7ee9e4d9e77b0ebe3a0409d7 Mon Sep 17 00:00:00 2001 From: Rufin VanRullen <40198228+rufinv@users.noreply.github.com> Date: Thu, 22 Jan 2026 16:35:24 +0100 Subject: [PATCH 6/8] Remove BroadcastLossCoefs from imports --- shimmer/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/shimmer/__init__.py b/shimmer/__init__.py index 5e8355d..763bcfe 100644 --- a/shimmer/__init__.py +++ b/shimmer/__init__.py @@ -33,7 +33,6 @@ translation, ) from shimmer.modules.losses import ( - BroadcastLossCoefs, GWLosses2Domains, GWLossesBase, LossCoefs, @@ -85,7 +84,6 @@ "contrastive_loss", "ContrastiveLoss", "LossCoefs", - "BroadcastLossCoefs", "combine_loss", "GWLossesBase", "GWLosses2Domains", From 85b2ca569f2b52a7a8d2d8eccc05c57e915cf2a1 Mon Sep 17 00:00:00 2001 From: Rufin VanRullen <40198228+rufinv@users.noreply.github.com> Date: Thu, 22 Jan 2026 16:35:59 +0100 Subject: [PATCH 7/8] Remove BroadcastLossCoefs from module imports Removed BroadcastLossCoefs from imports. --- shimmer/modules/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/shimmer/modules/__init__.py b/shimmer/modules/__init__.py index cd5957e..9145a76 100644 --- a/shimmer/modules/__init__.py +++ b/shimmer/modules/__init__.py @@ -27,7 +27,6 @@ translation, ) from shimmer.modules.losses import ( - BroadcastLossCoefs, GWLosses2Domains, GWLossesBase, LossCoefs, @@ -63,7 +62,6 @@ "contrastive_loss", "ContrastiveLoss", "LossCoefs", - "BroadcastLossCoefs", "combine_loss", "GWLossesBase", "GWLosses2Domains", From 3f72dfe1b6d14f165fe26aaf6b2839e933a071d8 Mon Sep 17 00:00:00 2001 From: Rufin VanRullen <40198228+rufinv@users.noreply.github.com> Date: Thu, 22 Jan 2026 16:37:10 +0100 Subject: [PATCH 8/8] Replace BroadcastLossCoefs with LossCoefs --- tests/test_broadcast.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index 7900e86..ff5e39e 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -5,7 +5,7 @@ from shimmer.modules.domain import DomainModule, LossOutput from shimmer.modules.global_workspace import GlobalWorkspaceFusion -from shimmer.modules.losses import BroadcastLossCoefs +from shimmer.modules.losses import LossCoefs class DummyDomainModule(DomainModule): @@ -35,7 +35,7 @@ def test_broadcast(): gw_encoders = {"domain1": nn.Linear(10, 10), "domain2": nn.Linear(10, 10)} gw_decoders = {"domain1": nn.Linear(10, 10), "domain2": nn.Linear(10, 10)} workspace_dim = 10 - loss_coefs: BroadcastLossCoefs = { + loss_coefs: LossCoefs = { "cycles": 1.0, "demi_cycles": 1.0, "translations": 1.0,