diff --git a/src/zimage/cli.py b/src/zimage/cli.py index 12d703b..1b14565 100644 --- a/src/zimage/cli.py +++ b/src/zimage/cli.py @@ -27,6 +27,7 @@ get_outputs_dir, get_db_path, get_config_path, + load_config, ) from zimage.logger import get_logger, setup_logging except ImportError: @@ -40,6 +41,7 @@ get_outputs_dir, get_db_path, get_config_path, + load_config, ) from logger import get_logger, setup_logging elif __package__: @@ -52,6 +54,7 @@ get_outputs_dir, get_db_path, get_config_path, + load_config, ) from .logger import get_logger, setup_logging else: @@ -66,6 +69,7 @@ get_outputs_dir, get_db_path, get_config_path, + load_config, ) from logger import get_logger, setup_logging except ImportError: @@ -80,6 +84,7 @@ get_outputs_dir, get_db_path, get_config_path, + load_config, ) from logger import get_logger, setup_logging @@ -304,6 +309,11 @@ def collect_info(): "loras_dir": str(get_loras_dir().resolve()), "db_path": str(get_db_path().resolve()), }, + "constraints": { + "max_steps": load_config().get("max_steps", 50), + "max_width": load_config().get("max_width", 4096), + "max_height": load_config().get("max_height", 4096), + }, "env_overrides": { "Z_IMAGE_STUDIO_DATA_DIR": os.environ.get("Z_IMAGE_STUDIO_DATA_DIR"), "Z_IMAGE_STUDIO_OUTPUT_DIR": os.environ.get("Z_IMAGE_STUDIO_OUTPUT_DIR"), @@ -336,6 +346,11 @@ def format_info_text(info: dict) -> str: f" LoRAs Dir: {info['paths']['loras_dir']}", f" DB Path: {info['paths']['db_path']}", "", + "Constraints:", + f" Max Steps: {info['constraints']['max_steps']}", + f" Max Width: {info['constraints']['max_width']}", + f" Max Height: {info['constraints']['max_height']}", + "", "Environment Overrides:", f" Z_IMAGE_STUDIO_DATA_DIR: {info['env_overrides']['Z_IMAGE_STUDIO_DATA_DIR']}", f" Z_IMAGE_STUDIO_OUTPUT_DIR: {info['env_overrides']['Z_IMAGE_STUDIO_OUTPUT_DIR']}", @@ -421,6 +436,21 @@ def run_generation(args): generate_image, save_image, record_generation = _load_generation_modules() logger.info(f"DEBUG: cwd: {Path.cwd().resolve()}") + # Check constraints + config = load_config() + max_steps = config.get("max_steps", 50) + max_width = config.get("max_width", 4096) + max_height = config.get("max_height", 4096) + + if args.steps > max_steps: + log_error(f"Requested steps ({args.steps}) exceeds the maximum allowed ({max_steps}).") + sys.exit(1) + if args.width > max_width: + log_error(f"Requested width ({args.width}) exceeds the maximum allowed ({max_width}).") + sys.exit(1) + if args.height > max_height: + log_error(f"Requested height ({args.height}) exceeds the maximum allowed ({max_height}).") + # Ensure width/height are multiples of 16 for name in ["width", "height"]: v = getattr(args, name) diff --git a/src/zimage/mcp_server.py b/src/zimage/mcp_server.py index fcf03a4..ed89d08 100644 --- a/src/zimage/mcp_server.py +++ b/src/zimage/mcp_server.py @@ -9,6 +9,7 @@ import base64 import random from urllib.parse import quote +from pydantic import Field # Lazy import for yarl to avoid dependency issues try: @@ -21,11 +22,13 @@ from . import db from .storage import save_image, record_generation from .logger import get_logger, setup_logging + from .paths import load_config except ImportError: from hardware import get_available_models, normalize_precision, MODEL_ID_MAP import db from storage import save_image, record_generation from logger import get_logger, setup_logging + from paths import load_config # Lazy imports for heavy dependencies def _get_engine(): @@ -115,6 +118,19 @@ async def send_progress(percentage: int, message: str): try: await send_progress(0, "Initializing generation...") + # Enforce constraints + config = load_config() + max_steps = config.get("max_steps", 50) + max_width = config.get("max_width", 4096) + max_height = config.get("max_height", 4096) + + if steps > max_steps: + raise ValueError(f"Requested steps ({steps}) exceeds the maximum allowed ({max_steps}).") + if width > max_width: + raise ValueError(f"Requested width ({width}) exceeds the maximum allowed ({max_width}).") + if height > max_height: + raise ValueError(f"Requested height ({height}) exceeds the maximum allowed ({max_height}).") + # Normalize and validate precision try: precision = normalize_precision(precision) @@ -451,9 +467,9 @@ async def send_progress(percentage: int, message: str): @mcp.tool() async def generate( prompt: str, - steps: int = 9, - width: int = 1280, - height: int = 720, + steps: int = Field(default=9, description="Number of inference steps (max bounded by server config)"), + width: int = Field(default=1280, description="Image width in pixels (max bounded by server config)"), + height: int = Field(default=720, description="Image height in pixels (max bounded by server config)"), seed: int | None = None, precision: str = "q8", ctx: Optional[Context] = None diff --git a/src/zimage/paths.py b/src/zimage/paths.py index cc8c40c..0c49c9c 100644 --- a/src/zimage/paths.py +++ b/src/zimage/paths.py @@ -141,6 +141,9 @@ def ensure_initial_setup(): "Z_IMAGE_STUDIO_DATA_DIR": None, "Z_IMAGE_STUDIO_OUTPUT_DIR": None, "ZIMAGE_ENABLE_TORCH_COMPILE": None, + "max_steps": 50, + "max_width": 4096, + "max_height": 4096, } global _CONFIG_CACHE _CONFIG_CACHE = config diff --git a/src/zimage/server.py b/src/zimage/server.py index cbad275..503aa54 100644 --- a/src/zimage/server.py +++ b/src/zimage/server.py @@ -219,8 +219,18 @@ class GenerateResponse(BaseModel): @app.get("/models") async def get_models(): - """Get list of available models with hardware recommendations.""" - return get_available_models() + """Get list of available models with hardware recommendations and constraints.""" + from .paths import load_config + config = load_config() + models_info = get_available_models() + + # Inject constraints into the response + models_info['constraints'] = { + "max_steps": config.get("max_steps", 50), + "max_width": config.get("max_width", 4096), + "max_height": config.get("max_height", 4096), + } + return models_info @app.get("/loras") async def get_loras(): @@ -372,6 +382,19 @@ async def delete_lora(lora_id: int): @app.post("/generate", response_model=GenerateResponse) async def generate(req: GenerateRequest, background_tasks: BackgroundTasks): try: + from .paths import load_config + config = load_config() + max_steps = config.get("max_steps", 50) + max_width = config.get("max_width", 4096) + max_height = config.get("max_height", 4096) + + if req.steps > max_steps: + raise HTTPException(status_code=400, detail=f"Requested steps ({req.steps}) exceeds the maximum allowed ({max_steps}).") + if req.width > max_width: + raise HTTPException(status_code=400, detail=f"Requested width ({req.width}) exceeds the maximum allowed ({max_width}).") + if req.height > max_height: + raise HTTPException(status_code=400, detail=f"Requested height ({req.height}) exceeds the maximum allowed ({max_height}).") + # Normalize and validate precision early to avoid KeyError inside engine try: precision = normalize_precision(req.precision) diff --git a/src/zimage/static/js/main.js b/src/zimage/static/js/main.js index 1b45c63..bfece9b 100644 --- a/src/zimage/static/js/main.js +++ b/src/zimage/static/js/main.js @@ -895,6 +895,23 @@ const res = await fetch('/models'); const data = await res.json(); + if (data.constraints) { + const stepsEl = document.getElementById('steps'); + if (stepsEl && data.constraints.max_steps) { + stepsEl.max = data.constraints.max_steps; + } + + const widthEl = document.getElementById('width'); + if (widthEl && data.constraints.max_width) { + widthEl.max = data.constraints.max_width; + } + + const heightEl = document.getElementById('height'); + if (heightEl && data.constraints.max_height) { + heightEl.max = data.constraints.max_height; + } + } + if (data.device) window.currentDevice = data.device; if (data.default_precision) window.defaultPrecision = data.default_precision; diff --git a/tests/test_cli_info.py b/tests/test_cli_info.py index e061b31..76f8376 100644 --- a/tests/test_cli_info.py +++ b/tests/test_cli_info.py @@ -71,6 +71,11 @@ def test_run_info_json_outputs_valid_json(capsys, fake_hardware): "loras_dir": "e", "db_path": "f", }, + "constraints": { + "max_steps": 50, + "max_width": 4096, + "max_height": 4096, + }, "env_overrides": { "Z_IMAGE_STUDIO_DATA_DIR": None, "Z_IMAGE_STUDIO_OUTPUT_DIR": None, @@ -108,6 +113,11 @@ def test_run_info_text_includes_hardware_error(capsys): "loras_dir": "e", "db_path": "f", }, + "constraints": { + "max_steps": 50, + "max_width": 4096, + "max_height": 4096, + }, "env_overrides": { "Z_IMAGE_STUDIO_DATA_DIR": None, "Z_IMAGE_STUDIO_OUTPUT_DIR": None,