diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..1081797c --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "benchmark/gflownet"] + path = benchmark/gflownet + url = https://github.com/alexhernandezgarcia/gflownet.git +[submodule "benchmark/gfnx"] + path = benchmark/gfnx + url = https://github.com/d-tiapkin/gfnx.git diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 00000000..d7f923e8 --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,174 @@ +# GFlowNet Library Benchmarks + +This benchmark compares three GFlowNet libraries across multiple environments: +- **torchgfn** - PyTorch-based (this repository) +- **gflownet** - PyTorch-based, Hydra configuration system +- **gfnx** - JAX/Equinox-based + +## Setup + +Initialize the benchmark dependencies: + +```bash +git submodule update --init --recursive +``` + +## Environment Support Matrix + +Not all libraries support all environments: + +| Environment | torchgfn | gflownet | gfnx | Description | +|-------------|:--------:|:--------:|:----:|-------------| +| hypergrid | ✓ | ✓ | ✓ | Discrete grid navigation | +| ising | ✓ | ✓ | ✓ | Discrete Ising model | +| box | ✓ | ✓ | - | Continuous 2D box | +| bitseq | ✓ | - | ✓ | Bit sequence generation | + +## Available Scenarios + +### Hypergrid (all libraries) +- `tb_hypergrid_small` - 2D grid, height 8, 1000 iterations (quick test) +- `tb_hypergrid_medium` - 4D grid, height 16, 2000 iterations +- `tb_hypergrid_large` - 4D grid, height 32, 5000 iterations + +### Ising (all libraries) +- `tb_ising_6x6` - 6x6 lattice (36 spins), 1000 iterations +- `tb_ising_10x10` - 10x10 lattice (100 spins), 2000 iterations + +### Box/CCube (torchgfn, gflownet) +- `tb_box_2d` - 2D continuous box, delta=0.25, 1000 iterations + +### BitSequence (torchgfn, gfnx) +- `tb_bitseq_small` - word_size=1, seq_size=4, 2 modes, 1000 iterations +- `tb_bitseq_medium` - word_size=2, seq_size=8, 4 modes, 2000 iterations + +## Usage + +### Quick Test + +```bash +# Test with just torchgfn first (least dependencies) +python benchmark/benchmark_libraries.py --scenario tb_hypergrid_small --libraries torchgfn --seeds 0 +``` + +### Running Benchmarks + +**Recommended approach (run each library separately to avoid OpenMP conflicts):** + +```bash +# Hypergrid - all libraries +python benchmark/benchmark_libraries.py --scenario tb_hypergrid_small --libraries torchgfn --seeds 0 1 2 +python benchmark/benchmark_libraries.py --scenario tb_hypergrid_small --libraries gflownet --seeds 0 1 2 +python benchmark/benchmark_libraries.py --scenario tb_hypergrid_small --libraries gfnx --seeds 0 1 2 + +# Ising - all libraries +python benchmark/benchmark_libraries.py --scenario tb_ising_6x6 --libraries torchgfn --seeds 0 1 2 +python benchmark/benchmark_libraries.py --scenario tb_ising_6x6 --libraries gflownet --seeds 0 1 2 +python benchmark/benchmark_libraries.py --scenario tb_ising_6x6 --libraries gfnx --seeds 0 1 2 + +# Box - torchgfn and gflownet only +python benchmark/benchmark_libraries.py --scenario tb_box_2d --libraries torchgfn --seeds 0 1 2 +python benchmark/benchmark_libraries.py --scenario tb_box_2d --libraries gflownet --seeds 0 1 2 + +# BitSequence - torchgfn and gfnx only +python benchmark/benchmark_libraries.py --scenario tb_bitseq_small --libraries torchgfn --seeds 0 1 2 +python benchmark/benchmark_libraries.py --scenario tb_bitseq_small --libraries gfnx --seeds 0 1 2 +``` + +**Run all supported libraries for an environment (automatic filtering):** + +```bash +# The script automatically selects supported libraries if --libraries is omitted +python benchmark/benchmark_libraries.py --scenario tb_box_2d --seeds 0 1 2 +# Will run torchgfn and gflownet only (gfnx doesn't support box) +``` + +Results are saved with library names in the filename (e.g., `benchmark_tb_ising_6x6_torchgfn_20231218_143052.json`). + +## Important Implementation Differences + +### Ising Environment + +| Library | Environment Class | Loss Type | Notes | +|---------|-------------------|-----------|-------| +| torchgfn | `DiscreteEBM` + `IsingModel` | Flow Matching | Uses coupling matrix J with periodic boundary conditions | +| gflownet | `ising` env | Trajectory Balance | Configured via Hydra with uniform proxy | +| gfnx | `IsingEnvironment` | Trajectory Balance | Uses `IsingRewardModule` | + +### Box/CCube Environment + +| Library | Environment Class | Policy Type | Notes | +|---------|-------------------|-------------|-------| +| torchgfn | `Box` | `BoxPFEstimator`/`BoxPBEstimator` | Continuous Beta mixture policies | +| gflownet | `ccube` | MLP with continuous actions | Configured via Hydra with corners proxy | + +### BitSequence Environment + +| Library | Environment Class | Loss Type | Notes | +|---------|-------------------|-----------|-------| +| torchgfn | `BitSequence` | Trajectory Balance | Uses shared trunk for PF/PB | +| gfnx | `BitseqEnvironment` | Trajectory Balance | Uses `BitseqRewardModule` with configurable modes | + +### General Differences + +- **torchgfn**: Pure PyTorch, imperative style, flexible API +- **gflownet**: PyTorch with Hydra configuration, more complex setup but highly configurable +- **gfnx**: JAX/Equinox, functional style, JIT compilation for performance + +## macOS OpenMP Conflict + +On macOS, conda/pip environments often have multiple copies of `libomp.dylib` from different sources (e.g., `llvm-openmp` from conda, plus bundled copies in PyTorch, scikit-learn, etc.). This causes an error: + +``` +OMP: Error #15: Initializing libomp.dylib, but found libomp.dylib already initialized. +... +Abort trap: 6 +``` + +**Cause:** Mixed conda/pip installations result in multiple OpenMP runtime libraries. Common sources include: +- `llvm-openmp` or `libopenblas` from conda +- PyTorch's bundled `libomp.dylib` +- scikit-learn's bundled `libomp.dylib` + +**Solution:** The benchmark script automatically sets `KMP_DUPLICATE_LIB_OK=TRUE` at startup to work around this conflict. + +**Note on benchmark accuracy:** While `KMP_DUPLICATE_LIB_OK=TRUE` can theoretically cause issues, in practice: +- All libraries use the same workaround, ensuring fair comparison +- The conflict is between *identical* OpenMP implementations (just different copies) +- For relative performance comparisons between libraries, the results remain valid + +**Alternative (clean environment):** For maximum confidence, create a fresh conda environment using only pip packages: + +```bash +conda create -n benchmark python=3.10 +conda activate benchmark +pip install torch torchgfn scikit-learn jax jaxlib equinox # all from pip +``` + +## Output Format + +Results are saved as JSON files in `benchmark/outputs/` with the following structure: + +```json +{ + "scenario": "tb_hypergrid_small", + "timestamp": "20231218_143052", + "config": { + "env_name": "hypergrid", + "env_kwargs": {"ndim": 2, "height": 8}, + "n_iterations": 1000, + "batch_size": 16, + ... + }, + "results": [...], + "summary": { + "torchgfn": { + "n_runs": 3, + "mean_iter_time_ms": 5.23, + "std_iter_time_ms": 0.42, + "mean_throughput_iters_per_sec": 191.2, + "mean_peak_memory_mb": 512.3 + } + } +} +``` diff --git a/benchmark/benchmark_libraries.py b/benchmark/benchmark_libraries.py new file mode 100644 index 00000000..223a8490 --- /dev/null +++ b/benchmark/benchmark_libraries.py @@ -0,0 +1,496 @@ +#!/usr/bin/env python +"""Benchmark script for comparing GFlowNet libraries. + +This script benchmarks torchgfn, gflownet, and gfnx libraries on +Trajectory Balance training for multiple environments: +- hypergrid: Discrete grid navigation (all libraries) +- ising: Discrete Ising model (all libraries) +- box/ccube: Continuous cube environment (torchgfn, gflownet only) +- bitseq: Bit sequence generation (torchgfn, gfnx only) + +Example usage: + python benchmark/benchmark_libraries.py --scenario tb_hypergrid_small --seeds 0 1 2 + python benchmark/benchmark_libraries.py --libraries torchgfn gfnx --scenario tb_bitseq_small + python benchmark/benchmark_libraries.py --scenario tb_ising_6x6 --libraries torchgfn gflownet gfnx +""" + +# ============================================================================ +# OpenMP Conflict Workaround (must be set before any imports) +# ============================================================================ +# On macOS, mixed conda/pip environments often have multiple copies of libomp +# (from llvm-openmp, torch, sklearn, etc.) which causes a crash when both are +# loaded. This workaround allows the program to continue despite the conflict. +# See: https://github.com/pytorch/pytorch/issues/78490 +import os + +os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE") + +import argparse # noqa: E402 +import json # noqa: E402 +import sys # noqa: E402 +import time # noqa: E402 +from datetime import datetime # noqa: E402 +from pathlib import Path # noqa: E402 +from typing import Dict, List, Type # noqa: E402 + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from benchmark.lib_runners.base import BenchmarkConfig # noqa: E402 +from benchmark.lib_runners.base import BenchmarkResult # noqa: E402 +from benchmark.lib_runners.base import LibraryRunner # noqa: E402 + +# ============================================================================ +# Lazy Import Functions (to avoid OpenMP conflicts on macOS) +# ============================================================================ + + +def _get_torchgfn_runner() -> Type[LibraryRunner]: + from benchmark.lib_runners.torchgfn_runner import TorchGFNRunner + + return TorchGFNRunner + + +def _get_gflownet_runner() -> Type[LibraryRunner]: + from benchmark.lib_runners.gflownet_runner import GFlowNetRunner + + return GFlowNetRunner + + +def _get_gfnx_runner() -> Type[LibraryRunner]: + from benchmark.lib_runners.gfnx_runner import GFNXRunner + + return GFNXRunner + + +# ============================================================================ +# Scenario Configurations +# ============================================================================ + +SCENARIOS: Dict[str, BenchmarkConfig] = { + # Hypergrid scenarios (all libraries: torchgfn, gflownet, gfnx) + "tb_hypergrid_small": BenchmarkConfig( + env_name="hypergrid", + env_kwargs={"ndim": 2, "height": 8}, + n_iterations=1000, + batch_size=16, + n_warmup=50, + ), + "tb_hypergrid_medium": BenchmarkConfig( + env_name="hypergrid", + env_kwargs={"ndim": 4, "height": 16}, + n_iterations=2000, + batch_size=16, + n_warmup=100, + ), + "tb_hypergrid_large": BenchmarkConfig( + env_name="hypergrid", + env_kwargs={"ndim": 4, "height": 32}, + n_iterations=5000, + batch_size=16, + n_warmup=100, + ), + # Ising scenarios (all libraries: torchgfn, gflownet, gfnx) + "tb_ising_6x6": BenchmarkConfig( + env_name="ising", + env_kwargs={"L": 6, "J": 0.44}, + n_iterations=1000, + batch_size=16, + n_warmup=50, + ), + "tb_ising_10x10": BenchmarkConfig( + env_name="ising", + env_kwargs={"L": 10, "J": 0.44}, + n_iterations=2000, + batch_size=16, + n_warmup=100, + ), + # Box/CCube scenarios (torchgfn, gflownet only - gfnx does not have this env) + "tb_box_2d": BenchmarkConfig( + env_name="box", + env_kwargs={"n_dim": 2, "delta": 0.25}, + n_iterations=1000, + batch_size=16, + n_warmup=50, + ), + # BitSequence scenarios (torchgfn, gfnx only - gflownet does not have this env) + "tb_bitseq_small": BenchmarkConfig( + env_name="bitseq", + env_kwargs={"word_size": 1, "seq_size": 4, "n_modes": 2}, + n_iterations=1000, + batch_size=16, + n_warmup=50, + ), + "tb_bitseq_medium": BenchmarkConfig( + env_name="bitseq", + env_kwargs={"word_size": 2, "seq_size": 8, "n_modes": 4}, + n_iterations=2000, + batch_size=16, + n_warmup=100, + ), +} + + +# ============================================================================ +# Library Registry (using lazy loaders to avoid importing all libraries) +# ============================================================================ + +# Maps library name to a function that returns the runner class +# This avoids importing all libraries at startup, which causes OpenMP conflicts on macOS +LIBRARY_RUNNERS: Dict[str, callable] = { + "torchgfn": _get_torchgfn_runner, + "gflownet": _get_gflownet_runner, + "gfnx": _get_gfnx_runner, +} + + +# ============================================================================ +# Environment-Library Availability Mapping +# ============================================================================ +# Not all libraries support all environments. This mapping defines which +# libraries can run each environment type. + +ENV_LIBRARY_SUPPORT: Dict[str, List[str]] = { + "hypergrid": ["torchgfn", "gflownet", "gfnx"], + "ising": ["torchgfn", "gflownet", "gfnx"], + "box": ["torchgfn", "gflownet"], # gfnx does not have continuous box/ccube + "bitseq": ["torchgfn", "gfnx"], # gflownet does not have bitsequence +} + + +def get_supported_libraries(env_name: str) -> List[str]: + """Get list of libraries that support the given environment. + + Args: + env_name: Name of the environment. + + Returns: + List of library names that support this environment. + """ + return ENV_LIBRARY_SUPPORT.get(env_name, list(LIBRARY_RUNNERS.keys())) + + +def get_default_libraries(env_name: str) -> List[str]: + """Get default libraries to run for a given environment. + + Args: + env_name: Name of the environment. + + Returns: + List of library names to use by default for this environment. + """ + return get_supported_libraries(env_name) + + +# ============================================================================ +# Benchmarking Functions +# ============================================================================ + + +def run_benchmark( + runner: LibraryRunner, + config: BenchmarkConfig, + seed: int, +) -> BenchmarkResult: + """Run a single benchmark for a library with given config and seed. + + Args: + runner: The library runner instance. + config: Benchmark configuration. + seed: Random seed. + + Returns: + BenchmarkResult with timing and memory information. + """ + print(f" Setting up {runner.name} with seed {seed}...") + runner.setup(config, seed) + + print(f" Running {config.n_warmup} warmup iterations...") + runner.warmup(config.n_warmup) + + print(f" Running {config.n_iterations} timed iterations...") + iter_times = [] + + # Time each iteration individually + runner.synchronize() + total_start = time.perf_counter() + + for i in range(config.n_iterations): + runner.synchronize() + iter_start = time.perf_counter() + + runner.run_iteration() + + runner.synchronize() + iter_end = time.perf_counter() + + iter_times.append(iter_end - iter_start) + + # Progress update every 10% + if (i + 1) % max(1, config.n_iterations // 10) == 0: + progress = (i + 1) / config.n_iterations * 100 + mean_time = sum(iter_times) / len(iter_times) + print( + f" Progress: {progress:.0f}% ({i+1}/{config.n_iterations}), " + f"mean iter time: {mean_time*1000:.2f}ms" + ) + + total_end = time.perf_counter() + total_time = total_end - total_start + + # Get peak memory + peak_memory = runner.get_peak_memory() + + # Cleanup + runner.cleanup() + + return BenchmarkResult( + library=runner.name, + seed=seed, + total_time=total_time, + iter_times=iter_times, + peak_memory=peak_memory, + ) + + +def aggregate_results(results: List[BenchmarkResult]) -> Dict: + """Aggregate results across seeds for each library. + + Args: + results: List of benchmark results. + + Returns: + Dictionary with summary statistics per library. + """ + import statistics + from collections import defaultdict + + by_library = defaultdict(list) + for r in results: + by_library[r.library].append(r) + + summary = {} + for library, lib_results in by_library.items(): + all_iter_times = [] + for r in lib_results: + all_iter_times.extend(r.iter_times) + + mean_iter_time = statistics.mean(all_iter_times) if all_iter_times else 0 + std_iter_time = ( + statistics.stdev(all_iter_times) if len(all_iter_times) > 1 else 0 + ) + + total_times = [r.total_time for r in lib_results] + mean_total_time = statistics.mean(total_times) if total_times else 0 + + throughputs = [r.throughput for r in lib_results] + mean_throughput = statistics.mean(throughputs) if throughputs else 0 + + peak_memories = [ + r.peak_memory_mb for r in lib_results if r.peak_memory_mb is not None + ] + mean_peak_memory = statistics.mean(peak_memories) if peak_memories else None + + summary[library] = { + "n_runs": len(lib_results), + "mean_iter_time_ms": mean_iter_time * 1000, + "std_iter_time_ms": std_iter_time * 1000, + "mean_total_time_s": mean_total_time, + "mean_throughput_iters_per_sec": mean_throughput, + "mean_peak_memory_mb": mean_peak_memory, + } + + return summary + + +def save_results( + scenario: str, + config: BenchmarkConfig, + results: List[BenchmarkResult], + output_dir: Path, +) -> Path: + """Save benchmark results to JSON file. + + Args: + scenario: Scenario name. + config: Benchmark configuration. + results: List of benchmark results. + output_dir: Output directory. + + Returns: + Path to the saved file. + """ + output_dir.mkdir(parents=True, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + libraries_str = "_".join(sorted(set(r.library for r in results))) + filename = f"benchmark_{scenario}_{libraries_str}_{timestamp}.json" + filepath = output_dir / filename + + summary = aggregate_results(results) + + output = { + "scenario": scenario, + "timestamp": timestamp, + "config": config.to_dict(), + "results": [r.to_dict() for r in results], + "summary": summary, + } + + with open(filepath, "w") as f: + json.dump(output, f, indent=2) + + return filepath + + +def print_summary(summary: Dict) -> None: + """Print a formatted summary of benchmark results.""" + print("\n" + "=" * 70) + print("BENCHMARK SUMMARY") + print("=" * 70) + + # Header + print( + f"{'Library':<15} {'Iter Time (ms)':<18} {'Throughput (it/s)':<20} {'Memory (MB)':<15}" + ) + print("-" * 70) + + for library, stats in sorted(summary.items()): + iter_time = f"{stats['mean_iter_time_ms']:.2f} ± {stats['std_iter_time_ms']:.2f}" + throughput = f"{stats['mean_throughput_iters_per_sec']:.1f}" + memory = ( + f"{stats['mean_peak_memory_mb']:.1f}" + if stats["mean_peak_memory_mb"] + else "N/A" + ) + + print(f"{library:<15} {iter_time:<18} {throughput:<20} {memory:<15}") + + print("=" * 70) + + +# ============================================================================ +# Main Entry Point +# ============================================================================ + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark GFlowNet libraries on multiple environments" + ) + parser.add_argument( + "--scenario", + type=str, + default="tb_hypergrid_small", + choices=list(SCENARIOS.keys()), + help="Benchmark scenario to run", + ) + parser.add_argument( + "--seeds", + type=int, + nargs="+", + default=[0, 1, 2], + help="Random seeds to use", + ) + parser.add_argument( + "--libraries", + type=str, + nargs="+", + default=None, # Will be set based on environment + choices=list(LIBRARY_RUNNERS.keys()), + help="Libraries to benchmark (default: all supported for the environment)", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Output directory (default: benchmark/outputs)", + ) + + args = parser.parse_args() + + # Get configuration + config = SCENARIOS[args.scenario] + output_dir = Path(args.output) if args.output else Path(__file__).parent / "outputs" + + # Determine which libraries to run + supported_libs = get_supported_libraries(config.env_name) + if args.libraries is None: + # Use all supported libraries for this environment + libraries = supported_libs + else: + # Validate that requested libraries support this environment + libraries = [] + for lib in args.libraries: + if lib in supported_libs: + libraries.append(lib) + else: + print( + f"Warning: {lib} does not support {config.env_name} environment, skipping." + ) + if not libraries: + print( + f"Error: No valid libraries for {config.env_name}. Supported: {supported_libs}" + ) + sys.exit(1) + + # Note about running multiple libraries on macOS + import platform + + if platform.system() == "Darwin" and len(libraries) > 1: + print( + "\nNote: Running multiple libraries together. KMP_DUPLICATE_LIB_OK is set\n" + "to work around macOS OpenMP conflicts. For cleanest results, consider\n" + "running each library separately.\n" + ) + + print("=" * 70) + print(f"GFlowNet Library Benchmark: {args.scenario}") + print("=" * 70) + print("Configuration:") + print(f" env: {config.env_name}, {config.env_kwargs}") + print(f" n_iterations={config.n_iterations}, batch_size={config.batch_size}") + print(f" n_warmup={config.n_warmup}") + print(f" lr={config.lr}, lr_logz={config.lr_logz}") + print(f" hidden_dim={config.hidden_dim}, n_layers={config.n_layers}") + print(f"Libraries: {', '.join(libraries)}") + print(f"Seeds: {args.seeds}") + print("=" * 70) + + results = [] + + for library in libraries: + print(f"\n[{library.upper()}]") + + runner_cls = LIBRARY_RUNNERS[library]() # Call the lazy loader function + + for seed in args.seeds: + print(f"\nSeed {seed}:") + try: + runner = runner_cls() + result = run_benchmark(runner, config, seed) + results.append(result) + + print(f" Total time: {result.total_time:.2f}s") + print(f" Mean iter time: {result.mean_iter_time*1000:.2f}ms") + print(f" Throughput: {result.throughput:.1f} iter/s") + if result.peak_memory_mb: + print(f" Peak memory: {result.peak_memory_mb:.1f}MB") + + except Exception as e: + print(f" ERROR: {e}") + import traceback + + traceback.print_exc() + + # Save and summarize results + if results: + filepath = save_results(args.scenario, config, results, output_dir) + print(f"\nResults saved to: {filepath}") + + summary = aggregate_results(results) + print_summary(summary) + + +if __name__ == "__main__": + main() diff --git a/benchmark/dependencies.sh b/benchmark/dependencies.sh new file mode 100644 index 00000000..280a777e --- /dev/null +++ b/benchmark/dependencies.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +pip install jax==0.4.35 jaxlib==0.4.35 jax-metal==0.1.1 hydra-core equinox optax torchtyping flashbax jax_tqdm diff --git a/benchmark/gflownet b/benchmark/gflownet new file mode 160000 index 00000000..833ec44c --- /dev/null +++ b/benchmark/gflownet @@ -0,0 +1 @@ +Subproject commit 833ec44cdb7906485bf86760fa9df86d4ce865af diff --git a/benchmark/gfnx b/benchmark/gfnx new file mode 160000 index 00000000..4235e2b3 --- /dev/null +++ b/benchmark/gfnx @@ -0,0 +1 @@ +Subproject commit 4235e2b3ddb0c57273ba54e61e8ee6f45740107b diff --git a/benchmark/lib_runners/__init__.py b/benchmark/lib_runners/__init__.py new file mode 100644 index 00000000..43395d17 --- /dev/null +++ b/benchmark/lib_runners/__init__.py @@ -0,0 +1,19 @@ +"""Library runners for GFlowNet benchmarking.""" + +from benchmark.lib_runners.base import ( + BenchmarkConfig, + BenchmarkResult, + LibraryRunner, +) +from benchmark.lib_runners.gflownet_runner import GFlowNetRunner +from benchmark.lib_runners.gfnx_runner import GFNXRunner +from benchmark.lib_runners.torchgfn_runner import TorchGFNRunner + +__all__ = [ + "BenchmarkConfig", + "BenchmarkResult", + "LibraryRunner", + "TorchGFNRunner", + "GFlowNetRunner", + "GFNXRunner", +] diff --git a/benchmark/lib_runners/base.py b/benchmark/lib_runners/base.py new file mode 100644 index 00000000..9e97f494 --- /dev/null +++ b/benchmark/lib_runners/base.py @@ -0,0 +1,152 @@ +"""Base classes and configuration for library benchmarking.""" + +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass, field +from typing import List, Optional + + +@dataclass +class BenchmarkConfig: + """Configuration for a benchmark scenario. + + Environment-specific parameters are stored in env_kwargs and + interpreted by each runner. + """ + + # Environment (type and parameters specified per scenario) + env_name: str # e.g., "hypergrid", "bitseq", etc. + env_kwargs: dict # Environment-specific parameters + + # Training + n_iterations: int = 1000 + batch_size: int = 16 + lr: float = 1e-3 + lr_logz: float = 0.1 + + # Network + hidden_dim: int = 256 + n_layers: int = 2 + + # Benchmark settings + n_warmup: int = 50 # Warmup iterations (excluded from timing) + + def to_dict(self) -> dict: + """Convert to dictionary for JSON serialization.""" + return asdict(self) + + +@dataclass +class BenchmarkResult: + """Results from a single benchmark run.""" + + library: str + seed: int + + # Timing (seconds) + total_time: float + iter_times: List[float] = field(default_factory=list) + + # Memory (bytes) + peak_memory: Optional[int] = None + + @property + def mean_iter_time(self) -> float: + """Mean iteration time in seconds.""" + if not self.iter_times: + return 0.0 + return sum(self.iter_times) / len(self.iter_times) + + @property + def std_iter_time(self) -> float: + """Standard deviation of iteration time.""" + if len(self.iter_times) < 2: + return 0.0 + mean = self.mean_iter_time + variance = sum((t - mean) ** 2 for t in self.iter_times) / len(self.iter_times) + return variance**0.5 + + @property + def throughput(self) -> float: + """Throughput in iterations per second.""" + if self.total_time <= 0: + return 0.0 + return len(self.iter_times) / self.total_time + + @property + def peak_memory_mb(self) -> Optional[float]: + """Peak memory in megabytes.""" + if self.peak_memory is None: + return None + return self.peak_memory / (1024 * 1024) + + def to_dict(self) -> dict: + """Convert to dictionary for JSON serialization.""" + return { + "library": self.library, + "seed": self.seed, + "total_time": self.total_time, + "mean_iter_time": self.mean_iter_time, + "std_iter_time": self.std_iter_time, + "throughput_iters_per_sec": self.throughput, + "peak_memory_mb": self.peak_memory_mb, + "n_iterations": len(self.iter_times), + } + + +class LibraryRunner(ABC): + """Abstract base class for library-specific benchmark runners.""" + + name: str # e.g., "torchgfn", "gflownet", "gfnx" + + @abstractmethod + def setup(self, config: BenchmarkConfig, seed: int) -> None: + """Initialize environment, model, and optimizer. + + Called once per seed. This phase is not timed. + + Args: + config: Benchmark configuration. + seed: Random seed for reproducibility. + """ + + @abstractmethod + def warmup(self, n_iters: int) -> None: + """Run warmup iterations. + + Used to trigger JIT compilation (JAX) and CUDA kernel caching (PyTorch). + These iterations are excluded from timing. + + Args: + n_iters: Number of warmup iterations to run. + """ + + @abstractmethod + def run_iteration(self) -> None: + """Run a single training iteration. + + This should include: sampling trajectories, computing loss, + and performing optimizer step. + """ + + @abstractmethod + def synchronize(self) -> None: + """Ensure all asynchronous operations are complete. + + For PyTorch: torch.cuda.synchronize() if on GPU + For JAX: jax.block_until_ready() on outputs + """ + + @abstractmethod + def get_peak_memory(self) -> Optional[int]: + """Return peak memory usage in bytes. + + Returns: + Peak memory in bytes, or None if not available (e.g., CPU-only). + """ + + @abstractmethod + def cleanup(self) -> None: + """Release resources and clean up. + + Called after benchmark completes. + """ diff --git a/benchmark/lib_runners/gflownet_runner.py b/benchmark/lib_runners/gflownet_runner.py new file mode 100644 index 00000000..15d92b6d --- /dev/null +++ b/benchmark/lib_runners/gflownet_runner.py @@ -0,0 +1,264 @@ +"""GFlowNet library (external) runner for benchmarking. + +This runner uses the gflownet library from benchmark/gflownet/. +The library is tightly coupled with Hydra, so we use Hydra's compose API +to build the configuration and instantiate the GFlowNet agent. + +Supports environments: hypergrid, ising, box (ccube). +""" + +import sys +from pathlib import Path +from typing import List, Optional + +import torch + +from benchmark.lib_runners.base import BenchmarkConfig, LibraryRunner + +# Add gflownet to path +GFLOWNET_PATH = Path(__file__).parent.parent / "gflownet" +if str(GFLOWNET_PATH) not in sys.path: + sys.path.insert(0, str(GFLOWNET_PATH)) + + +class GFlowNetRunner(LibraryRunner): + """Benchmark runner for the external gflownet library. + + This library uses Hydra for configuration, so we use Hydra's compose API + to build the configuration programmatically. + + Supports environments: + - hypergrid: Discrete grid navigation (uses grid env) + - ising: Discrete Ising model environment + - box: Continuous cube environment (uses ccube env) + """ + + name = "gflownet" + + def __init__(self): + self.agent = None + self.device = None + self.config = None + self._iteration = 0 + + def setup(self, config: BenchmarkConfig, seed: int) -> None: + """Initialize environment, model, and optimizer using Hydra compose.""" + import random + + import numpy as np + from gflownet.utils.common import gflownet_from_config + from hydra import compose, initialize_config_dir + from hydra.core.global_hydra import GlobalHydra + from omegaconf import OmegaConf, open_dict + + self.config = config + + # Set seeds + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Reset CUDA memory stats + if self.device.type == "cuda": + torch.cuda.reset_peak_memory_stats(self.device) + + # Clear any existing Hydra state + GlobalHydra.instance().clear() + + # Use Hydra compose to build Hydra configuration + config_dir = str(GFLOWNET_PATH / "config") + + with initialize_config_dir(config_dir=config_dir, version_base="1.1"): + # Build environment-specific overrides + overrides = self._get_env_overrides(config, seed) + + cfg = compose(config_name="train", overrides=overrides) + + # Set a dummy log directory + with open_dict(cfg): + cfg.logger = OmegaConf.create( + { + "_target_": "gflownet.utils.logger.Logger", + "do": {"online": False, "checkpoints": False, "times": False}, + "logdir": { + "root": "/tmp/gflownet_benchmark", + "path": "/tmp/gflownet_benchmark", + "ckpts": "ckpts", + "overwrite": True, + }, + "project_name": "benchmark", + "lightweight": True, + "progressbar": {"skip": True, "n_iters_mean": 100}, + "debug": False, + "context": "0", + "notes": None, + "entity": None, + "run_id": None, + "is_resumed": False, + "run_name": None, + "run_name_date": False, + "run_name_job": False, + } + ) + if not hasattr(cfg, "n_samples"): + cfg.n_samples = 0 + + # Create the GFlowNet agent + self.agent = gflownet_from_config(cfg) + + self._iteration = 0 + + def _get_env_overrides(self, config: BenchmarkConfig, seed: int) -> List[str]: + """Get Hydra overrides for the specified environment.""" + env_name = config.env_name + + # Common overrides for all environments + common_overrides = [ + # Policy settings + "policy=mlp", + f"policy.forward.n_hid={config.hidden_dim}", + f"policy.forward.n_layers={config.n_layers}", + # Training settings + f"gflownet.optimizer.batch_size.forward={config.batch_size}", + f"gflownet.optimizer.lr={config.lr}", + f"gflownet.optimizer.lr_z_mult={config.lr_logz / config.lr}", + f"gflownet.optimizer.n_train_steps={config.n_iterations + config.n_warmup}", + # Device and seed + f"device={self.device.type}", + f"seed={seed}", + # Disable evaluation for benchmarking + "evaluator.period=0", + ] + + if env_name == "hypergrid": + return self._get_hypergrid_overrides(config) + common_overrides + elif env_name == "ising": + return self._get_ising_overrides(config) + common_overrides + elif env_name == "box": + return self._get_ccube_overrides(config) + common_overrides + else: + raise ValueError(f"Unknown environment: {env_name}") + + def _get_hypergrid_overrides(self, config: BenchmarkConfig) -> List[str]: + """Get Hydra overrides for hypergrid environment.""" + ndim = config.env_kwargs["ndim"] + height = config.env_kwargs["height"] + + return [ + # Grid environment (default env is grid) + f"env.n_dim={ndim}", + f"env.length={height}", + # Proxy - configure corners proxy inline + "proxy._target_=gflownet.proxy.box.corners.Corners", + "+proxy.mu=0.75", + "+proxy.sigma=0.05", + "+proxy.do_gaussians=true", + "+proxy.do_threshold=false", + "++proxy.reward_function=identity", + "++proxy.reward_min=0.0", + "++proxy.do_clip_rewards=false", + ] + + def _get_ising_overrides(self, config: BenchmarkConfig) -> List[str]: + """Get Hydra overrides for Ising environment.""" + L = config.env_kwargs["L"] + + return [ + # Ising environment + "env=ising", + "env.n_dim=2", + f"env.length={L}", + # Uniform proxy for Ising + "proxy=uniform", + ] + + def _get_ccube_overrides(self, config: BenchmarkConfig) -> List[str]: + """Get Hydra overrides for continuous cube (box) environment.""" + n_dim = config.env_kwargs.get("n_dim", 2) + delta = config.env_kwargs.get("delta", 0.25) + + return [ + # CCube environment + "env=ccube", + f"env.n_dim={n_dim}", + "env.n_comp=5", + f"env.min_incr={delta}", + "env.beta_params_min=0.1", + "env.beta_params_max=100.0", + # Corners proxy for box/ccube + "proxy=box/corners", + ] + + def warmup(self, n_iters: int) -> None: + """Run warmup iterations.""" + for _ in range(n_iters): + self.run_iteration() + self.synchronize() + self._iteration = 0 # Reset iteration counter after warmup + + def run_iteration(self) -> None: + """Run a single training iteration. + + This method is environment-agnostic. The bugfix for batch.proxy + applies to all environments automatically. + """ + # Sample batch + batch, _ = self.agent.sample_batch( + n_forward=self.agent.batch_size.forward, + n_train=0, + n_replay=0, + collect_forwards_masks=True, + collect_backwards_masks=self.agent.collect_backwards_masks, + ) + + # Workaround for gflownet bug: sample_batch creates an empty batch without + # proxy, then merges sub-batches into it. The merge doesn't copy the proxy. + # See gflownet.py line 594 vs 601. + # This fix applies to ALL environments (hypergrid, ising, ccube). + if batch.proxy is None: + batch.proxy = self.agent.proxy + + # Compute loss + losses = self.agent.loss.compute(batch, get_sublosses=True) + + # Backward and optimize + if all([torch.isfinite(loss) for loss in losses.values()]): + losses["all"].backward() + if self.agent.clip_grad_norm > 0: + torch.nn.utils.clip_grad_norm_( + self.agent.parameters(), self.agent.clip_grad_norm + ) + self.agent.opt.step() + self.agent.lr_scheduler.step() + self.agent.opt.zero_grad() + + self._iteration += 1 + + def synchronize(self) -> None: + """Ensure all CUDA operations are complete.""" + if self.device is not None and self.device.type == "cuda": + torch.cuda.synchronize(self.device) + + def get_peak_memory(self) -> Optional[int]: + """Return peak GPU memory usage in bytes.""" + if self.device is not None and self.device.type == "cuda": + return torch.cuda.max_memory_allocated(self.device) + return None + + def cleanup(self) -> None: + """Release resources.""" + from hydra.core.global_hydra import GlobalHydra + + if self.agent is not None: + del self.agent + self.agent = None + + if self.device is not None and self.device.type == "cuda": + torch.cuda.empty_cache() + + # Clear Hydra state + GlobalHydra.instance().clear() diff --git a/benchmark/lib_runners/gfnx_runner.py b/benchmark/lib_runners/gfnx_runner.py new file mode 100644 index 00000000..83639142 --- /dev/null +++ b/benchmark/lib_runners/gfnx_runner.py @@ -0,0 +1,997 @@ +"""GFNX library (JAX-based) runner for benchmarking. + +This runner adapts the gfnx library from benchmark/gfnx/. +The library uses JAX/Equinox, so proper JIT warmup and synchronization +via jax.block_until_ready() is critical for accurate timing. + +Supports environments: hypergrid, ising, bitseq. + +Key JAX/Equinox concepts used here: +- eqx.filter_jit: JIT-compiles functions while handling non-array leaves (like functions) +- eqx.partition/combine: Splits pytrees into array/non-array parts for gradient computation +- jax.vmap: Vectorizes functions over batch dimensions +- optax.multi_transform: Applies different optimizers to different parameter groups +""" + +import sys +from pathlib import Path +from typing import Any, NamedTuple, Optional + +from benchmark.lib_runners.base import BenchmarkConfig, LibraryRunner + +# Add gfnx to path so we can import the library +GFNX_PATH = Path(__file__).parent.parent / "gfnx" +if str(GFNX_PATH / "src") not in sys.path: + sys.path.insert(0, str(GFNX_PATH / "src")) + + +class GFNXRunner(LibraryRunner): + """Benchmark runner for the gfnx JAX-based library. + + Critical timing considerations for JAX: + - JIT compilation happens on first call, so warmup is essential + - Use jax.block_until_ready() to ensure async operations complete + + Supports environments: + - hypergrid: Discrete grid navigation + - ising: Discrete Ising model environment + - bitseq: Bit sequence generation + """ + + name = "gfnx" + + def __init__(self): + # Training state contains all JAX arrays and pytrees needed for training + self.train_state = None + # JIT-compiled training step function (created per environment type) + self.train_step_fn = None + self.config = None + self._iteration = 0 + self._env_type = None + + def setup(self, config: BenchmarkConfig, seed: int) -> None: + """Initialize environment, model, and optimizer based on env_name. + + Dispatches to environment-specific setup methods which create: + 1. The gfnx environment and its parameters + 2. The policy network (MLPPolicy) + 3. The optimizer with separate learning rates for network and logZ + 4. The TrainState NamedTuple containing all training state + 5. The JIT-compiled training step function + """ + self.config = config + self._env_type = config.env_name + + if config.env_name == "hypergrid": + self._setup_hypergrid(config, seed) + elif config.env_name == "ising": + self._setup_ising(config, seed) + elif config.env_name == "bitseq": + self._setup_bitseq(config, seed) + else: + raise ValueError(f"Unknown environment: {config.env_name}") + + self._iteration = 0 + + def _setup_hypergrid(self, config: BenchmarkConfig, seed: int) -> None: + """Setup HyperGrid environment with Trajectory Balance loss. + + HyperGrid is a discrete navigation environment where the agent moves + through an N-dimensional grid from origin to any terminal state. + """ + import equinox as eqx + import gfnx + import jax + import jax.numpy as jnp + import optax + + ndim = config.env_kwargs["ndim"] + height = config.env_kwargs["height"] + + # JAX uses explicit PRNG keys for reproducibility + rng_key = jax.random.PRNGKey(seed) + env_init_key = jax.random.PRNGKey(seed + 1) + + # Create gfnx environment with reward module + # EasyHypergridRewardModule provides a simple reward structure + reward_module = gfnx.EasyHypergridRewardModule() + env = gfnx.environment.HypergridEnvironment(reward_module, dim=ndim, side=height) + # env.init() creates the environment parameters (reward params, etc.) + env_params = env.init(env_init_key) + + # Create policy network that outputs both forward and backward logits + rng_key, net_init_key = jax.random.split(rng_key) + model = MLPPolicy( + input_size=env.observation_space.shape[0], + n_fwd_actions=env.action_space.n, + n_bwd_actions=env.backward_action_space.n, + hidden_size=config.hidden_dim, + train_backward_policy=True, # Learn P_B for TB loss + depth=config.n_layers, + rng_key=net_init_key, + ) + + # logZ is a learnable scalar for the partition function estimate + logZ = jnp.array(0.0) + + # Setup optimizer with separate learning rates for network and logZ + # This is common in GFlowNet training where logZ needs higher LR + model_params_init = eqx.filter(model, eqx.is_array) + initial_optax_params = {"model_params": model_params_init, "logZ": logZ} + + # param_labels maps each parameter to its optimizer group + param_labels = { + "model_params": jax.tree.map(lambda _: "network_lr", model_params_init), + "logZ": "logZ_lr", + } + + optimizer_defs = { + "network_lr": optax.adam(learning_rate=config.lr), + "logZ_lr": optax.adam(learning_rate=config.lr_logz), + } + optimizer = optax.multi_transform(optimizer_defs, param_labels) # type: ignore + opt_state = optimizer.init(initial_optax_params) + + # No exploration during benchmark for consistency + exploration_schedule = optax.constant_schedule(0.0) + + # Pack everything into a NamedTuple for easy passing through JAX transforms + self.train_state = HypergridTrainState( + rng_key=rng_key, + env=env, + env_params=env_params, + model=model, + logZ=logZ, + optimizer=optimizer, + opt_state=opt_state, + exploration_schedule=exploration_schedule, + num_envs=config.batch_size, + ) + + # Create the JIT-compiled training step + self.train_step_fn = create_hypergrid_train_step() + + def _setup_ising(self, config: BenchmarkConfig, seed: int) -> None: + """Setup Ising environment with Trajectory Balance loss. + + The Ising environment models spin configurations on a lattice. + Each state is a binary assignment of spins, built incrementally. + """ + import equinox as eqx + import gfnx + import jax + import jax.numpy as jnp + import optax + + L = config.env_kwargs["L"] # Lattice side length + N = L**2 # Total number of spins + + rng_key = jax.random.PRNGKey(seed) + env_init_key = jax.random.PRNGKey(seed + 1) + + # IsingRewardModule computes rewards based on spin configurations + reward_module = gfnx.IsingRewardModule() + env = gfnx.environment.IsingEnvironment(reward_module, dim=N) + env_params = env.init(env_init_key) + + rng_key, net_init_key = jax.random.split(rng_key) + model = MLPPolicy( + input_size=env.observation_space.shape[0], + n_fwd_actions=env.action_space.n, + n_bwd_actions=env.backward_action_space.n, + hidden_size=config.hidden_dim, + train_backward_policy=True, + depth=config.n_layers, + rng_key=net_init_key, + ) + + logZ = jnp.array(0.0) + + model_params_init = eqx.filter(model, eqx.is_array) + initial_optax_params = {"model_params": model_params_init, "logZ": logZ} + + param_labels = { + "model_params": jax.tree.map(lambda _: "network_lr", model_params_init), + "logZ": "logZ_lr", + } + + optimizer_defs = { + "network_lr": optax.adam(learning_rate=config.lr), + "logZ_lr": optax.adam(learning_rate=config.lr_logz), + } + optimizer = optax.multi_transform(optimizer_defs, param_labels) # type: ignore + opt_state = optimizer.init(initial_optax_params) + + exploration_schedule = optax.constant_schedule(0.0) + + self.train_state = IsingTrainState( + rng_key=rng_key, + env=env, + env_params=env_params, + model=model, + logZ=logZ, + optimizer=optimizer, + opt_state=opt_state, + exploration_schedule=exploration_schedule, + num_envs=config.batch_size, + ) + + self.train_step_fn = create_ising_train_step() + + def _setup_bitseq(self, config: BenchmarkConfig, seed: int) -> None: + """Setup BitSequence environment with Trajectory Balance loss. + + BitSequence generates sequences of bits/tokens, with rewards based + on proximity to a set of target "mode" sequences. + """ + import equinox as eqx + import gfnx + import jax + import jax.numpy as jnp + import optax + + word_size = config.env_kwargs.get("word_size", 1) # Bits per word + seq_size = config.env_kwargs.get("seq_size", 4) # Words per sequence + n_modes = config.env_kwargs.get("n_modes", 2) # Number of target modes + + rng_key = jax.random.PRNGKey(seed) + env_init_key = jax.random.PRNGKey(seed + 1) + + # BitseqRewardModule rewards sequences close to the mode set + reward_module = gfnx.BitseqRewardModule( + sentence_len=seq_size, + k=word_size, + mode_set_size=n_modes, + reward_exponent=2.0, # Sharpness of reward around modes + ) + env = gfnx.BitseqEnvironment(reward_module, n=seq_size, k=word_size) + env_params = env.init(env_init_key) + + rng_key, net_init_key = jax.random.split(rng_key) + model = MLPPolicy( + input_size=env.observation_space.shape[0], + n_fwd_actions=env.action_space.n, + n_bwd_actions=env.backward_action_space.n, + hidden_size=config.hidden_dim, + train_backward_policy=True, + depth=config.n_layers, + rng_key=net_init_key, + ) + + logZ = jnp.array(0.0) + + model_params_init = eqx.filter(model, eqx.is_array) + initial_optax_params = {"model_params": model_params_init, "logZ": logZ} + + param_labels = { + "model_params": jax.tree.map(lambda _: "network_lr", model_params_init), + "logZ": "logZ_lr", + } + + optimizer_defs = { + "network_lr": optax.adam(learning_rate=config.lr), + "logZ_lr": optax.adam(learning_rate=config.lr_logz), + } + optimizer = optax.multi_transform(optimizer_defs, param_labels) # type: ignore + opt_state = optimizer.init(initial_optax_params) + + exploration_schedule = optax.constant_schedule(0.0) + + self.train_state = BitseqTrainState( + rng_key=rng_key, + env=env, + env_params=env_params, + model=model, + logZ=logZ, + optimizer=optimizer, + opt_state=opt_state, + exploration_schedule=exploration_schedule, + num_envs=config.batch_size, + ) + + self.train_step_fn = create_bitseq_train_step() + + def warmup(self, n_iters: int) -> None: + """Run warmup iterations to trigger JIT compilation. + + JAX compiles functions on first call. Running warmup iterations + ensures compilation overhead isn't counted in benchmark timing. + """ + for i in range(n_iters): + self.train_state = self.train_step_fn(i, self.train_state) + + # Block until all async operations complete + self.synchronize() + self._iteration = 0 + + def run_iteration(self) -> None: + """Run a single training iteration. + + The train_step_fn is already JIT-compiled, so this is just + a function call that updates the train_state in place. + """ + self.train_state = self.train_step_fn(self._iteration, self.train_state) + self._iteration += 1 + + def synchronize(self) -> None: + """Ensure all JAX operations are complete. + + JAX operations are asynchronous - they return immediately while + computation continues on device. block_until_ready() forces + synchronization for accurate timing measurements. + """ + import equinox as eqx + import jax + + if self.train_state is not None: + # Extract all arrays from train_state and wait for them + params, _ = eqx.partition(self.train_state, eqx.is_array) + jax.block_until_ready(params) + + def get_peak_memory(self) -> Optional[int]: + """Return peak memory usage in bytes. + + JAX memory stats are device-dependent and may not be available + on all backends (e.g., CPU backend doesn't track memory). + """ + import jax + + try: + devices = jax.local_devices() + if devices and hasattr(devices[0], "memory_stats"): + stats = devices[0].memory_stats() + if stats and "peak_bytes_in_use" in stats: + return stats["peak_bytes_in_use"] + except Exception: + pass + return None + + def cleanup(self) -> None: + """Release resources and clear JAX caches.""" + import jax + + self.train_state = None + self.train_step_fn = None + + # Clear compiled function caches to free memory + jax.clear_caches() + + +# ============================================================================ +# Train State definitions for each environment +# ============================================================================ +# Using NamedTuples allows these to be passed through JAX transforms. +# Each field is typed as Any because JAX arrays and pytrees don't have +# static types that work well with Python's type system. + + +class HypergridTrainState(NamedTuple): + """Training state for Hypergrid environment.""" + + rng_key: Any # JAX PRNG key for stochastic operations + env: Any # gfnx.HypergridEnvironment instance + env_params: Any # Environment parameters (reward params, etc.) + model: Any # MLPPolicy instance + logZ: Any # Learnable log partition function estimate + optimizer: Any # optax optimizer (GradientTransformation) + opt_state: Any # Optimizer state (momentum, etc.) + exploration_schedule: Any # Epsilon schedule for exploration + num_envs: int # Batch size for parallel trajectory sampling + + +class IsingTrainState(NamedTuple): + """Training state for Ising environment.""" + + rng_key: Any + env: Any # gfnx.IsingEnvironment + env_params: Any + model: Any + logZ: Any + optimizer: Any + opt_state: Any + exploration_schedule: Any + num_envs: int + + +class BitseqTrainState(NamedTuple): + """Training state for BitSequence environment.""" + + rng_key: Any + env: Any # gfnx.BitseqEnvironment + env_params: Any + model: Any + logZ: Any + optimizer: Any + opt_state: Any + exploration_schedule: Any + num_envs: int + + +# ============================================================================ +# MLP Policy (shared across environments) +# ============================================================================ + + +class MLPPolicy: + """MLP policy network for forward and backward actions. + + Outputs both forward and backward action logits from a shared network. + The network has a single output head that's split into forward and + backward parts, encouraging parameter sharing. + + In GFlowNets: + - Forward logits: Used to sample actions when building trajectories + - Backward logits: Used to compute P_B for the TB loss + """ + + def __init__( + self, + input_size: int, + n_fwd_actions: int, + n_bwd_actions: int, + hidden_size: int, + train_backward_policy: bool, + depth: int, + rng_key, + ): + import equinox as eqx + + self.train_backward_policy = train_backward_policy + self.n_fwd_actions = n_fwd_actions + self.n_bwd_actions = n_bwd_actions + + # Output size includes both forward and backward logits + output_size = n_fwd_actions + if train_backward_policy: + output_size += n_bwd_actions + + # Equinox MLP with specified depth and width + self.network = eqx.nn.MLP( + in_size=input_size, + out_size=output_size, + width_size=hidden_size, + depth=depth, + key=rng_key, + ) + + def __call__(self, x): + """Forward pass returning both forward and backward logits. + + Args: + x: Observation tensor of shape (input_size,) + + Returns: + Dict with 'forward_logits' and 'backward_logits' tensors + """ + import jax.numpy as jnp + + x = self.network(x) + if self.train_backward_policy: + # Split output into forward and backward parts + forward_logits, backward_logits = jnp.split(x, [self.n_fwd_actions], axis=-1) + else: + forward_logits = x + backward_logits = jnp.zeros(shape=(self.n_bwd_actions,), dtype=jnp.float32) + return { + "forward_logits": forward_logits, + "backward_logits": backward_logits, + } + + +# ============================================================================ +# Train Step functions for each environment +# ============================================================================ +# Each create_*_train_step() returns a JIT-compiled function that: +# 1. Samples a batch of trajectories using the forward policy +# 2. Computes the Trajectory Balance loss +# 3. Updates model parameters and logZ using gradients +# +# The train step functions are nearly identical across environments, but +# are kept separate because: +# - Different TrainState types (for type safety in the JIT) +# - Different environment-specific details in the future +# - Clearer debugging when issues arise + + +def create_hypergrid_train_step(): + """Create a JIT-compiled training step for Hypergrid. + + Returns a function: (iteration, train_state) -> train_state + """ + import equinox as eqx + import gfnx + import jax + import jax.numpy as jnp + import optax + + @eqx.filter_jit + def train_step(idx: int, train_state: HypergridTrainState) -> HypergridTrainState: + """Single training step for Hypergrid environment. + + Steps: + 1. Sample trajectories using forward policy + 2. Compute TB loss: (log P_F + log Z) vs (log P_B + log R) + 3. Compute gradients and update parameters + """ + rng_key = train_state.rng_key + num_envs = train_state.num_envs + env = train_state.env + env_params = train_state.env_params + + # Split model into learnable params and static structure + # This is needed for gradient computation in Equinox + policy_params, policy_static = eqx.partition(train_state.model, eqx.is_array) + + # Split RNG key for trajectory sampling + rng_key, sample_traj_key = jax.random.split(rng_key) + + # Get current exploration epsilon (0.0 for benchmark) + cur_eps = train_state.exploration_schedule(idx) + + # Forward policy function for trajectory rollout + def fwd_policy_fn(rng_key, env_obs, policy_params): + """Compute forward action logits for a batch of observations.""" + # Reconstruct the model from params and static parts + current_model = eqx.combine(policy_params, policy_static) + # vmap applies the model to each observation in the batch + policy_outputs = jax.vmap(current_model, in_axes=(0,))(env_obs) + fwd_logits = policy_outputs["forward_logits"] + + # Apply epsilon-greedy exploration (uniform random with prob epsilon) + rng_key, exploration_key = jax.random.split(rng_key) + batch_size, _ = fwd_logits.shape + exploration_mask = jax.random.bernoulli( + exploration_key, cur_eps, (batch_size,) + ) + fwd_logits = jnp.where(exploration_mask[..., None], 0, fwd_logits) + return fwd_logits, policy_outputs + + # Sample complete trajectories from initial to terminal states + traj_data, aux_info = gfnx.utils.forward_rollout( + rng_key=sample_traj_key, + num_envs=num_envs, + policy_fn=fwd_policy_fn, + policy_params=policy_params, + env=env, + env_params=env_params, + ) + + # Loss function for gradient computation + def loss_fn( + current_all_params, + static_model_parts, + current_traj_data, + current_env, + current_env_params, + ): + """Compute Trajectory Balance loss. + + TB Loss = E[(log P_F(τ) + log Z - log P_B(τ) - log R(x))^2] + + Where: + - P_F(τ): Forward probability of trajectory τ + - Z: Partition function estimate (learned) + - P_B(τ): Backward probability of trajectory τ + - R(x): Reward at terminal state x + """ + model_learnable_params = current_all_params["model_params"] + logZ_val = current_all_params["logZ"] + + # Reconstruct model and get policy outputs for entire trajectory + model_to_call = eqx.combine(model_learnable_params, static_model_parts) + # Double vmap: over batch and over time steps + policy_outputs_traj = jax.vmap(jax.vmap(model_to_call))( + current_traj_data.obs + ) + + # ========== Compute Forward Log Probabilities ========== + fwd_logits_traj = policy_outputs_traj["forward_logits"] + + # Mask invalid actions (e.g., moving outside grid) + invalid_fwd_mask = jax.vmap( + current_env.get_invalid_mask, in_axes=(1, None), out_axes=1 + )(current_traj_data.state, current_env_params) + masked_fwd_logits_traj = gfnx.utils.mask_logits( + fwd_logits_traj, invalid_fwd_mask + ) + + # Convert to log probabilities + fwd_all_log_probs_traj = jax.nn.log_softmax(masked_fwd_logits_traj, axis=-1) + + # Select log prob of the action that was actually taken + fwd_logprobs_traj = jnp.take_along_axis( + fwd_all_log_probs_traj, + jnp.expand_dims(current_traj_data.action, axis=-1), + axis=-1, + ).squeeze(-1) + + # Zero out padded time steps (trajectories may have different lengths) + fwd_logprobs_traj = jnp.where(current_traj_data.pad, 0.0, fwd_logprobs_traj) + + # Sum log probs along trajectory: log P_F(τ) = Σ log P_F(a_t|s_t) + sum_log_pf_along_traj = fwd_logprobs_traj.sum(axis=1) + log_pf_traj = logZ_val + sum_log_pf_along_traj + + # ========== Compute Backward Log Probabilities ========== + # Get state transitions for computing backward actions + prev_states = jax.tree.map(lambda x: x[:, :-1], current_traj_data.state) + fwd_actions = current_traj_data.action[:, :-1] + curr_states = jax.tree.map(lambda x: x[:, 1:], current_traj_data.state) + + # Get the backward action that would undo each forward action + bwd_actions_traj = jax.vmap( + current_env.get_backward_action, + in_axes=(1, 1, 1, None), + out_axes=1, + )(prev_states, fwd_actions, curr_states, current_env_params) + + bwd_logits_traj = policy_outputs_traj["backward_logits"] + # Shift by 1: backward logits at state s_t predict action to reach s_{t-1} + bwd_logits_for_pb = bwd_logits_traj[:, 1:] + + # Mask invalid backward actions + invalid_bwd_mask = jax.vmap( + current_env.get_invalid_backward_mask, + in_axes=(1, None), + out_axes=1, + )(curr_states, current_env_params) + + masked_bwd_logits_traj = gfnx.utils.mask_logits( + bwd_logits_for_pb, invalid_bwd_mask + ) + bwd_all_log_probs_traj = jax.nn.log_softmax(masked_bwd_logits_traj, axis=-1) + + # Select log prob of the backward action + log_pb_selected = jnp.take_along_axis( + bwd_all_log_probs_traj, + jnp.expand_dims(bwd_actions_traj, axis=-1), + axis=-1, + ).squeeze(-1) + + # Zero out padded steps + pad_mask_for_bwd = current_traj_data.pad[:, :-1] + log_pb_selected = jnp.where(pad_mask_for_bwd, 0.0, log_pb_selected) + + # ========== Compute Target: log P_B(τ) + log R(x) ========== + log_rewards_at_steps = current_traj_data.log_gfn_reward[:, :-1] + masked_log_rewards_at_steps = jnp.where( + pad_mask_for_bwd, 0.0, log_rewards_at_steps + ) + + log_pb_plus_rewards_along_traj = ( + log_pb_selected + masked_log_rewards_at_steps + ) + target = jnp.sum(log_pb_plus_rewards_along_traj, axis=1) + + # ========== TB Loss: MSE between log P_F + log Z and log P_B + log R ========== + loss = optax.losses.squared_error(log_pf_traj, target).mean() + return loss + + # Compute loss and gradients + params_for_loss = {"model_params": policy_params, "logZ": train_state.logZ} + mean_loss, grads = eqx.filter_value_and_grad(loss_fn)( + params_for_loss, policy_static, traj_data, env, env_params + ) + + # Apply optimizer updates + optax_params_for_update = { + "model_params": policy_params, + "logZ": train_state.logZ, + } + updates, new_opt_state = train_state.optimizer.update( + grads, train_state.opt_state, optax_params_for_update + ) + + # Apply updates to get new parameters + new_model = eqx.apply_updates(train_state.model, updates["model_params"]) + new_logZ = eqx.apply_updates(train_state.logZ, updates["logZ"]) + + # Return updated train state + return train_state._replace( + rng_key=rng_key, + model=new_model, + logZ=new_logZ, + opt_state=new_opt_state, + ) + + return train_step + + +def create_ising_train_step(): + """Create a JIT-compiled training step for Ising environment. + + The training logic is identical to Hypergrid - only the environment + and state types differ. See create_hypergrid_train_step for detailed comments. + """ + import equinox as eqx + import gfnx + import jax + import jax.numpy as jnp + import optax + + @eqx.filter_jit + def train_step(idx: int, train_state: IsingTrainState) -> IsingTrainState: + rng_key = train_state.rng_key + num_envs = train_state.num_envs + env = train_state.env + env_params = train_state.env_params + + policy_params, policy_static = eqx.partition(train_state.model, eqx.is_array) + + rng_key, sample_traj_key = jax.random.split(rng_key) + cur_eps = train_state.exploration_schedule(idx) + + def fwd_policy_fn(rng_key, env_obs, policy_params): + current_model = eqx.combine(policy_params, policy_static) + policy_outputs = jax.vmap(current_model, in_axes=(0,))(env_obs) + fwd_logits = policy_outputs["forward_logits"] + + rng_key, exploration_key = jax.random.split(rng_key) + batch_size, _ = fwd_logits.shape + exploration_mask = jax.random.bernoulli( + exploration_key, cur_eps, (batch_size,) + ) + fwd_logits = jnp.where(exploration_mask[..., None], 0, fwd_logits) + return fwd_logits, policy_outputs + + traj_data, aux_info = gfnx.utils.forward_rollout( + rng_key=sample_traj_key, + num_envs=num_envs, + policy_fn=fwd_policy_fn, + policy_params=policy_params, + env=env, + env_params=env_params, + ) + + def loss_fn( + current_all_params, + static_model_parts, + current_traj_data, + current_env, + current_env_params, + ): + model_learnable_params = current_all_params["model_params"] + logZ_val = current_all_params["logZ"] + + model_to_call = eqx.combine(model_learnable_params, static_model_parts) + policy_outputs_traj = jax.vmap(jax.vmap(model_to_call))( + current_traj_data.obs + ) + + # Forward log probabilities + fwd_logits_traj = policy_outputs_traj["forward_logits"] + invalid_fwd_mask = jax.vmap( + current_env.get_invalid_mask, in_axes=(1, None), out_axes=1 + )(current_traj_data.state, current_env_params) + masked_fwd_logits_traj = gfnx.utils.mask_logits( + fwd_logits_traj, invalid_fwd_mask + ) + fwd_all_log_probs_traj = jax.nn.log_softmax(masked_fwd_logits_traj, axis=-1) + fwd_logprobs_traj = jnp.take_along_axis( + fwd_all_log_probs_traj, + jnp.expand_dims(current_traj_data.action, axis=-1), + axis=-1, + ).squeeze(-1) + fwd_logprobs_traj = jnp.where(current_traj_data.pad, 0.0, fwd_logprobs_traj) + sum_log_pf_along_traj = fwd_logprobs_traj.sum(axis=1) + log_pf_traj = logZ_val + sum_log_pf_along_traj + + # Backward log probabilities + prev_states = jax.tree.map(lambda x: x[:, :-1], current_traj_data.state) + fwd_actions = current_traj_data.action[:, :-1] + curr_states = jax.tree.map(lambda x: x[:, 1:], current_traj_data.state) + + bwd_actions_traj = jax.vmap( + current_env.get_backward_action, + in_axes=(1, 1, 1, None), + out_axes=1, + )(prev_states, fwd_actions, curr_states, current_env_params) + + bwd_logits_traj = policy_outputs_traj["backward_logits"] + bwd_logits_for_pb = bwd_logits_traj[:, 1:] + invalid_bwd_mask = jax.vmap( + current_env.get_invalid_backward_mask, + in_axes=(1, None), + out_axes=1, + )(curr_states, current_env_params) + + masked_bwd_logits_traj = gfnx.utils.mask_logits( + bwd_logits_for_pb, invalid_bwd_mask + ) + bwd_all_log_probs_traj = jax.nn.log_softmax(masked_bwd_logits_traj, axis=-1) + log_pb_selected = jnp.take_along_axis( + bwd_all_log_probs_traj, + jnp.expand_dims(bwd_actions_traj, axis=-1), + axis=-1, + ).squeeze(-1) + + pad_mask_for_bwd = current_traj_data.pad[:, :-1] + log_pb_selected = jnp.where(pad_mask_for_bwd, 0.0, log_pb_selected) + + log_rewards_at_steps = current_traj_data.log_gfn_reward[:, :-1] + masked_log_rewards_at_steps = jnp.where( + pad_mask_for_bwd, 0.0, log_rewards_at_steps + ) + + log_pb_plus_rewards_along_traj = ( + log_pb_selected + masked_log_rewards_at_steps + ) + target = jnp.sum(log_pb_plus_rewards_along_traj, axis=1) + + loss = optax.losses.squared_error(log_pf_traj, target).mean() + return loss + + params_for_loss = {"model_params": policy_params, "logZ": train_state.logZ} + mean_loss, grads = eqx.filter_value_and_grad(loss_fn)( + params_for_loss, policy_static, traj_data, env, env_params + ) + + optax_params_for_update = { + "model_params": policy_params, + "logZ": train_state.logZ, + } + updates, new_opt_state = train_state.optimizer.update( + grads, train_state.opt_state, optax_params_for_update + ) + + new_model = eqx.apply_updates(train_state.model, updates["model_params"]) + new_logZ = eqx.apply_updates(train_state.logZ, updates["logZ"]) + + return train_state._replace( + rng_key=rng_key, + model=new_model, + logZ=new_logZ, + opt_state=new_opt_state, + ) + + return train_step + + +def create_bitseq_train_step(): + """Create a JIT-compiled training step for BitSequence environment. + + The training logic is identical to Hypergrid - only the environment + and state types differ. See create_hypergrid_train_step for detailed comments. + """ + import equinox as eqx + import gfnx + import jax + import jax.numpy as jnp + import optax + + @eqx.filter_jit + def train_step(idx: int, train_state: BitseqTrainState) -> BitseqTrainState: + rng_key = train_state.rng_key + num_envs = train_state.num_envs + env = train_state.env + env_params = train_state.env_params + + policy_params, policy_static = eqx.partition(train_state.model, eqx.is_array) + + rng_key, sample_traj_key = jax.random.split(rng_key) + cur_eps = train_state.exploration_schedule(idx) + + def fwd_policy_fn(rng_key, env_obs, policy_params): + current_model = eqx.combine(policy_params, policy_static) + policy_outputs = jax.vmap(current_model, in_axes=(0,))(env_obs) + fwd_logits = policy_outputs["forward_logits"] + + rng_key, exploration_key = jax.random.split(rng_key) + batch_size, _ = fwd_logits.shape + exploration_mask = jax.random.bernoulli( + exploration_key, cur_eps, (batch_size,) + ) + fwd_logits = jnp.where(exploration_mask[..., None], 0, fwd_logits) + return fwd_logits, policy_outputs + + traj_data, aux_info = gfnx.utils.forward_rollout( + rng_key=sample_traj_key, + num_envs=num_envs, + policy_fn=fwd_policy_fn, + policy_params=policy_params, + env=env, + env_params=env_params, + ) + + def loss_fn( + current_all_params, + static_model_parts, + current_traj_data, + current_env, + current_env_params, + ): + model_learnable_params = current_all_params["model_params"] + logZ_val = current_all_params["logZ"] + + model_to_call = eqx.combine(model_learnable_params, static_model_parts) + policy_outputs_traj = jax.vmap(jax.vmap(model_to_call))( + current_traj_data.obs + ) + + # Forward log probabilities + fwd_logits_traj = policy_outputs_traj["forward_logits"] + invalid_fwd_mask = jax.vmap( + current_env.get_invalid_mask, in_axes=(1, None), out_axes=1 + )(current_traj_data.state, current_env_params) + masked_fwd_logits_traj = gfnx.utils.mask_logits( + fwd_logits_traj, invalid_fwd_mask + ) + fwd_all_log_probs_traj = jax.nn.log_softmax(masked_fwd_logits_traj, axis=-1) + fwd_logprobs_traj = jnp.take_along_axis( + fwd_all_log_probs_traj, + jnp.expand_dims(current_traj_data.action, axis=-1), + axis=-1, + ).squeeze(-1) + fwd_logprobs_traj = jnp.where(current_traj_data.pad, 0.0, fwd_logprobs_traj) + sum_log_pf_along_traj = fwd_logprobs_traj.sum(axis=1) + log_pf_traj = logZ_val + sum_log_pf_along_traj + + # Backward log probabilities + prev_states = jax.tree.map(lambda x: x[:, :-1], current_traj_data.state) + fwd_actions = current_traj_data.action[:, :-1] + curr_states = jax.tree.map(lambda x: x[:, 1:], current_traj_data.state) + + bwd_actions_traj = jax.vmap( + current_env.get_backward_action, + in_axes=(1, 1, 1, None), + out_axes=1, + )(prev_states, fwd_actions, curr_states, current_env_params) + + bwd_logits_traj = policy_outputs_traj["backward_logits"] + bwd_logits_for_pb = bwd_logits_traj[:, 1:] + invalid_bwd_mask = jax.vmap( + current_env.get_invalid_backward_mask, + in_axes=(1, None), + out_axes=1, + )(curr_states, current_env_params) + + masked_bwd_logits_traj = gfnx.utils.mask_logits( + bwd_logits_for_pb, invalid_bwd_mask + ) + bwd_all_log_probs_traj = jax.nn.log_softmax(masked_bwd_logits_traj, axis=-1) + log_pb_selected = jnp.take_along_axis( + bwd_all_log_probs_traj, + jnp.expand_dims(bwd_actions_traj, axis=-1), + axis=-1, + ).squeeze(-1) + + pad_mask_for_bwd = current_traj_data.pad[:, :-1] + log_pb_selected = jnp.where(pad_mask_for_bwd, 0.0, log_pb_selected) + + log_rewards_at_steps = current_traj_data.log_gfn_reward[:, :-1] + masked_log_rewards_at_steps = jnp.where( + pad_mask_for_bwd, 0.0, log_rewards_at_steps + ) + + log_pb_plus_rewards_along_traj = ( + log_pb_selected + masked_log_rewards_at_steps + ) + target = jnp.sum(log_pb_plus_rewards_along_traj, axis=1) + + loss = optax.losses.squared_error(log_pf_traj, target).mean() + return loss + + params_for_loss = {"model_params": policy_params, "logZ": train_state.logZ} + mean_loss, grads = eqx.filter_value_and_grad(loss_fn)( + params_for_loss, policy_static, traj_data, env, env_params + ) + + optax_params_for_update = { + "model_params": policy_params, + "logZ": train_state.logZ, + } + updates, new_opt_state = train_state.optimizer.update( + grads, train_state.opt_state, optax_params_for_update + ) + + new_model = eqx.apply_updates(train_state.model, updates["model_params"]) + new_logZ = eqx.apply_updates(train_state.logZ, updates["logZ"]) + + return train_state._replace( + rng_key=rng_key, + model=new_model, + logZ=new_logZ, + opt_state=new_opt_state, + ) + + return train_step diff --git a/benchmark/lib_runners/torchgfn_runner.py b/benchmark/lib_runners/torchgfn_runner.py new file mode 100644 index 00000000..93ce342d --- /dev/null +++ b/benchmark/lib_runners/torchgfn_runner.py @@ -0,0 +1,361 @@ +"""TorchGFN library runner for benchmarking. + +Supports multiple environments: hypergrid, ising, box, bitseq. +""" + +from typing import Optional + +import torch + +from benchmark.lib_runners.base import BenchmarkConfig, LibraryRunner +from gfn.utils.common import set_seed + + +class TorchGFNRunner(LibraryRunner): + """Benchmark runner for the torchgfn library. + + Supports environments: + - hypergrid: Discrete grid navigation + - ising: Discrete EBM with Ising model energy + - box: Continuous 2D box environment + - bitseq: Bit sequence generation + """ + + name = "torchgfn" + + def __init__(self): + self.env = None + self.gflownet = None + self.optimizer = None + self.sampler = None + self.device = None + self.config = None + self._env_type = None + + def setup(self, config: BenchmarkConfig, seed: int) -> None: + """Initialize environment, model, and optimizer based on env_name.""" + set_seed(seed) + + self.config = config + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Reset CUDA memory stats if on GPU + if self.device.type == "cuda": + torch.cuda.reset_peak_memory_stats(self.device) + + self._env_type = config.env_name + + if config.env_name == "hypergrid": + self._setup_hypergrid(config) + elif config.env_name == "ising": + self._setup_ising(config) + elif config.env_name == "box": + self._setup_box(config) + elif config.env_name == "bitseq": + self._setup_bitseq(config) + else: + raise ValueError(f"Unknown environment: {config.env_name}") + + def _setup_hypergrid(self, config: BenchmarkConfig) -> None: + """Setup HyperGrid environment with TB loss.""" + from gfn.estimators import DiscretePolicyEstimator + from gfn.gflownet import TBGFlowNet + from gfn.gym import HyperGrid + from gfn.preprocessors import KHotPreprocessor + from gfn.samplers import Sampler + from gfn.utils.modules import MLP + + ndim = config.env_kwargs["ndim"] + height = config.env_kwargs["height"] + + self.env = HyperGrid( + ndim=ndim, + height=height, + reward_fn_str="original", + reward_fn_kwargs={"R0": 0.1, "R1": 0.5, "R2": 2.0}, + device=self.device, # type: ignore + calculate_partition=False, + store_all_states=False, + debug=False, + ) + + preprocessor = KHotPreprocessor(height=self.env.height, ndim=self.env.ndim) + + module_PF = MLP( + input_dim=preprocessor.output_dim, + output_dim=self.env.n_actions, + hidden_dim=config.hidden_dim, + n_hidden_layers=config.n_layers, + ) + module_PB = MLP( + input_dim=preprocessor.output_dim, + output_dim=self.env.n_actions - 1, + hidden_dim=config.hidden_dim, + n_hidden_layers=config.n_layers, + trunk=module_PF.trunk, + ) + + pf_estimator = DiscretePolicyEstimator( + module_PF, self.env.n_actions, preprocessor=preprocessor, is_backward=False + ) + pb_estimator = DiscretePolicyEstimator( + module_PB, self.env.n_actions, preprocessor=preprocessor, is_backward=True + ) + + self.gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator, init_logZ=0.0) + self.gflownet = self.gflownet.to(self.device) + + self.optimizer = torch.optim.Adam(self.gflownet.pf_pb_parameters(), lr=config.lr) + self.optimizer.add_param_group( + {"params": self.gflownet.logz_parameters(), "lr": config.lr_logz} + ) + + self.sampler = Sampler(estimator=pf_estimator) + + def _setup_ising(self, config: BenchmarkConfig) -> None: + """Setup Ising model environment with FM loss.""" + from gfn.estimators import DiscretePolicyEstimator + from gfn.gflownet import FMGFlowNet + from gfn.gym import DiscreteEBM + from gfn.gym.discrete_ebm import IsingModel + from gfn.utils.modules import MLP + + L = config.env_kwargs["L"] + J_coupling = config.env_kwargs["J"] + + # Build Ising coupling matrix with periodic boundary conditions + J = self._make_ising_J(L, J_coupling) + + N = L**2 + ising_energy = IsingModel(J) + self.env = DiscreteEBM( + N, + alpha=1, + energy=ising_energy, + device=self.device, # type: ignore + debug=False, + ) + + pf_module = MLP( + input_dim=self.env.ndim, + output_dim=self.env.n_actions, + hidden_dim=config.hidden_dim, + n_hidden_layers=config.n_layers, + activation_fn="relu", + ) + + pf_estimator = DiscretePolicyEstimator( + pf_module, self.env.n_actions, is_backward=False + ) + self.gflownet = FMGFlowNet(pf_estimator).to(self.device) + self.optimizer = torch.optim.Adam(self.gflownet.parameters(), lr=config.lr) + + # FMGFlowNet samples trajectories directly + self.sampler = None + + def _make_ising_J(self, L: int, coupling_constant: float) -> torch.Tensor: + """Build Ising coupling matrix with periodic boundary conditions.""" + + def ising_n_to_ij(L, n): + i = n // L + j = n - i * L + return (i, j) + + N = L**2 + J = torch.zeros((N, N), device=self.device) + for k in range(N): + for m in range(k): + x1, y1 = ising_n_to_ij(L, k) + x2, y2 = ising_n_to_ij(L, m) + if x1 == x2 and abs(y2 - y1) == 1: + J[k][m] = 1 + J[m][k] = 1 + elif y1 == y2 and abs(x2 - x1) == 1: + J[k][m] = 1 + J[m][k] = 1 + + # Periodic boundary conditions + for k in range(L): + J[k * L][(k + 1) * L - 1] = 1 + J[(k + 1) * L - 1][k * L] = 1 + J[k][k + N - L] = 1 + J[k + N - L][k] = 1 + + return coupling_constant * J + + def _setup_box(self, config: BenchmarkConfig) -> None: + """Setup continuous Box environment with TB loss.""" + from gfn.gflownet import TBGFlowNet + from gfn.gym import Box + from gfn.gym.helpers.box_utils import ( + BoxPBEstimator, + BoxPBMLP, + BoxPFEstimator, + BoxPFMLP, + ) + from gfn.samplers import Sampler + + delta = config.env_kwargs.get("delta", 0.25) + + self.env = Box( + delta=delta, + epsilon=1e-10, + device=self.device, # type: ignore + debug=False, + ) + + # Box environment uses specialized policy modules + pf_module = BoxPFMLP( + hidden_dim=config.hidden_dim, + n_hidden_layers=config.n_layers, + n_components=2, + n_components_s0=4, + ) + pb_module = BoxPBMLP( + hidden_dim=config.hidden_dim, + n_hidden_layers=config.n_layers, + n_components=2, + trunk=pf_module.trunk, # Tied weights + ) + + pf_estimator = BoxPFEstimator( + self.env, + pf_module, + n_components_s0=4, + n_components=2, + min_concentration=0.1, + max_concentration=5.1, + ) + pb_estimator = BoxPBEstimator( + self.env, + pb_module, + n_components=2, + min_concentration=0.1, + max_concentration=5.1, + ) + + self.gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator) + self.gflownet = self.gflownet.to(self.device) + + # Optimizer with separate learning rates + self.optimizer = torch.optim.Adam(pf_module.parameters(), lr=config.lr) + self.optimizer.add_param_group( + {"params": pb_module.last_layer.parameters(), "lr": config.lr} + ) + if "logZ" in dict(self.gflownet.named_parameters()): + logZ = dict(self.gflownet.named_parameters())["logZ"] + self.optimizer.add_param_group({"params": [logZ], "lr": config.lr_logz}) + + self.sampler = Sampler(estimator=pf_estimator) + + def _setup_bitseq(self, config: BenchmarkConfig) -> None: + """Setup BitSequence environment with TB loss.""" + from gfn.estimators import DiscretePolicyEstimator + from gfn.gflownet import TBGFlowNet + from gfn.gym import BitSequence + from gfn.samplers import Sampler + from gfn.utils.modules import MLP + + word_size = config.env_kwargs.get("word_size", 1) + seq_size = config.env_kwargs.get("seq_size", 4) + n_modes = config.env_kwargs.get("n_modes", 2) + + # Generate random mode set + H = torch.randint( + 0, 2, (n_modes, seq_size), dtype=torch.long, device=self.device + ) + self.env = BitSequence(word_size, seq_size, n_modes, H=H, debug=False) + + pf = MLP( + self.env.words_per_seq, self.env.n_actions, hidden_dim=config.hidden_dim + ) + pb = MLP(self.env.words_per_seq, self.env.n_actions - 1, trunk=pf.trunk) + + pf_estimator = DiscretePolicyEstimator(pf, n_actions=self.env.n_actions) + pb_estimator = DiscretePolicyEstimator( + pb, n_actions=self.env.n_actions, is_backward=True + ) + + self.gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator, init_logZ=0.0).to( + self.device + ) + + non_logz_params = [ + v for k, v in dict(self.gflownet.named_parameters()).items() if k != "logZ" + ] + self.optimizer = torch.optim.Adam(non_logz_params, lr=config.lr) + logz_params = [dict(self.gflownet.named_parameters())["logZ"]] + self.optimizer.add_param_group({"params": logz_params, "lr": config.lr_logz}) + + self.sampler = Sampler(estimator=pf_estimator) + + def warmup(self, n_iters: int) -> None: + """Run warmup iterations for CUDA kernel caching.""" + for _ in range(n_iters): + self.run_iteration() + self.synchronize() + + def run_iteration(self) -> None: + """Run a single training iteration.""" + if self._env_type == "ising": + # FMGFlowNet uses direct sampling + trajectories = self.gflownet.sample_trajectories( + self.env, # type: ignore + n=self.config.batch_size, + save_estimator_outputs=False, + save_logprobs=False, + ) + training_samples = self.gflownet.to_training_samples(trajectories) + self.optimizer.zero_grad() + loss = self.gflownet.loss( + self.env, # type: ignore + training_samples, # type: ignore + recalculate_all_logprobs=True, + ) + loss.backward() + self.optimizer.step() + else: + # TBGFlowNet with Sampler + trajectories = self.sampler.sample_trajectories( + self.env, # type: ignore + n=self.config.batch_size, + save_logprobs=self._env_type == "box", # Box saves logprobs + save_estimator_outputs=False, + epsilon=0.0, + ) + + self.optimizer.zero_grad() + loss = self.gflownet.loss_from_trajectories( + self.env, # type: ignore + trajectories, + recalculate_all_logprobs=(self._env_type != "box"), + ) + loss.backward() + torch.nn.utils.clip_grad_norm_(self.gflownet.parameters(), 1.0) + self.optimizer.step() + + def synchronize(self) -> None: + """Ensure all CUDA operations are complete.""" + if self.device.type == "cuda": + torch.cuda.synchronize(self.device) + + def get_peak_memory(self) -> Optional[int]: + """Return peak GPU memory usage in bytes.""" + if self.device.type == "cuda": + return torch.cuda.max_memory_allocated(self.device) + return None + + def cleanup(self) -> None: + """Release resources.""" + del self.gflownet + del self.optimizer + del self.sampler + del self.env + + if self.device is not None and self.device.type == "cuda": + torch.cuda.empty_cache() + + self.gflownet = None + self.optimizer = None + self.sampler = None + self.env = None diff --git a/benchmark/sanity_check.py b/benchmark/sanity_check.py new file mode 100644 index 00000000..4dd705c2 --- /dev/null +++ b/benchmark/sanity_check.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python + +import time + +import jax +import jax.numpy as jnp +import torch + +# --- Settings --- +# 4096 is a sweet spot: large enough to require GPU power, +# but fits easily in memory. +N = 4096 +LOOPS = 50 + +print(f"--- Benchmark Config: {N}x{N} Matrix Multiplication ---") +print(f"JAX Devices: {jax.devices()}") +print(f"PyTorch Version: {torch.__version__}") + +# --- 1. JAX Setup --- +key = jax.random.PRNGKey(0) +jax_x = jax.random.normal(key, (N, N)) +jax_y = jax.random.normal(key, (N, N)) + + +def matmul_fn(a, b): + return jnp.dot(a, b) + + +jit_matmul = jax.jit(matmul_fn) + +# --- 2. JAX EAGER MODE --- +print("\n--- 1. JAX Eager (No JIT) ---") +_ = matmul_fn(jax_x, jax_y).block_until_ready() # Warmup +start = time.time() +for _ in range(LOOPS): + _ = matmul_fn(jax_x, jax_y).block_until_ready() +jax_eager_time = (time.time() - start) / LOOPS +print(f"Time: {jax_eager_time:.4f} s") + +# --- 3. JAX JIT MODE --- +print("\n--- 2. JAX JIT (Compiled) ---") +print("Compiling...") +_ = jit_matmul(jax_x, jax_y).block_until_ready() # Compilation triggers here +start = time.time() +for _ in range(LOOPS): + _ = jit_matmul(jax_x, jax_y).block_until_ready() +jax_jit_time = (time.time() - start) / LOOPS +print(f"Time: {jax_jit_time:.4f} s") + +# --- 4. PYTORCH MPS (GPU) --- +print("\n--- 3. PyTorch MPS (Apple GPU) ---") +if torch.backends.mps.is_available(): + dev_mps = torch.device("mps") + x_mps = torch.randn(N, N, device=dev_mps) + y_mps = torch.randn(N, N, device=dev_mps) + + # Warmup + torch.mm(x_mps, y_mps) + torch.mps.synchronize() + + start = time.time() + for _ in range(LOOPS): + torch.mm(x_mps, y_mps) + torch.mps.synchronize() # Critical for fair timing + torch_mps_time = (time.time() - start) / LOOPS + print(f"Time: {torch_mps_time:.4f} s") +else: + torch_mps_time = None + print("MPS not available.") + +# --- 5. PYTORCH CPU --- +print("\n--- 4. PyTorch CPU ---") +dev_cpu = torch.device("cpu") +x_cpu = torch.randn(N, N, device=dev_cpu) +y_cpu = torch.randn(N, N, device=dev_cpu) + +# Warmup +torch.mm(x_cpu, y_cpu) + +start = time.time() +for _ in range(LOOPS): + torch.mm(x_cpu, y_cpu) + # CPU is synchronous by default, no special sync needed +torch_cpu_time = (time.time() - start) / LOOPS +print(f"Time: {torch_cpu_time:.4f} s") + +# --- SUMMARY TABLE --- +print("\n" + "=" * 30) +print(f"{'Method':<20} | {'Time (s)':<10} | {'Rel Speed'}") +print("-" * 45) + +# Use JAX JIT as the baseline (1.0x) +baseline = jax_jit_time + + +def fmt_speed(t): + if t is None: + return "N/A" + return f"{baseline / t:.2f}x" + + +print(f"{'JAX JIT':<20} | {jax_jit_time:.4f} | 1.00x (Baseline)") +print(f"{'JAX Eager':<20} | {jax_eager_time:.4f} | {fmt_speed(jax_eager_time)}") +if torch_mps_time: + print( + f"{'PyTorch MPS':<20} | {torch_mps_time:.4f} | {fmt_speed(torch_mps_time)}" + ) +print(f"{'PyTorch CPU':<20} | {torch_cpu_time:.4f} | {fmt_speed(torch_cpu_time)}") +print("=" * 30) diff --git a/pyproject.toml b/pyproject.toml index b226cba3..aa489a6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -142,6 +142,13 @@ exclude = [ "**/.*", # Exclude dot files and folders ] +# strictly disable the checks for None handling +reportOptionalMemberAccess = "none" # e.g., variable.method() +reportOptionalSubscript = "none" # e.g., variable['key'] +reportOptionalCall = "none" # e.g., variable() +reportOptionalOperand = "none" # e.g., variable + 1 +reportOptionalIterable = "none" # e.g., for x in variable: + strict = [ ]