diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index d33904fd..69da6f96 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -75,12 +75,12 @@ def __init__( remote_manager_rank: int | None = None, # Scoring config w_retained: float = 1.0, - w_novelty: float = 0.1, + w_novelty: float = 1.0, w_reward: float = 0.0, w_mode_bonus: float = 10.0, p_norm_novelty: float = 2.0, cdist_max_bytes: int = 268435456, - ema_decay: float = 0.98, + ema_decay: float = 0.90, ): super().__init__( env, @@ -483,7 +483,7 @@ def set_up_gflownet(args, env, preprocessor, agent_group_list, my_agent_group_id # Default (tests stable): on-policy, no noisy layers. # When --use_random_strategies is provided, sample a random initial strategy. if getattr(args, "use_random_strategies", False): - init_cfg = _sample_new_strategy( + cfg = _sample_new_strategy( args, agent_group_id=my_agent_group_id, iteration=0, @@ -491,16 +491,18 @@ def set_up_gflownet(args, env, preprocessor, agent_group_list, my_agent_group_id prev_temp=9999.0, prev_noisy=9999, ) - args.agent_epsilon = float(init_cfg.get("epsilon", 0.0)) - args.agent_temperature = float(init_cfg.get("temperature", 1.0)) - args.agent_n_noisy_layers = int(init_cfg.get("n_noisy_layers", 0)) - args.agent_noisy_std_init = float(init_cfg.get("noisy_std_init", 0.5)) else: - # Disable off-policy training. - args.agent_epsilon = 0.0 - args.agent_temperature = 1.0 - args.agent_n_noisy_layers = 0 - args.agent_noisy_std_init = 0.5 + cfg = { + "epsilon": 0.0, + "temperature": 1.0, + "n_noisy_layers": 0, + "noisy_std_init": 0.5, + } + + args.agent_epsilon = float(cfg.get("epsilon", 0.0)) + args.agent_temperature = float(cfg.get("temperature", 1.0)) + args.agent_n_noisy_layers = int(cfg.get("n_noisy_layers", 0)) + args.agent_noisy_std_init = float(cfg.get("noisy_std_init", 0.5)) # Depending on the loss, we may need several estimators: # one (forward only) for FM loss, @@ -508,13 +510,14 @@ def set_up_gflownet(args, env, preprocessor, agent_group_list, my_agent_group_id # three (forward, backward, logZ/logF) estimators for DB, TB. if args.loss == "FM": - return set_up_fm_gflownet( + gflownet = set_up_fm_gflownet( args, env, preprocessor, agent_group_list, my_agent_group_id, ) + return gflownet, cfg else: # We need a DiscretePFEstimator and a DiscretePBEstimator. pf_estimator, pb_estimator = set_up_pb_pf_estimators( @@ -528,13 +531,13 @@ def set_up_gflownet(args, env, preprocessor, agent_group_list, my_agent_group_id assert pb_estimator is not None if args.loss == "ModifiedDB": - return ModifiedDBGFlowNet(pf_estimator, pb_estimator) + return ModifiedDBGFlowNet(pf_estimator, pb_estimator), cfg elif args.loss == "TB": - return TBGFlowNet(pf=pf_estimator, pb=pb_estimator, init_logZ=0.0) + return TBGFlowNet(pf=pf_estimator, pb=pb_estimator, init_logZ=0.0), cfg elif args.loss == "ZVar": - return LogPartitionVarianceGFlowNet(pf=pf_estimator, pb=pb_estimator) + return LogPartitionVarianceGFlowNet(pf=pf_estimator, pb=pb_estimator), cfg elif args.loss in ("DB", "SubTB"): # We also need a LogStateFlowEstimator. @@ -548,19 +551,21 @@ def set_up_gflownet(args, env, preprocessor, agent_group_list, my_agent_group_id ) if args.loss == "DB": - return DBGFlowNet( + gflownet = DBGFlowNet( pf=pf_estimator, pb=pb_estimator, logF=logF_estimator, ) + return gflownet, cfg elif args.loss == "SubTB": - return SubTBGFlowNet( + gflownet = SubTBGFlowNet( pf=pf_estimator, pb=pb_estimator, logF=logF_estimator, weighting=args.subTB_weighting, lamda=args.subTB_lambda, ) + return gflownet, cfg def plot_results(env, gflownet, l1_distances, validation_steps): @@ -731,22 +736,34 @@ def main(args) -> dict: # noqa: C901 group_name = wandb.util.generate_id() wandb.init( - project=args.wandb_project, group=group_name, entity=args.wandb_entity + project=args.wandb_project, + group=group_name, + entity=args.wandb_entity, + config=vars(args), ) - wandb.config.update(args) # Initialize the preprocessor. preprocessor = KHotPreprocessor(height=args.height, ndim=args.ndim) + model_builder_count = 0 # Builder closure to create a fresh model + optimizer (used by spawn policy as well) def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: - model = set_up_gflownet( + nonlocal model_builder_count, use_wandb + model_builder_count += 1 + + model, cfg = set_up_gflownet( args, env, preprocessor, distributed_context.agent_groups, distributed_context.agent_group_id, ) + if use_wandb: + import wandb + + wandb.log({"model_builder_count": model_builder_count, **cfg}) + else: + print(f"Model builder count: {model_builder_count}") assert model is not None model = model.to(device) optim = _make_optimizer_for(model, args) @@ -943,12 +960,13 @@ def cleanup(): timing, "averaging_model", enabled=args.timing ) as model_averaging_timer: if averaging_policy is not None: - assert score_dict is not None gflownet, optimizer, averaging_info = averaging_policy( iteration=iteration, model=gflownet, optimizer=optimizer, - local_metric=score_dict["score"], + local_metric=( + score_dict["score"] if score_dict is not None else -loss.item() + ), group=distributed_context.train_global_group, ) @@ -1407,7 +1425,7 @@ def cleanup(): parser.add_argument( "--performance_tracker_threshold", type=float, - default=100, + default=None, help="Threshold for the performance tracker. If None, the performance tracker is not triggered.", ) parser.add_argument(