@@ -723,7 +723,6 @@ def __init__(
723723 workspace_dim : int ,
724724 loss_coefs : BroadcastLossCoefs | Mapping [str , float ],
725725 selection_temperature : float = 0.2 ,
726- selection_mod : SelectionBase | None = None ,
727726 optim_lr : float = 1e-3 ,
728727 optim_weight_decay : float = 0.0 ,
729728 scheduler_args : SchedulerArgs | None = None ,
@@ -733,6 +732,7 @@ def __init__(
733732 | None
734733 | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel .DEFAULT ,
735734 fusion_activation_fn : Callable [[torch .Tensor ], torch .Tensor ] = torch .tanh ,
735+ selection_mod : SelectionBase | None = None ,
736736 ) -> None :
737737 """
738738 Initializes a Global Workspace
@@ -766,6 +766,8 @@ def __init__(
766766 no scheduler will be used. Defaults to use OneCycleScheduler
767767 fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation
768768 function to fuse the domains.
769+ selection_mod (`SelectionBase | None`): optional custom selection module.
770+ If None (default), uses `RandomSelection`.
769771 """
770772 domain_mods = freeze_domain_modules (domain_mods )
771773 gw_mod = GWModule (
0 commit comments