Skip to content

Commit 5b4a8dc

Browse files
Make selection_mod optional without breaking positional args
1 parent 5cce857 commit 5b4a8dc

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

shimmer/modules/global_workspace.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,6 @@ def __init__(
723723
workspace_dim: int,
724724
loss_coefs: BroadcastLossCoefs | Mapping[str, float],
725725
selection_temperature: float = 0.2,
726-
selection_mod: SelectionBase | None = None,
727726
optim_lr: float = 1e-3,
728727
optim_weight_decay: float = 0.0,
729728
scheduler_args: SchedulerArgs | None = None,
@@ -733,6 +732,7 @@ def __init__(
733732
| None
734733
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
735734
fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh,
735+
selection_mod: SelectionBase | None = None,
736736
) -> None:
737737
"""
738738
Initializes a Global Workspace
@@ -766,6 +766,8 @@ def __init__(
766766
no scheduler will be used. Defaults to use OneCycleScheduler
767767
fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation
768768
function to fuse the domains.
769+
selection_mod (`SelectionBase | None`): optional custom selection module.
770+
If None (default), uses `RandomSelection`.
769771
"""
770772
domain_mods = freeze_domain_modules(domain_mods)
771773
gw_mod = GWModule(

0 commit comments

Comments
 (0)