diff --git a/apps/gradio-demo/main.py b/apps/gradio-demo/main.py index 77511b73..79ff2cd1 100644 --- a/apps/gradio-demo/main.py +++ b/apps/gradio-demo/main.py @@ -1,3 +1,8 @@ + + +Here is the improved code with all emojis removed, maintaining a clean, professional text-based interface. + +```python import asyncio import json import logging @@ -6,651 +11,532 @@ import time import uuid from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field from pathlib import Path -from typing import AsyncGenerator, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional, TypedDict import gradio as gr from dotenv import load_dotenv from hydra import compose, initialize_config_dir from omegaconf import DictConfig + +# Assuming these are local project imports from src.config.settings import expose_sub_agents_as_tools from src.core.pipeline import create_pipeline_components, execute_task_pipeline from utils import contains_chinese, replace_chinese_punctuation -# Create global cleanup thread pool for operations that won't be affected by asyncio.cancel -cleanup_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="cleanup") +# ================= Configuration & Logging ================= +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) logger = logging.getLogger(__name__) - -# Load environment variables from .env file load_dotenv() # Global Hydra initialization flag _hydra_initialized = False +# ================= Types & Data Structures ================= + +class ToolCallData(TypedDict, total=False): + tool_name: str + input: Any + output: Any + content: str # Used for show_text/message + +class AgentData(TypedDict): + agent_name: str + tool_call_order: List[str] + tools: Dict[str, ToolCallData] + +@dataclass +class UIState: + task_id: Optional[str] = None + agent_order: List[str] = field(default_factory=list) + agents: Dict[str, AgentData] = field(default_factory=dict) + current_agent_id: Optional[str] = None + errors: List[str] = field(default_factory=list) + +# ================= Configuration Loading ================= def load_miroflow_config(config_overrides: Optional[dict] = None) -> DictConfig: - """ - Load the full MiroFlow configuration using Hydra, similar to how benchmarks work. - """ + """Load the full MiroFlow configuration using Hydra.""" global _hydra_initialized - # Get the path to the miroflow agent config directory miroflow_config_dir = Path(__file__).parent.parent / "miroflow-agent" / "conf" - miroflow_config_dir = miroflow_config_dir.resolve() - print("config dir", miroflow_config_dir) - if not miroflow_config_dir.exists(): - raise FileNotFoundError( - f"MiroFlow config directory not found: {miroflow_config_dir}" - ) + raise FileNotFoundError(f"MiroFlow config directory not found: {miroflow_config_dir}") - # Initialize Hydra if not already done if not _hydra_initialized: try: - initialize_config_dir( - config_dir=str(miroflow_config_dir), version_base=None - ) + initialize_config_dir(config_dir=str(miroflow_config_dir), version_base=None) _hydra_initialized = True except Exception as e: logger.warning(f"Hydra already initialized or error: {e}") - # Compose configuration with environment variable overrides - overrides = [] - - # Add environment variable based overrides (refer to scripts/debug.sh) - llm_provider = os.getenv( - "DEFAULT_LLM_PROVIDER", "qwen" - ) # debug.sh defaults to qwen - model_name = os.getenv( - "DEFAULT_MODEL_NAME", "MiroThinker" - ) # debug.sh default model - agent_set = os.getenv("DEFAULT_AGENT_SET", "evaluation") # debug.sh uses evaluation - base_url = os.getenv("BASE_URL", "http://localhost:11434") - print("base_url", base_url) + overrides = [ + f"llm.provider={os.getenv('DEFAULT_LLM_PROVIDER', 'qwen')}", + f"llm.model_name={os.getenv('DEFAULT_MODEL_NAME', 'MiroThinker')}", + f"llm.base_url={os.getenv('BASE_URL', 'http://localhost:11434')}", + f"agent={os.getenv('DEFAULT_AGENT_SET', 'evaluation')}", + "benchmark=gaia-validation", + "+pricing=default", + ] # Map provider names to config files - provider_config_map = { - "anthropic": "claude", - "openai": "openai", - "deepseek": "deepseek", - "qwen": "qwen-3", - } + provider_map = {"anthropic": "claude", "openai": "openai", "deepseek": "deepseek", "qwen": "qwen-3"} + if os.getenv("DEFAULT_LLM_PROVIDER") in provider_map: + overrides[0] = f"llm={provider_map[os.getenv('DEFAULT_LLM_PROVIDER')]}" - llm_config = provider_config_map.get( - llm_provider, "qwen-3" - ) # default changed to qwen-3 - overrides.extend( - [ - f"llm={llm_config}", - f"llm.provider={llm_provider}", - f"llm.model_name={model_name}", - f"llm.base_url={base_url}", - f"agent={agent_set}", # use evaluation instead of default - "benchmark=gaia-validation", # refer to debug.sh - "+pricing=default", - ] - ) - - # Add config overrides from request if config_overrides: - for key, value in config_overrides.items(): - if isinstance(value, dict): - for subkey, subvalue in value.items(): - overrides.append(f"{key}.{subkey}={subvalue}") + for k, v in config_overrides.items(): + if isinstance(v, dict): + overrides.extend([f"{k}.{sk}={sv}" for sk, sv in v.items()]) else: - overrides.append(f"{key}={value}") + overrides.append(f"{k}={v}") try: - cfg = compose(config_name="config", overrides=overrides) - return cfg + return compose(config_name="config", overrides=overrides) except Exception as e: logger.error(f"Failed to compose Hydra config: {e}") - exit() - + raise -# pre load main agent tool definitions to speed up the first request +# Pre-load resources cfg = load_miroflow_config(None) -# Create pipeline components -main_agent_tool_manager, sub_agent_tool_managers, output_formatter = ( - create_pipeline_components(cfg) -) +main_agent_tool_manager, sub_agent_tool_managers, output_formatter = create_pipeline_components(cfg) tool_definitions = asyncio.run(main_agent_tool_manager.get_all_tool_definitions()) tool_definitions += expose_sub_agents_as_tools(cfg.agent.sub_agents) -# pre load sub agent tool definitions to speed up the first request sub_agent_tool_definitions = { - name: asyncio.run(sub_agent_tool_manager.get_all_tool_definitions()) - for name, sub_agent_tool_manager in sub_agent_tool_managers.items() + name: asyncio.run(mgr.get_all_tool_definitions()) + for name, mgr in sub_agent_tool_managers.items() } +# ================= Core Logic Classes ================= class ThreadSafeAsyncQueue: - """Thread-safe async queue wrapper""" - + """Wrapper for asyncio.Queue to handle thread-safe puts.""" def __init__(self): self._queue = asyncio.Queue() - self._loop = None + self._loop = asyncio.get_running_loop() self._closed = False - def set_loop(self, loop): - self._loop = loop + async def put(self, item: Any): + if not self._closed: + await self._queue.put(item) - async def put(self, item): - """Put data safely from any thread""" + def put_nowait_threadsafe(self, item: Any): + """Schedules a put from a different thread.""" if self._closed: return - await self._queue.put(item) - - def put_nowait_threadsafe(self, item): - """Put data from other threads""" - if self._closed or not self._loop: + if not self._loop.is_running(): return - self._loop.call_soon_threadsafe(lambda: asyncio.create_task(self.put(item))) + # We schedule the coroutine on the loop + asyncio.run_coroutine_threadsafe(self.put(item), self._loop) - async def get(self): + async def get(self) -> Any: return await self._queue.get() def close(self): self._closed = True +class PipelineRunner: + """Manages the execution of the pipeline in a separate thread.""" + def __init__(self, task_id: str, query: str, queue: ThreadSafeAsyncQueue): + self.task_id = task_id + self.query = query + self.queue = queue + self.cancel_event = threading.Event() + self.executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="pipeline") + self.future = None -def filter_google_search_organic(organic: List[dict]) -> List[dict]: - """ - Filter google search organic results to remove unnecessary information - """ - result = [] - for item in organic: - result.append( - { - "title": item.get("title", ""), - "link": item.get("link", ""), - } - ) - return result - - -def is_scrape_error(result: str) -> bool: - """ - Check if the scrape result is an error - """ - try: - json.loads(result) - return False - except json.JSONDecodeError: - return True - - -def filter_message(message: dict) -> dict: - """ - Filter message to remove unnecessary information - """ - if message["event"] == "tool_call": - tool_name = message["data"].get("tool_name") - tool_input = message["data"].get("tool_input") - if ( - tool_name == "google_search" - and isinstance(tool_input, dict) - and "result" in tool_input - ): - result_dict = json.loads(tool_input["result"]) - if "organic" in result_dict: - new_result = { - "organic": filter_google_search_organic(result_dict["organic"]) - } - message["data"]["tool_input"]["result"] = json.dumps( - new_result, ensure_ascii=False - ) - if ( - tool_name in ["scrape", "scrape_website"] - and isinstance(tool_input, dict) - and "result" in tool_input - ): - # if error, it can not be json - if is_scrape_error(tool_input["result"]): - message["data"]["tool_input"] = {"error": tool_input["result"]} - else: - message["data"]["tool_input"] = {} - return message - - -async def stream_events_optimized( - task_id: str, query: str, _: Optional[dict] = None, disconnect_check=None -) -> AsyncGenerator[dict, None]: - """Optimized event stream generator that directly outputs structured events, no longer wrapped as SSE strings.""" - workflow_id = task_id - last_send_time = time.time() - last_heartbeat_time = time.time() + def start(self): + self.future = self.executor.submit(self._run_pipeline_thread) - # Create thread-safe queue - stream_queue = ThreadSafeAsyncQueue() - stream_queue.set_loop(asyncio.get_event_loop()) + def cancel(self): + self.cancel_event.set() + # We don't wait for future result here to allow UI to update immediately - cancel_event = threading.Event() - - def run_pipeline_in_thread(): + def _run_pipeline_thread(self): try: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - class ThreadQueueWrapper: - def __init__(self, thread_queue, cancel_event): - self.thread_queue = thread_queue - self.cancel_event = cancel_event - + # Wrapper to interface with thread-safe queue + class QueueProxy: async def put(self, item): - if self.cancel_event.is_set(): - logger.info("Pipeline cancelled, stopping execution") - return - self.thread_queue.put_nowait_threadsafe(filter_message(item)) - - wrapper_queue = ThreadQueueWrapper(stream_queue, cancel_event) - - global cfg - global main_agent_tool_manager - global sub_agent_tool_managers - global output_formatter - global tool_definitions - global sub_agent_tool_definitions - - async def pipeline_with_cancellation(): - pipeline_task = asyncio.create_task( + if self.cancel_event.is_set(): return + # Filter before sending to UI + self.queue.put_nowait_threadsafe(filter_message(item)) + + proxy = QueueProxy() + proxy.cancel_event = self.cancel_event + proxy.queue = self.queue + + async def run(): + task = asyncio.create_task( execute_task_pipeline( - cfg=cfg, - task_id=workflow_id, - task_description=query, - task_file_name=None, - main_agent_tool_manager=main_agent_tool_manager, + cfg=cfg, task_id=self.task_id, task_description=self.query, + task_file_name=None, main_agent_tool_manager=main_agent_tool_manager, sub_agent_tool_managers=sub_agent_tool_managers, - output_formatter=output_formatter, - stream_queue=wrapper_queue, + output_formatter=output_formatter, stream_queue=proxy, log_dir=os.getenv("LOG_DIR", "logs/api-server"), tool_definitions=tool_definitions, sub_agent_tool_definitions=sub_agent_tool_definitions, ) ) - - async def check_cancellation(): - while not cancel_event.is_set(): - await asyncio.sleep(0.5) - logger.info("Cancel event detected, cancelling pipeline") - pipeline_task.cancel() - - cancel_task = asyncio.create_task(check_cancellation()) - - try: - done, pending = await asyncio.wait( - [pipeline_task, cancel_task], - return_when=asyncio.FIRST_COMPLETED, - ) - for task in pending: - task.cancel() - for task in done: - if task == pipeline_task: - try: - await task - except asyncio.CancelledError: - logger.info("Pipeline task was cancelled") - except Exception as e: - logger.error(f"Pipeline execution error: {e}") - pipeline_task.cancel() - cancel_task.cancel() - - loop.run_until_complete(pipeline_with_cancellation()) + + # Cancellation watcher + while not self.cancel_event.is_set(): + await asyncio.sleep(0.5) + if task.done(): break + + if self.cancel_event.is_set(): + task.cancel() + try: await task + except asyncio.CancelledError: pass + + loop.run_until_complete(run()) except Exception as e: - if not cancel_event.is_set(): - logger.error(f"Pipeline error: {e}", exc_info=True) - stream_queue.put_nowait_threadsafe( - { - "event": "error", - "data": {"error": str(e), "workflow_id": workflow_id}, - } - ) + logger.error(f"Pipeline Thread Exception: {e}", exc_info=True) + self.queue.put_nowait_threadsafe({ + "event": "error", "data": {"error": str(e), "workflow_id": self.task_id} + }) finally: - stream_queue.put_nowait_threadsafe(None) - if "loop" in locals(): - loop.close() - - executor = ThreadPoolExecutor(max_workers=1) - future = executor.submit(run_pipeline_in_thread) - - try: - while True: - try: - if disconnect_check and await disconnect_check(): - logger.info("Client disconnected, stopping pipeline") - cancel_event.set() - break - message = await asyncio.wait_for(stream_queue.get(), timeout=0.1) - if message is None: - logger.info("Pipeline completed") - break - yield message - last_send_time = time.time() - except asyncio.TimeoutError: - current_time = time.time() - if current_time - last_send_time > 300: - logger.info("Stream timeout") - break - if future.done(): - try: - message = stream_queue._queue.get_nowait() - if message is not None: - yield message - continue - except Exception: - break - if current_time - last_heartbeat_time >= 15: - yield { - "event": "heartbeat", - "data": {"timestamp": current_time, "workflow_id": workflow_id}, - } - last_heartbeat_time = current_time - except Exception as e: - logger.error(f"Stream error: {e}", exc_info=True) - yield { - "event": "error", - "data": {"workflow_id": workflow_id, "error": f"Stream error: {str(e)}"}, - } - finally: - cancel_event.set() - stream_queue.close() - try: - future.result(timeout=1.0) - except Exception: - pass - executor.shutdown(wait=False) - + self.queue.put_nowait_threadsafe(None) # Sentinel for end of stream + loop.close() -# ========================= Gradio Integration ========================= + def cleanup(self): + if self.future: + self.future.cancel() + self.executor.shutdown(wait=False) +# ================= Helper Functions ================= -def _init_render_state(): - return { - "agent_order": [], - "agents": {}, # agent_id -> {"agent_name": str, "tool_call_order": [], "tools": {tool_call_id: {...}}} - "current_agent_id": None, - "errors": [], - } - +def filter_google_search_organic(organic: List[dict]) -> List[dict]: + return [{"title": i.get("title", ""), "link": i.get("link", "")} for i in organic] -def _append_show_text(tool_entry: dict, delta: str): - existing = tool_entry.get("content", "") - tool_entry["content"] = existing + delta +def filter_message(message: dict) -> dict: + """Sanitize message data for UI rendering.""" + if message.get("event") != "tool_call": + return message + + data = message.get("data", {}) + tool_name = data.get("tool_name") + tool_input = data.get("tool_input") + if not isinstance(tool_input, dict): return message -def _is_empty_payload(value) -> bool: - if value is None: - return True - if isinstance(value, str): - stripped = value.strip() - return stripped == "" or stripped in ("{}", "[]") - if isinstance(value, (dict, list, tuple, set)): - return len(value) == 0 - return False + # Filter Search Results + if tool_name == "google_search" and "result" in tool_input: + try: + res = json.loads(tool_input["result"]) + if "organic" in res: + res["organic"] = filter_google_search_organic(res["organic"]) + tool_input["result"] = json.dumps(res, ensure_ascii=False) + except json.JSONDecodeError: pass + + # Filter Scrape Results + if tool_name in ["scrape", "scrape_website"] and "result" in tool_input: + try: + json.loads(tool_input["result"]) # Check validity + # If valid JSON, we might want to truncate it or just hide it entirely + # For now, just clear it to save UI space, or keep if small + if len(tool_input["result"]) > 5000: + tool_input["result"] = tool_input["result"][:5000] + "... [truncated]" + except json.JSONDecodeError: + # It's an error text + tool_input = {"error": tool_input["result"]} + + message["data"]["tool_input"] = tool_input + return message +# ================= UI State Logic ================= -def _render_markdown(state: dict) -> str: - lines = [] - emoji_cycle = ["๐Ÿง ", "๐Ÿ”Ž", "๐Ÿ› ๏ธ", "๐Ÿ“š", "๐Ÿค–", "๐Ÿงช", "๐Ÿ“", "๐Ÿงญ", "โš™๏ธ", "๐Ÿงฎ"] - # Render errors first if any - if state.get("errors"): - lines.append("### โŒ Errors") - for idx, err in enumerate(state["errors"], start=1): - lines.append(f"- **Error {idx}**: {err}") - lines.append("\n---\n") - for idx, agent_id in enumerate(state.get("agent_order", [])): - agent = state["agents"].get(agent_id, {}) - agent_name = agent.get("agent_name", "unknown") - emoji = emoji_cycle[idx % len(emoji_cycle)] - lines.append(f"### {emoji} Agent: {agent_name}") - for call_id in agent.get("tool_call_order", []): - call = agent["tools"].get(call_id, {}) - tool_name = call.get("tool_name", "unknown_tool") - if tool_name in ("show_text", "message"): - content = call.get("content", "") - if content: - lines.append(content) - else: - tool_input = call.get("input") - tool_output = call.get("output") - has_input = not _is_empty_payload(tool_input) - has_output = not _is_empty_payload(tool_output) - if not has_input and not has_output: - # No parameters, only show tool name with emoji on separate line - if tool_name == "Partial Summary": - lines.append("\n๐Ÿ’กPartial Summary\n") - else: - lines.append(f"\n๐Ÿ”ง{tool_name}\n") - else: - # Show as collapsible details for any tool with input or output - if tool_name == "Partial Summary": - summary = f"๐Ÿ’ก{tool_name} ({call_id[:8]})" - else: - summary = f"๐Ÿ”ง{tool_name} ({call_id[:8]})" - lines.append(f"\n
{summary}") - if has_input: - pretty = json.dumps(tool_input, ensure_ascii=False, indent=2) - lines.append("\n**Input**:\n") - lines.append(f"```json\n{pretty}\n```") - if has_output: - pretty = json.dumps(tool_output, ensure_ascii=False, indent=2) - lines.append("\n**Output**:\n") - lines.append(f"```json\n{pretty}\n```") - lines.append("
\n") - lines.append("\n---\n") - return "\n".join(lines) if lines else "Waiting..." - - -def _update_state_with_event(state: dict, message: dict): +def _update_state(state: UIState, message: dict) -> UIState: event = message.get("event") data = message.get("data", {}) + if event == "start_of_agent": agent_id = data.get("agent_id") - agent_name = data.get("agent_name", "unknown") - if agent_id and agent_id not in state["agents"]: - state["agents"][agent_id] = { - "agent_name": agent_name, + if agent_id: + state.agents[agent_id] = { + "agent_name": data.get("agent_name", "Unknown"), "tool_call_order": [], - "tools": {}, + "tools": {} } - state["agent_order"].append(agent_id) - state["current_agent_id"] = agent_id - elif event == "end_of_agent": - # End marker, no special handling needed, keep structure - state["current_agent_id"] = None + state.agent_order.append(agent_id) + state.current_agent_id = agent_id + elif event == "tool_call": - tool_call_id = data.get("tool_call_id") - tool_name = data.get("tool_name", "unknown_tool") - agent_id = state.get("current_agent_id") or ( - state["agent_order"][-1] if state["agent_order"] else None - ) - if not agent_id: - return state - agent = state["agents"].setdefault( - agent_id, {"agent_name": "unknown", "tool_call_order": [], "tools": {}} - ) - tools = agent["tools"] - if tool_call_id not in tools: - tools[tool_call_id] = {"tool_name": tool_name} - agent["tool_call_order"].append(tool_call_id) - entry = tools[tool_call_id] - if tool_name == "show_text" and "delta_input" in data: - delta = data.get("delta_input", {}).get("text", "") - _append_show_text(entry, delta) - elif tool_name == "show_text" and "tool_input" in data: - ti = data.get("tool_input") - text = "" - if isinstance(ti, dict): - text = ti.get("text", "") or ( - (ti.get("result") or {}).get("text") - if isinstance(ti.get("result"), dict) - else "" - ) - elif isinstance(ti, str): - text = ti - if text: - _append_show_text(entry, text) - else: - # Distinguish between input and output: - if "tool_input" in data: - # Could be input (first time) or output with result (second time) - ti = data["tool_input"] - # If contains result, assign to output; otherwise assign to input - if isinstance(ti, dict) and "result" in ti: - entry["output"] = ti + tool_id = data.get("tool_call_id") + agent_id = state.current_agent_id or (state.agent_order[-1] if state.agent_order else None) + + if agent_id and tool_id: + if tool_id not in state.agents[agent_id]["tools"]: + state.agents[agent_id]["tools"][tool_id] = {"tool_name": data.get("tool_name", "unknown")} + state.agents[agent_id]["tool_call_order"].append(tool_id) + + entry = state.agents[agent_id]["tools"][tool_id] + tool_name = entry["tool_name"] + + # Handle text streaming + if tool_name == "show_text": + delta = data.get("delta_input", {}).get("text", "") + full = data.get("tool_input", {}).get("text", "") + entry["content"] = entry.get("content", "") + delta or full + + # Handle I/O + elif "tool_input" in data: + inp = data["tool_input"] + if isinstance(inp, dict) and "result" in inp: + entry["output"] = inp else: - # Only update input if we don't already have valid input data, or if the new data is not empty - if "input" not in entry or not _is_empty_payload(ti): - entry["input"] = ti + entry["input"] = inp + elif event == "message": - # Same incremental text display as show_text, aggregated by message_id - message_id = data.get("message_id") - agent_id = state.get("current_agent_id") or ( - state["agent_order"][-1] if state["agent_order"] else None - ) - if not agent_id: - return state - agent = state["agents"].setdefault( - agent_id, {"agent_name": "unknown", "tool_call_order": [], "tools": {}} - ) - tools = agent["tools"] - if message_id not in tools: - tools[message_id] = {"tool_name": "message"} - agent["tool_call_order"].append(message_id) - entry = tools[message_id] - delta_content = (data.get("delta") or {}).get("content", "") - if isinstance(delta_content, str) and delta_content: - _append_show_text(entry, delta_content) + # Same as show_text essentially + msg_id = data.get("message_id") + agent_id = state.current_agent_id or (state.agent_order[-1] if state.agent_order else None) + if agent_id and msg_id: + if msg_id not in state.agents[agent_id]["tools"]: + state.agents[agent_id]["tools"][msg_id] = {"tool_name": "message"} + state.agents[agent_id]["tool_call_order"].append(msg_id) + + entry = state.agents[agent_id]["tools"][msg_id] + delta = (data.get("delta") or {}).get("content", "") + if delta: + entry["content"] = entry.get("content", "") + delta + elif event == "error": - # Collect errors, display uniformly during rendering - err_text = data.get("error") if isinstance(data, dict) else None - if not err_text: - try: - err_text = json.dumps(data, ensure_ascii=False) - except Exception: - err_text = str(data) - state.setdefault("errors", []).append(err_text) - else: - # Ignore heartbeat or other events - pass + err = data.get("error", str(data)) + state.errors.append(err) + return state +def _render_markdown(state: UIState) -> str: + if not state.agent_order and not state.errors: + return "### System Ready\nWaiting for a task..." -_CANCEL_FLAGS = {} + lines = [] + + if state.errors: + lines.append("### Errors") + for err in state.errors: + lines.append(f"- `{err}`") + lines.append("---") + + for idx, agent_id in enumerate(state.agent_order): + agent = state.agents[agent_id] + name = agent["agent_name"] + + lines.append(f"### Agent: {name}") + + for call_id in agent["tool_call_order"]: + tool = agent["tools"][call_id] + t_name = tool["tool_name"] + + if t_name in ("show_text", "message"): + content = tool.get("content", "") + if content: + lines.append(f"\n{content}") + else: + has_io = "input" in tool or "output" in tool + if not has_io: + lines.append(f"\n*Used `{t_name}`*") + else: + summary = f"[Tool] {t_name}" + lines.append(f"\n
{summary}") + + if "input" in tool: + lines.append("**Input:**") + lines.append(f"```json\n{json.dumps(tool['input'], ensure_ascii=False, indent=2)}\n```") + + if "output" in tool: + lines.append("**Output:**") + out_str = json.dumps(tool['output'], ensure_ascii=False, indent=2) + # Truncate huge outputs in UI + if len(out_str) > 2000: + out_str = out_str[:2000] + "\n... [truncated]" + lines.append(f"```json\n{out_str}\n```") + + lines.append("
") + lines.append("---") + + return "\n".join(lines) + +# ================= Gradio Interface ================= + +_CANCEL_FLAGS: Dict[str, bool] = {} _CANCEL_LOCK = threading.Lock() - -def _set_cancel_flag(task_id: str): - with _CANCEL_LOCK: - _CANCEL_FLAGS[task_id] = True - - -def _reset_cancel_flag(task_id: str): +def _set_cancel(task_id: str, status: bool): with _CANCEL_LOCK: - _CANCEL_FLAGS[task_id] = False + _CANCEL_FLAGS[task_id] = status - -async def _disconnect_check_for_task(task_id: str): +def _get_cancel(task_id: str) -> bool: with _CANCEL_LOCK: return _CANCEL_FLAGS.get(task_id, False) - -def _spinner_markup(running: bool) -> str: - if not running: - return "" - return ( - '\n\n
' - '
' - "Generating..." - "
\n\n" - ) - - -async def gradio_run(query: str, ui_state: Optional[dict]): +async def gradio_run(query: str, history: list): query = replace_chinese_punctuation(query or "") if contains_chinese(query): yield ( - "we only support English input for the time being.", + _render_markdown(UIState(errors=["Chinese input is currently unsupported."])), gr.update(interactive=True), gr.update(interactive=False), - ui_state or {"task_id": None}, + history ) return + task_id = str(uuid.uuid4()) - _reset_cancel_flag(task_id) - if not ui_state: - ui_state = {"task_id": task_id} - else: - ui_state = {**ui_state, "task_id": task_id} - state = _init_render_state() - # Initial: disable Run, enable Stop, and show spinner at bottom of text - yield ( - _render_markdown(state) + _spinner_markup(True), - gr.update(interactive=False), - gr.update(interactive=True), - ui_state, - ) - async for message in stream_events_optimized( - task_id, query, None, lambda: _disconnect_check_for_task(task_id) - ): - state = _update_state_with_event(state, message) - md = _render_markdown(state) + _set_cancel(task_id, False) + + state = UIState(task_id=task_id) + queue = ThreadSafeAsyncQueue() + runner = PipelineRunner(task_id, query, queue) + runner.start() + + try: + # UI Update: Started yield ( - md + _spinner_markup(True), + _render_markdown(state) + "\n\n*Initializing...*", gr.update(interactive=False), gr.update(interactive=True), - ui_state, + history + [[query, None]] ) - # End: enable Run, disable Stop, remove spinner - yield ( - _render_markdown(state), - gr.update(interactive=True), - gr.update(interactive=False), - ui_state, - ) + last_heartbeat = time.time() + + while True: + if _get_cancel(task_id): + runner.cancel() + yield ( + _render_markdown(state) + "\n\n*Stopped by user.*", + gr.update(interactive=True), + gr.update(interactive=False), + history + ) + break + + try: + # Wait for message with timeout to allow checking cancel flag + msg = await asyncio.wait_for(queue.get(), timeout=0.2) + + if msg is None: + # End of stream + yield ( + _render_markdown(state), + gr.update(interactive=True), + gr.update(interactive=False), + history + ) + break + + state = _update_state(state, msg) + yield ( + _render_markdown(state), + gr.update(interactive=False), + gr.update(interactive=True), + history + ) + + last_heartbeat = time.time() -def stop_current(ui_state: Optional[dict]): - tid = (ui_state or {}).get("task_id") - if tid: - _set_cancel_flag(tid) - # Immediately switch button availability: enable Run, disable Stop - return ( - gr.update(interactive=True), - gr.update(interactive=False), - ) + except asyncio.TimeoutError: + if time.time() - last_heartbeat > 30: + # Pipeline seems stalled + runner.cancel() + yield ( + _render_markdown(state) + "\n\n*Timeout waiting for agent response.*", + gr.update(interactive=True), + gr.update(interactive=False), + history + ) + break + finally: + runner.cleanup() + _set_cancel(task_id, False) + +def stop_click(history: list): + # The actual cancellation is handled in the async generator via flags + # This just updates buttons immediately to give feedback + return gr.update(interactive=True), gr.update(interactive=False) +def clear_history(): + return [], None, UIState() + +# ================= Main Application ================= def build_demo(): custom_css = """ - #log-view { border: 1px solid #ececec; padding: 12px; border-radius: 8px; scroll-behavior: smooth; } + #log-view { + border: 1px solid #e5e7eb; + border-radius: 8px; + padding: 16px; + background-color: #f9fafb; + max-height: 600px; + overflow-y: auto; + font-family: 'Segoe UI', sans-serif; + } + details { margin-bottom: 8px; border: 1px solid #e5e7eb; border-radius: 4px; padding: 4px; } + summary { cursor: pointer; font-weight: 600; color: #374151; } + pre { background: #f3f4f6; padding: 8px; border-radius: 4px; overflow-x: auto; } """ - with gr.Blocks(css=custom_css) as demo: - gr.Markdown(""" - **MiroFlow DeepResearch - Gradio Demo** - Enter an English question and observe Agents and tool calls in real time (Markdown + collapsible sections). - """) + + with gr.Blocks(css=custom_css, title="MiroFlow DeepResearch") as demo: + gr.Markdown("# MiroFlow DeepResearch") + gr.Markdown("Multi-agent research pipeline. Enter a query to start the agents.") + with gr.Row(): - inp = gr.Textbox(lines=3, label="Question (English only)") + with gr.Column(scale=4): + query_input = gr.Textbox(lines=3, label="Query", placeholder="e.g., What are the latest advancements in fusion energy?") + with gr.Column(scale=1): + run_btn = gr.Button("Run", variant="primary", size="lg") + stop_btn = gr.Button("Stop", variant="stop", interactive=False, size="lg") + + with gr.Row(): + clear_btn = gr.Button("Clear History", size="sm") + with gr.Row(): - run_btn = gr.Button("Run") - stop_btn = gr.Button("Stop", variant="stop", interactive=False) - out_md = gr.Markdown("", elem_id="log-view") - ui_state = gr.State({"task_id": None}) - # run: outputs -> markdown, run_btn(update), stop_btn(update), ui_state + # We use a Markdown component to render the logs nicely + output_log = gr.Markdown(elem_id="log-view", value="### System Ready") + + # State storage (using a simple dict to mimic internal state) + ui_state = gr.State(value=UIState()) + chat_history = gr.State(value=[]) + + # Event bindings run_btn.click( fn=gradio_run, - inputs=[inp, ui_state], - outputs=[out_md, run_btn, stop_btn, ui_state], + inputs=[query_input, chat_history], + outputs=[output_log, run_btn, stop_btn, chat_history] + ) + + stop_btn.click( + fn=stop_click, + inputs=[chat_history], + outputs=[run_btn, stop_btn] + ).then( # Then update the log to show stopping + lambda s: _render_markdown(s) + "\n\n*Stopping...*", + inputs=[ui_state], + outputs=[output_log] ) - # stop: outputs -> run_btn(update), stop_btn(update) - stop_btn.click(fn=stop_current, inputs=[ui_state], outputs=[run_btn, stop_btn]) - return demo + clear_btn.click( + fn=clear_history, + outputs=[chat_history, output_log, ui_state] + ) + + return demo if __name__ == "__main__": demo = build_demo() - host = os.getenv("HOST", "0.0.0.0") - port = int(os.getenv("PORT", "8000")) - demo.queue().launch(server_name=host, server_port=port) + demo.queue() + demo.launch( + server_name=os.getenv("HOST", "0.0.0.0"), + server_port=int(os.getenv("PORT", "8000")), + show_error=True + ) +```