@@ -34,6 +34,7 @@ def __init__(
3434 discrete : bool ,
3535 gammas : torch .Tensor ,
3636 continuous_dist_type : Type [PolicyOutput ],
37+ discrete_dist_type : Type [PolicyOutput ] = Discrete ,
3738 ):
3839 super ().__init__ ()
3940 self .state_dim = state_dim
@@ -42,7 +43,7 @@ def __init__(
4243 self .gammas = gammas
4344 self .num_gammas = len (self .gammas )
4445 # determine policy output
45- dist_type = Discrete if self .discrete else continuous_dist_type
46+ dist_type = discrete_dist_type if self .discrete else continuous_dist_type
4647 self .policy_dist = dist_type (d_action = self .action_dim )
4748 assert isinstance (self .policy_dist , PolicyOutput )
4849 assert self .policy_dist .is_discrete == self .discrete
@@ -116,13 +117,15 @@ def __init__(
116117 activation : str = "leaky_relu" ,
117118 dropout_p : float = 0.0 ,
118119 continuous_dist_type : Type [PolicyOutput ] = TanhGaussian ,
120+ discrete_dist_type : Type [PolicyOutput ] = Discrete ,
119121 ):
120122 super ().__init__ (
121123 state_dim = state_dim ,
122124 action_dim = action_dim ,
123125 discrete = discrete ,
124126 gammas = gammas ,
125127 continuous_dist_type = continuous_dist_type ,
128+ discrete_dist_type = discrete_dist_type ,
126129 )
127130 # build base network
128131 self .base = MLP (
@@ -184,13 +187,15 @@ def __init__(
184187 normalization : str = "layer" ,
185188 dropout_p : float = 0.0 ,
186189 continuous_dist_type : Type [PolicyOutput ] = TanhGaussian ,
190+ discrete_dist_type : Type [PolicyOutput ] = Discrete ,
187191 ):
188192 super ().__init__ (
189193 state_dim = state_dim ,
190194 action_dim = action_dim ,
191195 discrete = discrete ,
192196 gammas = gammas ,
193197 continuous_dist_type = continuous_dist_type ,
198+ discrete_dist_type = discrete_dist_type ,
194199 )
195200 self .inp = MLP (
196201 d_inp = state_dim ,
0 commit comments