diff --git a/app/services/reverse/ws_imagine.py b/app/services/reverse/ws_imagine.py index 2e22978be..0c3710726 100644 --- a/app/services/reverse/ws_imagine.py +++ b/app/services/reverse/ws_imagine.py @@ -101,20 +101,53 @@ async def stream( max_retries: Optional[int] = None, ) -> AsyncGenerator[Dict[str, object], None]: retries = max(1, max_retries if max_retries is not None else 1) + parallel_enabled = bool(get_config("image.blocked_parallel_enabled", True)) logger.info( f"Image generation: prompt='{prompt[:50]}...', n={n}, ratio={aspect_ratio}, nsfw={enable_nsfw}" ) + async def _collect_once() -> list[Dict[str, object]]: + items: list[Dict[str, object]] = [] + async for item in self._stream_once( + token, prompt, aspect_ratio, n, enable_nsfw + ): + items.append(item) + return items + for attempt in range(retries): try: - yielded_any = False - async for item in self._stream_once( - token, prompt, aspect_ratio, n, enable_nsfw - ): - yielded_any = True + items = await _collect_once() + for item in items: yield item return except _BlockedError: + retries_left = retries - (attempt + 1) + if retries_left > 0 and parallel_enabled: + logger.warning( + f"WebSocket blocked/reviewed, launching {retries_left} parallel retries" + ) + tasks = [asyncio.create_task(_collect_once()) for _ in range(retries_left)] + results = await asyncio.gather(*tasks, return_exceptions=True) + for result in results: + if isinstance(result, Exception): + continue + has_final = any( + isinstance(item, dict) + and item.get("type") == "image" + and item.get("is_final") + for item in result + ) + if has_final: + for item in result: + yield item + return + yield { + "type": "error", + "error_code": "blocked", + "error": "blocked_no_final_image", + "parallel_attempts": retries_left, + } + return if attempt + 1 < retries: logger.warning( f"WebSocket blocked/reviewed, retry {attempt + 1}/{retries}" @@ -124,7 +157,6 @@ async def stream( "type": "error", "error_code": "blocked", "error": "blocked_no_final_image", - "yielded_partial": yielded_any, } return except Exception as e: