-
Notifications
You must be signed in to change notification settings - Fork 58
[WIP] Adding MultiAgent Utilities #323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 16 commits
1d49049
ef4a179
2ecd086
53450a8
274aff9
38f95f0
0927001
daa8b2a
44db72e
4ef8f48
d8cf1a9
8d2cf06
1365585
5067e42
6f0563e
5061abe
e6a378c
915d19d
b0b5025
8cc732b
b8f7f6a
e50e230
835819e
793c045
2635fd5
cd87506
65b6520
3d01b85
a62c100
841ff66
2be8df5
ac9b5a8
10282f0
79b531b
a3885a0
e3dc677
4c2ad51
43554e4
6828e93
eac920c
194065f
c0198bc
fe40835
a50204a
4a3cd74
602a7b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,13 +2,14 @@ | |
|
|
||
| import torch # noqa | ||
| import torch.nn as nn # noqa | ||
| import torch.nn.functional as F | ||
| from gym import spaces | ||
| from torch.distributions import Categorical, Normal | ||
|
|
||
| from genrl.core.base import BaseActorCritic | ||
| from genrl.core.policies import MlpPolicy | ||
| from genrl.core.values import MlpValue | ||
| from genrl.utils.utils import cnn | ||
| from genrl.utils.utils import cnn, shared_mlp | ||
|
|
||
|
|
||
| class MlpActorCritic(BaseActorCritic): | ||
|
|
@@ -216,10 +217,119 @@ def get_value(self, inp: torch.Tensor) -> torch.Tensor: | |
| return value | ||
|
|
||
|
|
||
| class SharedActorCritic(BaseActorCritic): | ||
AdityaKapoor74 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def __init__( | ||
| self, | ||
| critic_prev, | ||
| actor_prev, | ||
| shared, | ||
| critic_post, | ||
| actor_post, | ||
| weight_init, | ||
| activation_func, | ||
| ): | ||
| super(SharedActorCritic, self).__init__() | ||
|
|
||
| self.critic, self.actor = shared_mlp( | ||
| critic_prev, | ||
| actor_prev, | ||
| shared, | ||
| critic_post, | ||
| actor_post, | ||
| weight_init, | ||
| activation_func, | ||
| False, | ||
| ) | ||
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
|
||
| def forward(self, state_critic, state_action): | ||
|
|
||
| if state_critic is not None: | ||
| return self.critic(state_critic) | ||
|
|
||
| if state_action is not None: | ||
| return self.actor(state_action) | ||
|
|
||
| def get_action(self, state, deterministic=False): | ||
| # state = torch.FloatTensor(state).to(self.device) | ||
| logits = self.forward(None, state) | ||
|
|
||
| dist = F.softmax(logits, dim=0) | ||
| probs = Categorical(dist) | ||
| if deterministic: | ||
| index = torch.argmax(probs) | ||
| else: | ||
| index = probs.sample().cpu().detach().item() | ||
| return index | ||
|
|
||
| def get_value(self, state): | ||
| # state = torch.FloatTensor(state).to(self.device) | ||
| value = self.forward(state, None) | ||
| return value | ||
|
|
||
|
|
||
| class Actor(MlpPolicy): | ||
AdityaKapoor74 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def __init__( | ||
| self, | ||
| state_dim: spaces.Space, | ||
| action_dim: spaces.Space, | ||
| hidden: Tuple = (32, 32), | ||
| discrete: bool = True, | ||
| **kwargs, | ||
| ): | ||
| super(Actor, self).__init__(state_dim, action_dim, hidden, discrete ** kwargs) | ||
|
|
||
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
|
||
| def forward(self, state): | ||
| state = self.model(state) | ||
| return state | ||
|
|
||
| def get_action(self, state, deterministic=False): | ||
| # state = torch.FloatTensor(state).to(self.device) | ||
| logits = self.forward(state) | ||
|
|
||
| dist = F.softmax(logits, dim=0) | ||
| probs = Categorical(dist) | ||
| if deterministic: | ||
| index = torch.argmax(probs) | ||
| else: | ||
| index = probs.sample().cpu().detach().item() | ||
| return index | ||
|
|
||
|
|
||
| class Critic(MlpValue): | ||
| def __init__( | ||
| self, | ||
| state_dim: spaces.Space, | ||
| action_dim: spaces.Space, | ||
| fc_layers: Tuple = (32, 32), | ||
| val_type: str = "V", | ||
| **kwargs, | ||
| ): | ||
| super(Critic, self).__init__( | ||
| state_dim, action_dim, fc_layers, val_type, **kwargs | ||
| ) | ||
|
|
||
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
|
||
| def forward(self, state): | ||
|
|
||
| state = self.model(state) | ||
|
|
||
| return state | ||
|
|
||
| def get_value(self, state): | ||
| # state = torch.FloatTensor(state).to(self.device) | ||
| value = self.forward(state) | ||
| return value | ||
|
||
|
|
||
|
|
||
| actor_critic_registry = { | ||
| "mlp": MlpActorCritic, | ||
| "cnn": CNNActorCritic, | ||
| "mlp12": MlpSingleActorMultiCritic, | ||
| "mlpshared": SharedActorCritic, | ||
| } | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its very risky to change the versions here because of compatibility. Can you revert these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I change that I get some errors. Any other way to rectify it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you changed the isort.cfg file?