Skip to content

Commit 1e3bfed

Browse files
authored
Merge pull request #93 from UT-Austin-RPL/discret_dist_type
`discrete_dist_type`
2 parents b8450e2 + 4798608 commit 1e3bfed

3 files changed

Lines changed: 8 additions & 3 deletions

File tree

amago/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "3.1.1"
1+
__version__ = "3.1.2"
22

33
from .experiment import Experiment
44
from .agent import Agent

amago/nets/actor_critic.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
discrete: bool,
3535
gammas: torch.Tensor,
3636
continuous_dist_type: Type[PolicyOutput],
37+
discrete_dist_type: Type[PolicyOutput] = Discrete,
3738
):
3839
super().__init__()
3940
self.state_dim = state_dim
@@ -42,7 +43,7 @@ def __init__(
4243
self.gammas = gammas
4344
self.num_gammas = len(self.gammas)
4445
# determine policy output
45-
dist_type = Discrete if self.discrete else continuous_dist_type
46+
dist_type = discrete_dist_type if self.discrete else continuous_dist_type
4647
self.policy_dist = dist_type(d_action=self.action_dim)
4748
assert isinstance(self.policy_dist, PolicyOutput)
4849
assert self.policy_dist.is_discrete == self.discrete
@@ -116,13 +117,15 @@ def __init__(
116117
activation: str = "leaky_relu",
117118
dropout_p: float = 0.0,
118119
continuous_dist_type: Type[PolicyOutput] = TanhGaussian,
120+
discrete_dist_type: Type[PolicyOutput] = Discrete,
119121
):
120122
super().__init__(
121123
state_dim=state_dim,
122124
action_dim=action_dim,
123125
discrete=discrete,
124126
gammas=gammas,
125127
continuous_dist_type=continuous_dist_type,
128+
discrete_dist_type=discrete_dist_type,
126129
)
127130
# build base network
128131
self.base = MLP(
@@ -184,13 +187,15 @@ def __init__(
184187
normalization: str = "layer",
185188
dropout_p: float = 0.0,
186189
continuous_dist_type: Type[PolicyOutput] = TanhGaussian,
190+
discrete_dist_type: Type[PolicyOutput] = Discrete,
187191
):
188192
super().__init__(
189193
state_dim=state_dim,
190194
action_dim=action_dim,
191195
discrete=discrete,
192196
gammas=gammas,
193197
continuous_dist_type=continuous_dist_type,
198+
discrete_dist_type=discrete_dist_type,
194199
)
195200
self.inp = MLP(
196201
d_inp=state_dim,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="amago",
5-
version="3.1.1",
5+
version="3.1.2",
66
author="Jake Grigsby",
77
author_email="grigsby@cs.utexas.edu",
88
license="MIT",

0 commit comments

Comments
 (0)