diff --git a/apps/gradio-demo/main.py b/apps/gradio-demo/main.py
index 77511b73..79ff2cd1 100644
--- a/apps/gradio-demo/main.py
+++ b/apps/gradio-demo/main.py
@@ -1,3 +1,8 @@
+
+
+Here is the improved code with all emojis removed, maintaining a clean, professional text-based interface.
+
+```python
import asyncio
import json
import logging
@@ -6,651 +11,532 @@
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
+from dataclasses import dataclass, field
from pathlib import Path
-from typing import AsyncGenerator, List, Optional
+from typing import Any, AsyncGenerator, Dict, List, Optional, TypedDict
import gradio as gr
from dotenv import load_dotenv
from hydra import compose, initialize_config_dir
from omegaconf import DictConfig
+
+# Assuming these are local project imports
from src.config.settings import expose_sub_agents_as_tools
from src.core.pipeline import create_pipeline_components, execute_task_pipeline
from utils import contains_chinese, replace_chinese_punctuation
-# Create global cleanup thread pool for operations that won't be affected by asyncio.cancel
-cleanup_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="cleanup")
+# ================= Configuration & Logging =================
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
logger = logging.getLogger(__name__)
-
-# Load environment variables from .env file
load_dotenv()
# Global Hydra initialization flag
_hydra_initialized = False
+# ================= Types & Data Structures =================
+
+class ToolCallData(TypedDict, total=False):
+ tool_name: str
+ input: Any
+ output: Any
+ content: str # Used for show_text/message
+
+class AgentData(TypedDict):
+ agent_name: str
+ tool_call_order: List[str]
+ tools: Dict[str, ToolCallData]
+
+@dataclass
+class UIState:
+ task_id: Optional[str] = None
+ agent_order: List[str] = field(default_factory=list)
+ agents: Dict[str, AgentData] = field(default_factory=dict)
+ current_agent_id: Optional[str] = None
+ errors: List[str] = field(default_factory=list)
+
+# ================= Configuration Loading =================
def load_miroflow_config(config_overrides: Optional[dict] = None) -> DictConfig:
- """
- Load the full MiroFlow configuration using Hydra, similar to how benchmarks work.
- """
+ """Load the full MiroFlow configuration using Hydra."""
global _hydra_initialized
- # Get the path to the miroflow agent config directory
miroflow_config_dir = Path(__file__).parent.parent / "miroflow-agent" / "conf"
- miroflow_config_dir = miroflow_config_dir.resolve()
- print("config dir", miroflow_config_dir)
-
if not miroflow_config_dir.exists():
- raise FileNotFoundError(
- f"MiroFlow config directory not found: {miroflow_config_dir}"
- )
+ raise FileNotFoundError(f"MiroFlow config directory not found: {miroflow_config_dir}")
- # Initialize Hydra if not already done
if not _hydra_initialized:
try:
- initialize_config_dir(
- config_dir=str(miroflow_config_dir), version_base=None
- )
+ initialize_config_dir(config_dir=str(miroflow_config_dir), version_base=None)
_hydra_initialized = True
except Exception as e:
logger.warning(f"Hydra already initialized or error: {e}")
- # Compose configuration with environment variable overrides
- overrides = []
-
- # Add environment variable based overrides (refer to scripts/debug.sh)
- llm_provider = os.getenv(
- "DEFAULT_LLM_PROVIDER", "qwen"
- ) # debug.sh defaults to qwen
- model_name = os.getenv(
- "DEFAULT_MODEL_NAME", "MiroThinker"
- ) # debug.sh default model
- agent_set = os.getenv("DEFAULT_AGENT_SET", "evaluation") # debug.sh uses evaluation
- base_url = os.getenv("BASE_URL", "http://localhost:11434")
- print("base_url", base_url)
+ overrides = [
+ f"llm.provider={os.getenv('DEFAULT_LLM_PROVIDER', 'qwen')}",
+ f"llm.model_name={os.getenv('DEFAULT_MODEL_NAME', 'MiroThinker')}",
+ f"llm.base_url={os.getenv('BASE_URL', 'http://localhost:11434')}",
+ f"agent={os.getenv('DEFAULT_AGENT_SET', 'evaluation')}",
+ "benchmark=gaia-validation",
+ "+pricing=default",
+ ]
# Map provider names to config files
- provider_config_map = {
- "anthropic": "claude",
- "openai": "openai",
- "deepseek": "deepseek",
- "qwen": "qwen-3",
- }
+ provider_map = {"anthropic": "claude", "openai": "openai", "deepseek": "deepseek", "qwen": "qwen-3"}
+ if os.getenv("DEFAULT_LLM_PROVIDER") in provider_map:
+ overrides[0] = f"llm={provider_map[os.getenv('DEFAULT_LLM_PROVIDER')]}"
- llm_config = provider_config_map.get(
- llm_provider, "qwen-3"
- ) # default changed to qwen-3
- overrides.extend(
- [
- f"llm={llm_config}",
- f"llm.provider={llm_provider}",
- f"llm.model_name={model_name}",
- f"llm.base_url={base_url}",
- f"agent={agent_set}", # use evaluation instead of default
- "benchmark=gaia-validation", # refer to debug.sh
- "+pricing=default",
- ]
- )
-
- # Add config overrides from request
if config_overrides:
- for key, value in config_overrides.items():
- if isinstance(value, dict):
- for subkey, subvalue in value.items():
- overrides.append(f"{key}.{subkey}={subvalue}")
+ for k, v in config_overrides.items():
+ if isinstance(v, dict):
+ overrides.extend([f"{k}.{sk}={sv}" for sk, sv in v.items()])
else:
- overrides.append(f"{key}={value}")
+ overrides.append(f"{k}={v}")
try:
- cfg = compose(config_name="config", overrides=overrides)
- return cfg
+ return compose(config_name="config", overrides=overrides)
except Exception as e:
logger.error(f"Failed to compose Hydra config: {e}")
- exit()
-
+ raise
-# pre load main agent tool definitions to speed up the first request
+# Pre-load resources
cfg = load_miroflow_config(None)
-# Create pipeline components
-main_agent_tool_manager, sub_agent_tool_managers, output_formatter = (
- create_pipeline_components(cfg)
-)
+main_agent_tool_manager, sub_agent_tool_managers, output_formatter = create_pipeline_components(cfg)
tool_definitions = asyncio.run(main_agent_tool_manager.get_all_tool_definitions())
tool_definitions += expose_sub_agents_as_tools(cfg.agent.sub_agents)
-# pre load sub agent tool definitions to speed up the first request
sub_agent_tool_definitions = {
- name: asyncio.run(sub_agent_tool_manager.get_all_tool_definitions())
- for name, sub_agent_tool_manager in sub_agent_tool_managers.items()
+ name: asyncio.run(mgr.get_all_tool_definitions())
+ for name, mgr in sub_agent_tool_managers.items()
}
+# ================= Core Logic Classes =================
class ThreadSafeAsyncQueue:
- """Thread-safe async queue wrapper"""
-
+ """Wrapper for asyncio.Queue to handle thread-safe puts."""
def __init__(self):
self._queue = asyncio.Queue()
- self._loop = None
+ self._loop = asyncio.get_running_loop()
self._closed = False
- def set_loop(self, loop):
- self._loop = loop
+ async def put(self, item: Any):
+ if not self._closed:
+ await self._queue.put(item)
- async def put(self, item):
- """Put data safely from any thread"""
+ def put_nowait_threadsafe(self, item: Any):
+ """Schedules a put from a different thread."""
if self._closed:
return
- await self._queue.put(item)
-
- def put_nowait_threadsafe(self, item):
- """Put data from other threads"""
- if self._closed or not self._loop:
+ if not self._loop.is_running():
return
- self._loop.call_soon_threadsafe(lambda: asyncio.create_task(self.put(item)))
+ # We schedule the coroutine on the loop
+ asyncio.run_coroutine_threadsafe(self.put(item), self._loop)
- async def get(self):
+ async def get(self) -> Any:
return await self._queue.get()
def close(self):
self._closed = True
+class PipelineRunner:
+ """Manages the execution of the pipeline in a separate thread."""
+ def __init__(self, task_id: str, query: str, queue: ThreadSafeAsyncQueue):
+ self.task_id = task_id
+ self.query = query
+ self.queue = queue
+ self.cancel_event = threading.Event()
+ self.executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="pipeline")
+ self.future = None
-def filter_google_search_organic(organic: List[dict]) -> List[dict]:
- """
- Filter google search organic results to remove unnecessary information
- """
- result = []
- for item in organic:
- result.append(
- {
- "title": item.get("title", ""),
- "link": item.get("link", ""),
- }
- )
- return result
-
-
-def is_scrape_error(result: str) -> bool:
- """
- Check if the scrape result is an error
- """
- try:
- json.loads(result)
- return False
- except json.JSONDecodeError:
- return True
-
-
-def filter_message(message: dict) -> dict:
- """
- Filter message to remove unnecessary information
- """
- if message["event"] == "tool_call":
- tool_name = message["data"].get("tool_name")
- tool_input = message["data"].get("tool_input")
- if (
- tool_name == "google_search"
- and isinstance(tool_input, dict)
- and "result" in tool_input
- ):
- result_dict = json.loads(tool_input["result"])
- if "organic" in result_dict:
- new_result = {
- "organic": filter_google_search_organic(result_dict["organic"])
- }
- message["data"]["tool_input"]["result"] = json.dumps(
- new_result, ensure_ascii=False
- )
- if (
- tool_name in ["scrape", "scrape_website"]
- and isinstance(tool_input, dict)
- and "result" in tool_input
- ):
- # if error, it can not be json
- if is_scrape_error(tool_input["result"]):
- message["data"]["tool_input"] = {"error": tool_input["result"]}
- else:
- message["data"]["tool_input"] = {}
- return message
-
-
-async def stream_events_optimized(
- task_id: str, query: str, _: Optional[dict] = None, disconnect_check=None
-) -> AsyncGenerator[dict, None]:
- """Optimized event stream generator that directly outputs structured events, no longer wrapped as SSE strings."""
- workflow_id = task_id
- last_send_time = time.time()
- last_heartbeat_time = time.time()
+ def start(self):
+ self.future = self.executor.submit(self._run_pipeline_thread)
- # Create thread-safe queue
- stream_queue = ThreadSafeAsyncQueue()
- stream_queue.set_loop(asyncio.get_event_loop())
+ def cancel(self):
+ self.cancel_event.set()
+ # We don't wait for future result here to allow UI to update immediately
- cancel_event = threading.Event()
-
- def run_pipeline_in_thread():
+ def _run_pipeline_thread(self):
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
- class ThreadQueueWrapper:
- def __init__(self, thread_queue, cancel_event):
- self.thread_queue = thread_queue
- self.cancel_event = cancel_event
-
+ # Wrapper to interface with thread-safe queue
+ class QueueProxy:
async def put(self, item):
- if self.cancel_event.is_set():
- logger.info("Pipeline cancelled, stopping execution")
- return
- self.thread_queue.put_nowait_threadsafe(filter_message(item))
-
- wrapper_queue = ThreadQueueWrapper(stream_queue, cancel_event)
-
- global cfg
- global main_agent_tool_manager
- global sub_agent_tool_managers
- global output_formatter
- global tool_definitions
- global sub_agent_tool_definitions
-
- async def pipeline_with_cancellation():
- pipeline_task = asyncio.create_task(
+ if self.cancel_event.is_set(): return
+ # Filter before sending to UI
+ self.queue.put_nowait_threadsafe(filter_message(item))
+
+ proxy = QueueProxy()
+ proxy.cancel_event = self.cancel_event
+ proxy.queue = self.queue
+
+ async def run():
+ task = asyncio.create_task(
execute_task_pipeline(
- cfg=cfg,
- task_id=workflow_id,
- task_description=query,
- task_file_name=None,
- main_agent_tool_manager=main_agent_tool_manager,
+ cfg=cfg, task_id=self.task_id, task_description=self.query,
+ task_file_name=None, main_agent_tool_manager=main_agent_tool_manager,
sub_agent_tool_managers=sub_agent_tool_managers,
- output_formatter=output_formatter,
- stream_queue=wrapper_queue,
+ output_formatter=output_formatter, stream_queue=proxy,
log_dir=os.getenv("LOG_DIR", "logs/api-server"),
tool_definitions=tool_definitions,
sub_agent_tool_definitions=sub_agent_tool_definitions,
)
)
-
- async def check_cancellation():
- while not cancel_event.is_set():
- await asyncio.sleep(0.5)
- logger.info("Cancel event detected, cancelling pipeline")
- pipeline_task.cancel()
-
- cancel_task = asyncio.create_task(check_cancellation())
-
- try:
- done, pending = await asyncio.wait(
- [pipeline_task, cancel_task],
- return_when=asyncio.FIRST_COMPLETED,
- )
- for task in pending:
- task.cancel()
- for task in done:
- if task == pipeline_task:
- try:
- await task
- except asyncio.CancelledError:
- logger.info("Pipeline task was cancelled")
- except Exception as e:
- logger.error(f"Pipeline execution error: {e}")
- pipeline_task.cancel()
- cancel_task.cancel()
-
- loop.run_until_complete(pipeline_with_cancellation())
+
+ # Cancellation watcher
+ while not self.cancel_event.is_set():
+ await asyncio.sleep(0.5)
+ if task.done(): break
+
+ if self.cancel_event.is_set():
+ task.cancel()
+ try: await task
+ except asyncio.CancelledError: pass
+
+ loop.run_until_complete(run())
except Exception as e:
- if not cancel_event.is_set():
- logger.error(f"Pipeline error: {e}", exc_info=True)
- stream_queue.put_nowait_threadsafe(
- {
- "event": "error",
- "data": {"error": str(e), "workflow_id": workflow_id},
- }
- )
+ logger.error(f"Pipeline Thread Exception: {e}", exc_info=True)
+ self.queue.put_nowait_threadsafe({
+ "event": "error", "data": {"error": str(e), "workflow_id": self.task_id}
+ })
finally:
- stream_queue.put_nowait_threadsafe(None)
- if "loop" in locals():
- loop.close()
-
- executor = ThreadPoolExecutor(max_workers=1)
- future = executor.submit(run_pipeline_in_thread)
-
- try:
- while True:
- try:
- if disconnect_check and await disconnect_check():
- logger.info("Client disconnected, stopping pipeline")
- cancel_event.set()
- break
- message = await asyncio.wait_for(stream_queue.get(), timeout=0.1)
- if message is None:
- logger.info("Pipeline completed")
- break
- yield message
- last_send_time = time.time()
- except asyncio.TimeoutError:
- current_time = time.time()
- if current_time - last_send_time > 300:
- logger.info("Stream timeout")
- break
- if future.done():
- try:
- message = stream_queue._queue.get_nowait()
- if message is not None:
- yield message
- continue
- except Exception:
- break
- if current_time - last_heartbeat_time >= 15:
- yield {
- "event": "heartbeat",
- "data": {"timestamp": current_time, "workflow_id": workflow_id},
- }
- last_heartbeat_time = current_time
- except Exception as e:
- logger.error(f"Stream error: {e}", exc_info=True)
- yield {
- "event": "error",
- "data": {"workflow_id": workflow_id, "error": f"Stream error: {str(e)}"},
- }
- finally:
- cancel_event.set()
- stream_queue.close()
- try:
- future.result(timeout=1.0)
- except Exception:
- pass
- executor.shutdown(wait=False)
-
+ self.queue.put_nowait_threadsafe(None) # Sentinel for end of stream
+ loop.close()
-# ========================= Gradio Integration =========================
+ def cleanup(self):
+ if self.future:
+ self.future.cancel()
+ self.executor.shutdown(wait=False)
+# ================= Helper Functions =================
-def _init_render_state():
- return {
- "agent_order": [],
- "agents": {}, # agent_id -> {"agent_name": str, "tool_call_order": [], "tools": {tool_call_id: {...}}}
- "current_agent_id": None,
- "errors": [],
- }
-
+def filter_google_search_organic(organic: List[dict]) -> List[dict]:
+ return [{"title": i.get("title", ""), "link": i.get("link", "")} for i in organic]
-def _append_show_text(tool_entry: dict, delta: str):
- existing = tool_entry.get("content", "")
- tool_entry["content"] = existing + delta
+def filter_message(message: dict) -> dict:
+ """Sanitize message data for UI rendering."""
+ if message.get("event") != "tool_call":
+ return message
+
+ data = message.get("data", {})
+ tool_name = data.get("tool_name")
+ tool_input = data.get("tool_input")
+ if not isinstance(tool_input, dict): return message
-def _is_empty_payload(value) -> bool:
- if value is None:
- return True
- if isinstance(value, str):
- stripped = value.strip()
- return stripped == "" or stripped in ("{}", "[]")
- if isinstance(value, (dict, list, tuple, set)):
- return len(value) == 0
- return False
+ # Filter Search Results
+ if tool_name == "google_search" and "result" in tool_input:
+ try:
+ res = json.loads(tool_input["result"])
+ if "organic" in res:
+ res["organic"] = filter_google_search_organic(res["organic"])
+ tool_input["result"] = json.dumps(res, ensure_ascii=False)
+ except json.JSONDecodeError: pass
+
+ # Filter Scrape Results
+ if tool_name in ["scrape", "scrape_website"] and "result" in tool_input:
+ try:
+ json.loads(tool_input["result"]) # Check validity
+ # If valid JSON, we might want to truncate it or just hide it entirely
+ # For now, just clear it to save UI space, or keep if small
+ if len(tool_input["result"]) > 5000:
+ tool_input["result"] = tool_input["result"][:5000] + "... [truncated]"
+ except json.JSONDecodeError:
+ # It's an error text
+ tool_input = {"error": tool_input["result"]}
+
+ message["data"]["tool_input"] = tool_input
+ return message
+# ================= UI State Logic =================
-def _render_markdown(state: dict) -> str:
- lines = []
- emoji_cycle = ["๐ง ", "๐", "๐ ๏ธ", "๐", "๐ค", "๐งช", "๐", "๐งญ", "โ๏ธ", "๐งฎ"]
- # Render errors first if any
- if state.get("errors"):
- lines.append("### โ Errors")
- for idx, err in enumerate(state["errors"], start=1):
- lines.append(f"- **Error {idx}**: {err}")
- lines.append("\n---\n")
- for idx, agent_id in enumerate(state.get("agent_order", [])):
- agent = state["agents"].get(agent_id, {})
- agent_name = agent.get("agent_name", "unknown")
- emoji = emoji_cycle[idx % len(emoji_cycle)]
- lines.append(f"### {emoji} Agent: {agent_name}")
- for call_id in agent.get("tool_call_order", []):
- call = agent["tools"].get(call_id, {})
- tool_name = call.get("tool_name", "unknown_tool")
- if tool_name in ("show_text", "message"):
- content = call.get("content", "")
- if content:
- lines.append(content)
- else:
- tool_input = call.get("input")
- tool_output = call.get("output")
- has_input = not _is_empty_payload(tool_input)
- has_output = not _is_empty_payload(tool_output)
- if not has_input and not has_output:
- # No parameters, only show tool name with emoji on separate line
- if tool_name == "Partial Summary":
- lines.append("\n๐กPartial Summary\n")
- else:
- lines.append(f"\n๐ง{tool_name}\n")
- else:
- # Show as collapsible details for any tool with input or output
- if tool_name == "Partial Summary":
- summary = f"๐ก{tool_name} ({call_id[:8]})"
- else:
- summary = f"๐ง{tool_name} ({call_id[:8]})"
- lines.append(f"\n{summary}
")
- if has_input:
- pretty = json.dumps(tool_input, ensure_ascii=False, indent=2)
- lines.append("\n**Input**:\n")
- lines.append(f"```json\n{pretty}\n```")
- if has_output:
- pretty = json.dumps(tool_output, ensure_ascii=False, indent=2)
- lines.append("\n**Output**:\n")
- lines.append(f"```json\n{pretty}\n```")
- lines.append("{summary}
")
+
+ if "input" in tool:
+ lines.append("**Input:**")
+ lines.append(f"```json\n{json.dumps(tool['input'], ensure_ascii=False, indent=2)}\n```")
+
+ if "output" in tool:
+ lines.append("**Output:**")
+ out_str = json.dumps(tool['output'], ensure_ascii=False, indent=2)
+ # Truncate huge outputs in UI
+ if len(out_str) > 2000:
+ out_str = out_str[:2000] + "\n... [truncated]"
+ lines.append(f"```json\n{out_str}\n```")
+
+ lines.append("