diff --git a/src/gfn/gym/diffusion_sampling.py b/src/gfn/gym/diffusion_sampling.py index 4eb50ea3..93a83882 100644 --- a/src/gfn/gym/diffusion_sampling.py +++ b/src/gfn/gym/diffusion_sampling.py @@ -27,6 +27,9 @@ # - Exit action trigger: t + dt >= 1.0 - dt * TERMINAL_TIME_EPS (next step reaches terminal) TERMINAL_TIME_EPS = 1e-2 +# Default output directory for saving visualizations +OUTPUT_DIR = "output" + ############################### ### Target energy functions ### @@ -407,8 +410,8 @@ def visualize( if show: plt.show() else: - os.makedirs("output", exist_ok=True) - plt.savefig(f"output/{prefix}simple_gmm.png") + os.makedirs(OUTPUT_DIR, exist_ok=True) + plt.savefig(f"{OUTPUT_DIR}/{prefix}simple_gmm.png") plt.close() @@ -479,8 +482,8 @@ def visualize( if show: plt.show() else: - os.makedirs("output", exist_ok=True) - fig.savefig(f"output/{prefix}gmm25.png") + os.makedirs(OUTPUT_DIR, exist_ok=True) + fig.savefig(f"{OUTPUT_DIR}/{prefix}gmm25.png") plt.close() @@ -565,8 +568,8 @@ def visualize( if show: plt.show() else: - os.makedirs("output", exist_ok=True) - fig.savefig(f"output/{prefix}posterior9of25.png") + os.makedirs(OUTPUT_DIR, exist_ok=True) + fig.savefig(f"{OUTPUT_DIR}/{prefix}posterior9of25.png") plt.close() @@ -670,8 +673,8 @@ def visualize( if show: plt.show() else: - os.makedirs("output", exist_ok=True) - fig.savefig(f"output/{prefix}funnel.png") + os.makedirs(OUTPUT_DIR, exist_ok=True) + fig.savefig(f"{OUTPUT_DIR}/{prefix}funnel.png") plt.close() @@ -830,8 +833,8 @@ def visualize( if show: plt.show() else: - os.makedirs("output", exist_ok=True) - fig.savefig(f"output/{prefix}manywell.png") + os.makedirs(OUTPUT_DIR, exist_ok=True) + fig.savefig(f"{OUTPUT_DIR}/{prefix}manywell.png") plt.close()