diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 4d0c251..9b42204 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -722,7 +722,6 @@ def __init__( workspace_dim: int, loss_coefs: LossCoefs | Mapping[str, float], selection_temperature: float = 0.2, - selection_mod: SelectionBase | None = None, optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, @@ -732,6 +731,7 @@ def __init__( | None | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh, + selection_mod: SelectionBase | None = None, ) -> None: """ Initializes a Global Workspace @@ -765,6 +765,8 @@ def __init__( no scheduler will be used. Defaults to use OneCycleScheduler fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation function to fuse the domains. + selection_mod (`SelectionBase | None`): optional custom selection module. + If None (default), uses `RandomSelection`. """ domain_mods = freeze_domain_modules(domain_mods) gw_mod = GWModule(