Skip to content

Commit

Permalink
Fixed critic baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
wouterkool committed Apr 9, 2018
1 parent 081bc04 commit 5c0a1c8
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def eval(self, x, c):
return v.detach(), F.mse_loss(v, c.detach())

def get_learnable_parameters(self):
return self.critic.parameters()
return list(self.critic.parameters())

def epoch_callback(self, model, epoch):
pass
Expand Down
4 changes: 2 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def maybe_cuda_model(model, cuda, parallel=True):
baseline = CriticBaseline(
maybe_cuda_model(
CriticNetwork(
problem.INPUT_DIM,
problem.NODE_DIM,
opts.embedding_dim,
opts.hidden_dim,
opts.n_encode_layers,
Expand All @@ -96,7 +96,7 @@ def maybe_cuda_model(model, cuda, parallel=True):

# Initialize optimizer
optimizer = optim.Adam(
[{'params': model.parameters(), 'lr': float(opts.lr_model)}] + baseline.get_learnable_parameters()
[{'params': model.parameters(), 'lr': float(opts.lr_model)}]
+ (
[{'params': baseline.get_learnable_parameters(), 'lr': float(opts.lr_critic)}]
if len(baseline.get_learnable_parameters()) > 0
Expand Down

0 comments on commit 5c0a1c8

Please sign in to comment.