Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 42 additions & 24 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ def __init__(
remote_manager_rank: int | None = None,
# Scoring config
w_retained: float = 1.0,
w_novelty: float = 0.1,
w_novelty: float = 1.0,
w_reward: float = 0.0,
w_mode_bonus: float = 10.0,
p_norm_novelty: float = 2.0,
cdist_max_bytes: int = 268435456,
ema_decay: float = 0.98,
ema_decay: float = 0.90,
):
super().__init__(
env,
Expand Down Expand Up @@ -483,38 +483,41 @@ def set_up_gflownet(args, env, preprocessor, agent_group_list, my_agent_group_id
# Default (tests stable): on-policy, no noisy layers.
# When --use_random_strategies is provided, sample a random initial strategy.
if getattr(args, "use_random_strategies", False):
init_cfg = _sample_new_strategy(
cfg = _sample_new_strategy(
args,
agent_group_id=my_agent_group_id,
iteration=0,
prev_eps=9999.0,
prev_temp=9999.0,
prev_noisy=9999,
)
args.agent_epsilon = float(init_cfg.get("epsilon", 0.0))
args.agent_temperature = float(init_cfg.get("temperature", 1.0))
args.agent_n_noisy_layers = int(init_cfg.get("n_noisy_layers", 0))
args.agent_noisy_std_init = float(init_cfg.get("noisy_std_init", 0.5))
else:
# Disable off-policy training.
args.agent_epsilon = 0.0
args.agent_temperature = 1.0
args.agent_n_noisy_layers = 0
args.agent_noisy_std_init = 0.5
cfg = {
"epsilon": 0.0,
"temperature": 1.0,
"n_noisy_layers": 0,
"noisy_std_init": 0.5,
}

args.agent_epsilon = float(cfg.get("epsilon", 0.0))
args.agent_temperature = float(cfg.get("temperature", 1.0))
args.agent_n_noisy_layers = int(cfg.get("n_noisy_layers", 0))
args.agent_noisy_std_init = float(cfg.get("noisy_std_init", 0.5))

# Depending on the loss, we may need several estimators:
# one (forward only) for FM loss,
# two (forward and backward) or other losses
# three (forward, backward, logZ/logF) estimators for DB, TB.

if args.loss == "FM":
return set_up_fm_gflownet(
gflownet = set_up_fm_gflownet(
args,
env,
preprocessor,
agent_group_list,
my_agent_group_id,
)
return gflownet, cfg
else:
# We need a DiscretePFEstimator and a DiscretePBEstimator.
pf_estimator, pb_estimator = set_up_pb_pf_estimators(
Expand All @@ -528,13 +531,13 @@ def set_up_gflownet(args, env, preprocessor, agent_group_list, my_agent_group_id
assert pb_estimator is not None

if args.loss == "ModifiedDB":
return ModifiedDBGFlowNet(pf_estimator, pb_estimator)
return ModifiedDBGFlowNet(pf_estimator, pb_estimator), cfg

elif args.loss == "TB":
return TBGFlowNet(pf=pf_estimator, pb=pb_estimator, init_logZ=0.0)
return TBGFlowNet(pf=pf_estimator, pb=pb_estimator, init_logZ=0.0), cfg

elif args.loss == "ZVar":
return LogPartitionVarianceGFlowNet(pf=pf_estimator, pb=pb_estimator)
return LogPartitionVarianceGFlowNet(pf=pf_estimator, pb=pb_estimator), cfg

elif args.loss in ("DB", "SubTB"):
# We also need a LogStateFlowEstimator.
Expand All @@ -548,19 +551,21 @@ def set_up_gflownet(args, env, preprocessor, agent_group_list, my_agent_group_id
)

if args.loss == "DB":
return DBGFlowNet(
gflownet = DBGFlowNet(
pf=pf_estimator,
pb=pb_estimator,
logF=logF_estimator,
)
return gflownet, cfg
elif args.loss == "SubTB":
return SubTBGFlowNet(
gflownet = SubTBGFlowNet(
pf=pf_estimator,
pb=pb_estimator,
logF=logF_estimator,
weighting=args.subTB_weighting,
lamda=args.subTB_lambda,
)
return gflownet, cfg


def plot_results(env, gflownet, l1_distances, validation_steps):
Expand Down Expand Up @@ -731,22 +736,34 @@ def main(args) -> dict: # noqa: C901
group_name = wandb.util.generate_id()

wandb.init(
project=args.wandb_project, group=group_name, entity=args.wandb_entity
project=args.wandb_project,
group=group_name,
entity=args.wandb_entity,
config=vars(args),
)
wandb.config.update(args)

# Initialize the preprocessor.
preprocessor = KHotPreprocessor(height=args.height, ndim=args.ndim)
model_builder_count = 0

# Builder closure to create a fresh model + optimizer (used by spawn policy as well)
def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]:
model = set_up_gflownet(
nonlocal model_builder_count, use_wandb
model_builder_count += 1

model, cfg = set_up_gflownet(
args,
env,
preprocessor,
distributed_context.agent_groups,
distributed_context.agent_group_id,
)
if use_wandb:
import wandb

wandb.log({"model_builder_count": model_builder_count, **cfg})
else:
print(f"Model builder count: {model_builder_count}")
assert model is not None
model = model.to(device)
optim = _make_optimizer_for(model, args)
Expand Down Expand Up @@ -943,12 +960,13 @@ def cleanup():
timing, "averaging_model", enabled=args.timing
) as model_averaging_timer:
if averaging_policy is not None:
assert score_dict is not None
gflownet, optimizer, averaging_info = averaging_policy(
iteration=iteration,
model=gflownet,
optimizer=optimizer,
local_metric=score_dict["score"],
local_metric=(
score_dict["score"] if score_dict is not None else -loss.item()
),
group=distributed_context.train_global_group,
)

Expand Down Expand Up @@ -1407,7 +1425,7 @@ def cleanup():
parser.add_argument(
"--performance_tracker_threshold",
type=float,
default=100,
default=None,
help="Threshold for the performance tracker. If None, the performance tracker is not triggered.",
)
parser.add_argument(
Expand Down