Skip to content

Commit 0eb9ccb

Browse files
committed
fix: save full rl checkpoint & merge comet experiments into one
1 parent 1d4f229 commit 0eb9ccb

15 files changed

Lines changed: 313 additions & 147 deletions

TrackToLearn/algorithms/sac_auto.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from TrackToLearn.algorithms.sac import SAC
1010
from TrackToLearn.algorithms.shared.offpolicy import SACActorCritic
1111
from TrackToLearn.algorithms.shared.replay import OffPolicyReplayBuffer
12-
from TrackToLearn.utils.torch_utils import get_device
12+
from TrackToLearn.utils.torch_utils import get_device, gradients_norm
1313
from TrackToLearn.algorithms.shared.kl import AdaptiveKLController, FixedKLController
1414

1515
LOG_STD_MAX = 2
@@ -184,6 +184,10 @@ def load_checkpoint(self, checkpoint_file: str):
184184
self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer'])
185185
self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer'])
186186
self.alpha_optimizer.load_state_dict(checkpoint['alpha_optimizer'])
187+
if checkpoint.get('replay_buffer', None) is not None:
188+
self.replay_buffer.load_state_dict(checkpoint['replay_buffer'])
189+
if checkpoint.get('log_alpha', None) is not None:
190+
self.log_alpha = checkpoint['log_alpha']
187191

188192
def save_checkpoint(self, checkpoint_file: str, **extra_info):
189193
"""
@@ -200,6 +204,8 @@ def save_checkpoint(self, checkpoint_file: str, **extra_info):
200204
'actor_optimizer': self.actor_optimizer.state_dict(),
201205
'critic_optimizer': self.critic_optimizer.state_dict(),
202206
'alpha_optimizer': self.alpha_optimizer.state_dict(),
207+
'replay_buffer': self.replay_buffer.state_dict(),
208+
'log_alpha': self.log_alpha,
203209
**extra_info
204210
}
205211

@@ -273,18 +279,6 @@ def update(
273279
# Total critic loss
274280
critic_loss = loss_q1 + loss_q2
275281

276-
losses = {
277-
# 'actor_loss': actor_loss.detach(),
278-
# 'alpha_loss': alpha_loss.detach(),
279-
# 'critic_loss': critic_loss.detach(),
280-
# 'loss_q1': loss_q1.detach(),
281-
# 'loss_q2': loss_q2.detach(),
282-
# 'entropy': alpha.detach(),
283-
# 'Q1': current_Q1.mean().detach(),
284-
# 'Q2': current_Q2.mean().detach(),
285-
# 'backup': backup.mean().detach(),
286-
}
287-
288282
# Optimize the temperature
289283
self.alpha_optimizer.zero_grad()
290284
alpha_loss.backward()
@@ -313,7 +307,30 @@ def update(
313307
self.target.actor.parameters()
314308
):
315309
target_param.data.copy_(
316-
self.tau * param.data + (1 - self.tau) * target_param.data
317-
)
310+
self.tau * param.data + (1 - self.tau) * target_param.data)
311+
312+
# Compute the norm of the gradients to plot.
313+
alpha_norm = self.log_alpha.grad.norm(2).cpu().detach().numpy()
314+
critic_norm = gradients_norm(self.agent.critic)
315+
actor_norm = gradients_norm(self.agent.actor)
316+
317+
# print("alpha_norm: ", type(alpha_norm))
318+
# print("critic_norm: ", type(critic_norm))
319+
# print("actor_norm: ", type(actor_norm))
320+
321+
losses = {
322+
# 'actor_loss': actor_loss.detach(),
323+
# 'alpha_loss': alpha_loss.detach(),
324+
# 'critic_loss': critic_loss.detach(),
325+
# 'loss_q1': loss_q1.detach(),
326+
# 'loss_q2': loss_q2.detach(),
327+
# 'entropy': alpha.detach(),
328+
# 'Q1': current_Q1.mean().detach(),
329+
# 'Q2': current_Q2.mean().detach(),
330+
# 'backup': backup.mean().detach(),
331+
"alpha_norm": alpha_norm,
332+
"critic_norm": critic_norm,
333+
"actor_norm": actor_norm,
334+
}
318335

319336
return losses

TrackToLearn/algorithms/shared/replay.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from TrackToLearn.utils.torch_utils import get_device, get_device_str
1111

1212
device = get_device()
13+
rb_type = torch.float32
1314

1415
class OffPolicyReplayBuffer(object):
1516
""" Replay buffer to store transitions. Implemented in a "ring-buffer"
@@ -40,16 +41,17 @@ def __init__(
4041

4142

4243
self.state = torch.zeros(
43-
(self.max_size, state_dim), dtype=torch.float32)
44+
(self.max_size, state_dim), dtype=rb_type)
4445
self.action = torch.zeros(
45-
(self.max_size, action_dim), dtype=torch.float32)
46+
(self.max_size, action_dim), dtype=rb_type)
4647
self.next_state = torch.zeros(
47-
(self.max_size, state_dim), dtype=torch.float32)
48+
(self.max_size, state_dim), dtype=rb_type)
4849
self.reward = torch.zeros(
49-
(self.max_size, 1), dtype=torch.float32)
50+
(self.max_size, 1), dtype=rb_type)
5051
self.not_done = torch.zeros(
51-
(self.max_size, 1), dtype=torch.float32)
52-
52+
(self.max_size, 1), dtype=rb_type)
53+
54+
def _pin_to_memory(self):
5355
if get_device_str() == "cuda":
5456
self.state = self.state.pin_memory()
5557
self.action = self.action.pin_memory()
@@ -162,6 +164,28 @@ def load_from_file(self, path):
162164
"""
163165
pass
164166

167+
def state_dict(self):
168+
size = self.size
169+
return {
170+
"state": self.state[:size],
171+
"action": self.action[:size],
172+
"next_state": self.next_state[:size],
173+
"reward": self.reward[:size],
174+
"not_done": self.not_done[:size],
175+
"ptr": self.ptr,
176+
"size": self.size
177+
}
178+
179+
def load_state_dict(self, state_dict):
180+
self.size = state_dict["size"]
181+
self.ptr = state_dict["ptr"]
182+
183+
self.state[:self.size] = state_dict["state"]
184+
self.action[:self.size] = state_dict["action"]
185+
self.next_state[:self.size] = state_dict["next_state"]
186+
self.reward[:self.size] = state_dict["reward"]
187+
self.not_done[:self.size] = state_dict["not_done"]
188+
165189
class OnPolicyReplayBuffer(object):
166190
""" Replay buffer to store transitions. Efficiency could probably be
167191
improved.

