From f9400f960008a1326573d693bb6b2d16ee3d904f Mon Sep 17 00:00:00 2001 From: shirin shahabi <71061824+shirin-shahabi@users.noreply.github.com> Date: Tue, 2 Dec 2025 15:46:04 -0500 Subject: [PATCH 1/4] feat: Add JSTprove backend integration - Add JSTprove backend (dsperse/src/backends/JSTprove.py) with CLI-based compilation, witness generation, proving, and verification - Update compiler to support per-layer backend specifications and fallback logic (jstprove -> ezkl -> onnx) - Add JSTprove installation to install.sh (Open MPI + JSTprove CLI via uv) - Add mpi4py to requirements.txt for JSTprove support - Update compiler_utils to handle JSTprove compilation success checks (no payload subdirectory) - Add JSTPROVE_COMMAND constant to constants.py - Add JSTPROVE_BACKEND.md documentation - Fix: Enable fallback for unspecified layers in per-layer backend specs - Fix: Compile all layers when no --layers flag specified (default fallback behavior) --- docs/JSTPROVE_BACKEND.md | 105 +++++ dsperse/src/backends/JSTprove.py | 469 ++++++++++++++++++++ dsperse/src/cli/compile.py | 22 +- dsperse/src/compile/compiler.py | 240 ++++++++-- dsperse/src/compile/utils/compiler_utils.py | 18 +- dsperse/src/constants.py | 6 +- install.sh | 123 ++++- requirements.txt | 4 +- 8 files changed, 945 insertions(+), 42 deletions(-) create mode 100644 docs/JSTPROVE_BACKEND.md create mode 100644 dsperse/src/backends/JSTprove.py diff --git a/docs/JSTPROVE_BACKEND.md b/docs/JSTPROVE_BACKEND.md new file mode 100644 index 0000000..4c37e07 --- /dev/null +++ b/docs/JSTPROVE_BACKEND.md @@ -0,0 +1,105 @@ +# JSTprove Backend Integration + +## Overview + +This document describes the integration of JSTprove as an additional ZK proof backend alongside EZKL in the Dsperse compilation pipeline. + +## Features + +### 1. JSTprove Backend Support +- New backend class `JSTprove` in `dsperse/src/backends/JSTprove.py` +- Uses JSTprove CLI (`jst` command) for circuit compilation, witness generation, proof generation, and verification +- Compatible with existing EZKL interface for seamless integration + +### 2. Flexible Backend Selection +The compiler now supports three modes: + +**Default (Fallback Mode):** +```bash +dsperse compile --path model/slices +``` +- Tries JSTprove first +- Falls back to EZKL if JSTprove fails +- Falls back to ONNX (skip ZK compilation) if both fail + +**Single Backend:** +```bash +dsperse compile --path model/slices --backend jstprove +dsperse compile --path model/slices --backend ezkl +``` + +**Per-Layer Backend Assignment:** +```bash +dsperse compile --path model/slices --backend "0,2:jstprove;3-4:ezkl" +``` +- Layer 0 and 2: Use JSTprove +- Layer 3 and 4: Use EZKL +- Unspecified layers use default backend + +## Installation + +1. Install Open MPI (required for JSTprove): + ```bash + brew install open-mpi # macOS + # or apt-get install openmpi-bin libopenmpi-dev # Linux + ``` + +2. Install JSTprove: + ```bash + uv tool install jstprove + # or: pip install jstprove + ``` + +3. Verify installation: + ```bash + jst --help + ``` + +The `install.sh` script has been updated to automatically install these dependencies. + +## File Changes + +### New Files +- `dsperse/src/backends/JSTprove.py` - JSTprove backend implementation + +### Modified Files +- `dsperse/src/cli/compile.py` - Added `--backend` argument +- `dsperse/src/compile/compiler.py` - Backend selection and fallback logic +- `dsperse/src/compile/utils/compiler_utils.py` - Support for JSTprove compilation success check +- `dsperse/src/constants.py` - Added JSTprove command constant +- `install.sh` - Added Open MPI and JSTprove installation +- `requirements.txt` - Added mpi4py dependency + +## Usage Examples + +**Compile all layers with default fallback:** +```bash +dsperse compile --path model/slices +``` + +**Compile specific layers with mixed backends:** +```bash +dsperse compile --path model/slices --layers "0-4" --backend "0,2:jstprove;3-4:ezkl" +``` + +**Compile with single backend:** +```bash +dsperse compile --path model/slices --backend jstprove +``` + +## Backend Comparison + +| Feature | JSTprove | EZKL | +|---------|----------|------| +| Circuit Format | `.txt` | `.compiled` | +| Keys | Not required | `vk.key`, `pk.key` | +| Settings | Dummy JSON | Full settings.json | +| CLI Command | `jst` | `ezkl` | + +## Notes + +- JSTprove uses CLI-only interface (no Python package import) +- Fallback logic ensures compilation continues even if preferred backend fails +- Metadata tracks which backend was used for each slice +- All changes maintain backward compatibility with existing EZKL workflows + diff --git a/dsperse/src/backends/JSTprove.py b/dsperse/src/backends/JSTprove.py new file mode 100644 index 0000000..ff008a1 --- /dev/null +++ b/dsperse/src/backends/JSTprove.py @@ -0,0 +1,469 @@ +""" +JSTprove backend for zero-knowledge proof generation. +This module provides a backend for generating ZK proofs using the JSTprove CLI. +""" +import json +import os +import subprocess +import torch +import logging +from pathlib import Path +from typing import Optional, Tuple, Dict, Any, Union, List + +from dsperse.src.constants import JSTPROVE_COMMAND + +# Configure logger +logger = logging.getLogger(__name__) + + +class JSTprove: + """JSTprove backend for zero-knowledge proof generation using the JSTprove CLI.""" + + # Class constants + COMMAND = JSTPROVE_COMMAND + DEFAULT_FLAGS = ["--no-banner"] + + def __init__(self, model_directory: Optional[str] = None) -> None: + """ + Initialize the JSTprove backend. + + Args: + model_directory: Optional path to the model directory for organizing artifacts. + + Raises: + RuntimeError: If JSTprove CLI is not available + """ + self.env = os.environ.copy() + self.model_directory = Path(model_directory) if model_directory else None + self._witness_format = "jstprove" # Track witness output format + + # Check if JSTprove CLI is available + try: + result = subprocess.run( + [self.COMMAND, "--help"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + if result.returncode != 0: + raise RuntimeError("JSTprove CLI not found. Please install JSTprove first.") + except FileNotFoundError: + raise RuntimeError("JSTprove CLI not found. Please install JSTprove: uv tool install jstprove") + + def _run_command( + self, + subcommand: str, + args: List[str], + check: bool = True, + capture_output: bool = True, + ) -> subprocess.CompletedProcess: + """ + Execute a JSTprove CLI command. + + Args: + subcommand: The jst subcommand (compile, witness, prove, verify) + args: Additional arguments for the subcommand + check: Whether to check return code + capture_output: Whether to capture output + + Returns: + subprocess.CompletedProcess: The completed process + + Raises: + RuntimeError: If command fails + """ + cmd = [self.COMMAND] + self.DEFAULT_FLAGS + [subcommand] + args + try: + logger.debug(f"Running JSTprove command: {' '.join(cmd)}") + process = subprocess.run( + cmd, + env=self.env, + check=check, + capture_output=capture_output, + text=True, + ) + return process + except subprocess.CalledProcessError as e: + error_msg = f"JSTprove command failed: {' '.join(cmd)}" + if e.stderr: + error_msg += f"\nError output: {e.stderr}" + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + # + # High-level methods that dispatch to specific implementations + # + + def generate_witness( + self, + input_file: Union[str, Path], + model_path: Union[str, Path], # This is the circuit path in JSTprove context + output_file: Union[str, Path], + vk_path: Optional[Union[str, Path]] = None, # Kept for backward compatibility but not used + settings_path: Optional[Union[str, Path]] = None # Kept for backward compatibility but not used + ) -> Tuple[bool, Any]: + """ + Generate a witness for the given circuit and input using JSTprove. + + Args: + input_file: Path to the input JSON file + model_path: Path to the compiled circuit file (called model_path for interface compatibility) + output_file: Path where to save the model outputs JSON + vk_path: Ignored (kept for backward compatibility) + settings_path: Ignored (kept for backward compatibility) + + Returns: + Tuple of (success: bool, output: Any) where output is the processed witness data + """ + # Normalize paths + input_file = Path(input_file) + circuit_path = Path(model_path) # model_path is actually the circuit path + output_file = Path(output_file) + witness_path = output_file.parent / f"{output_file.stem}_witness.bin" + + # Validate required files exist + if not input_file.exists(): + raise FileNotFoundError(f"Input file not found: {input_file}") + + # Check if we have an ONNX model that needs compilation, or an existing circuit + onnx_model_path = None + if circuit_path.exists() and circuit_path.suffix == '.onnx': + onnx_model_path = circuit_path + circuit_path = circuit_path.parent / f"{circuit_path.stem}_jstprove_circuit.txt" + logger.info(f"JSTprove: Compiling ONNX model {onnx_model_path} to circuit {circuit_path}") + + # If we have an ONNX model, compile it first + if onnx_model_path: + ok, err = self.compile_circuit(onnx_model_path, circuit_path) + if not ok: + raise RuntimeError(f"Circuit compilation failed: {err}") + elif not circuit_path.exists(): + raise FileNotFoundError(f"Circuit file not found: {circuit_path}") + + # Create output directories if they don't exist + output_file.parent.mkdir(parents=True, exist_ok=True) + witness_path.parent.mkdir(parents=True, exist_ok=True) + + try: + self._run_command("witness", [ + "-c", str(circuit_path), + "-i", str(input_file), + "-o", str(output_file), + "-w", str(witness_path), + ]) + except RuntimeError as e: + error_msg = f"Witness generation failed: {e}" + logger.error(error_msg) + return False, error_msg + + # Process the outputs + try: + with open(output_file, "r") as f: + output_data = json.load(f) + processed_output = self.process_witness_output(output_data) + return True, processed_output + except (json.JSONDecodeError, FileNotFoundError) as e: + error_msg = f"Failed to process witness output: {e}" + logger.error(error_msg) + return False, error_msg + + def prove( + self, + witness_path: Union[str, Path], + circuit_path: Union[str, Path], + proof_path: Union[str, Path], + pk_path: Optional[Union[str, Path]] = None, # Kept for backward compatibility but not used + check_mode: str = "unsafe", # Kept for backward compatibility but not used + settings_path: Optional[Union[str, Path]] = None # Kept for backward compatibility but not used + ) -> Tuple[bool, Union[str, Path]]: + """ + Generate a proof for the given witness and circuit using JSTprove. + + Args: + witness_path: Path to the witness file + circuit_path: Path to the compiled circuit + proof_path: Path where to save the proof + pk_path: Ignored (kept for backward compatibility) + check_mode: Ignored (kept for backward compatibility) + settings_path: Ignored (kept for backward compatibility) + + Returns: + Tuple of (success: bool, results: Union[str, Path]) where results is the proof path + """ + # Normalize paths + witness_path = Path(witness_path) + circuit_path = Path(circuit_path) + proof_path = Path(proof_path) + + # Validate required files exist + if not witness_path.exists(): + raise FileNotFoundError(f"Witness file not found: {witness_path}") + if not circuit_path.exists(): + raise FileNotFoundError(f"Circuit file not found: {circuit_path}") + + # Create output directory if it doesn't exist + proof_path.parent.mkdir(parents=True, exist_ok=True) + + try: + self._run_command("prove", [ + "-c", str(circuit_path), + "-w", str(witness_path), + "-p", str(proof_path), + ]) + except RuntimeError as e: + error_msg = f"Proof generation failed: {e}" + logger.error(error_msg) + return False, error_msg + + return True, proof_path + + def verify( + self, + proof_path: Union[str, Path], + circuit_path: Union[str, Path], + input_path: Union[str, Path], + output_path: Union[str, Path], + witness_path: Union[str, Path], + settings_path: Optional[Union[str, Path]] = None, # Kept for backward compatibility but not used + vk_path: Optional[Union[str, Path]] = None # Kept for backward compatibility but not used + ) -> bool: + """ + Verify a proof using JSTprove. + + Args: + proof_path: Path to the proof file + circuit_path: Path to the compiled circuit + input_path: Path to the input JSON used for the proof + output_path: Path to the expected outputs JSON + witness_path: Path to the witness file + settings_path: Ignored (kept for backward compatibility) + vk_path: Ignored (kept for backward compatibility) + + Returns: + True if verification succeeded, False otherwise + """ + # Normalize paths + proof_path = Path(proof_path) + circuit_path = Path(circuit_path) + input_path = Path(input_path) + output_path = Path(output_path) + witness_path = Path(witness_path) + + # Validate required files exist + required_files = [proof_path, circuit_path, input_path, output_path, witness_path] + for file_path in required_files: + if not file_path.exists(): + raise FileNotFoundError(f"Required file not found: {file_path}") + + try: + self._run_command("verify", [ + "-c", str(circuit_path), + "-i", str(input_path), + "-o", str(output_path), + "-w", str(witness_path), + "-p", str(proof_path), + ]) + return True + except RuntimeError as e: + logger.error(f"Proof verification failed: {e}") + return False + + def compile_circuit( + self, + model_path: Union[str, Path], + circuit_path: Union[str, Path], + settings_path: Optional[Union[str, Path]] = None # Kept for backward compatibility but not used + ) -> Tuple[bool, Optional[str]]: + """ + Compile a circuit from an ONNX model using JSTprove. + + Args: + model_path: Path to the original ONNX model + circuit_path: Path where to save the compiled circuit + settings_path: Ignored (kept for backward compatibility) + + Returns: + Tuple of (success: bool, error: Optional[str]) + """ + # Normalize paths + model_path = Path(model_path) + circuit_path = Path(circuit_path) + + # Validate required files exist + if not model_path.exists(): + raise FileNotFoundError(f"Model file not found: {model_path}") + + # Create output directory if it doesn't exist + circuit_path.parent.mkdir(parents=True, exist_ok=True) + + try: + self._run_command("compile", [ + "-m", str(model_path), + "-c", str(circuit_path), + ]) + return True, None + except Exception as e: + error_msg = f"Circuit compilation failed: {e}" + logger.error(error_msg) + return False, error_msg + + def circuitization_pipeline( + self, + model_path: Union[str, Path], + output_path: Union[str, Path], + input_file_path: Optional[Union[str, Path]] = None, + segment_details: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Run the JSTprove circuitization pipeline. + + In JSTprove, circuitization is a single step that compiles the model into a circuit. + The compile command handles all the necessary setup internally. + + Args: + model_path: Path to the ONNX model file. + output_path: Base path for output files. + input_file_path: Ignored (kept for backward compatibility). + segment_details: Ignored (kept for backward compatibility). + + Returns: + Dictionary containing paths to generated files and any error information. + """ + # Normalize paths + model_path = Path(model_path) + output_path = Path(output_path) + + # Ensure model_path exists + if not model_path.exists(): + raise FileNotFoundError(f"Model file not found: {model_path}") + + # Create output directory + output_path.mkdir(parents=True, exist_ok=True) + + model_name = model_path.stem + + # Define file paths (JSTprove outputs circuit and quantized model) + circuit_path = output_path / f"{model_name}_circuit.txt" + quantized_model_path = output_path / f"{model_name}_circuit_quantized_model.onnx" + witness_solver_path = output_path / f"{model_name}_circuit_witness_solver.txt" + + # Create dummy settings file for compatibility with runner analyzer + settings_path = output_path / f"{model_name}_settings.json" + + # Initialize circuitization data dictionary (match EZKL structure for compatibility) + circuitization_data: Dict[str, Any] = { + "compiled": str(circuit_path), # This is what runner_analyzer looks for + "circuit": str(circuit_path), + "quantized_model": str(quantized_model_path), + "witness_solver": str(witness_solver_path), + "calibration": input_file_path, + # Create dummy settings file for runner analyzer compatibility + "settings": str(settings_path), + # JSTprove doesn't use vk, pk in the same way as EZKL + "vk_key": None, + "pk_key": None, + } + + try: + logger.info(f"Compiling circuit for {model_name}") + # JSTprove compile command handles everything in one step + ok, err = self.compile_circuit( + model_path=model_path, + circuit_path=circuit_path, + ) + if not ok: + logger.warning("Failed to compile circuit") + circuitization_data["compile_error"] = err + else: + # Create dummy settings file for runner analyzer compatibility + dummy_settings = { + "backend": "jstprove", + "model_path": str(model_path), + "circuit_path": str(circuit_path), + "compiled_at": str(output_path), + "note": "This is a dummy settings file for dsperse compatibility. JSTprove handles settings internally." + } + with open(settings_path, 'w') as f: + json.dump(dummy_settings, f, indent=2) + logger.info(f"Circuitization pipeline completed for {model_path}") + except Exception as e: + error_msg = f"Error during circuitization: {str(e)}" + logger.exception(error_msg) + circuitization_data["error"] = error_msg + + return circuitization_data + + # Alias for backward compatibility with EZKL interface + compilation_pipeline = circuitization_pipeline + + def process_witness_output(self, witness_data: Any) -> Optional[Dict[str, Any]]: + """ + Process the witness output data to get prediction results. + + This method handles JSTprove witness output format. JSTprove outputs + a raw array of floats representing the final logits. + + Args: + witness_data: The parsed JSON data from witness output. + + Returns: + Dictionary containing processed predictions, or None if processing fails. + """ + def _to_logits(data) -> torch.Tensor: + """Helper to convert data to logits tensor with batch dimension.""" + logits = torch.tensor(data) + if logits.dim() == 1: + logits = logits.unsqueeze(0) + return logits + + try: + # JSTprove dict format with 'rescaled_output' key + if isinstance(witness_data, dict) and "rescaled_output" in witness_data: + self._witness_format = "jstprove_dict" + return {"logits": _to_logits(witness_data["rescaled_output"])} + # Raw array format + elif isinstance(witness_data, list): + self._witness_format = "jstprove_list" + return {"logits": _to_logits(witness_data)} + # EZKL-like format fallback + else: + self._witness_format = "ezkl_compat" + rescaled = witness_data["pretty_elements"]["rescaled_outputs"][0] + return {"logits": _to_logits(rescaled)} + except (KeyError, TypeError) as e: + logger.error(f"Could not process witness data: {e}") + return None + + @classmethod + def get_version(cls) -> Optional[str]: + """ + Get the JSTprove version. + + Returns: + str: JSTprove version string, or None if version cannot be determined + """ + try: + result = subprocess.run( + [cls.COMMAND, "--version"], + capture_output=True, + text=True, + timeout=5 + ) + if result.returncode == 0: + # Parse version from output + version_output = result.stdout.strip() or result.stderr.strip() + return version_output + except Exception as e: + logger.debug(f"Could not get JSTprove version: {e}") + return None + + +if __name__ == "__main__": + # Example usage with JSTprove + print("JSTprove backend example:") + print("backend = JSTprove()") + print("backend.compile_circuit('model.onnx', 'circuit.txt')") + print("backend.generate_witness('input.json', 'circuit.txt', 'output.json')") + print("backend.prove('witness.bin', 'circuit.txt', 'proof.bin')") + print("backend.verify('proof.bin', 'circuit.txt', 'input.json', 'output.json', 'witness.bin')") + diff --git a/dsperse/src/cli/compile.py b/dsperse/src/cli/compile.py index a6cae85..5561c78 100644 --- a/dsperse/src/cli/compile.py +++ b/dsperse/src/cli/compile.py @@ -115,6 +115,9 @@ def setup_parser(subparsers): compile_parser.add_argument('--input-file', '--input', '--if', '-i', dest='input_file', help='Path to input file for calibration (optional)') compile_parser.add_argument('--layers', '-l', help='Specify which layers to compile (e.g., "3, 20-22"). If not provided, all layers will be compiled.') + compile_parser.add_argument('--backend', '-b', default=None, + help='Backend specification. Can be: "jstprove", "ezkl", or per-layer like "0,2:jstprove;3-4:ezkl". ' + 'Default: try both jstprove and ezkl, fallback to onnx. Note: requires --layers to compile.') return compile_parser @@ -126,8 +129,21 @@ def compile_model(args): Args: args: The parsed command-line arguments """ - print(f"{Fore.CYAN}Compiling slices with EZKL...{Style.RESET_ALL}") - logger.info("Starting slices compilation") + backend = getattr(args, 'backend', None) + layers = getattr(args, 'layers', None) + + if not layers: + print(f"{Fore.CYAN}No layers specified. Will compile all layers with default fallback (jstprove -> ezkl -> onnx)...{Style.RESET_ALL}") + logger.info("No layers specified - compiling all layers with default fallback") + elif backend: + if ':' in backend: + print(f"{Fore.CYAN}Compiling specified layers with mixed backends...{Style.RESET_ALL}") + else: + backend_name = 'JSTprove' if backend == 'jstprove' else 'EZKL' + print(f"{Fore.CYAN}Compiling specified layers with {backend_name}...{Style.RESET_ALL}") + else: + print(f"{Fore.CYAN}Compiling specified layers (trying jstprove & ezkl, fallback to onnx)...{Style.RESET_ALL}") + logger.info(f"Starting slices compilation") # Resolve path (slices dir or .dsperse/.dslice file) target_path = getattr(args, 'path', None) or getattr(args, 'slices_path', None) @@ -172,7 +188,7 @@ def compile_model(args): # Initialize the Compiler (it supports dirs or model.onnx) try: - compiler = Compiler() + compiler = Compiler(backend=backend) logger.info(f"Compiler initialized successfully") except RuntimeError as e: error_msg = f"Failed to initialize Compiler: {e}" diff --git a/dsperse/src/compile/compiler.py b/dsperse/src/compile/compiler.py index c73e5f2..ba79678 100644 --- a/dsperse/src/compile/compiler.py +++ b/dsperse/src/compile/compiler.py @@ -14,6 +14,7 @@ from typing import Optional, Dict, Any from dsperse.src.backends.ezkl import EZKL +from dsperse.src.backends.JSTprove import JSTprove from dsperse.src.compile.utils.compiler_utils import CompilerUtils from dsperse.src.slice.utils.converter import Converter from dsperse.src.utils.utils import Utils @@ -28,19 +29,114 @@ class Compiler: to the appropriate compiler implementation based on the model type. """ - def __init__(self): + def __init__(self, backend: Optional[str] = None): """ - Initialize the Compiler with a specific implementation. + Initialize the Compiler with a specific backend configuration. Args: - compiler_impl: The compiler implementation to use + backend: Backend specification. Can be: + - None: Use jstprove with fallback to ezkl then onnx (tries jstprove first) + - "jstprove" or "ezkl": Use specific backend for all layers + - "0,2:jstprove;3-4:ezkl": Per-layer backend specification """ - self.ezkl = EZKL() + self.backend_spec = backend + self.layer_backends = {} # Map layer index -> backend name + self.use_fallback = False + + # Parse backend specification + if backend is None: + # Default: use fallback logic (try both jstprove and ezkl, then onnx) + self.default_backend = None # Will try both + self.use_fallback = True + elif ':' in str(backend): + # Per-layer specification like "0,2:jstprove;3-4:ezkl" + # Unspecified layers use default fallback, specified layers try their backend first + self.default_backend = None + self.use_fallback = True # Enable fallback for both specified and unspecified layers + self._parse_layer_backends(backend) + else: + # Simple backend name - no fallback, use only this backend + self.default_backend = backend.lower() + self.use_fallback = False + + # Initialize backends (lazy loading to avoid errors if not used) + self._jstprove = None + self._ezkl = None + + def _parse_layer_backends(self, spec: str): + """Parse layer-specific backend specification like '0,2:jstprove;3-4:ezkl'""" + parts = spec.split(';') + for part in parts: + part = part.strip() + if ':' not in part: + continue + layers_str, backend_name = part.split(':', 1) + backend_name = backend_name.strip().lower() + + # Reuse existing layer parsing utility + layer_indices = CompilerUtils.parse_layers(layers_str) + if layer_indices: + for idx in layer_indices: + self.layer_backends[idx] = backend_name + + def _get_jstprove(self): + """Lazy initialization of JSTprove backend""" + if self._jstprove is None: + try: + self._jstprove = JSTprove() + except Exception as e: + logger.warning(f"Failed to initialize JSTprove: {e}") + return None + return self._jstprove + + def _get_ezkl(self): + """Lazy initialization of EZKL backend""" + if self._ezkl is None: + try: + self._ezkl = EZKL() + except Exception as e: + logger.warning(f"Failed to initialize EZKL: {e}") + return None + return self._ezkl + + def _get_backend_for_layer(self, layer_idx: int): + """Get the backend instance for a specific layer""" + # Check if layer has specific backend assigned + if layer_idx in self.layer_backends: + backend_name = self.layer_backends[layer_idx] + if backend_name == "jstprove": + return self._get_jstprove(), "jstprove" + else: + return self._get_ezkl(), "ezkl" + elif self.default_backend is None: + # Default: try both backends (will be handled in fallback logic) + return None, None + else: + # Simple backend specified + if self.default_backend == "jstprove": + return self._get_jstprove(), "jstprove" + else: + return self._get_ezkl(), "ezkl" + + # Keep backward compatibility properties + @property + def backend(self): + # Return ezkl for backward compatibility + return self._get_ezkl() + + @property + def backend_name(self): + return self.default_backend or "ezkl" + + @property + def ezkl(self): + return self._get_ezkl() - def _compile_slice(self, slice_data, base_path: str): + def _compile_slice(self, slice_data, base_path: str, layer_idx: int = 0): """ - Function for compiling a single slice. + Function for compiling a single slice with fallback support. + Tries jstprove -> ezkl -> onnx (skip) if fallback is enabled. """ slice_path = slice_data.get('path') if slice_path and os.path.exists(slice_path): @@ -54,22 +150,81 @@ def _compile_slice(self, slice_data, base_path: str): logger.error(f"No valid path found for slice") raise FileNotFoundError(f"No valid path found for slice") - slice_output_path = os.path.join(os.path.dirname(slice_path), "ezkl") - - calibration_input = os.path.join( - os.path.dirname(slice_path), - "ezkl", - f"calibration.json" - ) if os.path.exists(os.path.join(os.path.dirname(slice_path), "ezkl", "calibration.json")) else None - - compilation_data = self.ezkl.compilation_pipeline( - slice_path, - slice_output_path, - input_file_path=calibration_input - ) - - success = CompilerUtils.is_ezkl_compilation_successful(compilation_data) - file_paths = CompilerUtils.get_relative_paths(compilation_data, calibration_input) + # Get the backend for this specific layer + backend, backend_name = self._get_backend_for_layer(layer_idx) + + # Build list of backends to try + backends_to_try = [] + if backend is not None: + # Specific backend assigned to this layer + backends_to_try = [(backend, backend_name)] + if self.use_fallback: + # Add fallback: try other backend, then onnx + if backend_name == "jstprove": + ezkl = self._get_ezkl() + if ezkl: + backends_to_try.append((ezkl, "ezkl")) + elif backend_name == "ezkl": + jst = self._get_jstprove() + if jst: + backends_to_try.append((jst, "jstprove")) + backends_to_try.append((None, "onnx")) + elif self.use_fallback: + # No specific backend for this layer, use default fallback chain + # (jstprove -> ezkl -> onnx) + jst = self._get_jstprove() + ezkl = self._get_ezkl() + if jst: + backends_to_try.append((jst, "jstprove")) + if ezkl: + backends_to_try.append((ezkl, "ezkl")) + backends_to_try.append((None, "onnx")) + else: + # No backend specified and no fallback - skip compilation (use pure ONNX) + backends_to_try = [(None, "onnx")] + + success = False + compilation_data = {} + used_backend = None + + for try_backend, try_backend_name in backends_to_try: + if try_backend is None: + # Skip compilation - will use onnx at runtime + logger.info(f"Slice {layer_idx}: Skipping ZK compilation, will use ONNX at runtime") + success = True + used_backend = "onnx" + compilation_data = {"skipped": True, "reason": "fallback_to_onnx"} + break + + backend_dir = try_backend_name + slice_output_path = os.path.join(os.path.dirname(slice_path), backend_dir) + + calibration_input = os.path.join( + os.path.dirname(slice_path), + backend_dir, + f"calibration.json" + ) if os.path.exists(os.path.join(os.path.dirname(slice_path), backend_dir, "calibration.json")) else None + + try: + logger.info(f"Slice {layer_idx}: Trying {try_backend_name}...") + compilation_data = try_backend.compilation_pipeline( + slice_path, + slice_output_path, + input_file_path=calibration_input + ) + success = CompilerUtils.is_ezkl_compilation_successful(compilation_data) + if success: + used_backend = try_backend_name + logger.info(f"Slice {layer_idx}: {try_backend_name} compilation succeeded") + break + else: + logger.warning(f"Slice {layer_idx}: {try_backend_name} compilation failed, trying fallback...") + except Exception as e: + logger.warning(f"Slice {layer_idx}: {try_backend_name} error: {e}, trying fallback...") + if not self.use_fallback: + raise + + file_paths = CompilerUtils.get_relative_paths(compilation_data, calibration_input) if used_backend not in [None, "onnx"] else {} if slice_data.get('slice_metadata') and os.path.exists(slice_data.get('slice_metadata')): path = Path(slice_data.get('slice_metadata')) @@ -78,7 +233,7 @@ def _compile_slice(self, slice_data, base_path: str): path = Path(os.path.join(base_path, slice_data.get('slice_metadata_relative_path'))) CompilerUtils.update_slice_metadata(path, success, file_paths) - return success, file_paths + return success, file_paths, used_backend def _compile_model(self, model_file_path: str, input_file_path: Optional[str] = None) -> str: if not os.path.isfile(model_file_path): @@ -87,8 +242,8 @@ def _compile_model(self, model_file_path: str, input_file_path: Optional[str] = circuit_folder = os.path.join(os.path.dirname(output_path_root), "model") os.makedirs(circuit_folder, exist_ok=True) # Call backend pipeline - self.ezkl.compilation_pipeline(model_file_path, circuit_folder, input_file_path=input_file_path) - logger.info(f"Compilation completed. Output saved to {circuit_folder}") + self.backend.compilation_pipeline(model_file_path, circuit_folder, input_file_path=input_file_path) + logger.info(f"Compilation completed with {self.backend_name}. Output saved to {circuit_folder}") return circuit_folder @@ -108,36 +263,54 @@ def _compile_slices(self, dir_path: str, input_file_path: Optional[str] = None, # Phase 2: Compile layers compiled_count = 0 skipped_count = 0 + backend_stats = {} # Track which backend was used for each slice for idx, slice_data in enumerate(slices_data): if layer_indices is not None and idx not in layer_indices: - logger.info(f"Skipping compilation for slice {idx} as it's not in the specified layers") + logger.info(f"Skipping ZK compilation for slice {idx} (not in specified layers) - will use pure ONNX at runtime") skipped_count += 1 continue logger.info(f"Compiling slice {idx}...") - success, file_paths = self._compile_slice(slice_data, base_path) + success, file_paths, used_backend = self._compile_slice(slice_data, base_path, layer_idx=idx) compiled_count += 1 - logger.info(f"Completed slice {idx}") + backend_stats[idx] = used_backend + logger.info(f"Completed slice {idx} with {used_backend}") + + # Get version for the backend that was actually used + backend_version = None + if used_backend == "jstprove" and self._jstprove: + backend_version = self._jstprove.get_version() if hasattr(self._jstprove, 'get_version') else None + elif used_backend == "ezkl" and self._ezkl: + backend_version = self._ezkl.get_version() if hasattr(self._ezkl, 'get_version') else None comp_block = { "compiled": bool(success), "compilation_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), - "ezkl_version": EZKL.get_version(), + "backend": used_backend, + "backend_version": backend_version, "files": file_paths } - # Attach under 'compilation.ezkl' + # Attach under 'compilation.' if isinstance(slice_data, dict): if 'compilation' not in slice_data or not isinstance(slice_data.get('compilation'), dict): slice_data['compilation'] = {} - slice_data['compilation']['ezkl'] = comp_block + slice_data['compilation'][used_backend] = comp_block # Save model-level metadata (or single slice metadata) Utils.save_metadata_file(metadata, os.path.dirname(metadata_path), os.path.basename(metadata_path)) - logger.info(f"Compilation of slices completed. Compiled {compiled_count} slices, skipped {skipped_count} slices.") + # Log summary + backend_summary = {} + for idx, backend in backend_stats.items(): + backend_summary[backend] = backend_summary.get(backend, 0) + 1 + summary_str = ", ".join(f"{k}: {v}" for k, v in backend_summary.items()) + if skipped_count > 0: + logger.info(f"Compilation completed. ZK compiled: {compiled_count} slices ({summary_str}). Skipped: {skipped_count} slices (will use pure ONNX at runtime)") + else: + logger.info(f"Compilation completed. ZK compiled: {compiled_count} slices. Backends used: {summary_str}") def compile(self, model_path: str, input_file: Optional[str] = None, layers: Optional[str] = None): @@ -159,7 +332,8 @@ def compile(self, model_path: str, input_file: Optional[str] = None, layers: Opt if layer_indices: logger.info(f"Will compile only layers with indices: {layer_indices}") else: - logger.info("Will compile all layers.") + # No layers specified: compile ALL layers with default fallback + logger.info("No layers specified. Will compile all layers with default fallback (jstprove -> ezkl -> onnx).") is_sliced, slice_path, type = CompilerUtils.is_sliced_model(model_path) if is_sliced: diff --git a/dsperse/src/compile/utils/compiler_utils.py b/dsperse/src/compile/utils/compiler_utils.py index 713bb22..05530b9 100644 --- a/dsperse/src/compile/utils/compiler_utils.py +++ b/dsperse/src/compile/utils/compiler_utils.py @@ -105,12 +105,24 @@ def _with_slice_prefix(rel_path: Optional[str], slice_dirname: str) -> Optional[ def is_ezkl_compilation_successful(compilation_data: Dict[str, Any]) -> bool: """ Determine if compilation was successful based on produced file paths. - A path is considered valid if it exists and contains the 'payload' directory. + EZKL files are in payload subdirectories, JSTprove files are in backend directories. + Supports both EZKL and JSTprove backends. """ - def _ok(key: str) -> bool: + def _ok_ezkl(key: str) -> bool: p = compilation_data.get(key) return bool(p) and os.path.exists(p) and ('payload' in str(p).split(os.sep)) - return all([_ok('compiled'), _ok('vk_key'), _ok('pk_key'), _ok('settings')]) + + def _ok_jstprove(key: str) -> bool: + p = compilation_data.get(key) + return bool(p) and os.path.exists(p) # JSTprove doesn't use payload subdirs + + # Check if this is a JSTprove compilation (has 'circuit' key, no 'vk_key'/'pk_key') + if compilation_data.get('circuit') and not compilation_data.get('vk_key'): + # JSTprove requires 'compiled' (circuit) and 'settings' + return _ok_jstprove('compiled') and _ok_jstprove('settings') + + # EZKL requires compiled, vk_key, pk_key, settings + return all([_ok_ezkl('compiled'), _ok_ezkl('vk_key'), _ok_ezkl('pk_key'), _ok_ezkl('settings')]) @staticmethod def get_relative_paths(compilation_data: Dict[str, Any], calibration_input: Optional[str]) -> dict[str, str | None]: diff --git a/dsperse/src/constants.py b/dsperse/src/constants.py index a9daf1c..e4079bd 100644 --- a/dsperse/src/constants.py +++ b/dsperse/src/constants.py @@ -3,6 +3,7 @@ """ from pathlib import Path +# EZKL configuration MIN_EZKL_VERSION = "22.0.0" EZKL_PATH = Path.home() / ".ezkl" / "ezkl" SRS_DIR = Path.home() / ".ezkl" / "srs" @@ -10,4 +11,7 @@ SRS_LOGROWS_MIN = 2 SRS_LOGROWS_MAX = 24 SRS_LOGROWS_RANGE = range(SRS_LOGROWS_MIN, SRS_LOGROWS_MAX + 1) -SRS_FILES = [f"kzg{n}.srs" for n in SRS_LOGROWS_RANGE] \ No newline at end of file +SRS_FILES = [f"kzg{n}.srs" for n in SRS_LOGROWS_RANGE] + +# JSTprove configuration +JSTPROVE_COMMAND = "jst" \ No newline at end of file diff --git a/install.sh b/install.sh index ec317bf..3e76a9f 100755 --- a/install.sh +++ b/install.sh @@ -1,9 +1,11 @@ #!/usr/bin/env bash -# install.sh - Installer for Dsperse CLI and EZKL (with lookup tables) +# install.sh - Installer for Dsperse CLI, JSTproveand EZKL (with lookup tables) # This script installs: # - Dsperse CLI (python package, console script: dsperse) # - EZKL CLI (if missing) # - EZKL lookup tables (if missing) +# - Open MPI (required for JSTprove) +# - JSTprove CLI (if missing) # It detects existing installations and will skip or prompt accordingly. set -euo pipefail @@ -198,6 +200,113 @@ ensure_ezkl() { fi } +# Install Open MPI (required for JSTprove) +install_openmpi() { + info "Checking for Open MPI installation..." + if command -v mpirun >/dev/null 2>&1; then + info "Open MPI already installed: $(command -v mpirun)" + mpirun --version 2>/dev/null | head -n1 || true + return 0 + fi + + if [[ "$OSTYPE" == "darwin"* ]]; then + # macOS - use Homebrew + if command -v brew >/dev/null 2>&1; then + info "Installing Open MPI via Homebrew..." + if brew install open-mpi; then + info "Open MPI installed successfully via Homebrew" + else + warn "Failed to install Open MPI via Homebrew" + fi + else + warn "Homebrew not found. Please install Homebrew first, then run: brew install open-mpi" + fi + elif [[ "$OSTYPE" == "linux-gnu"* ]]; then + # Linux - try apt or yum + if command -v apt-get >/dev/null 2>&1; then + info "Installing Open MPI via apt..." + if sudo apt-get update && sudo apt-get install -y openmpi-bin libopenmpi-dev; then + info "Open MPI installed successfully via apt" + else + warn "Failed to install Open MPI via apt" + fi + elif command -v yum >/dev/null 2>&1; then + info "Installing Open MPI via yum..." + if sudo yum install -y openmpi openmpi-devel; then + info "Open MPI installed successfully via yum" + else + warn "Failed to install Open MPI via yum" + fi + else + warn "Could not detect package manager. Please install Open MPI manually." + fi + else + warn "Unsupported OS for automatic Open MPI installation. Please install manually." + fi + + if command -v mpirun >/dev/null 2>&1; then + info "Open MPI is now available: $(command -v mpirun)" + else + warn "Open MPI installation may require PATH configuration. Check your shell profile." + fi +} + +# Install JSTprove via uv +install_jstprove() { + info "Checking for JSTprove installation..." + + # Add common uv tool bin paths to PATH for detection + export PATH="$HOME/.local/bin:$PATH" + + if command -v jst >/dev/null 2>&1; then + info "JSTprove already installed: $(command -v jst)" + jst --version 2>/dev/null || jst --help 2>/dev/null | head -n1 || true + if [[ "$INTERACTIVE" == true ]] && confirm "Reinstall/upgrade JSTprove?"; then + : + else + info "Skipping JSTprove install." + return 0 + fi + fi + + # Check if uv is available + if command -v uv >/dev/null 2>&1; then + info "Installing JSTprove via uv..." + if uv tool install jstprove 2>/dev/null || uv pip install jstprove 2>/dev/null; then + info "JSTprove installed successfully via uv" + # Add uv tool bin to PATH + export PATH="$HOME/.local/bin:$PATH" + info "Added $HOME/.local/bin to PATH for JSTprove" + else + warn "Failed to install JSTprove via uv. Trying pip..." + if eval $PIP_BIN install jstprove 2>/dev/null; then + info "JSTprove installed successfully via pip" + else + warn "Failed to install JSTprove. Please install manually: uv tool install jstprove" + fi + fi + else + # Fallback to pip + info "uv not found. Installing JSTprove via pip..." + if eval $PIP_BIN install jstprove 2>/dev/null; then + info "JSTprove installed successfully via pip" + else + warn "Failed to install JSTprove via pip. Please install uv first: curl -LsSf https://astral.sh/uv/install.sh | sh" + warn "Then install JSTprove: uv tool install jstprove" + fi + fi + + if command -v jst >/dev/null 2>&1; then + info "JSTprove is now available: $(command -v jst)" + jst --version 2>/dev/null || jst --help 2>/dev/null | head -n1 || true + else + warn "JSTprove CLI (jst) not found on PATH." + warn "You may need to add ~/.local/bin to your PATH:" + warn " export PATH=\"\$HOME/.local/bin:\$PATH\"" + warn "Add this line to your shell profile (~/.bashrc, ~/.zshrc, etc.) for persistence." + fi +} + # Ensure SRS files (kzg commitment) exist under ~/.ezkl/srs ensure_srs() { if ! command -v ezkl >/dev/null 2>&1; then @@ -370,6 +479,12 @@ main() { # Lookup tables (for some ezkl versions) install_lookup_tables + # Install Open MPI (required for JSTprove) + install_openmpi + + # Install JSTprove + install_jstprove + # Final verification step for EZKL and SRS verify_environment_post_install @@ -386,6 +501,12 @@ main() { warn "EZKL not detected; some features will fall back to ONNX or fail." fi + if command -v jst >/dev/null 2>&1; then + info "JSTprove is ready: jst --help" + else + warn "JSTprove not detected; JSTprove backend features will not be available." + fi + say "\nInstallation complete!" } diff --git a/requirements.txt b/requirements.txt index 867b220..1b55376 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,6 @@ matplotlib~=3.10.1 numpy~=2.2.3 tqdm~=4.67.1 onnxruntime==1.21.0 -colorama~=0.4.6 \ No newline at end of file +colorama~=0.4.6 + +mpi4py>=3.1.0 \ No newline at end of file From 78fd8d43d1cc6993657bc817cf2bd20cc89d65e9 Mon Sep 17 00:00:00 2001 From: shirin shahabi <71061824+shirin-shahabi@users.noreply.github.com> Date: Tue, 2 Dec 2025 15:53:41 -0500 Subject: [PATCH 2/4] fix: Correct backend tracking in metadata and model compilation - Bug 1: update_slice_metadata() now accepts backend_name parameter and writes under correct key (jstprove/ezkl/onnx) instead of always 'ezkl' - Bug 2: _compile_model() now uses proper backend selection with fallback logic instead of always using EZKL --- dsperse/src/compile/compiler.py | 74 ++++++++++++++++++--- dsperse/src/compile/utils/compiler_utils.py | 27 +++++--- 2 files changed, 83 insertions(+), 18 deletions(-) diff --git a/dsperse/src/compile/compiler.py b/dsperse/src/compile/compiler.py index ba79678..a4a39c7 100644 --- a/dsperse/src/compile/compiler.py +++ b/dsperse/src/compile/compiler.py @@ -228,23 +228,81 @@ def _compile_slice(self, slice_data, base_path: str, layer_idx: int = 0): if slice_data.get('slice_metadata') and os.path.exists(slice_data.get('slice_metadata')): path = Path(slice_data.get('slice_metadata')) - CompilerUtils.update_slice_metadata(path, success, file_paths) + CompilerUtils.update_slice_metadata(path, success, file_paths, backend_name=used_backend or "onnx") elif slice_data.get('slice_metadata_relative_path') and os.path.exists(os.path.join(base_path, slice_data.get('slice_metadata_relative_path'))): path = Path(os.path.join(base_path, slice_data.get('slice_metadata_relative_path'))) - CompilerUtils.update_slice_metadata(path, success, file_paths) + CompilerUtils.update_slice_metadata(path, success, file_paths, backend_name=used_backend or "onnx") return success, file_paths, used_backend def _compile_model(self, model_file_path: str, input_file_path: Optional[str] = None) -> str: + """ + Compile a single ONNX model file (not sliced) with backend fallback support. + """ if not os.path.isfile(model_file_path): raise ValueError(f"model_path must be a file: {model_file_path}") output_path_root = os.path.splitext(model_file_path)[0] - circuit_folder = os.path.join(os.path.dirname(output_path_root), "model") - os.makedirs(circuit_folder, exist_ok=True) - # Call backend pipeline - self.backend.compilation_pipeline(model_file_path, circuit_folder, input_file_path=input_file_path) - logger.info(f"Compilation completed with {self.backend_name}. Output saved to {circuit_folder}") - return circuit_folder + + # Build list of backends to try (same logic as _compile_slice) + backends_to_try = [] + if self.default_backend: + # Specific backend requested + if self.default_backend == "jstprove": + jst = self._get_jstprove() + if jst: + backends_to_try.append((jst, "jstprove")) + else: + ezkl = self._get_ezkl() + if ezkl: + backends_to_try.append((ezkl, "ezkl")) + if self.use_fallback: + # Add fallback options + if self.default_backend == "jstprove": + ezkl = self._get_ezkl() + if ezkl: + backends_to_try.append((ezkl, "ezkl")) + else: + jst = self._get_jstprove() + if jst: + backends_to_try.append((jst, "jstprove")) + elif self.use_fallback: + # Default fallback: jstprove -> ezkl + jst = self._get_jstprove() + ezkl = self._get_ezkl() + if jst: + backends_to_try.append((jst, "jstprove")) + if ezkl: + backends_to_try.append((ezkl, "ezkl")) + else: + # No backend specified, no fallback - use EZKL as default + ezkl = self._get_ezkl() + if ezkl: + backends_to_try.append((ezkl, "ezkl")) + + if not backends_to_try: + raise RuntimeError("No backends available for compilation") + + # Try each backend until one succeeds + for try_backend, try_backend_name in backends_to_try: + circuit_folder = os.path.join(os.path.dirname(output_path_root), try_backend_name) + os.makedirs(circuit_folder, exist_ok=True) + try: + logger.info(f"Compiling model with {try_backend_name}...") + compilation_data = try_backend.compilation_pipeline( + model_file_path, circuit_folder, input_file_path=input_file_path + ) + success = CompilerUtils.is_ezkl_compilation_successful(compilation_data) + if success: + logger.info(f"Compilation completed with {try_backend_name}. Output saved to {circuit_folder}") + return circuit_folder + else: + logger.warning(f"{try_backend_name} compilation failed, trying fallback...") + except Exception as e: + logger.warning(f"{try_backend_name} error: {e}, trying fallback...") + if not self.use_fallback: + raise + + raise RuntimeError("All backends failed to compile the model") def _compile_slices(self, dir_path: str, input_file_path: Optional[str] = None, layer_indices=None): diff --git a/dsperse/src/compile/utils/compiler_utils.py b/dsperse/src/compile/utils/compiler_utils.py index 05530b9..0c09120 100644 --- a/dsperse/src/compile/utils/compiler_utils.py +++ b/dsperse/src/compile/utils/compiler_utils.py @@ -187,7 +187,7 @@ def build_model_level_ezkl(payload_rel: Dict[str, Optional[str]], calibration_re @staticmethod - def update_slice_metadata(filepath: str | Path, success: bool, file_paths: Dict[str, str | None]): + def update_slice_metadata(filepath: str | Path, success: bool, file_paths: Dict[str, str | None], backend_name: str = "ezkl"): """ Update the per-slice metadata.json file with compilation results. @@ -195,6 +195,7 @@ def update_slice_metadata(filepath: str | Path, success: bool, file_paths: Dict[ filepath: Path to the slice's metadata.json file success: Boolean indicating if compilation was successful file_paths: Dictionary containing file paths for compilation results + backend_name: Name of the backend used (jstprove, ezkl, or onnx) """ # Load existing slice metadata or create new if os.path.exists(filepath): @@ -203,14 +204,20 @@ def update_slice_metadata(filepath: str | Path, success: bool, file_paths: Dict[ else: slice_metadata = {} - # Get EZKL version - ezkl_version = EZKL.get_version() + # Get backend version based on which backend was used + if backend_name == "jstprove": + from dsperse.src.backends.JSTprove import JSTprove + backend_version = JSTprove.get_version() + elif backend_name == "ezkl": + backend_version = EZKL.get_version() + else: + backend_version = None - # Create compilation info nested under 'ezkl' - ezkl_compilation_info = { + # Create compilation info nested under the backend name + compilation_info = { "compiled": success, "compilation_timestamp": __import__('time').strftime("%Y-%m-%d %H:%M:%S"), - "ezkl_version": ezkl_version, + "backend_version": backend_version, "files": { "settings": file_paths.get('settings'), "compiled_circuit": file_paths.get('compiled'), @@ -223,20 +230,20 @@ def update_slice_metadata(filepath: str | Path, success: bool, file_paths: Dict[ # Add any errors if present errors = {k: v for k, v in file_paths.items() if k.endswith('_error')} if errors: - ezkl_compilation_info["errors"] = errors + compilation_info["errors"] = errors # Ensure compilation section exists if 'compilation' not in slice_metadata: slice_metadata['compilation'] = {} - # Update slice metadata with ezkl nested under compilation - slice_metadata['compilation']['ezkl'] = ezkl_compilation_info + # Update slice metadata with backend info nested under compilation + slice_metadata['compilation'][backend_name] = compilation_info # Save updated slice metadata with open(filepath, 'w') as f: json.dump(slice_metadata, f, indent=2) - logger.debug(f"Updated slice metadata at {filepath}") + logger.debug(f"Updated slice metadata at {filepath} for backend {backend_name}") @staticmethod From 7d91b000b85912a72855139204537e2483741202 Mon Sep 17 00:00:00 2001 From: shirin shahabi <71061824+shirin-shahabi@users.noreply.github.com> Date: Tue, 2 Dec 2025 16:17:45 -0500 Subject: [PATCH 3/4] fix: Replace undefined layer_idx with idx in _compile_slice logger statements - Fixed NameError: logger statements in _compile_slice were referencing undefined variable layer_idx - Changed all logger statements to use idx parameter (lines 198, 214, 223, 226, 228) --- dsperse/src/compile/compiler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dsperse/src/compile/compiler.py b/dsperse/src/compile/compiler.py index 7fc8f05..3b887c0 100644 --- a/dsperse/src/compile/compiler.py +++ b/dsperse/src/compile/compiler.py @@ -195,7 +195,7 @@ def _compile_slice(self, idx: int, slice_data: dict, base_path: str): for try_backend, try_backend_name in backends_to_try: if try_backend is None: # Skip compilation - will use onnx at runtime - logger.info(f"Slice {layer_idx}: Skipping ZK compilation, will use ONNX at runtime") + logger.info(f"Slice {idx}: Skipping ZK compilation, will use ONNX at runtime") success = True used_backend = "onnx" compilation_data = {"skipped": True, "reason": "fallback_to_onnx"} @@ -211,7 +211,7 @@ def _compile_slice(self, idx: int, slice_data: dict, base_path: str): ) if os.path.exists(os.path.join(os.path.dirname(slice_path), backend_dir, "calibration.json")) else None try: - logger.info(f"Slice {layer_idx}: Trying {try_backend_name}...") + logger.info(f"Slice {idx}: Trying {try_backend_name}...") compilation_data = try_backend.compilation_pipeline( slice_path, slice_output_path, @@ -220,12 +220,12 @@ def _compile_slice(self, idx: int, slice_data: dict, base_path: str): success = CompilerUtils.is_ezkl_compilation_successful(compilation_data) if success: used_backend = try_backend_name - logger.info(f"Slice {layer_idx}: {try_backend_name} compilation succeeded") + logger.info(f"Slice {idx}: {try_backend_name} compilation succeeded") break else: - logger.warning(f"Slice {layer_idx}: {try_backend_name} compilation failed, trying fallback...") + logger.warning(f"Slice {idx}: {try_backend_name} compilation failed, trying fallback...") except Exception as e: - logger.warning(f"Slice {layer_idx}: {try_backend_name} error: {e}, trying fallback...") + logger.warning(f"Slice {idx}: {try_backend_name} error: {e}, trying fallback...") if not self.use_fallback: raise From 191d2e6c7bb8b61f12c095eea98ec0040dc6a41c Mon Sep 17 00:00:00 2001 From: shirin shahabi <71061824+shirin-shahabi@users.noreply.github.com> Date: Fri, 12 Dec 2025 21:14:11 -0500 Subject: [PATCH 4/4] feat: Add JSTprove optimizations and runner integration - JSTprove.py: Skip recompilation if circuit already exists in witness generation Prevents unnecessary recompilation when circuit file is already available - JSTprove.py: Add warning when using rescaled outputs from output.json Clarifies that rescaled outputs come from output.json (not witness binary) - runner_analyzer.py: Check for JSTprove compilation first, then EZKL Prioritizes JSTprove backend when both are available, sets backend field - runner.py: Add JSTprove import and initialization Enables JSTprove backend in runner with graceful fallback if CLI unavailable --- dsperse/src/analyzers/runner_analyzer.py | 78 +++++++++++++++++------- dsperse/src/backends/JSTprove.py | 14 ++++- dsperse/src/constants.py | 2 +- dsperse/src/run/runner.py | 56 +++++++++++++++++ 4 files changed, 123 insertions(+), 27 deletions(-) diff --git a/dsperse/src/analyzers/runner_analyzer.py b/dsperse/src/analyzers/runner_analyzer.py index 0a68feb..630c646 100644 --- a/dsperse/src/analyzers/runner_analyzer.py +++ b/dsperse/src/analyzers/runner_analyzer.py @@ -95,16 +95,29 @@ def _process_slices_model(slices_dir: Path, slices_list: list[dict]) -> dict: dependencies = item.get("dependencies") or {} parameters = item.get("parameters", 0) - # EZKL compilation info - comp = ((item.get("compilation") or {}).get("ezkl") or {}) - files = (comp.get("files") or {}) - compiled_flag = bool(comp.get("compiled", False)) - - # Accept both keys: 'compiled_circuit' and legacy 'compiled' - compiled_rel = files.get("compiled_circuit") or files.get("compiled") - settings_rel = files.get("settings") - pk_rel = files.get("pk_key") - vk_rel = files.get("vk_key") + # Check for compilation info (JSTprove first, then EZKL) + compilation = item.get("compilation") or {} + jst_comp = compilation.get("jstprove") or {} + ezkl_comp = compilation.get("ezkl") or {} + + # Prefer JSTprove if available, otherwise EZKL + if jst_comp.get("compiled"): + backend = "jstprove" + files = jst_comp.get("files") or {} + compiled_flag = True + compiled_rel = files.get("circuit") + settings_rel = None + pk_rel = None + vk_rel = None + else: + backend = "ezkl" + files = ezkl_comp.get("files") or {} + compiled_flag = bool(ezkl_comp.get("compiled", False)) + # Accept both keys: 'compiled_circuit' and legacy 'compiled' + compiled_rel = files.get("compiled_circuit") or files.get("compiled") + settings_rel = files.get("settings") + pk_rel = files.get("pk_key") + vk_rel = files.get("vk_key") def _norm(rel: Optional[str]) -> Optional[str]: if not rel: @@ -124,6 +137,7 @@ def _norm(rel: Optional[str]) -> Optional[str]: "output_shape": output_shape, "ezkl_compatible": True, "ezkl": bool(compiled_flag), + "backend": backend, "circuit_size": 0, # unknown without touching filesystem; keep 0 "dependencies": dependencies, "parameters": parameters, @@ -171,21 +185,36 @@ def _process_slices_per_slice(slices_dir: Path, slices_data_list: list[dict]) -> dependencies = slice.get("dependencies") or {} parameters = slice.get("parameters", 0) - # EZKL compilation info - comp = ((slice.get("compilation") or {}).get("ezkl") or {}) - files = (comp.get("files") or {}) - compiled_flag = bool(comp.get("compiled", False)) - - if files: - circuit_path = os.path.join(parent_dir, files.get("compiled_circuit") or files.get("compiled")) - settings_path = os.path.join(parent_dir, files.get("settings")) - pk_path = os.path.join(parent_dir, files.get("pk_key")) - vk_path = os.path.join(parent_dir, files.get("vk_key")) - else: - circuit_path = None + # Check for compilation info (JSTprove first, then EZKL) + compilation = slice.get("compilation") or {} + jst_comp = compilation.get("jstprove") or {} + ezkl_comp = compilation.get("ezkl") or {} + + if jst_comp.get("compiled"): + backend = "jstprove" + files = jst_comp.get("files") or {} + compiled_flag = True + if files: + circuit_path = os.path.join(parent_dir, files.get("circuit")) + else: + circuit_path = None settings_path = None pk_path = None vk_path = None + else: + backend = "ezkl" + files = ezkl_comp.get("files") or {} + compiled_flag = bool(ezkl_comp.get("compiled", False)) + if files: + circuit_path = os.path.join(parent_dir, files.get("compiled_circuit") or files.get("compiled")) + settings_path = os.path.join(parent_dir, files.get("settings")) + pk_path = os.path.join(parent_dir, files.get("pk_key")) + vk_path = os.path.join(parent_dir, files.get("vk_key")) + else: + circuit_path = None + settings_path = None + pk_path = None + vk_path = None slices[slice_key] = { "path": onnx_path, @@ -193,6 +222,7 @@ def _process_slices_per_slice(slices_dir: Path, slices_data_list: list[dict]) -> "output_shape": output_shape, "ezkl_compatible": True, "ezkl": bool(compiled_flag), + "backend": backend, "circuit_size": 0, "dependencies": dependencies, "parameters": parameters, @@ -243,9 +273,11 @@ def _build_execution_chain(slices: dict): meta = slices.get(slice_key, {}) circuit_path = meta.get('circuit_path') onnx_path = meta.get('path') + backend = meta.get('backend', 'ezkl') has_circuit = circuit_path is not None and circuit_path != "" has_keys = (meta.get('pk_path') is not None) and (meta.get('vk_path') is not None) - use_circuit = bool(meta.get('ezkl')) and has_circuit and has_keys + # JSTprove doesn't require pk/vk keys; EZKL does + use_circuit = bool(meta.get('ezkl')) and has_circuit and (backend == 'jstprove' or has_keys) next_slice = ordered_keys[i + 1] if i < len(ordered_keys) - 1 else None execution_chain["nodes"][slice_key] = { diff --git a/dsperse/src/backends/JSTprove.py b/dsperse/src/backends/JSTprove.py index ff008a1..ebe7593 100644 --- a/dsperse/src/backends/JSTprove.py +++ b/dsperse/src/backends/JSTprove.py @@ -130,13 +130,15 @@ def generate_witness( if circuit_path.exists() and circuit_path.suffix == '.onnx': onnx_model_path = circuit_path circuit_path = circuit_path.parent / f"{circuit_path.stem}_jstprove_circuit.txt" - logger.info(f"JSTprove: Compiling ONNX model {onnx_model_path} to circuit {circuit_path}") - # If we have an ONNX model, compile it first - if onnx_model_path: + # If we have an ONNX model, compile it first only if circuit doesn't exist + if onnx_model_path and not circuit_path.exists(): + logger.info(f"JSTprove: Compiling ONNX model {onnx_model_path} to circuit {circuit_path}") ok, err = self.compile_circuit(onnx_model_path, circuit_path) if not ok: raise RuntimeError(f"Circuit compilation failed: {err}") + elif onnx_model_path and circuit_path.exists(): + logger.info(f"Using existing circuit: {circuit_path}") elif not circuit_path.exists(): raise FileNotFoundError(f"Circuit file not found: {circuit_path}") @@ -420,6 +422,12 @@ def _to_logits(data) -> torch.Tensor: # JSTprove dict format with 'rescaled_output' key if isinstance(witness_data, dict) and "rescaled_output" in witness_data: self._witness_format = "jstprove_dict" + # NOTE: Rescaled outputs are in output.json (from -o flag), not in the witness binary file (-w flag). + # The witness binary contains only the raw quantized values needed for proof generation. + logger.warning( + "Using rescaled outputs from output.json (not witness binary). " + "These are the model's floating-point outputs after de-quantization." + ) return {"logits": _to_logits(witness_data["rescaled_output"])} # Raw array format elif isinstance(witness_data, list): diff --git a/dsperse/src/constants.py b/dsperse/src/constants.py index e4079bd..ece5cdf 100644 --- a/dsperse/src/constants.py +++ b/dsperse/src/constants.py @@ -13,5 +13,5 @@ SRS_LOGROWS_RANGE = range(SRS_LOGROWS_MIN, SRS_LOGROWS_MAX + 1) SRS_FILES = [f"kzg{n}.srs" for n in SRS_LOGROWS_RANGE] -# JSTprove configuration +# JSTprove CLI command JSTPROVE_COMMAND = "jst" \ No newline at end of file diff --git a/dsperse/src/run/runner.py b/dsperse/src/run/runner.py index 1cabb09..4852a8e 100644 --- a/dsperse/src/run/runner.py +++ b/dsperse/src/run/runner.py @@ -13,6 +13,7 @@ from dsperse.src.analyzers.runner_analyzer import RunnerAnalyzer from dsperse.src.backends.ezkl import EZKL +from dsperse.src.backends.JSTprove import JSTprove from dsperse.src.backends.onnx_models import OnnxModels from dsperse.src.run.utils.runner_utils import RunnerUtils from dsperse.src.slice.utils.converter import Converter @@ -31,6 +32,11 @@ def __init__(self, run_metadata_path: str = None, save_metadata_path: str = None self.run_metadata = None self.ezkl_runner = EZKL() + try: + self.jstprove_runner = JSTprove() + except RuntimeError: + self.jstprove_runner = None + logger.warning("JSTprove CLI not available. JSTprove backend will be disabled.") def run(self, input_json_path, slice_path: str, output_path: str = None) -> dict: @@ -151,6 +157,56 @@ def _resolve_rel_path(p: str, base_dir: Path) -> str: return success, output_tensor, exec_info + def _run_jstprove_slice(self, slice_info: dict, input_tensor_path, output_witness_path, slice_dir: Path = None): + """Run JSTprove inference for a slice with fallback to ONNX. + Accepts paths possibly formatted as `slice_#/payload/...` or `payload/...` and resolves them + under the provided `slice_dir` if necessary. + """ + if self.jstprove_runner is None: + return False, "JSTprove CLI not available", {'success': False, 'method': 'jstprove_gen_witness', 'error': 'JSTprove CLI not available'} + + def _resolve_rel_path(p: str, base_dir: Path) -> str: + path = str((base_dir / p).resolve()) + if not Path(path).exists(): + path = str((Path(base_dir).parent / Path(p)).resolve()) + return path + + circuit_path = slice_info.get("circuit_path") + settings_path = slice_info.get("settings_path") + + # Resolve possibly relative paths + if circuit_path and not os.path.isabs(str(circuit_path)): + circuit_path = _resolve_rel_path(circuit_path, slice_dir) + + start_time = time.time() + # Attempt JSTprove execution + try: + success, output_tensor = self.jstprove_runner.generate_witness( + input_file=input_tensor_path, + model_path=circuit_path, + output_file=output_witness_path, + ) + except Exception as e: + success = False + output_tensor = str(e) + + end_time = time.time() + exec_info = { + 'success': success, + 'method': 'jstprove_gen_witness', + 'execution_time': end_time - start_time, + 'witness_path': str(output_witness_path), + 'attempted_jstprove': True + } + + if success: + exec_info['input_file'] = str(input_tensor_path.resolve()) + exec_info['output_file'] = str(output_witness_path.resolve()) + else: + exec_info['error'] = output_tensor if isinstance(output_tensor, str) else "Unknown JSTprove error" + + return success, output_tensor, exec_info + def _save_inference_output(self, results, output_path): """Save inference_output.json with execution details.""" model_path = self.run_metadata.get("model_path", "unknown")