You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# because we will want to save the checkpoints only when we improve the
131
125
# total agent. We manually checkpoint those oracles instead.
132
126
self.oracle_reward_trainer=OracleTrainer(
133
-
oracle_experiment,
134
-
oracle_experiment_id,
127
+
comet_experiment,
135
128
self.oracle_training_dir,
136
129
self.oracle_train_steps,
137
130
enable_auto_checkpointing=False,
@@ -143,8 +136,7 @@ def __init__(
143
136
)
144
137
145
138
self.oracle_crit_trainer=OracleTrainer(
146
-
oracle_experiment,
147
-
oracle_experiment_id,
139
+
comet_experiment,
148
140
self.oracle_training_dir,
149
141
self.oracle_train_steps,
150
142
enable_auto_checkpointing=False,
@@ -293,10 +285,7 @@ def rl_train(
293
285
whilei<max_ep:
294
286
self.start_finetuning_epoch(i, do_warmup)
295
287
296
-
ifself.disable_oracle_training:
297
-
LOGGER.info("Oracle training is disabled. Only the agent will be trained and the dataset will not be augmented.\n",
298
-
"This is equivalent to just training the agent for an additional {} ({} x {}) epochs.".format(self.agent_train_steps*max_ep, max_ep, self.agent_train_steps))
0 commit comments