TrackToLearn/experiment/experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def setup_comet(self, prefix=''):
8383
"""
8484
# The comet object that will handle monitors
8585
self.comet_monitor = CometMonitor(
86-
self.comet_experiment, self.name, self.experiment_path,
86+
self.comet_experiment, self.experiment_path,
8787
prefix, use_comet=self.use_comet)
8888
print(self.hyperparameters)
8989
self.comet_monitor.log_parameters(self.hyperparameters)

TrackToLearn/runners/tractoracle_predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ def test(self):
6767
parse_args=False,
6868
auto_metric_logging=False,
6969
disabled=True)
70+
oracle_experiment.set_name(self.id)
7071

7172
print("Done.")
7273

7374
oracle_trainer = OracleTrainer(
7475
oracle_experiment,
75-
self.id,
7676
root_dir,
7777
self.oracle_train_steps,
7878
enable_checkpointing=True,

TrackToLearn/trainers/oracle/oracle_monitor.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

TrackToLearn/trainers/oracle/oracle_trainer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn as nn
33
from TrackToLearn.oracles.transformer_oracle import LightningLikeModule
4-
from TrackToLearn.trainers.oracle.oracle_monitor import OracleMonitor
4+
from TrackToLearn.utils.comet_monitor import OracleMonitor
55
from TrackToLearn.utils.torch_utils import get_device
66
from TrackToLearn.algorithms.shared.utils import \
77
(add_item_to_means, mean_losses, add_losses, get_mean_item)
@@ -68,7 +68,6 @@ def reset(self):
6868
class OracleTrainer(object):
6969
def __init__(self,
7070
experiment,
71-
experiment_id,
7271
saving_path,
7372
max_epochs,
7473
use_comet=True,
@@ -81,7 +80,6 @@ def __init__(self,
8180
metrics_prefix=None,
8281
):
8382
self.experiment = experiment
84-
self.experiment_id = experiment_id
8583
self.saving_path = saving_path
8684

8785
self.auto_checkpointing_enabled = enable_auto_checkpointing
@@ -95,7 +93,6 @@ def __init__(self,
9593
self.hooks_manager = HooksManager(OracleHookEvent)
9694
self.oracle_monitor = OracleMonitor(
9795
experiment=self.experiment,
98-
experiment_id=self.experiment_id,
9996
use_comet=use_comet,
10097
metrics_prefix=metrics_prefix
10198
)
@@ -109,7 +106,6 @@ def save_hyperparameters(self):
109106

110107
hyperparameters = self.oracle_model.hyperparameters
111108
hyperparameters.update({
112-
'experiment_id': self.experiment_id,
113109
'saving_path': self.saving_path,
114110
'max_epochs': self.max_epochs,
115111
'val_interval': self.val_interval,

TrackToLearn/trainers/rlhf_train.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ def __init__(
3939
self,
4040
rlhf_train_dto: dict,
4141
trainer_cls: TrackToLearnTraining,
42-
agent_experiment: CometExperiment = None,
43-
oracle_experiment: CometExperiment = None
42+
comet_experiment: CometExperiment = None
4443
):
4544
# Only load the parameters from the parent instead of calling
4645
# the full constructor twice. (As we call it for the agent_trainer
@@ -73,8 +72,14 @@ def __init__(
7372
self.agent_train_steps = rlhf_train_dto['agent_train_steps']
7473
self.num_workers = rlhf_train_dto['num_workers']
7574
self.rlhf_inter_npv = rlhf_train_dto['rlhf_inter_npv']
75+
7676
self.disable_oracle_training = rlhf_train_dto.get(
7777
'disable_oracle_training', False)
78+
if self.disable_oracle_training:
79+
LOGGER.warning("Oracle training is disabled. The dataset will "
80+
"be augmented to evaluate the oracles during the "
81+
"agent's training.")
82+
7883
self.batch_size = rlhf_train_dto['batch_size']
7984
self.oracle_batch_size = rlhf_train_dto['oracle_batch_size']
8085
grad_accumulation_steps = rlhf_train_dto.get(
@@ -91,15 +96,15 @@ def __init__(
9196

9297
################################################
9398
# Start by initializing the agent trainer. #
94-
if agent_experiment is None:
95-
agent_experiment = CometExperiment(project_name=self.experiment,
99+
if comet_experiment is None:
100+
comet_experiment = CometExperiment(project_name=self.experiment,
96101
workspace=rlhf_train_dto['workspace'], parse_args=False,
97102
auto_metric_logging=False,
98103
disabled=not self.use_comet)
99104

100-
agent_experiment.set_name(self.name)
105+
comet_experiment.set_name(self.name)
101106

102-
self.agent_trainer: TrackToLearnTraining = trainer_cls(rlhf_train_dto, agent_experiment)
107+
self.agent_trainer: TrackToLearnTraining = trainer_cls(rlhf_train_dto, comet_experiment)
103108
_ = self.agent_trainer.setup_environment_and_info()
104109
self.get_alg = self.agent_trainer.get_alg
105110

@@ -110,17 +115,6 @@ def __init__(
110115

111116
################################################
112117
# Continue by initializing the oracle trainer. #
113-
# Need this to avoid erasing the RL agent's experiment
114-
# when creating a new one.
115-
if oracle_experiment is None:
116-
comet_ml.config.set_global_experiment(None)
117-
oracle_experiment = CometExperiment(project_name="TractOracleRLHF",
118-
workspace=rlhf_train_dto['workspace'], parse_args=False,
119-
auto_metric_logging=False,
120-
disabled=not self.use_comet)
121-
122-
oracle_experiment_id = '-'.join([self.experiment, self.name])
123-
124118
dataset_to_augment = rlhf_train_dto.get('dataset_to_augment', None)
125119
self.dataset_manager = StreamlineDatasetManager(saving_path=self.oracle_training_dir,
126120
dataset_to_augment_path=dataset_to_augment,
@@ -130,8 +124,7 @@ def __init__(
130124
# because we will want to save the checkpoints only when we improve the
131125
# total agent. We manually checkpoint those oracles instead.
132126
self.oracle_reward_trainer = OracleTrainer(
133-
oracle_experiment,
134-
oracle_experiment_id,
127+
comet_experiment,
135128
self.oracle_training_dir,
136129
self.oracle_train_steps,
137130
enable_auto_checkpointing=False,
@@ -143,8 +136,7 @@ def __init__(
143136
)
144137

145138
self.oracle_crit_trainer = OracleTrainer(
146-
oracle_experiment,
147-
oracle_experiment_id,
139+
comet_experiment,
148140
self.oracle_training_dir,
149141
self.oracle_train_steps,
150142
enable_auto_checkpointing=False,
@@ -293,10 +285,7 @@ def rl_train(
293285
while i < max_ep:
294286
self.start_finetuning_epoch(i, do_warmup)
295287

296-
if self.disable_oracle_training:
297-
LOGGER.info("Oracle training is disabled. Only the agent will be trained and the dataset will not be augmented.\n",
298-
"This is equivalent to just training the agent for an additional {} ({} x {}) epochs.".format(self.agent_train_steps*max_ep, max_ep, self.agent_train_steps))
299-
elif not do_warmup:
288+
if not do_warmup:
300289
total_added = 0
301290

302291
with tqdm(total=self.nb_new_streamlines_per_iter,
@@ -353,9 +342,9 @@ def rl_train(
353342
prettier_dict(data_stats, title="Dataset stats (iter {})".format(i)))
354343

355344
# Train reward model
356-
LOGGER.info("Training reward model...")
357-
self.train_reward()
358-
self.train_stopping_criterion()
345+
if not self.disable_oracle_training:
346+
self.train_reward()
347+
self.train_stopping_criterion()
359348

360349
# Train the RL agent
361350
agent_nb_steps = self.agent_train_steps if not do_warmup else self.warmup_agent_steps
@@ -369,14 +358,16 @@ def rl_train(
369358
max_ep=agent_nb_steps,
370359
starting_ep=current_ep,
371360
save_model_dir=self.model_dir,
372-
test_before_training=False
361+
test_before_training=do_warmup or i == 0
373362
)
374-
current_ep += self.agent_train_steps
375363

376364
self.end_finetuning_epoch(i, do_warmup)
377365

378-
if not do_warmup:
366+
if do_warmup:
367+
current_ep += self.warmup_agent_steps
368+
else:
379369
self.backuper.backup(step=i)
370+
current_ep += self.agent_train_steps
380371
i += 1
381372
do_warmup = False
382373

0 commit comments

Comments
 (0)