Skip to content

Commit

Permalink
use dqn without ppo
Browse files Browse the repository at this point in the history
  • Loading branch information
CircleCly committed Dec 13, 2023
1 parent 0956e23 commit cd36df6
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 17 deletions.
33 changes: 21 additions & 12 deletions rl-starter-files/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def similarity_bonus(llm_rsp, dqn_rsp):
algo = torch_ac.PPOAlgo(envs, acmodel, device, args.frames_per_proc, args.discount, args.lr, args.gae_lambda, args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence,
args.optim_eps, args.clip_eps, args.epochs, args.batch_size, preprocess_obss,
reshape_reward)
elif args.algo == "base":
algo = torch_ac.BaseAlgo(envs, acmodel, device, args.frames_per_proc, args.discount, args.lr, args.gae_lambda, args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence, preprocess_obss, None)
else:
raise ValueError("Incorrect algorithm name: {}".format(args.algo))

Expand Down Expand Up @@ -292,20 +294,26 @@ def similarity_bonus(llm_rsp, dqn_rsp):
data += rreturn_per_episode.values()
header += ["num_frames_" + key for key in num_frames_per_episode.keys()]
data += num_frames_per_episode.values()
header += ["entropy", "value", "policy_loss", "value_loss", "grad_norm"]
data += [logs["entropy"], logs["value"], logs["policy_loss"], logs["value_loss"], logs["grad_norm"]]

if args.use_dqn:
header += ["critic_loss", "q_values", "target_values"]
data += [logs["critic_loss"], logs["q_values"], logs["target_values"]]


if args.algo == "base":
txt_logger.info(
"U {} | F {:06} | FPS {:04.0f} | D {} | rR:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {} | H {:.3f} | V {:.3f} | pL {:.3f} | vL {:.3f} | ∇ {:.3f} | criticL {:.3f} | Q {:.3f} | targetQ {:.3f} "
"U {} | F {:06} | FPS {:04.0f} | D {} | rR:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {} "
.format(*data))
else:
txt_logger.info(
"U {} | F {:06} | FPS {:04.0f} | D {} | rR:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {} | H {:.3f} | V {:.3f} | pL {:.3f} | vL {:.3f} | ∇ {:.3f}"
.format(*data))
header += ["entropy", "value", "policy_loss", "value_loss", "grad_norm"]
data += [logs["entropy"], logs["value"], logs["policy_loss"], logs["value_loss"], logs["grad_norm"]]

if args.use_dqn:
header += ["critic_loss", "q_values", "target_values"]
data += [logs["critic_loss"], logs["q_values"], logs["target_values"]]

txt_logger.info(
"U {} | F {:06} | FPS {:04.0f} | D {} | rR:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {} | H {:.3f} | V {:.3f} | pL {:.3f} | vL {:.3f} | ∇ {:.3f} | criticL {:.3f} | Q {:.3f} | targetQ {:.3f} "
.format(*data))
else:
txt_logger.info(
"U {} | F {:06} | FPS {:04.0f} | D {} | rR:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {} | H {:.3f} | V {:.3f} | pL {:.3f} | vL {:.3f} | ∇ {:.3f}"
.format(*data))

header += ["return_" + key for key in return_per_episode.keys()]
data += return_per_episode.values()
Expand All @@ -324,8 +332,9 @@ def similarity_bonus(llm_rsp, dqn_rsp):
"num_frames": num_frames,
"update": update,
"model_state": acmodel.state_dict(),
"optimizer_state": algo.optimizer.state_dict()
}
if hasattr(algo, "optimizer"):
status["optimizer_state"] = algo.optimizer.state_dict()
if hasattr(preprocess_obss, "vocab"):
status["vocab"] = preprocess_obss.vocab.vocab
utils.save_status(status, model_dir)
Expand Down
3 changes: 3 additions & 0 deletions rl-starter-files/utils/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def get_actions(self, obss):
else:
actions = dist.sample()

if isinstance(self.acmodel, PlannerPolicy):
self.acmodel.decrease_cooldown()

return actions.cpu().numpy()

def get_action(self, obs):
Expand Down
2 changes: 1 addition & 1 deletion torch-ac/torch_ac/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from torch_ac.algos import A2CAlgo, PPOAlgo
from torch_ac.algos import A2CAlgo, PPOAlgo, BaseAlgo
from torch_ac.model import ACModel, RecurrentACModel
from torch_ac.utils import DictList
3 changes: 2 additions & 1 deletion torch-ac/torch_ac/algos/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from torch_ac.algos.a2c import A2CAlgo
from torch_ac.algos.ppo import PPOAlgo
from torch_ac.algos.ppo import PPOAlgo
from torch_ac.algos.base import BaseAlgo
5 changes: 2 additions & 3 deletions torch-ac/torch_ac/algos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,5 @@ def collect_experiences(self, replay_buffer=None):

return exps, logs

@abstractmethod
def update_parameters(self):
pass
def update_parameters(self, *args, **kwargs):
return {}

0 comments on commit cd36df6

Please sign in to comment.