Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
translation,
)
from shimmer.modules.losses import (
BroadcastLossCoefs,
GWLosses2Domains,
GWLossesBase,
LossCoefs,
Expand Down Expand Up @@ -85,7 +84,6 @@
"contrastive_loss",
"ContrastiveLoss",
"LossCoefs",
"BroadcastLossCoefs",
"combine_loss",
"GWLossesBase",
"GWLosses2Domains",
Expand Down
2 changes: 0 additions & 2 deletions shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
translation,
)
from shimmer.modules.losses import (
BroadcastLossCoefs,
GWLosses2Domains,
GWLossesBase,
LossCoefs,
Expand Down Expand Up @@ -63,7 +62,6 @@
"contrastive_loss",
"ContrastiveLoss",
"LossCoefs",
"BroadcastLossCoefs",
"combine_loss",
"GWLossesBase",
"GWLosses2Domains",
Expand Down
3 changes: 1 addition & 2 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
translation,
)
from shimmer.modules.losses import (
BroadcastLossCoefs,
GWLosses,
GWLosses2Domains,
GWLossesBase,
Expand Down Expand Up @@ -721,7 +720,7 @@ def __init__(
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: BroadcastLossCoefs | Mapping[str, float],
loss_coefs: LossCoefs | Mapping[str, float],
selection_temperature: float = 0.2,
selection_mod: SelectionBase | None = None,
optim_lr: float = 1e-3,
Expand Down
40 changes: 7 additions & 33 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,31 +306,9 @@ class LossCoefs(TypedDict, total=False):
"""Contrastive loss coefficient."""


class BroadcastLossCoefs(TypedDict, total=False):
"""
Dict of loss coefficients used in the GWLossesFusion.

If one is not provided, the coefficient is assumed to be 0 and will not be logged
(a warning is emitted). If the loss is explicitly set to 0, it will be logged, but
not take part in the total loss.
"""

contrastives: float
"""Contrastive loss coefficient."""

demi_cycles: float
"""demi_cycles loss coefficient. Demi-cycles aggregate fused cases too."""

cycles: float
"""cycles loss coefficient. Cycles can be many-to-one"""

translations: float
"""translation loss coefficient. Translation, like cycles, can be many-to-one."""


def combine_loss(
metrics: dict[str, torch.Tensor],
coefs: Mapping[str, float] | LossCoefs | BroadcastLossCoefs,
coefs: Mapping[str, float] | LossCoefs,
) -> torch.Tensor:
"""
Combines the metrics according to the ones selected in coefs
Expand Down Expand Up @@ -626,12 +604,10 @@ def broadcast(
continue
ground_truth = latents[domain]

if num_active_domains == 1 and domain in selected_latents:
if domain in selected_latents:
loss_fn = domain_mods[domain].compute_dcy_loss
elif domain not in selected_latents:
loss_fn = domain_mods[domain].compute_tr_loss
else:
loss_fn = domain_mods[domain].compute_dcy_loss
loss_fn = domain_mods[domain].compute_tr_loss

loss_output = loss_fn(
pred, ground_truth, raw_data[group_domains][domain]
Expand All @@ -645,12 +621,10 @@ def broadcast(
{f"{loss_label}_{k}": v for k, v in loss_output.metrics.items()}
)

if num_active_domains == 1 and domain in selected_latents:
if domain in selected_latents:
demi_cycle_losses.append(loss_label + "_loss")
elif domain not in selected_latents:
else:
translation_losses.append(loss_label + "_loss")
else: # fused loss counts toward demi_cycles aggregate
demi_cycle_losses.append(loss_label + "_loss")

if num_active_domains < num_total_domains:
cycle_cases.append(
Expand Down Expand Up @@ -752,7 +726,7 @@ def __init__(
gw_mod: GWModule,
selection_mod: SelectionBase,
domain_mods: dict[str, DomainModule],
loss_coefs: BroadcastLossCoefs | Mapping[str, float],
loss_coefs: LossCoefs | Mapping[str, float],
contrastive_fn: ContrastiveLossType,
):
"""
Expand All @@ -762,7 +736,7 @@ def __init__(
gw_mod: The GWModule for the global workspace.
selection_mod: The selection mechanism for the model.
domain_mods: A mapping of domain names to their respective DomainModule.
loss_coefs (`BroadcastLossCoefs`): coefs for the losses
loss_coefs (`LossCoefs`): coefs for the losses
contrastive_fn: The function used for computing contrastive loss.
"""
super().__init__()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from shimmer.modules.domain import DomainModule, LossOutput
from shimmer.modules.global_workspace import GlobalWorkspaceFusion
from shimmer.modules.losses import BroadcastLossCoefs
from shimmer.modules.losses import LossCoefs


class DummyDomainModule(DomainModule):
Expand Down Expand Up @@ -35,7 +35,7 @@ def test_broadcast():
gw_encoders = {"domain1": nn.Linear(10, 10), "domain2": nn.Linear(10, 10)}
gw_decoders = {"domain1": nn.Linear(10, 10), "domain2": nn.Linear(10, 10)}
workspace_dim = 10
loss_coefs: BroadcastLossCoefs = {
loss_coefs: LossCoefs = {
"cycles": 1.0,
"demi_cycles": 1.0,
"translations": 1.0,
Expand Down