From 728ede3b83c3781bbe55e7348fb788ecd37892d9 Mon Sep 17 00:00:00 2001 From: rickychen-infinirc Date: Tue, 27 Jan 2026 19:48:35 +0800 Subject: [PATCH 1/3] feat: replace LLM agent with Bayesian optimization for auto-tuning - Add BayesianTuningService using Optuna TPE sampler - Fix SGLang container startup command - Add real-time logs to tuning jobs - Simplify AutoTuning UI --- backend/app/api/auto_tuning.py | 25 +- backend/app/models/tuning.py | 42 +- backend/app/schemas/tuning.py | 60 +- backend/app/services/bayesian_tuner.py | 986 +++++++++++++++ backend/app/services/deployer.py | 7 +- frontend/src/pages/AutoTuning.tsx | 1532 ++++++++++++------------ 6 files changed, 1855 insertions(+), 797 deletions(-) create mode 100644 backend/app/services/bayesian_tuner.py diff --git a/backend/app/api/auto_tuning.py b/backend/app/api/auto_tuning.py index 7a4dcc6..f6b2f99 100644 --- a/backend/app/api/auto_tuning.py +++ b/backend/app/api/auto_tuning.py @@ -67,6 +67,13 @@ def tuning_job_to_response(job: TuningJob, include_conversation: bool = True) -> conversation_log = [ConversationMessage(**msg) for msg in job.conversation_log] + # Parse logs + logs = None + if job.logs: + from app.schemas.tuning import TuningLogEntry + + logs = [TuningLogEntry(**log) for log in job.logs] + return TuningJobResponse( id=job.id, model_id=job.model_id, @@ -79,6 +86,7 @@ def tuning_job_to_response(job: TuningJob, include_conversation: bool = True) -> progress=progress, best_config=job.best_config, all_results=job.all_results, + logs=logs, conversation_log=conversation_log, created_at=job.created_at, updated_at=job.updated_at, @@ -677,15 +685,24 @@ class DummyJob: # ============================================================================ -# Auto-Tuning Agent Runner +# Auto-Tuning Runner # ============================================================================ async def run_auto_tuning(job_id: int, llm_config: dict | None = None): - """Run the LLM-driven Auto-Tuning Agent""" - from app.services.tuning_agent import run_tuning_agent + """Run the Auto-Tuning process using Bayesian optimization. + + Uses Optuna's TPE (Tree-structured Parzen Estimator) for efficient + hyperparameter search instead of LLM Agent. + + Args: + job_id: The tuning job ID + llm_config: Legacy parameter (ignored, kept for API compatibility) + """ + from app.services.bayesian_tuner import run_bayesian_tuning - await run_tuning_agent(job_id, llm_config) + # Default to 10 trials for good optimization coverage + await run_bayesian_tuning(job_id, n_trials=10) async def _run_benchmark_test(deployment: Deployment, request: BenchmarkRequest) -> dict: diff --git a/backend/app/models/tuning.py b/backend/app/models/tuning.py index 14696d5..c5244d4 100644 --- a/backend/app/models/tuning.py +++ b/backend/app/models/tuning.py @@ -42,9 +42,26 @@ class TuningJob(Base): model_id: Mapped[int] = mapped_column(Integer, ForeignKey("llm_models.id"), nullable=False) worker_id: Mapped[int] = mapped_column(Integer, ForeignKey("workers.id"), nullable=False) optimization_target: Mapped[str] = mapped_column( - String(50), default=OptimizationTarget.BALANCED.value + String(50), default=OptimizationTarget.THROUGHPUT.value ) + # Tuning configuration - which frameworks and parameters to test + # Format: { + # "engines": ["vllm", "sglang"], + # "parameters": { + # "tensor_parallel_size": [1, 2], + # "gpu_memory_utilization": [0.85, 0.90], + # "max_model_len": [4096, 8192] + # }, + # "benchmark": { + # "duration_seconds": 60, + # "input_length": 512, + # "output_length": 128, + # "concurrency": [1, 4, 8] + # } + # } + tuning_config: Mapped[dict | None] = mapped_column(JSON, nullable=True) + # Job status status: Mapped[str] = mapped_column(String(50), default=TuningJobStatus.PENDING.value) status_message: Mapped[str | None] = mapped_column(Text, nullable=True) @@ -52,16 +69,33 @@ class TuningJob(Base): total_steps: Mapped[int] = mapped_column(Integer, default=5) # Progress details (JSON for flexibility) + # Format: { + # "step": 3, + # "total_steps": 5, + # "step_name": "benchmarking", + # "configs_tested": 5, + # "configs_total": 12, + # "current_config": {"engine": "vllm", "tensor_parallel_size": 2, ...}, + # "results": [{"config": {...}, "throughput_tps": 1234.5, ...}, ...] + # } progress: Mapped[dict | None] = mapped_column(JSON, nullable=True) - # Results + # Results - final sorted results + # Format: [{"rank": 1, "engine": "vllm", "config": {...}, "throughput_tps": 1500, ...}, ...] best_config: Mapped[dict | None] = mapped_column(JSON, nullable=True) all_results: Mapped[list | None] = mapped_column(JSON, nullable=True) - # Agent conversation log (for UI display) - # Format: [{"role": "user"|"assistant"|"tool", "content": "...", "tool_calls": [...], "timestamp": "..."}] + # Agent conversation ID (links to conversations table for Agent Chat display) + conversation_id: Mapped[int | None] = mapped_column( + Integer, ForeignKey("conversations.id"), nullable=True + ) + + # Legacy: Agent conversation log (for backward compatibility) conversation_log: Mapped[list | None] = mapped_column(JSON, nullable=True) + # Tuning logs for frontend display + logs: Mapped[list | None] = mapped_column(JSON, nullable=True) + # Metadata created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), default=lambda: datetime.now(UTC) diff --git a/backend/app/schemas/tuning.py b/backend/app/schemas/tuning.py index 43513bc..616f872 100644 --- a/backend/app/schemas/tuning.py +++ b/backend/app/schemas/tuning.py @@ -25,17 +25,52 @@ class LLMConfig(BaseModel): # ============================================================================ +class TuningParameters(BaseModel): + """Parameters to test during tuning""" + + tensor_parallel_size: list[int] = Field( + default=[1], description="Tensor parallel sizes to test" + ) + gpu_memory_utilization: list[float] = Field( + default=[0.85, 0.90], description="GPU memory utilization values to test (0.0-1.0)" + ) + max_model_len: list[int] = Field(default=[4096], description="Max model lengths to test") + max_num_seqs: list[int] | None = Field( + default=None, description="Max concurrent sequences to test" + ) + + +class BenchmarkSettings(BaseModel): + """Benchmark test settings""" + + duration_seconds: int = Field(default=60, ge=10, le=300, description="Test duration per config") + input_length: int = Field(default=512, ge=64, le=8192, description="Input token length") + output_length: int = Field(default=128, ge=16, le=2048, description="Output token length") + concurrency: list[int] = Field(default=[1, 4], description="Concurrency levels to test") + + +class TuningConfig(BaseModel): + """Full tuning configuration""" + + engines: list[str] = Field( + default=["vllm"], description="Inference engines to test: vllm, sglang, ollama" + ) + parameters: TuningParameters = Field(default_factory=TuningParameters) + benchmark: BenchmarkSettings = Field(default_factory=BenchmarkSettings) + + class TuningJobCreate(BaseModel): """Schema for creating a tuning job""" model_id: int = Field(..., description="ID of the model to tune") worker_id: int = Field(..., description="ID of the worker to use") optimization_target: OptimizationTarget = Field( - default=OptimizationTarget.BALANCED, description="What to optimize for" + default=OptimizationTarget.THROUGHPUT, description="What to optimize for" ) - llm_config: LLMConfig | None = Field( - None, description="LLM configuration for the agent (uses chat panel's selected model)" + tuning_config: TuningConfig = Field( + default_factory=TuningConfig, description="Tuning configuration" ) + llm_config: LLMConfig | None = Field(None, description="LLM configuration for the agent") class TuningJobProgress(BaseModel): @@ -44,12 +79,18 @@ class TuningJobProgress(BaseModel): step: int total_steps: int step_name: str - step_description: str + step_description: str | None = None configs_tested: int = 0 configs_total: int = 0 current_config: dict | None = None best_config_so_far: dict | None = None best_score_so_far: float | None = None + # Bayesian optimization specific fields + completed_trials: int | None = None + successful_trials: int | None = None + deployment_status: str | None = None + deployment_message: str | None = None + elapsed_seconds: int | None = None class ConversationMessage(BaseModel): @@ -63,6 +104,14 @@ class ConversationMessage(BaseModel): name: str | None = None # Tool name for tool responses +class TuningLogEntry(BaseModel): + """A single log entry""" + + timestamp: str + level: str + message: str + + class TuningJobResponse(BaseModel): """Schema for tuning job response""" @@ -70,6 +119,7 @@ class TuningJobResponse(BaseModel): model_id: int worker_id: int optimization_target: str + tuning_config: TuningConfig | None = None status: str status_message: str | None = None current_step: int @@ -77,6 +127,8 @@ class TuningJobResponse(BaseModel): progress: TuningJobProgress | None = None best_config: dict | None = None all_results: list | None = None + logs: list[TuningLogEntry] | None = None + conversation_id: int | None = None conversation_log: list[ConversationMessage] | None = None created_at: datetime updated_at: datetime diff --git a/backend/app/services/bayesian_tuner.py b/backend/app/services/bayesian_tuner.py new file mode 100644 index 0000000..1abe1d7 --- /dev/null +++ b/backend/app/services/bayesian_tuner.py @@ -0,0 +1,986 @@ +""" +Bayesian Optimization-based Auto-Tuning Service + +Uses Optuna's TPE (Tree-structured Parzen Estimator) for efficient +hyperparameter search. This replaces the LLM Agent approach with +systematic Bayesian optimization while maintaining MCP-compatible +tool interfaces. + +Key concepts: +- Bayesian Optimization: Uses surrogate model + acquisition function +- Filter-Scorer Architecture: Filters invalid configs, scores valid ones +- Knowledge Transfer: Uses historical results to warm-start optimization +""" + +import asyncio +import json +import logging +import time +from dataclasses import dataclass, field +from datetime import UTC, datetime +from enum import Enum +from typing import Any + +import httpx +import optuna +from optuna.samplers import TPESampler +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.database import async_session_maker +from app.models.deployment import Deployment, DeploymentStatus +from app.models.tuning import OptimizationTarget, PerformanceKnowledge, TuningJob, TuningJobStatus +from app.models.worker import Worker +from app.services.deployer import DeployerService + +# Configure logging with detailed format +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +# Suppress Optuna's verbose logging +optuna.logging.set_verbosity(optuna.logging.WARNING) + + +class SearchSpace(Enum): + """Parameter search space types""" + + CATEGORICAL = "categorical" + INTEGER = "integer" + FLOAT = "float" + LOG_FLOAT = "log_float" + + +@dataclass +class TuningParameter: + """Definition of a tunable parameter""" + + name: str + space_type: SearchSpace + choices: list[Any] | None = None # For categorical + low: float | None = None # For numeric + high: float | None = None # For numeric + step: float | None = None # For discrete numeric + default: Any = None + + def suggest(self, trial: optuna.Trial) -> Any: + """Generate parameter suggestion from Optuna trial""" + if self.space_type == SearchSpace.CATEGORICAL: + return trial.suggest_categorical(self.name, self.choices) + elif self.space_type == SearchSpace.INTEGER: + return trial.suggest_int( + self.name, int(self.low), int(self.high), step=int(self.step or 1) + ) + elif self.space_type == SearchSpace.FLOAT: + return trial.suggest_float(self.name, self.low, self.high, step=self.step) + elif self.space_type == SearchSpace.LOG_FLOAT: + return trial.suggest_float(self.name, self.low, self.high, log=True) + return self.default + + +@dataclass +class HardwareProfile: + """Hardware characteristics for filtering""" + + gpu_name: str + gpu_count: int + vram_per_gpu_gb: float + total_vram_gb: float + compute_capability: str = "unknown" + + @classmethod + def from_worker(cls, worker: Worker) -> "HardwareProfile": + """Build profile from Worker model""" + gpu_info = worker.gpu_info or [] + if not gpu_info: + return cls( + gpu_name="Unknown", + gpu_count=0, + vram_per_gpu_gb=0, + total_vram_gb=0, + ) + + # Normalize memory values to GB + def normalize_memory(mem: int | float) -> float: + if mem > 1_000_000_000: + return mem / (1024**3) + elif mem > 1_000_000: + return mem / (1024**2) + elif mem > 1000: + return mem / 1024 + return float(mem) + + first_gpu = gpu_info[0] + vram = normalize_memory(first_gpu.get("memory_total", 0)) + + return cls( + gpu_name=first_gpu.get("name", "Unknown"), + gpu_count=len(gpu_info), + vram_per_gpu_gb=vram, + total_vram_gb=vram * len(gpu_info), + ) + + +@dataclass +class TrialOutcome: + """Result from a single trial execution""" + + trial_id: int + parameters: dict[str, Any] + metrics: dict[str, float] = field(default_factory=dict) + success: bool = True + error_message: str | None = None + duration_seconds: float = 0.0 + deployment_id: int | None = None + + +def build_llm_search_space(hardware: HardwareProfile) -> list[TuningParameter]: + """ + Construct parameter search space based on hardware capabilities. + + This implements the Filter phase of Filter-Scorer architecture: + - Filters out invalid parameter combinations + - Adapts ranges based on GPU capabilities + """ + params = [] + + # Engine selection (always include vLLM, conditionally add others) + engines = ["vllm"] + if hardware.vram_per_gpu_gb >= 8: # SGLang needs decent VRAM + engines.append("sglang") + + params.append( + TuningParameter( + name="engine", + space_type=SearchSpace.CATEGORICAL, + choices=engines, + default="vllm", + ) + ) + + # GPU memory utilization (conservative for smaller GPUs) + if hardware.vram_per_gpu_gb < 12: + mem_range = (0.7, 0.85) + elif hardware.vram_per_gpu_gb < 24: + mem_range = (0.75, 0.92) + else: + mem_range = (0.8, 0.95) + + params.append( + TuningParameter( + name="gpu_memory_utilization", + space_type=SearchSpace.FLOAT, + low=mem_range[0], + high=mem_range[1], + step=0.05, + default=0.85, + ) + ) + + # Max concurrent sequences + params.append( + TuningParameter( + name="max_num_seqs", + space_type=SearchSpace.INTEGER, + low=4, + high=64 if hardware.vram_per_gpu_gb >= 24 else 32, + step=4, + default=16, + ) + ) + + # Tensor parallelism (only if multiple GPUs) + if hardware.gpu_count > 1: + tp_choices = [1] + if hardware.gpu_count >= 2: + tp_choices.append(2) + if hardware.gpu_count >= 4: + tp_choices.append(4) + if hardware.gpu_count >= 8: + tp_choices.append(8) + + params.append( + TuningParameter( + name="tensor_parallel_size", + space_type=SearchSpace.CATEGORICAL, + choices=tp_choices, + default=1, + ) + ) + + return params + + +class ConfigurationFilter: + """ + Validates parameter configurations against hardware constraints. + + Implements Filter phase: reject invalid configs before evaluation. + """ + + def __init__(self, hardware: HardwareProfile, model_size_gb: float = 7.0): + self.hardware = hardware + self.estimated_model_size = model_size_gb + + def is_valid(self, params: dict[str, Any]) -> tuple[bool, str]: + """Check if configuration is valid for the hardware""" + + # Check VRAM requirement + tp_size = params.get("tensor_parallel_size", 1) + mem_util = params.get("gpu_memory_utilization", 0.9) + + # Rough VRAM estimation: model_size / tp_size + overhead + estimated_vram = (self.estimated_model_size / tp_size) + 2.0 # 2GB overhead + available_vram = self.hardware.vram_per_gpu_gb * mem_util + + if estimated_vram > available_vram: + return ( + False, + f"Estimated VRAM ({estimated_vram:.1f}GB) exceeds available ({available_vram:.1f}GB)", + ) + + # Check tensor parallelism divisibility + if tp_size > self.hardware.gpu_count: + return False, f"TP size {tp_size} exceeds GPU count {self.hardware.gpu_count}" + + return True, "" + + +class ObjectiveCalculator: + """ + Computes optimization objective from benchmark metrics. + + Implements Scorer phase: score valid configurations by performance. + """ + + def __init__(self, target: OptimizationTarget): + self.target = target + + def compute(self, metrics: dict[str, float]) -> float: + """ + Calculate objective value (higher is better for maximization). + + For Optuna, we negate values when minimizing so all objectives + are treated as maximization internally. + """ + throughput = metrics.get("throughput_tps", 0.0) + ttft = metrics.get("avg_ttft_ms", float("inf")) + tpot = metrics.get("avg_tpot_ms", float("inf")) + + if self.target == OptimizationTarget.THROUGHPUT: + # Maximize throughput + return throughput + + elif self.target == OptimizationTarget.LATENCY: + # Minimize latency (negate for maximization) + if ttft == float("inf"): + return float("-inf") + return -1.0 * (ttft + tpot * 10) # Weight TPOT more + + elif self.target == OptimizationTarget.BALANCED: + # Combined score: throughput / latency + if ttft == 0 or throughput == 0: + return float("-inf") + latency_factor = 1 + (ttft / 100) + (tpot / 10) + return throughput / latency_factor + + elif self.target == OptimizationTarget.COST: + # Maximize efficiency (throughput per resource unit) + # For now, just use throughput as proxy + return throughput + + return throughput + + +class BayesianTuningService: + """ + Main tuning service using Bayesian optimization. + + Workflow: + 1. Analyze hardware and build search space + 2. Query knowledge base for warm-start + 3. Run Optuna optimization loop + 4. Each trial: deploy -> benchmark -> cleanup + 5. Report best configuration + """ + + def __init__( + self, + db: AsyncSession, + job: TuningJob, + n_trials: int = 10, + timeout_per_trial: int = 600, + ): + self.db = db + self.job = job + self.n_trials = n_trials + self.timeout_per_trial = timeout_per_trial + self.deployer = DeployerService() + + self._current_deployment_id: int | None = None + self._cancelled = False + self._outcomes: list[TrialOutcome] = [] + self._logs: list[dict[str, str]] = [] + + async def _log(self, level: str, message: str): + """Log message to both logger and database""" + timestamp = datetime.now(UTC).isoformat() + log_entry = {"timestamp": timestamp, "level": level, "message": message} + self._logs.append(log_entry) + + # Also log to standard logger + if level == "ERROR": + logger.error(message) + elif level == "WARNING": + logger.warning(message) + else: + logger.info(message) + + # Save to database (limit to last 100 logs) + self.job.logs = self._logs[-100:] + await self.db.commit() + + async def run(self) -> dict[str, Any]: + """Execute the tuning process""" + start_time = time.time() + + await self._log("INFO", "=" * 50) + await self._log("INFO", "Starting Bayesian Optimization") + await self._log("INFO", f"Model: {self.job.model.name}") + await self._log("INFO", f"Target: {self.job.optimization_target}") + await self._log("INFO", f"Trials: {self.n_trials}") + await self._log("INFO", "=" * 50) + + try: + # Phase 1: Hardware analysis + await self._log("INFO", "Phase 1: Analyzing hardware...") + await self._update_status( + TuningJobStatus.ANALYZING, "Analyzing hardware configuration..." + ) + hardware = await self._analyze_hardware() + + if hardware.gpu_count == 0: + raise RuntimeError("No GPUs detected on worker") + + await self._log("INFO", f"Hardware: {hardware.gpu_count}x {hardware.gpu_name}") + await self._log("INFO", f"VRAM: {hardware.total_vram_gb:.1f} GB total") + + # Phase 2: Query knowledge base + await self._log("INFO", "Phase 2: Querying knowledge base...") + await self._update_status( + TuningJobStatus.QUERYING_KB, "Checking historical performance data..." + ) + warm_start_params = await self._query_knowledge_base(hardware) + + if warm_start_params: + await self._log("INFO", f"Found {len(warm_start_params)} warm-start configurations") + else: + await self._log("INFO", "No historical data found, starting fresh") + + # Phase 3: Build search space + await self._log("INFO", "Phase 3: Building search space...") + search_space = build_llm_search_space(hardware) + config_filter = ConfigurationFilter(hardware, model_size_gb=self._estimate_model_size()) + objective_calc = ObjectiveCalculator(OptimizationTarget(self.job.optimization_target)) + + await self._log("INFO", f"Search space: {[p.name for p in search_space]}") + + # Phase 4: Create Optuna study + await self._log("INFO", "Phase 4: Initializing TPE sampler...") + sampler = TPESampler( + seed=42, + n_startup_trials=min(3, self.n_trials // 2), + multivariate=True, # Consider parameter correlations + ) + + study = optuna.create_study( + direction="maximize", + sampler=sampler, + study_name=f"lmstack_tuning_{self.job.id}", + ) + + # Add warm-start trials from knowledge base + if warm_start_params: + for params in warm_start_params[:2]: # At most 2 warm-start + try: + study.enqueue_trial(params) + await self._log("INFO", f"Enqueued warm-start: {params}") + except Exception as e: + await self._log("WARNING", f"Failed to enqueue warm-start: {e}") + + # Phase 5: Optimization loop + await self._log("INFO", "Phase 5: Starting optimization loop...") + await self._update_status(TuningJobStatus.EXPLORING, "Starting optimization...") + + for trial_idx in range(self.n_trials): + if self._cancelled: + break + + # Check cancellation + await self.db.refresh(self.job) + if self.job.status == TuningJobStatus.CANCELLED.value: + self._cancelled = True + await self._log("INFO", "Cancelled by user") + break + + await self._log("INFO", "-" * 40) + await self._log("INFO", f"Trial {trial_idx + 1}/{self.n_trials}") + + # Generate trial parameters + trial = study.ask() + params = {p.name: p.suggest(trial) for p in search_space} + await self._log("INFO", f"Parameters: {params}") + + # Filter invalid configurations + is_valid, reason = config_filter.is_valid(params) + if not is_valid: + await self._log("WARNING", f"Skipped: {reason}") + study.tell(trial, float("-inf")) + continue + + # Update progress + await self._update_progress( + step=trial_idx + 1, + total=self.n_trials, + step_name=f"Trial {trial_idx + 1}", + params=params, + ) + + # Execute trial + outcome = await self._execute_trial(trial_idx, params) + self._outcomes.append(outcome) + + # Report to Optuna + if outcome.success and outcome.metrics: + objective = objective_calc.compute(outcome.metrics) + study.tell(trial, objective) + await self._log("INFO", f"Trial {trial_idx + 1} completed:") + await self._log("INFO", f" Objective: {objective:.2f}") + await self._log( + "INFO", f" TPS: {outcome.metrics.get('throughput_tps', 0):.1f}" + ) + await self._log( + "INFO", f" TTFT: {outcome.metrics.get('avg_ttft_ms', 0):.0f} ms" + ) + else: + study.tell(trial, float("-inf")) + await self._log( + "ERROR", f"Trial {trial_idx + 1} failed: {outcome.error_message}" + ) + + # Phase 6: Finalize results + await self._log("INFO", "=" * 50) + await self._log("INFO", "Phase 6: Finalizing results...") + + if self._cancelled: + await self._update_status(TuningJobStatus.CANCELLED, "Tuning was cancelled") + await self._log("INFO", "Job cancelled") + return {"success": False, "reason": "cancelled"} + + # Get best result + best_trial = study.best_trial + best_params = best_trial.params + best_value = best_trial.value + + # Find corresponding outcome for full metrics + best_outcome = None + for outcome in self._outcomes: + if outcome.parameters == best_params: + best_outcome = outcome + break + + # Save to knowledge base + await self._save_to_knowledge_base(best_params, best_outcome) + + # Update job with results + self.job.best_config = { + **best_params, + "objective_value": best_value, + "metrics": best_outcome.metrics if best_outcome else {}, + } + self.job.all_results = [ + { + "parameters": o.parameters, + "metrics": o.metrics, + "success": o.success, + "error": o.error_message, + } + for o in self._outcomes + ] + + elapsed = time.time() - start_time + await self._log("INFO", "=" * 50) + await self._log("INFO", "Optimization Complete!") + await self._log("INFO", f"Duration: {elapsed/60:.1f} minutes") + await self._log("INFO", f"Trials: {len(self._outcomes)} completed") + await self._log("INFO", f"Best objective: {best_value:.2f}") + await self._log("INFO", f"Best config: {best_params}") + if best_outcome and best_outcome.metrics: + await self._log( + "INFO", f"Best TPS: {best_outcome.metrics.get('throughput_tps', 0):.1f}" + ) + await self._log("INFO", "=" * 50) + + # Update progress to show completion + await self._update_progress( + step=self.n_trials, + total=self.n_trials, + step_name="Completed", + params=best_params, + ) + + await self._update_status(TuningJobStatus.COMPLETED, "Tuning completed successfully") + self.job.completed_at = datetime.now(UTC) + await self.db.commit() + + return { + "success": True, + "best_config": best_params, + "best_value": best_value, + "trials_completed": len(self._outcomes), + } + + except Exception as e: + await self._log("ERROR", f"Tuning failed: {e}") + logger.exception(f"[Job #{self.job.id}] Tuning failed: {e}") + await self._update_status(TuningJobStatus.FAILED, f"Error: {str(e)}") + return {"success": False, "error": str(e)} + + finally: + # Cleanup any remaining deployment + await self._cleanup_deployment() + + async def _analyze_hardware(self) -> HardwareProfile: + """Analyze worker hardware configuration""" + result = await self.db.execute(select(Worker).where(Worker.id == self.job.worker_id)) + worker = result.scalar_one_or_none() + + if not worker: + raise RuntimeError(f"Worker {self.job.worker_id} not found") + + return HardwareProfile.from_worker(worker) + + async def _query_knowledge_base(self, hardware: HardwareProfile) -> list[dict[str, Any]]: + """Query historical results for similar configurations""" + model = self.job.model + model_family = self._extract_model_family(model.name) + + # Query for similar hardware + model combinations + stmt = ( + select(PerformanceKnowledge) + .where(PerformanceKnowledge.gpu_model.ilike(f"%{hardware.gpu_name.split()[0]}%")) + .where(PerformanceKnowledge.model_family == model_family) + .order_by(PerformanceKnowledge.score.desc().nulls_last()) + .limit(5) + ) + + result = await self.db.execute(stmt) + records = result.scalars().all() + + warm_start_params = [] + for r in records: + params = { + "engine": r.engine, + "gpu_memory_utilization": 0.9, # Default + "max_num_seqs": 16, # Default + } + if r.tensor_parallel and r.tensor_parallel > 1: + params["tensor_parallel_size"] = r.tensor_parallel + warm_start_params.append(params) + + return warm_start_params + + def _estimate_model_size(self) -> float: + """Estimate model size in GB from name""" + name_lower = self.job.model.name.lower() + + # Extract number of parameters from common patterns + import re + + patterns = [ + r"(\d+\.?\d*)b", # 7b, 7.5b + r"(\d+)b-", # 7b- + r"-(\d+)b", # -7b + ] + + for pattern in patterns: + match = re.search(pattern, name_lower) + if match: + params_b = float(match.group(1)) + # Rough estimate: 2 bytes per param (FP16) + return params_b * 2 + + # Default for unknown models + return 7.0 + + def _extract_model_family(self, name: str) -> str: + """Extract model family from name""" + name_lower = name.lower() + families = { + "qwen": "Qwen", + "llama": "Llama", + "mistral": "Mistral", + "deepseek": "DeepSeek", + "phi": "Phi", + "gemma": "Gemma", + } + for key, value in families.items(): + if key in name_lower: + return value + return "Unknown" + + async def _execute_trial(self, trial_idx: int, params: dict[str, Any]) -> TrialOutcome: + """Execute a single trial: deploy, benchmark, cleanup""" + start_time = time.time() + outcome = TrialOutcome(trial_id=trial_idx, parameters=params.copy()) + + try: + # Step 1: Deploy model + await self._log("INFO", "Deploying model...") + await self._update_status( + TuningJobStatus.BENCHMARKING, f"Trial {trial_idx + 1}: Deploying model..." + ) + + deployment_id = await self._deploy_model(params) + outcome.deployment_id = deployment_id + self._current_deployment_id = deployment_id + await self._log("INFO", f"Deployment #{deployment_id} created") + + # Step 2: Wait for deployment + await self._log("INFO", "Waiting for model to load...") + await self._update_status( + TuningJobStatus.BENCHMARKING, f"Trial {trial_idx + 1}: Waiting for model to load..." + ) + + ready = await self._wait_for_deployment(deployment_id, timeout=self.timeout_per_trial) + if not ready: + outcome.success = False + outcome.error_message = "Deployment timeout" + await self._log("ERROR", f"Deployment timeout after {self.timeout_per_trial}s") + return outcome + + await self._log("INFO", "Model loaded successfully") + + # Step 3: Run benchmark + await self._log("INFO", "Running benchmark...") + await self._update_status( + TuningJobStatus.BENCHMARKING, f"Trial {trial_idx + 1}: Running benchmark..." + ) + + metrics = await self._run_benchmark(deployment_id) + outcome.metrics = metrics + await self._log( + "INFO", f"Benchmark complete: TPS={metrics.get('throughput_tps', 0):.1f}" + ) + + except Exception as e: + outcome.success = False + outcome.error_message = str(e) + await self._log("ERROR", f"Trial exception: {e}") + + finally: + # Step 4: Cleanup + await self._log("INFO", "Cleaning up deployment...") + await self._cleanup_deployment() + outcome.duration_seconds = time.time() - start_time + await self._log("INFO", f"Trial duration: {outcome.duration_seconds:.1f}s") + + return outcome + + async def _deploy_model(self, params: dict[str, Any]) -> int: + """Deploy model with given parameters""" + engine = params.get("engine", "vllm") + gpu_indexes = list(range(params.get("tensor_parallel_size", 1))) + + # Build extra params from tuning parameters + # Note: vLLM and SGLang use different parameter names + extra_params = {} + + if engine == "sglang": + # SGLang parameter names + if "gpu_memory_utilization" in params: + extra_params["mem-fraction-static"] = params["gpu_memory_utilization"] + if "max_num_seqs" in params: + extra_params["max-running-requests"] = params["max_num_seqs"] + else: + # vLLM parameter names (default) + if "gpu_memory_utilization" in params: + extra_params["gpu-memory-utilization"] = params["gpu_memory_utilization"] + if "max_num_seqs" in params: + extra_params["max-num-seqs"] = params["max_num_seqs"] + + # Create deployment + deployment = Deployment( + name=f"tuning-trial-{self.job.id}-{int(time.time())}", + model_id=self.job.model_id, + worker_id=self.job.worker_id, + backend=engine, + gpu_indexes=gpu_indexes, + extra_params=extra_params, + status=DeploymentStatus.PENDING.value, + ) + + self.db.add(deployment) + await self.db.commit() + await self.db.refresh(deployment) + + # Start deployment async + asyncio.create_task(self.deployer.deploy(deployment.id)) + + return deployment.id + + async def _wait_for_deployment(self, deployment_id: int, timeout: int = 600) -> bool: + """Wait for deployment to be ready""" + start = time.time() + + while time.time() - start < timeout: + if self._cancelled: + return False + + # Expire all cached objects to get fresh data from database + self.db.expire_all() + + result = await self.db.execute(select(Deployment).where(Deployment.id == deployment_id)) + deployment = result.scalar_one_or_none() + + if not deployment: + await self._log("WARNING", f"Deployment #{deployment_id} not found") + return False + + await self._log("INFO", f"Deployment status: {deployment.status}") + + if deployment.status == DeploymentStatus.RUNNING.value: + return True + + if deployment.status in [DeploymentStatus.ERROR.value, DeploymentStatus.STOPPED.value]: + await self._log("ERROR", f"Deployment failed with status: {deployment.status}") + return False + + await asyncio.sleep(5) + + return False + + async def _run_benchmark( + self, + deployment_id: int, + num_requests: int = 20, + concurrency: int = 4, + ) -> dict[str, float]: + """Run benchmark against deployment""" + result = await self.db.execute( + select(Deployment) + .where(Deployment.id == deployment_id) + .options(selectinload(Deployment.worker), selectinload(Deployment.model)) + ) + deployment = result.scalar_one_or_none() + + if not deployment or deployment.status != DeploymentStatus.RUNNING.value: + raise RuntimeError("Deployment not running") + + # Build endpoint URL + worker_ip = deployment.worker.address.split(":")[0] + base_url = f"http://{worker_ip}:{deployment.port}/v1" + model_name = deployment.model.model_id + + # Run HTTP benchmark (reuse existing implementation pattern) + return await self._http_benchmark(base_url, model_name, num_requests, concurrency) + + async def _http_benchmark( + self, + base_url: str, + model_name: str, + num_requests: int, + concurrency: int, + ) -> dict[str, float]: + """Execute HTTP benchmark against OpenAI-compatible endpoint""" + test_prompt = "Explain the concept of machine learning in simple terms. " * 20 + + results = [] + semaphore = asyncio.Semaphore(concurrency) + + async def make_request(client: httpx.AsyncClient) -> dict | None: + async with semaphore: + start = time.perf_counter() + first_token_time = None + token_count = 0 + + try: + async with client.stream( + "POST", + f"{base_url}/chat/completions", + json={ + "model": model_name, + "messages": [{"role": "user", "content": test_prompt}], + "max_tokens": 64, + "stream": True, + }, + timeout=60.0, + ) as resp: + if resp.status_code != 200: + return None + + async for line in resp.aiter_lines(): + if line.startswith("data: ") and line != "data: [DONE]": + try: + chunk = json.loads(line[6:]) + content = ( + chunk.get("choices", [{}])[0] + .get("delta", {}) + .get("content", "") + ) + if content: + if first_token_time is None: + first_token_time = time.perf_counter() + token_count += 1 + except json.JSONDecodeError: + pass + + end = time.perf_counter() + + if first_token_time and token_count > 0: + return { + "ttft_ms": (first_token_time - start) * 1000, + "tpot_ms": ( + ((end - first_token_time) / max(1, token_count - 1)) * 1000 + if token_count > 1 + else 0 + ), + "tokens": token_count, + "total_time": end - start, + } + except Exception: + pass + return None + + async with httpx.AsyncClient() as client: + # Warmup + for _ in range(2): + await make_request(client) + + # Actual benchmark + tasks = [make_request(client) for _ in range(num_requests)] + results = await asyncio.gather(*tasks) + + valid = [r for r in results if r] + if not valid: + return {"throughput_tps": 0, "avg_ttft_ms": 0, "avg_tpot_ms": 0} + + total_tokens = sum(r["tokens"] for r in valid) + total_time = sum(r["total_time"] for r in valid) + + return { + "throughput_tps": round(total_tokens / total_time, 2) if total_time > 0 else 0, + "avg_ttft_ms": round(sum(r["ttft_ms"] for r in valid) / len(valid), 2), + "avg_tpot_ms": round( + sum(r["tpot_ms"] for r in valid if r["tpot_ms"] > 0) + / max(1, len([r for r in valid if r["tpot_ms"] > 0])), + 2, + ), + "successful_requests": len(valid), + "total_requests": num_requests, + } + + async def _cleanup_deployment(self): + """Stop and remove current deployment""" + if self._current_deployment_id: + try: + result = await self.db.execute( + select(Deployment).where(Deployment.id == self._current_deployment_id) + ) + deployment = result.scalar_one_or_none() + + if deployment: + await self.deployer.stop(deployment.id) + deployment.status = DeploymentStatus.STOPPED.value + await self.db.delete(deployment) + await self.db.commit() + except Exception as e: + logger.warning(f"Cleanup failed: {e}") + finally: + self._current_deployment_id = None + + async def _save_to_knowledge_base(self, params: dict[str, Any], outcome: TrialOutcome | None): + """Save best result to knowledge base for future warm-start""" + if not outcome or not outcome.metrics: + return + + model = self.job.model + worker_result = await self.db.execute(select(Worker).where(Worker.id == self.job.worker_id)) + worker = worker_result.scalar_one_or_none() + + if not worker: + return + + hardware = HardwareProfile.from_worker(worker) + + knowledge = PerformanceKnowledge( + gpu_model=hardware.gpu_name, + gpu_count=hardware.gpu_count, + total_vram_gb=hardware.total_vram_gb, + model_name=model.name, + model_family=self._extract_model_family(model.name), + engine=params.get("engine", "vllm"), + tensor_parallel=params.get("tensor_parallel_size", 1), + extra_args=params, + throughput_tps=outcome.metrics.get("throughput_tps", 0), + ttft_ms=outcome.metrics.get("avg_ttft_ms", 0), + tpot_ms=outcome.metrics.get("avg_tpot_ms", 0), + score=outcome.metrics.get("throughput_tps", 0), + source_tuning_job_id=self.job.id, + ) + + self.db.add(knowledge) + await self.db.commit() + + async def _update_status(self, status: TuningJobStatus, message: str): + """Update job status and message""" + self.job.status = status.value + self.job.status_message = message + await self.db.commit() + + async def _update_progress( + self, + step: int, + total: int, + step_name: str, + params: dict[str, Any] | None = None, + ): + """Update job progress""" + self.job.current_step = step + self.job.total_steps = total + self.job.progress = { + "step": step, + "total_steps": total, + "step_name": step_name, + "current_config": params, + "completed_trials": len(self._outcomes), + "successful_trials": len([o for o in self._outcomes if o.success]), + } + await self.db.commit() + + +async def run_bayesian_tuning(job_id: int, n_trials: int = 10): + """ + Entry point for running Bayesian optimization tuning. + + Args: + job_id: TuningJob ID + n_trials: Number of optimization trials + """ + async with async_session_maker() as db: + result = await db.execute( + select(TuningJob) + .where(TuningJob.id == job_id) + .options( + selectinload(TuningJob.model), + selectinload(TuningJob.worker), + ) + ) + job = result.scalar_one_or_none() + + if not job: + logger.error(f"Tuning job {job_id} not found") + return + + service = BayesianTuningService(db, job, n_trials=n_trials) + await service.run() diff --git a/backend/app/services/deployer.py b/backend/app/services/deployer.py index bf4d62d..537d50d 100644 --- a/backend/app/services/deployer.py +++ b/backend/app/services/deployer.py @@ -740,9 +740,14 @@ def _build_sglang_config( """Build SGLang container command and environment. SGLang uses similar command-line arguments to vLLM but with some - differences in parameter names. + differences in parameter names. Unlike vLLM, the sglang Docker image + does not have a proper ENTRYPOINT, so we need to explicitly specify + the launch command. """ cmd = [ + "python", + "-m", + "sglang.launch_server", "--model-path", model.model_id, "--host", diff --git a/frontend/src/pages/AutoTuning.tsx b/frontend/src/pages/AutoTuning.tsx index ca054bf..4a0aa76 100644 --- a/frontend/src/pages/AutoTuning.tsx +++ b/frontend/src/pages/AutoTuning.tsx @@ -1,9 +1,14 @@ -import { useEffect, useState, useCallback } from "react"; +/** + * Auto-Tuning Page + * + * Bayesian optimization-based hyperparameter tuning for LLM deployments. + * Uses Optuna TPE (Tree-structured Parzen Estimator) for efficient search. + */ +import React, { useEffect, useState, useCallback, useRef } from "react"; import { Button, Card, Form, - Modal, Select, Space, Table, @@ -13,15 +18,16 @@ import { Typography, Empty, Tooltip, - Radio, Tabs, Statistic, Row, Col, - Input, - Divider, Alert, Popconfirm, + Modal, + Timeline, + Descriptions, + Spin, } from "antd"; import { PlusOutlined, @@ -35,82 +41,54 @@ import { RocketOutlined, BarChartOutlined, HistoryOutlined, - ApiOutlined, DeleteOutlined, - CommentOutlined, + PlayCircleOutlined, + ClockCircleOutlined, + AimOutlined, + SettingOutlined, + LineChartOutlined, } from "@ant-design/icons"; -import { useAppTheme } from "../hooks/useTheme"; import { workersApi, modelsApi } from "../services/api"; -import { deploymentsApi } from "../api"; import { api } from "../api/client"; -import type { Worker, LLMModel, Deployment } from "../types"; +import type { Worker, LLMModel } from "../types"; import { useResponsive } from "../hooks"; import { useAuth } from "../contexts/AuthContext"; -import { - CHAT_PANEL_STORAGE_KEY, - TUNING_JOB_EVENT_KEY, - type CustomEndpoint, - type ChatPanelState, -} from "../components/chat-panel"; import dayjs from "dayjs"; import relativeTime from "dayjs/plugin/relativeTime"; +import duration from "dayjs/plugin/duration"; dayjs.extend(relativeTime); +dayjs.extend(duration); -const { Text, Paragraph } = Typography; - -// Helper to load chat panel state (shared with Chat Panel) -function loadChatPanelState(): Partial { - try { - const saved = localStorage.getItem(CHAT_PANEL_STORAGE_KEY); - if (saved) { - return JSON.parse(saved); - } - } catch { - // Ignore load errors - } - return {}; -} - -// Helper to save chat panel state (shared with Chat Panel) -function saveChatPanelState(state: Partial) { - try { - const current = loadChatPanelState(); - localStorage.setItem( - CHAT_PANEL_STORAGE_KEY, - JSON.stringify({ ...current, ...state }), - ); - } catch { - // Ignore save errors - } -} +const { Text, Title } = Typography; const REFRESH_INTERVAL = 3000; +// ============================================================================ // Types +// ============================================================================ + interface TuningJobProgress { step: number; total_steps: number; step_name: string; - step_description: string; - configs_tested: number; - configs_total: number; + step_description?: string; current_config?: Record; - best_config_so_far?: Record; - best_score_so_far?: number; + completed_trials?: number; + successful_trials?: number; +} + +interface TrialResult { + parameters?: Record; + metrics?: Record; + success?: boolean; + error?: string; } -interface ConversationMessage { - role: "user" | "assistant" | "tool"; - content: string; - timestamp?: string; - tool_calls?: Array<{ - id: string; - name: string; - arguments: string; - }>; - tool_call_id?: string; - name?: string; // tool name +interface LogEntry { + timestamp: string; + level: string; + message: string; } interface TuningJob { @@ -124,8 +102,8 @@ interface TuningJob { total_steps: number; progress?: TuningJobProgress; best_config?: Record; - all_results?: Record[]; - conversation_log?: ConversationMessage[]; + all_results?: TrialResult[]; + logs?: LogEntry[]; created_at: string; updated_at: string; completed_at?: string; @@ -150,77 +128,416 @@ interface KnowledgeRecord { created_at: string; } -// Helper functions -function getStatusColor(status: string): string { +// ============================================================================ +// Constants +// ============================================================================ + +const STATUS_CONFIG: Record< + string, + { color: string; icon: React.ReactNode; label: string } +> = { + pending: { + color: "default", + icon: , + label: "Pending", + }, + analyzing: { + color: "processing", + icon: , + label: "Analyzing", + }, + querying_kb: { + color: "processing", + icon: , + label: "Querying KB", + }, + exploring: { + color: "processing", + icon: , + label: "Exploring", + }, + benchmarking: { + color: "processing", + icon: , + label: "Benchmarking", + }, + completed: { + color: "success", + icon: , + label: "Completed", + }, + failed: { color: "error", icon: , label: "Failed" }, + cancelled: { + color: "warning", + icon: , + label: "Cancelled", + }, +}; + +const OPTIMIZATION_TARGETS = [ + { + value: "throughput", + label: "Throughput", + description: "Maximize tokens per second (TPS)", + icon: , + }, + { + value: "latency", + label: "Latency", + description: "Minimize time-to-first-token and response time", + icon: , + }, + { + value: "balanced", + label: "Balanced", + description: "Optimize for both throughput and latency", + icon: , + }, +]; + +// ============================================================================ +// Helper Components +// ============================================================================ + +function StatusTag({ status }: { status: string }) { + const config = STATUS_CONFIG[status] || STATUS_CONFIG.pending; + return ( + + {config.label} + + ); +} + +function TargetTag({ target }: { target: string }) { + const config = OPTIMIZATION_TARGETS.find((t) => t.value === target); const colors: Record = { - pending: "default", - analyzing: "processing", - querying_kb: "processing", - exploring: "processing", - benchmarking: "processing", - completed: "success", - failed: "error", - cancelled: "warning", + throughput: "green", + latency: "blue", + balanced: "purple", }; - return colors[status] || "default"; + return ( + {config?.label || target} + ); } -function getStatusIcon(status: string) { - const icons: Record = { - pending: , - analyzing: , - querying_kb: , - exploring: , - benchmarking: , - completed: , - failed: , - cancelled: , +function LogViewer({ + logs, + maxHeight = 300, +}: { + logs: LogEntry[]; + maxHeight?: number; +}) { + const logContainerRef = React.useRef(null); + + // Auto-scroll to bottom when new logs arrive + React.useEffect(() => { + if (logContainerRef.current) { + logContainerRef.current.scrollTop = logContainerRef.current.scrollHeight; + } + }, [logs]); + + const getLevelColor = (level: string) => { + switch (level.toUpperCase()) { + case "ERROR": + return "#ff4d4f"; + case "WARNING": + return "#faad14"; + case "INFO": + default: + return "#8c8c8c"; + } }; - return icons[status] || ; + + return ( +
+ {logs.length === 0 ? ( + + Waiting for logs... + + ) : ( + logs.map((log, idx) => { + const time = dayjs(log.timestamp).format("HH:mm:ss"); + return ( +
+ [{time}]{" "} + + {log.level.padEnd(7)} + {" "} + {log.message} +
+ ); + }) + )} +
+ ); } -function getOptimizationTargetLabel(target: string): string { - const labels: Record = { - throughput: "Throughput (TPS)", - latency: "Latency (TTFT/TPOT)", - cost: "Cost (Min Resources)", - balanced: "Balanced", - }; - return labels[target] || target; +function ProgressDisplay({ job }: { job: TuningJob }) { + const { progress, status } = job; + + if (!progress) { + return -; + } + + const completed = progress.completed_trials ?? progress.step ?? 0; + const total = progress.total_steps || 10; + const percent = Math.round((completed / total) * 100); + const successful = progress.successful_trials ?? 0; + + const isRunning = [ + "analyzing", + "querying_kb", + "exploring", + "benchmarking", + ].includes(status); + + return ( + +
+ Trial {completed} / {total} +
+
Successful: {successful}
+ {progress.step_name &&
Current: {progress.step_name}
} + + } + > + `${completed}/${total}`} + style={{ width: 100, minWidth: 80 }} + /> +
+ ); +} + +function JobDetailCard({ + job, + onClose, +}: { + job: TuningJob; + onClose: () => void; +}) { + const bestMetrics = job.best_config?.metrics as + | Record + | undefined; + const trials = job.all_results || []; + const successfulTrials = trials.filter((t) => t.success); + const logs = job.logs || []; + + return ( +
+ {/* Best Configuration */} + {job.best_config && ( + + + Best Configuration + + } + style={{ marginBottom: 16 }} + > + + + + + + + + + + + + + + + + + + {(job.best_config.max_num_seqs as number) || "-"} + + + {(job.best_config.tensor_parallel_size as number) || 1} + + + {(job.best_config.objective_value as number)?.toFixed(2) || "-"} + + + + )} + + {/* Logs */} + {logs.length > 0 && ( + + + Execution Logs + {logs.length} entries + + } + style={{ marginBottom: 16 }} + > + + + )} + + {/* Trial Results Table */} + {trials.length > 0 && ( + + ({ + ...t, + key: idx, + trial_num: idx + 1, + }))} + columns={[ + { + title: "#", + dataIndex: "trial_num", + key: "trial_num", + width: 50, + }, + { + title: "Engine", + key: "engine", + render: (_, record) => record.parameters?.engine || "-", + }, + { + title: "GPU Mem", + key: "gpu_mem", + render: (_, record) => { + const val = record.parameters + ?.gpu_memory_utilization as number; + return val ? `${(val * 100).toFixed(0)}%` : "-"; + }, + }, + { + title: "TPS", + key: "tps", + render: (_, record) => { + const val = record.metrics?.throughput_tps; + return val ? ( + {val.toFixed(1)} + ) : ( + "-" + ); + }, + }, + { + title: "TTFT", + key: "ttft", + render: (_, record) => { + const val = record.metrics?.avg_ttft_ms; + return val ? `${val.toFixed(0)} ms` : "-"; + }, + }, + { + title: "Status", + key: "success", + width: 80, + render: (_, record) => + record.success ? ( + + ) : ( + + + + ), + }, + ]} + size="small" + pagination={false} + scroll={{ y: 200 }} + /> + + )} + + {/* Summary */} +
+ |}> + Total Trials: {trials.length} + Successful: {successfulTrials.length} + {job.completed_at && ( + + Duration:{" "} + {dayjs(job.completed_at).diff(dayjs(job.created_at), "minute")}{" "} + min + + )} + +
+ + ); } +// ============================================================================ +// Main Component +// ============================================================================ + export default function AutoTuning() { const [jobs, setJobs] = useState([]); const [workers, setWorkers] = useState([]); const [models, setModels] = useState([]); - const [deployments, setDeployments] = useState([]); const [knowledge, setKnowledge] = useState([]); const [loading, setLoading] = useState(true); - const [modalOpen, setModalOpen] = useState(false); + const [createModalOpen, setCreateModalOpen] = useState(false); const [detailModal, setDetailModal] = useState(null); + const [logModal, setLogModal] = useState(null); const [form] = Form.useForm(); - const [addEndpointForm] = Form.useForm(); const { isMobile } = useResponsive(); - const { isDark } = useAppTheme(); const { canEdit } = useAuth(); - // Custom endpoints from shared localStorage (same as Chat Panel) - const [customEndpoints, setCustomEndpoints] = useState( - () => loadChatPanelState().customEndpoints || [], - ); + // -------------------------------------------------------------------------- + // Data Fetching + // -------------------------------------------------------------------------- - // LLM source type for modal - const [llmSourceType, setLlmSourceType] = useState<"deployment" | "custom">( - "deployment", - ); - const [showAddEndpoint, setShowAddEndpoint] = useState(false); - - // Save custom endpoints to shared localStorage (same as Chat Panel) - useEffect(() => { - saveChatPanelState({ customEndpoints }); - }, [customEndpoints]); - - // Fetch tuning jobs const fetchJobs = useCallback(async () => { try { const response = await api.get("/auto-tuning/jobs"); @@ -230,7 +547,6 @@ export default function AutoTuning() { } }, []); - // Fetch knowledge base const fetchKnowledge = useCallback(async () => { try { const response = await api.post("/auto-tuning/knowledge/query", { @@ -242,23 +558,19 @@ export default function AutoTuning() { } }, []); - // Fetch workers, models, and deployments const fetchResources = useCallback(async () => { try { - const [workersRes, modelsRes, deploymentsRes] = await Promise.all([ + const [workersRes, modelsRes] = await Promise.all([ workersApi.list(), modelsApi.list(), - deploymentsApi.list({ status: "running" }), ]); setWorkers(workersRes.items || []); setModels(modelsRes.items || []); - setDeployments(deploymentsRes.items || []); } catch (error) { console.error("Failed to fetch resources:", error); } }, []); - // Initial load useEffect(() => { const load = async () => { setLoading(true); @@ -279,103 +591,49 @@ export default function AutoTuning() { "benchmarking", ].includes(j.status), ); - - if (!hasRunningJobs) return; - + if (!hasRunningJobs && !logModal) return; const interval = setInterval(fetchJobs, REFRESH_INTERVAL); return () => clearInterval(interval); - }, [jobs, fetchJobs]); + }, [jobs, fetchJobs, logModal]); + + // Update log modal when jobs refresh + useEffect(() => { + if (logModal) { + const updatedJob = jobs.find((j) => j.id === logModal.id); + if (updatedJob) { + setLogModal(updatedJob); + } + } + }, [jobs, logModal?.id]); + + // -------------------------------------------------------------------------- + // Actions + // -------------------------------------------------------------------------- - // Create new tuning job const handleCreate = async (values: { model_id: number; worker_id: number; optimization_target: string; - llm_deployment_id?: number; - llm_custom_endpoint?: string; }) => { try { - // Build LLM config based on selection - let llm_config: Record | undefined; - - if (llmSourceType === "deployment" && values.llm_deployment_id) { - llm_config = { deployment_id: values.llm_deployment_id }; - } else if (llmSourceType === "custom" && values.llm_custom_endpoint) { - const endpoint = customEndpoints.find( - (e) => e.id === values.llm_custom_endpoint, - ); - if (endpoint) { - llm_config = { - base_url: endpoint.endpoint, - api_key: endpoint.apiKey, - model: endpoint.modelId, - }; - } - } - const response = await api.post("/auto-tuning/jobs", { model_id: values.model_id, worker_id: values.worker_id, optimization_target: values.optimization_target, - llm_config, }); - message.success("Auto-tuning job started"); - setModalOpen(false); + + message.success(`Tuning job #${response.data.id} created successfully`); + setCreateModalOpen(false); form.resetFields(); - setLlmSourceType("deployment"); fetchJobs(); - - // Trigger Chat Panel to open with tuning job view - const jobId = response.data.id; - if (jobId) { - localStorage.setItem( - TUNING_JOB_EVENT_KEY, - JSON.stringify({ - jobId, - timestamp: Date.now(), - }), - ); - // Dispatch storage event for same-window listeners - window.dispatchEvent( - new StorageEvent("storage", { - key: TUNING_JOB_EVENT_KEY, - newValue: JSON.stringify({ jobId, timestamp: Date.now() }), - }), - ); - } } catch (error: unknown) { const err = error as { response?: { data?: { detail?: string } } }; - message.error(err.response?.data?.detail || "Failed to start tuning job"); + message.error( + err.response?.data?.detail || "Failed to create tuning job", + ); } }; - // Add custom endpoint - const handleAddEndpoint = (values: { - name: string; - endpoint: string; - apiKey?: string; - modelId?: string; - }) => { - const newEndpoint: CustomEndpoint = { - id: `custom-${Date.now()}`, - name: values.name, - endpoint: values.endpoint, - apiKey: values.apiKey, - modelId: values.modelId, - }; - setCustomEndpoints((prev) => [...prev, newEndpoint]); - addEndpointForm.resetFields(); - setShowAddEndpoint(false); - message.success("Endpoint added"); - }; - - // Delete custom endpoint - const handleDeleteEndpoint = (id: string) => { - setCustomEndpoints((prev) => prev.filter((e) => e.id !== id)); - message.success("Endpoint removed"); - }; - - // Cancel job const handleCancel = async (jobId: number) => { try { await api.post(`/auto-tuning/jobs/${jobId}/cancel`); @@ -387,7 +645,6 @@ export default function AutoTuning() { } }; - // Delete job const handleDelete = async (jobId: number) => { try { await api.delete(`/auto-tuning/jobs/${jobId}`); @@ -399,46 +656,45 @@ export default function AutoTuning() { } }; - // View job details (fetch with conversation log) - const [detailLoading, setDetailLoading] = useState(false); - const handleViewDetail = async (job: TuningJob) => { - setDetailModal(job); // Show modal immediately with basic info - setDetailLoading(true); + const handleDeployBestConfig = async (job: TuningJob) => { + if (!job.best_config) { + message.warning("No best configuration available"); + return; + } + try { - const response = await api.get(`/auto-tuning/jobs/${job.id}`); - setDetailModal(response.data); - } catch (error) { - console.error("Failed to fetch job details:", error); - } finally { - setDetailLoading(false); + const engine = (job.best_config.engine as string) || "vllm"; + const gpuMemUtil = job.best_config.gpu_memory_utilization as number; + const maxNumSeqs = job.best_config.max_num_seqs as number; + const tpSize = (job.best_config.tensor_parallel_size as number) || 1; + + const extraParams: Record = {}; + if (gpuMemUtil) extraParams["gpu-memory-utilization"] = gpuMemUtil; + if (maxNumSeqs) extraParams["max-num-seqs"] = maxNumSeqs; + + await api.post("/deployments", { + model_id: job.model_id, + worker_id: job.worker_id, + name: `tuned-${job.model_name?.split("/").pop() || "model"}-${Date.now()}`, + backend: engine, + gpu_indexes: Array.from({ length: tpSize }, (_, i) => i), + extra_params: extraParams, + }); + + message.success("Deployment created with optimized configuration"); + setDetailModal(null); + } catch (error: unknown) { + const err = error as { response?: { data?: { detail?: string } } }; + message.error( + err.response?.data?.detail || "Failed to create deployment", + ); } }; - // Auto-refresh detail modal for running jobs - useEffect(() => { - if (!detailModal) return; - const isRunning = [ - "pending", - "analyzing", - "querying_kb", - "exploring", - "benchmarking", - ].includes(detailModal.status); - if (!isRunning) return; - - const interval = setInterval(async () => { - try { - const response = await api.get(`/auto-tuning/jobs/${detailModal.id}`); - setDetailModal(response.data); - } catch (error) { - console.error("Failed to refresh job:", error); - } - }, 2000); + // -------------------------------------------------------------------------- + // Computed Values + // -------------------------------------------------------------------------- - return () => clearInterval(interval); - }, [detailModal?.id, detailModal?.status]); - - // Stats const completedJobs = jobs.filter((j) => j.status === "completed").length; const runningJobs = jobs.filter((j) => [ @@ -449,42 +705,50 @@ export default function AutoTuning() { "benchmarking", ].includes(j.status), ).length; + const availableWorkers = workers.filter( + (w) => w.status === "online" && w.gpu_info && w.gpu_info.length > 0, + ); + + // -------------------------------------------------------------------------- + // Table Columns + // -------------------------------------------------------------------------- - // Table columns for jobs const jobColumns = [ { title: "Model", dataIndex: "model_name", key: "model_name", - render: (name: string) => {name || "Unknown"}, - }, - { - title: "Worker", - dataIndex: "worker_name", - key: "worker_name", - responsive: ["md" as const], + render: (name: string, record: TuningJob) => ( + + {name || "Unknown"} + + Job #{record.id} + + + ), }, { title: "Target", dataIndex: "optimization_target", key: "optimization_target", - responsive: ["sm" as const], - render: (target: string) => ( - {getOptimizationTargetLabel(target)} - ), + width: 120, + render: (target: string) => , }, { title: "Status", dataIndex: "status", key: "status", + width: 140, render: (status: string, record: TuningJob) => ( - - - {status.toUpperCase()} - - {record.progress && ["benchmarking"].includes(status) && ( - - {record.progress.configs_tested}/{record.progress.configs_total} + + + {record.status_message && ( + + {record.status_message.slice(0, 30)} )} @@ -493,42 +757,26 @@ export default function AutoTuning() { { title: "Progress", key: "progress", - width: 120, - render: (_: unknown, record: TuningJob) => { - if (!record.progress) return "-"; - const percent = Math.round( - (record.progress.step / record.progress.total_steps) * 100, - ); - return ( - - - - ); - }, + width: 130, + render: (_: unknown, record: TuningJob) => ( + + ), }, { - title: "Score", - key: "best_score", - responsive: ["lg" as const], + title: "Best TPS", + key: "best_tps", + width: 100, render: (_: unknown, record: TuningJob) => { - const score = - record.progress?.best_score_so_far ?? - (record.best_config?.score as number | undefined); - return typeof score === "number" ? ( - {score.toFixed(2)} + const metrics = record.best_config?.metrics as + | Record + | undefined; + const tps = metrics?.throughput_tps; + return tps ? ( + + {tps.toFixed(1)} + ) : ( - "-" + - ); }, }, @@ -536,12 +784,18 @@ export default function AutoTuning() { title: "Created", dataIndex: "created_at", key: "created_at", + width: 120, responsive: ["md" as const], - render: (date: string) => dayjs(date).fromNow(), + render: (date: string) => ( + + {dayjs(date).fromNow()} + + ), }, { title: "Actions", key: "actions", + width: 220, render: (_: unknown, record: TuningJob) => { const isRunning = [ "pending", @@ -551,49 +805,44 @@ export default function AutoTuning() { "benchmarking", ].includes(record.status); return ( - - - - - + )} + {record.status === "completed" && ( + <> + + + + )} + {record.status === "failed" && ( + + )} {isRunning && canEdit && ( )} {!isRunning && canEdit && ( handleDelete(record.id)} okText="Delete" @@ -609,67 +858,35 @@ export default function AutoTuning() { }, ]; - // Table columns for knowledge base const knowledgeColumns = [ { title: "Model", dataIndex: "model_name", key: "model_name", - render: (name: string, record: KnowledgeRecord) => ( -
- {name} - {!isMobile && ( - <> -
- - {record.model_family} - - - )} -
- ), + render: (name: string) => {name}, }, { title: "GPU", key: "gpu", responsive: ["md" as const], render: (_: unknown, record: KnowledgeRecord) => ( -
- - {record.gpu_count}x {record.gpu_model} - -
- - {record.total_vram_gb.toFixed(1)} GB - -
+ + {record.gpu_count}x {record.gpu_model} + ), }, { title: "Engine", dataIndex: "engine", key: "engine", - render: (engine: string, record: KnowledgeRecord) => ( - - {engine} - {record.quantization && ( - - {record.quantization} - - )} - - ), - }, - { - title: "TP", - dataIndex: "tensor_parallel", - key: "tensor_parallel", - responsive: ["lg" as const], + width: 100, + render: (engine: string) => {engine}, }, { title: "TPS", dataIndex: "throughput_tps", key: "throughput_tps", + width: 80, render: (v: number) => {v.toFixed(1)}, sorter: (a: KnowledgeRecord, b: KnowledgeRecord) => a.throughput_tps - b.throughput_tps, @@ -678,39 +895,28 @@ export default function AutoTuning() { title: "TTFT", dataIndex: "ttft_ms", key: "ttft_ms", + width: 80, responsive: ["sm" as const], render: (v: number) => `${v.toFixed(0)} ms`, - sorter: (a: KnowledgeRecord, b: KnowledgeRecord) => a.ttft_ms - b.ttft_ms, }, { title: "TPOT", dataIndex: "tpot_ms", key: "tpot_ms", + width: 80, responsive: ["md" as const], render: (v: number) => `${v.toFixed(1)} ms`, - sorter: (a: KnowledgeRecord, b: KnowledgeRecord) => a.tpot_ms - b.tpot_ms, - }, - { - title: "Score", - dataIndex: "score", - key: "score", - responsive: ["lg" as const], - render: (v: number | undefined) => - v ? {v.toFixed(2)} : "-", - sorter: (a: KnowledgeRecord, b: KnowledgeRecord) => - (a.score || 0) - (b.score || 0), }, ]; - // Online workers with GPUs - const availableWorkers = workers.filter( - (w) => w.status === "online" && w.gpu_info && w.gpu_info.length > 0, - ); + // -------------------------------------------------------------------------- + // Render + // -------------------------------------------------------------------------- return (
- {/* Stats Cards */} - + {/* Statistics Cards */} +
0 ? "#1890ff" : "#d9d9d9" }} - /> + runningJobs > 0 ? ( + + ) : ( + + ) } /> @@ -749,7 +957,8 @@ export default function AutoTuning() { title={ - Auto-Tuning Agent + Auto-Tuning + Bayesian Optimization } extra={ @@ -767,19 +976,22 @@ export default function AutoTuning() { )} } > - - Auto-Tuning Agent automatically finds the best deployment - configuration. Use the Chat Panel on the right to - interact with the agent, or start a job directly below. - + } + style={{ marginBottom: 16 }} + /> - Job History + Tuning Jobs ), children: ( @@ -797,9 +1009,8 @@ export default function AutoTuning() { columns={jobColumns} rowKey="id" loading={loading} - pagination={{ pageSize: 10 }} + pagination={{ pageSize: 10, showSizeChanger: false }} scroll={{ x: "max-content" }} - style={{ overflowX: "auto" }} locale={{ emptyText: ( } - onClick={() => setModalOpen(true)} + onClick={() => setCreateModalOpen(true)} > - Start Auto-Tuning + Create First Job )} @@ -829,65 +1040,60 @@ export default function AutoTuning() { ), children: ( -
- - Historical benchmark results used for configuration - recommendations. The agent uses this data to suggest optimal - configs for similar setups. - -
- ), - }} - /> - +
+ ), + }} + /> ), }, ]} /> - {/* Create Modal */} + {/* Create Job Modal */} - Start Auto-Tuning + New Auto-Tuning Job } - open={modalOpen} + open={createModalOpen} onCancel={() => { - setModalOpen(false); + setCreateModalOpen(false); form.resetFields(); - setLlmSourceType("deployment"); - setShowAddEndpoint(false); }} footer={null} - width={600} + width={520} > -
- {/* Model to tune */} + - {/* Worker */} - + {availableWorkers.length === 0 ? ( + + No workers with GPU available - ))} + ) : ( + availableWorkers.map((worker) => ( + + + {worker.name} + + {worker.gpu_info?.length || 0} GPU + + + + )) + )} - {/* Optimization Target */} - - - Throughput - Latency - Balanced - Cost - + + - - - - Agent LLM - - - - {/* Agent LLM Selection */} } + message="Parameters searched automatically" + description={ +
    +
  • Inference engine (vLLM, SGLang)
  • +
  • GPU memory utilization
  • +
  • Maximum concurrent sequences
  • +
  • Tensor parallelism (if multiple GPUs)
  • +
+ } + style={{ marginBottom: 24 }} /> - - setLlmSourceType(e.target.value)} - > - Local Deployment - Custom Endpoint - - - - {llmSourceType === "deployment" && ( - - - - )} - - {llmSourceType === "custom" && ( - <> - 0, - message: "Please select an endpoint", - }, - ]} - > - - - - {customEndpoints.length === 0 && !showAddEndpoint && ( - - )} - - {showAddEndpoint && ( - setShowAddEndpoint(false)} - > - Cancel - - } - style={{ marginBottom: 16 }} - > - - - - - - - - - - - - - - - - - )} - - )} - - - + + + -
- {/* Detail Modal - Docker-style Log View */} + {/* Job Detail Modal */} - - Tuning Log - {detailModal?.model_name} - - {detailModal?.status.toUpperCase()} - - {detailLoading && } + + Tuning Results + {detailModal && Job #{detailModal.id}} } open={!!detailModal} onCancel={() => setDetailModal(null)} - footer={null} - width={900} - styles={{ body: { padding: 0 } }} + footer={ + + {detailModal?.best_config && ( + + )} + + + } + width={800} > {detailModal && ( + setDetailModal(null)} + /> + )} + + + {/* Live Log Modal */} + + + Live Logs + {logModal && ( + <> + Job #{logModal.id} + + + )} + + } + open={!!logModal} + onCancel={() => setLogModal(null)} + footer={ + + + + + } + width={800} + > + {logModal && (
- {/* Docker-style Log Container */} -
- {detailModal.conversation_log && - detailModal.conversation_log.length > 0 ? ( - detailModal.conversation_log.map((msg, idx) => { - const timestamp = msg.timestamp - ? dayjs(msg.timestamp).format("HH:mm:ss") - : ""; - - if (msg.role === "user") { - return ( -
- [{timestamp}] - [USER] - {msg.content} -
- ); - } - - if (msg.role === "assistant") { - return ( -
- [{timestamp}] - [AGENT] - {msg.content && ( - - {msg.content} - - )} - {msg.tool_calls && msg.tool_calls.length > 0 && ( -
- {msg.tool_calls.map((tc, tcIdx) => ( -
- -> Calling: {tc.name}( - {(() => { - try { - const args = JSON.parse(tc.arguments); - return Object.entries(args) - .map( - ([k, v]) => `${k}=${JSON.stringify(v)}`, - ) - .join(", "); - } catch { - return tc.arguments; - } - })()} - ) -
- ))} -
- )} -
- ); - } - - if (msg.role === "tool") { - let content = msg.content; - try { - const parsed = JSON.parse(msg.content); - content = JSON.stringify(parsed, null, 2); - } catch { - // Keep original - } - return ( -
- [{timestamp}] - - {" "} - [TOOL:{msg.name}]{" "} - -
- {content} -
-
- ); - } - - return null; - }) - ) : ( -
- {detailLoading - ? "Loading logs..." - : detailModal.status === "pending" - ? "Waiting for agent to start..." - : "No logs available"} -
- )} + {/* Progress */} + + +
+ + {logModal.model_name} + {logModal.status_message} + + + + + + + - {/* Running indicator */} - {[ - "pending", - "analyzing", - "querying_kb", - "exploring", - "benchmarking", - ].includes(detailModal.status) && ( -
- - [{dayjs().format("HH:mm:ss")}] - - [STATUS] - - {detailModal.status_message || "Processing..."} - - {" "} - _ - - -
- )} - + {/* Logs */} + - {/* Best Config Section */} - {detailModal.best_config && ( -
- Best Configuration: -
-                  {JSON.stringify(detailModal.best_config, null, 2)}
-                
-
- )} +
+ + Logs auto-refresh every {REFRESH_INTERVAL / 1000} seconds + +
)} From e164931a502109a0c549861dd0c51905402cdbb4 Mon Sep 17 00:00:00 2001 From: rickychen-infinirc Date: Tue, 27 Jan 2026 19:52:06 +0800 Subject: [PATCH 2/3] feat: add MCP agent chat and tools - Add Agent Chat panel with streaming support - Add MCP client for backend - Add benchmark and web search tools - Refactor chat panel components --- backend/app/api/__init__.py | 6 +- backend/app/api/agent.py | 733 +++++++++++ backend/app/database.py | 52 + backend/app/models/conversation.py | 52 +- backend/app/services/mcp/__init__.py | 63 + backend/app/services/mcp/agent.py | 848 +++++++++++++ backend/app/services/mcp/client.py | 617 ++++++++++ backend/app/services/mcp/types.py | 185 +++ backend/app/services/tuning_agent.py | 190 ++- backend/requirements.txt | 1 + frontend/src/App.tsx | 436 +++---- .../components/chat-panel/AgentChatView.tsx | 1087 +++++++++++++++++ .../src/components/chat-panel/ChatPanel.tsx | 635 +++++----- .../components/chat-panel/TuningJobView.tsx | 873 ++++++++----- frontend/src/components/chat-panel/index.ts | 15 +- .../src/components/chat-panel/useAgentChat.ts | 761 ++++++++++++ frontend/src/components/chat/ChatInput.tsx | 10 +- frontend/src/contexts/ChatPanelContext.tsx | 145 +++ mcp-server/src/client.ts | 88 +- mcp-server/src/formatters.ts | 30 +- mcp-server/src/index.ts | 435 ++++++- mcp-server/src/tools/benchmark.ts | 373 ++++++ mcp-server/src/tools/webSearch.ts | 139 +++ 23 files changed, 6905 insertions(+), 869 deletions(-) create mode 100644 backend/app/api/agent.py create mode 100644 backend/app/services/mcp/__init__.py create mode 100644 backend/app/services/mcp/agent.py create mode 100644 backend/app/services/mcp/client.py create mode 100644 backend/app/services/mcp/types.py create mode 100644 frontend/src/components/chat-panel/AgentChatView.tsx create mode 100644 frontend/src/components/chat-panel/useAgentChat.ts create mode 100644 frontend/src/contexts/ChatPanelContext.tsx create mode 100644 mcp-server/src/tools/benchmark.ts create mode 100644 mcp-server/src/tools/webSearch.ts diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py index ebecf06..57cbe57 100644 --- a/backend/app/api/__init__.py +++ b/backend/app/api/__init__.py @@ -3,6 +3,7 @@ from fastapi import APIRouter from app.api import ( + agent, api_keys, apps, auth, @@ -68,5 +69,8 @@ # Chat Proxy for external endpoints api_router.include_router(chat_proxy.router, tags=["chat-proxy"]) -# Auto-Tuning Agent +# Auto-Tuning (legacy job-based API) api_router.include_router(auto_tuning.router, prefix="/auto-tuning", tags=["auto-tuning"]) + +# MCP-based AI Agent +api_router.include_router(agent.router, prefix="/agent", tags=["agent"]) diff --git a/backend/app/api/agent.py b/backend/app/api/agent.py new file mode 100644 index 0000000..98bbe34 --- /dev/null +++ b/backend/app/api/agent.py @@ -0,0 +1,733 @@ +""" +Agent Chat API + +Provides SSE streaming endpoints for the MCP-based AI agent. +This enables Claude Code-style interaction where users can see +the agent's thinking process and tool executions in real-time. + +Endpoints: + POST /agent/chat - Stream agent chat with SSE + POST /agent/chat/simple - Simple request-response chat + GET /agent/tools - List available tools + GET /agent/conversations - List user's Agent conversations + GET /agent/conversations/{id} - Get conversation details + DELETE /agent/conversations/{id} - Delete conversation + DELETE /agent/conversation - Clear conversation history (legacy) +""" + +import asyncio +import json +import logging +from datetime import datetime + +from fastapi import APIRouter, Depends, HTTPException, Query, Request +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.core.deps import require_viewer +from app.database import get_db +from app.models.conversation import Conversation, ConversationType, Message, MessageRole +from app.models.deployment import Deployment, DeploymentStatus +from app.models.user import User +from app.services.mcp import AgentService, EventType + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +# ============================================================================ +# Request/Response Models +# ============================================================================ + + +class LLMConfig(BaseModel): + """LLM configuration for the agent.""" + + provider: str = Field( + default="system", + description="LLM provider: 'system' (local deployment), 'openai', or 'custom'", + ) + deployment_id: int | None = Field( + default=None, description="Deployment ID when using 'system' provider" + ) + api_key: str | None = Field(default=None, description="API key for external providers") + base_url: str | None = Field(default=None, description="Base URL for custom provider") + model: str | None = Field(default=None, description="Model name to use") + + +class AgentChatRequest(BaseModel): + """Request body for agent chat.""" + + message: str = Field(..., description="User message to send to the agent") + llm_config: LLMConfig = Field(default_factory=LLMConfig, description="LLM configuration") + conversation_id: int | None = Field( + default=None, description="Database conversation ID to continue" + ) + + +class AgentChatSimpleRequest(BaseModel): + """Request body for simple (non-streaming) agent chat.""" + + message: str = Field(..., description="User message to send to the agent") + llm_config: LLMConfig = Field(default_factory=LLMConfig, description="LLM configuration") + + +class AgentChatSimpleResponse(BaseModel): + """Response for simple agent chat.""" + + response: str = Field(..., description="Agent's response") + conversation_history: list[dict] = Field( + default_factory=list, description="Full conversation history" + ) + + +class ToolInfo(BaseModel): + """Information about an available tool.""" + + name: str + description: str + parameters: list[dict] + + +class ToolsResponse(BaseModel): + """Response for listing tools.""" + + tools: list[ToolInfo] + total: int + + +class ConversationSummary(BaseModel): + """Summary of a conversation for list view.""" + + id: int + title: str + conversation_type: str + created_at: datetime + updated_at: datetime + message_count: int + + +class ConversationListResponse(BaseModel): + """Response for listing conversations.""" + + conversations: list[ConversationSummary] + total: int + + +class MessageDetail(BaseModel): + """Detail of a message in a conversation.""" + + id: int + role: str + content: str + thinking: str | None = None + tool_calls: list | None = None + tool_call_id: str | None = None + step_type: str | None = None + execution_time_ms: float | None = None + created_at: datetime + + +class ConversationDetail(BaseModel): + """Full conversation with messages.""" + + id: int + title: str + conversation_type: str + agent_config: dict | None = None + created_at: datetime + updated_at: datetime + messages: list[MessageDetail] + + +# ============================================================================ +# Active Agent Sessions +# ============================================================================ + +# Store active agent sessions (in production, use Redis or similar) +_active_sessions: dict[str, AgentService] = {} +_session_lock = asyncio.Lock() + + +async def get_or_create_agent( + session_id: str, + llm_config: LLMConfig, + db: AsyncSession, + api_token: str | None = None, +) -> AgentService: + """Get existing agent session or create a new one.""" + async with _session_lock: + if session_id in _active_sessions: + return _active_sessions[session_id] + + # Resolve deployment if using system provider + llm_base_url = None + llm_api_key = llm_config.api_key + llm_model = llm_config.model + + if llm_config.provider == "system" and llm_config.deployment_id: + # Look up the deployment + result = await db.execute( + select(Deployment) + .where(Deployment.id == llm_config.deployment_id) + .options(selectinload(Deployment.worker)) + ) + deployment = result.scalar_one_or_none() + + if not deployment: + raise HTTPException(status_code=404, detail="Deployment not found") + if deployment.status != DeploymentStatus.RUNNING.value: + raise HTTPException(status_code=400, detail="Deployment is not running") + + worker = deployment.worker + llm_base_url = f"http://{worker.host}:{deployment.port}/v1" + llm_api_key = "dummy" + llm_model = llm_config.model or "default" + + elif llm_config.provider == "openai": + llm_base_url = "https://api.openai.com/v1" + llm_model = llm_config.model or "gpt-4o" + + elif llm_config.provider == "custom": + llm_base_url = llm_config.base_url + llm_model = llm_config.model or "default" + + else: + raise HTTPException( + status_code=400, + detail=f"Invalid provider: {llm_config.provider}. " + "Use 'system' with deployment_id, 'openai' with api_key, or 'custom' with base_url", + ) + + # Create agent with MCP configuration + # The MCP server needs to call back to the LMStack API + from app.config import get_settings + + settings = get_settings() + mcp_api_url = f"http://localhost:{settings.port}/api" + + agent = AgentService( + llm_base_url=llm_base_url, + llm_api_key=llm_api_key, + llm_model=llm_model, + mcp_api_url=mcp_api_url, + mcp_api_token=api_token, + ) + await agent.initialize() + + _active_sessions[session_id] = agent + return agent + + +async def cleanup_session(session_id: str) -> None: + """Cleanup an agent session.""" + async with _session_lock: + agent = _active_sessions.pop(session_id, None) + if agent: + await agent.cleanup() + + +# ============================================================================ +# API Endpoints +# ============================================================================ + + +@router.post("/chat") +async def agent_chat_stream( + request: AgentChatRequest, + http_request: Request, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(require_viewer), +): + """ + Stream agent chat with Server-Sent Events (SSE). + + The agent will process the user's message and stream back events + including thinking process, tool executions, and responses. + + Messages are persisted to the database for conversation continuity. + + Event types: + - thinking: Agent is processing + - planning: Agent is planning actions + - message: Agent message to user + - tool_start: Starting tool execution + - tool_progress: Tool execution progress + - tool_result: Tool execution completed + - tool_error: Tool execution failed + - done: Agent finished + - error: Agent error + - cancelled: User cancelled + """ + from app.database import async_session_maker + + # Extract auth token from request to pass to MCP server + auth_header = http_request.headers.get("Authorization", "") + api_token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else None + + # Create or load conversation from database + conversation_id = request.conversation_id + if conversation_id: + # Load existing conversation + result = await db.execute( + select(Conversation) + .where(Conversation.id == conversation_id) + .where(Conversation.user_id == current_user.id) + ) + conversation = result.scalar_one_or_none() + if not conversation: + raise HTTPException(status_code=404, detail="Conversation not found") + else: + # Create new conversation + title = request.message[:50] + "..." if len(request.message) > 50 else request.message + conversation = Conversation( + user_id=current_user.id, + title=title, + conversation_type=ConversationType.AGENT.value, + agent_config={ + "llm_provider": request.llm_config.provider, + "llm_model": request.llm_config.model, + "deployment_id": request.llm_config.deployment_id, + }, + ) + db.add(conversation) + await db.commit() + await db.refresh(conversation) + conversation_id = conversation.id + + # Save user message to database + user_message = Message( + conversation_id=conversation_id, + role=MessageRole.USER.value, + content=request.message, + ) + db.add(user_message) + await db.commit() + + # Use conversation_id as session key + session_id = f"agent_conv_{conversation_id}" + + async def event_generator(): + agent = None + accumulated_content = "" + accumulated_thinking = "" + tool_calls_list = [] + + try: + agent = await get_or_create_agent(session_id, request.llm_config, db, api_token) + + # First, send conversation_id to client + init_event = json.dumps( + { + "type": "init", + "data": {"conversation_id": conversation_id}, + } + ) + yield f"data: {init_event}\n\n" + + async for event in agent.chat(request.message): + # Format as SSE + event_data = json.dumps(event.to_dict(), ensure_ascii=False) + yield f"data: {event_data}\n\n" + + # Accumulate content for database + if event.type == EventType.MESSAGE and event.content: + accumulated_content += event.content + elif event.type == EventType.THINKING and event.content: + accumulated_thinking += event.content + elif event.type == EventType.TOOL_START: + tool_calls_list.append( + { + "tool_name": event.data.get("tool_name") if event.data else None, + "arguments": event.data.get("arguments") if event.data else None, + "status": "running", + } + ) + elif event.type == EventType.TOOL_RESULT: + if tool_calls_list: + tool_calls_list[-1]["status"] = "completed" + tool_calls_list[-1]["result"] = ( + event.data.get("result") if event.data else None + ) + tool_calls_list[-1]["execution_time_ms"] = ( + event.data.get("execution_time_ms") if event.data else None + ) + elif event.type == EventType.TOOL_ERROR: + if tool_calls_list: + tool_calls_list[-1]["status"] = "error" + tool_calls_list[-1]["error"] = ( + event.data.get("error") if event.data else event.content + ) + + # Check for disconnect + if await http_request.is_disconnected(): + agent.cancel() + break + + # If done or error, save to database + if event.type in (EventType.DONE, EventType.ERROR, EventType.CANCELLED): + break + + except Exception as e: + logger.exception(f"Agent chat error: {e}") + error_event = json.dumps( + { + "type": "error", + "content": str(e), + } + ) + yield f"data: {error_event}\n\n" + accumulated_content = f"Error: {str(e)}" + + finally: + # Save assistant message to database using a fresh session + if accumulated_content or tool_calls_list: + try: + async with async_session_maker() as save_db: + assistant_message = Message( + conversation_id=conversation_id, + role=MessageRole.ASSISTANT.value, + content=accumulated_content or "No response", + thinking=accumulated_thinking if accumulated_thinking else None, + tool_calls=tool_calls_list if tool_calls_list else None, + ) + save_db.add(assistant_message) + await save_db.commit() + except Exception as save_error: + logger.error(f"Failed to save assistant message: {save_error}") + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + +@router.post("/chat/simple", response_model=AgentChatSimpleResponse) +async def agent_chat_simple( + request: AgentChatSimpleRequest, + http_request: Request, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(require_viewer), +): + """ + Simple (non-streaming) agent chat. + + Sends a message to the agent and waits for the complete response. + Useful for programmatic access or when streaming is not needed. + """ + # Extract auth token from request to pass to MCP server + auth_header = http_request.headers.get("Authorization", "") + api_token = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else None + + # Get MCP configuration + from app.config import get_settings + + settings = get_settings() + mcp_api_url = f"http://localhost:{settings.port}/api" + + # Create a temporary agent for this request + llm_base_url = None + llm_api_key = request.llm_config.api_key + llm_model = request.llm_config.model + + if request.llm_config.provider == "system" and request.llm_config.deployment_id: + result = await db.execute( + select(Deployment) + .where(Deployment.id == request.llm_config.deployment_id) + .options(selectinload(Deployment.worker)) + ) + deployment = result.scalar_one_or_none() + + if not deployment: + raise HTTPException(status_code=404, detail="Deployment not found") + if deployment.status != DeploymentStatus.RUNNING.value: + raise HTTPException(status_code=400, detail="Deployment is not running") + + worker = deployment.worker + llm_base_url = f"http://{worker.host}:{deployment.port}/v1" + llm_api_key = "dummy" + llm_model = request.llm_config.model or "default" + + elif request.llm_config.provider == "openai": + llm_base_url = "https://api.openai.com/v1" + llm_model = request.llm_config.model or "gpt-4o" + + elif request.llm_config.provider == "custom": + llm_base_url = request.llm_config.base_url + llm_model = request.llm_config.model or "default" + + else: + raise HTTPException( + status_code=400, + detail=f"Invalid provider: {request.llm_config.provider}", + ) + + try: + async with AgentService( + llm_base_url=llm_base_url, + llm_api_key=llm_api_key, + llm_model=llm_model, + mcp_api_url=mcp_api_url, + mcp_api_token=api_token, + ) as agent: + response = await agent.chat_simple(request.message) + history = agent.get_conversation_history() + + return AgentChatSimpleResponse( + response=response, + conversation_history=history, + ) + + except Exception as e: + logger.exception(f"Agent chat error: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/tools", response_model=ToolsResponse) +async def list_agent_tools( + current_user: User = Depends(require_viewer), +): + """ + List all tools available to the agent. + + Returns the list of MCP tools that the agent can use. + """ + try: + from app.services.mcp import MCPClient + + async with MCPClient() as client: + tools = await client.list_tools() + tool_infos = [ + ToolInfo( + name=t.name, + description=t.description, + parameters=[ + { + "name": p.name, + "type": p.type, + "description": p.description, + "required": p.required, + } + for p in t.parameters + ], + ) + for t in tools + ] + return ToolsResponse(tools=tool_infos, total=len(tool_infos)) + + except Exception as e: + logger.exception(f"Failed to list tools: {e}") + raise HTTPException(status_code=500, detail=f"Failed to list tools: {e}") + + +@router.delete("/conversation") +async def clear_conversation( + session_id: str = Query(..., description="Session ID to clear"), + current_user: User = Depends(require_viewer), +): + """ + Clear the conversation history for a session. + + This will reset the agent's memory for the specified session. + """ + async with _session_lock: + agent = _active_sessions.get(session_id) + if agent: + agent.reset() + return {"success": True, "message": "Conversation cleared"} + else: + return {"success": False, "message": "Session not found"} + + +@router.post("/cancel") +async def cancel_operation( + session_id: str = Query(..., description="Session ID to cancel"), + current_user: User = Depends(require_viewer), +): + """ + Cancel an ongoing agent operation. + + This will stop the current tool execution or LLM call. + """ + async with _session_lock: + agent = _active_sessions.get(session_id) + if agent: + agent.cancel() + return {"success": True, "message": "Operation cancelled"} + else: + return {"success": False, "message": "Session not found"} + + +# ============================================================================ +# Conversation Management Endpoints +# ============================================================================ + + +@router.get("/conversations", response_model=ConversationListResponse) +async def list_conversations( + conversation_type: str = Query( + default="agent", + description="Filter by conversation type: 'agent' or 'chat'", + ), + limit: int = Query( + default=50, ge=1, le=100, description="Maximum number of conversations to return" + ), + offset: int = Query(default=0, ge=0, description="Number of conversations to skip"), + db: AsyncSession = Depends(get_db), + current_user: User = Depends(require_viewer), +): + """ + List user's conversations. + + Returns a paginated list of conversations with summary info. + """ + # Count total + count_query = ( + select(Conversation) + .where(Conversation.user_id == current_user.id) + .where(Conversation.conversation_type == conversation_type) + ) + count_result = await db.execute(count_query) + total = len(count_result.all()) + + # Get conversations with message count + query = ( + select(Conversation) + .where(Conversation.user_id == current_user.id) + .where(Conversation.conversation_type == conversation_type) + .order_by(Conversation.updated_at.desc()) + .offset(offset) + .limit(limit) + .options(selectinload(Conversation.messages)) + ) + result = await db.execute(query) + conversations = result.scalars().all() + + summaries = [ + ConversationSummary( + id=conv.id, + title=conv.title, + conversation_type=conv.conversation_type, + created_at=conv.created_at, + updated_at=conv.updated_at, + message_count=len(conv.messages), + ) + for conv in conversations + ] + + return ConversationListResponse(conversations=summaries, total=total) + + +@router.get("/conversations/{conversation_id}", response_model=ConversationDetail) +async def get_conversation( + conversation_id: int, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(require_viewer), +): + """ + Get a conversation with all its messages. + + Returns the full conversation history including tool calls and thinking. + """ + result = await db.execute( + select(Conversation) + .where(Conversation.id == conversation_id) + .where(Conversation.user_id == current_user.id) + .options(selectinload(Conversation.messages)) + ) + conversation = result.scalar_one_or_none() + + if not conversation: + raise HTTPException(status_code=404, detail="Conversation not found") + + messages = [ + MessageDetail( + id=msg.id, + role=msg.role, + content=msg.content, + thinking=msg.thinking, + tool_calls=msg.tool_calls, + tool_call_id=msg.tool_call_id, + step_type=msg.step_type, + execution_time_ms=msg.execution_time_ms, + created_at=msg.created_at, + ) + for msg in conversation.messages + ] + + return ConversationDetail( + id=conversation.id, + title=conversation.title, + conversation_type=conversation.conversation_type, + agent_config=conversation.agent_config, + created_at=conversation.created_at, + updated_at=conversation.updated_at, + messages=messages, + ) + + +@router.delete("/conversations/{conversation_id}") +async def delete_conversation( + conversation_id: int, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(require_viewer), +): + """ + Delete a conversation and all its messages. + + This will permanently remove the conversation history. + """ + result = await db.execute( + select(Conversation) + .where(Conversation.id == conversation_id) + .where(Conversation.user_id == current_user.id) + ) + conversation = result.scalar_one_or_none() + + if not conversation: + raise HTTPException(status_code=404, detail="Conversation not found") + + # Delete conversation (messages will be cascade deleted) + await db.delete(conversation) + await db.commit() + + # Also cleanup any active session + session_id = f"agent_conv_{conversation_id}" + await cleanup_session(session_id) + + return {"success": True, "message": "Conversation deleted"} + + +@router.patch("/conversations/{conversation_id}") +async def update_conversation( + conversation_id: int, + title: str = Query(..., description="New title for the conversation"), + db: AsyncSession = Depends(get_db), + current_user: User = Depends(require_viewer), +): + """ + Update a conversation's title. + """ + result = await db.execute( + select(Conversation) + .where(Conversation.id == conversation_id) + .where(Conversation.user_id == current_user.id) + ) + conversation = result.scalar_one_or_none() + + if not conversation: + raise HTTPException(status_code=404, detail="Conversation not found") + + conversation.title = title + await db.commit() + + return {"success": True, "message": "Conversation updated"} diff --git a/backend/app/database.py b/backend/app/database.py index eaa925a..2f7ea8f 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -61,6 +61,58 @@ async def column_exists(table_name: str, column_name: str) -> bool: ) logger.info("'is_local' column added!") + # Migration: Add conversation_type to conversations (for Agent chat support) + if not await column_exists("conversations", "conversation_type"): + logger.info("Adding 'conversation_type' column to conversations table...") + await conn.execute( + text( + "ALTER TABLE conversations ADD COLUMN conversation_type VARCHAR(20) DEFAULT 'chat' NOT NULL" + ) + ) + logger.info("'conversation_type' column added!") + + # Migration: Add agent_config to conversations (for Agent configuration) + if not await column_exists("conversations", "agent_config"): + logger.info("Adding 'agent_config' column to conversations table...") + await conn.execute(text("ALTER TABLE conversations ADD COLUMN agent_config JSON")) + logger.info("'agent_config' column added!") + + # Migration: Add tool_calls to messages (for Agent tool calls) + if not await column_exists("messages", "tool_calls"): + logger.info("Adding 'tool_calls' column to messages table...") + await conn.execute(text("ALTER TABLE messages ADD COLUMN tool_calls JSON")) + logger.info("'tool_calls' column added!") + + # Migration: Add tool_call_id to messages (for Agent tool results) + if not await column_exists("messages", "tool_call_id"): + logger.info("Adding 'tool_call_id' column to messages table...") + await conn.execute(text("ALTER TABLE messages ADD COLUMN tool_call_id VARCHAR(100)")) + logger.info("'tool_call_id' column added!") + + # Migration: Add step_type to messages (for Agent execution steps) + if not await column_exists("messages", "step_type"): + logger.info("Adding 'step_type' column to messages table...") + await conn.execute(text("ALTER TABLE messages ADD COLUMN step_type VARCHAR(50)")) + logger.info("'step_type' column added!") + + # Migration: Add execution_time_ms to messages (for tool execution timing) + if not await column_exists("messages", "execution_time_ms"): + logger.info("Adding 'execution_time_ms' column to messages table...") + await conn.execute(text("ALTER TABLE messages ADD COLUMN execution_time_ms FLOAT")) + logger.info("'execution_time_ms' column added!") + + # Migration: Add tuning_config to tuning_jobs (for multi-framework testing) + if not await column_exists("tuning_jobs", "tuning_config"): + logger.info("Adding 'tuning_config' column to tuning_jobs table...") + await conn.execute(text("ALTER TABLE tuning_jobs ADD COLUMN tuning_config JSON")) + logger.info("'tuning_config' column added!") + + # Migration: Add conversation_id to tuning_jobs (for Agent Chat integration) + if not await column_exists("tuning_jobs", "conversation_id"): + logger.info("Adding 'conversation_id' column to tuning_jobs table...") + await conn.execute(text("ALTER TABLE tuning_jobs ADD COLUMN conversation_id INTEGER")) + logger.info("'conversation_id' column added!") + async def init_db(): """Initialize database tables and run migrations""" diff --git a/backend/app/models/conversation.py b/backend/app/models/conversation.py index 31603cf..468a006 100644 --- a/backend/app/models/conversation.py +++ b/backend/app/models/conversation.py @@ -1,13 +1,22 @@ """Conversation and Message database models""" from datetime import datetime +from enum import Enum -from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, func +from sqlalchemy import DateTime, Float, ForeignKey, Integer, String, Text, func +from sqlalchemy.dialects.sqlite import JSON from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database import Base +class ConversationType(str, Enum): + """Type of conversation""" + + CHAT = "chat" # Traditional chat with deployment + AGENT = "agent" # MCP-based agent chat + + class Conversation(Base): """Chat conversation""" @@ -19,11 +28,19 @@ class Conversation(Base): ) title: Mapped[str] = mapped_column(String(255), nullable=False) - # Optional: link to deployment used + # Conversation type: "chat" or "agent" + conversation_type: Mapped[str] = mapped_column( + String(20), default=ConversationType.CHAT.value, nullable=False + ) + + # Optional: link to deployment used (for chat type) deployment_id: Mapped[int | None] = mapped_column( Integer, ForeignKey("deployments.id"), nullable=True ) + # Agent configuration (for agent type) + agent_config: Mapped[dict | None] = mapped_column(JSON, nullable=True) + # Timestamps created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) updated_at: Mapped[datetime] = mapped_column( @@ -42,6 +59,25 @@ def __repr__(self) -> str: return f"" +class MessageRole(str, Enum): + """Role of message sender""" + + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" # For agent tool results + + +class MessageStepType(str, Enum): + """Type of agent execution step""" + + THINKING = "thinking" + PLANNING = "planning" + REASONING = "reasoning" + TOOL_CALL = "tool_call" + TOOL_RESULT = "tool_result" + MESSAGE = "message" + + class Message(Base): """Chat message within a conversation""" @@ -56,7 +92,7 @@ class Message(Base): ) # Message content - role: Mapped[str] = mapped_column(String(20), nullable=False) # 'user' or 'assistant' + role: Mapped[str] = mapped_column(String(20), nullable=False) # 'user', 'assistant', or 'tool' content: Mapped[str] = mapped_column(Text, nullable=False) # Optional: thinking content for assistant messages @@ -66,6 +102,16 @@ class Message(Base): prompt_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True) completion_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True) + # Agent-specific fields + tool_calls: Mapped[list | None] = mapped_column(JSON, nullable=True) # List of tool calls + tool_call_id: Mapped[str | None] = mapped_column(String(100), nullable=True) # For tool results + step_type: Mapped[str | None] = mapped_column( + String(50), nullable=True + ) # thinking, tool_call, etc. + execution_time_ms: Mapped[float | None] = mapped_column( + Float, nullable=True + ) # Tool execution time + # Timestamp created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) diff --git a/backend/app/services/mcp/__init__.py b/backend/app/services/mcp/__init__.py new file mode 100644 index 0000000..76075f8 --- /dev/null +++ b/backend/app/services/mcp/__init__.py @@ -0,0 +1,63 @@ +""" +MCP (Model Context Protocol) Module + +This module provides: +1. MCPClient - A Python client for communicating with MCP servers +2. AgentService - An AI agent that uses MCP for system interaction + +Usage: + # Direct MCP client usage + from app.services.mcp import MCPClient + + async with MCPClient() as client: + result = await client.call_tool("get_hardware_info", {"worker_id": 1}) + + # Agent service usage + from app.services.mcp import AgentService + + async with AgentService(...) as agent: + async for event in agent.chat("Deploy Qwen-7B on Worker 1"): + print(event) +""" + +from .agent import ( + AGENT_SYSTEM_PROMPT, + AgentEvent, + AgentService, + ConversationMessage, + EventType, + create_agent, +) +from .client import MCPClient, MCPClientPool +from .types import ( + MCPConnectionError, + MCPError, + MCPResource, + MCPTimeoutError, + MCPTool, + MCPToolError, + ToolCallResult, + ToolCallStatus, +) + +__all__ = [ + # Client + "MCPClient", + "MCPClientPool", + # Types + "MCPError", + "MCPConnectionError", + "MCPToolError", + "MCPTimeoutError", + "ToolCallResult", + "ToolCallStatus", + "MCPResource", + "MCPTool", + # Agent + "AgentService", + "AgentEvent", + "EventType", + "ConversationMessage", + "create_agent", + "AGENT_SYSTEM_PROMPT", +] diff --git a/backend/app/services/mcp/agent.py b/backend/app/services/mcp/agent.py new file mode 100644 index 0000000..016c541 --- /dev/null +++ b/backend/app/services/mcp/agent.py @@ -0,0 +1,848 @@ +""" +MCP-based Agent Service + +This module implements an AI agent that uses MCP to interact with the LMStack platform. +The agent can autonomously execute tasks like model deployment, benchmarking, and +configuration optimization through natural language interaction. + +The agent follows a Claude Code-style interaction pattern: +1. User provides a natural language request +2. Agent analyzes the request and plans actions +3. Agent executes actions via MCP tools, streaming progress +4. Agent provides results and recommendations + +Example: + async with AgentService() as agent: + async for event in agent.chat("Deploy Qwen-7B on Worker 1"): + print(event) +""" + +import asyncio +import json +import logging +from collections.abc import AsyncIterator, Callable +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any + +from .client import MCPClient +from .types import MCPTool + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Type Definitions +# ============================================================================ + + +class EventType(str, Enum): + """Types of events streamed from the agent.""" + + # Agent status events + THINKING = "thinking" # Agent is processing + PLANNING = "planning" # Agent is planning actions + REASONING = "reasoning" # Model's reasoning/thinking process + MESSAGE = "message" # Agent message to user + + # Tool execution events + TOOL_START = "tool_start" # Starting tool execution + TOOL_PROGRESS = "tool_progress" # Tool execution progress + TOOL_RESULT = "tool_result" # Tool execution completed + TOOL_ERROR = "tool_error" # Tool execution failed + + # UI events + PAGE_REFERENCE = "page_reference" # Reference to a LMStack page + ACTION_SUGGESTIONS = "action_suggestions" # Suggested actions user can click + + # Control events + DONE = "done" # Agent finished + ERROR = "error" # Agent error + CANCELLED = "cancelled" # User cancelled + + +@dataclass +class AgentEvent: + """An event emitted by the agent during execution.""" + + type: EventType + content: str | None = None + data: dict | None = None + timestamp: datetime = field(default_factory=datetime.now) + + def to_dict(self) -> dict: + return { + "type": self.type.value, + "content": self.content, + "data": self.data, + "timestamp": self.timestamp.isoformat(), + } + + +@dataclass +class ConversationMessage: + """A message in the conversation history.""" + + role: str # "user", "assistant", "tool" + content: str + tool_calls: list[dict] | None = None + tool_call_id: str | None = None + timestamp: datetime = field(default_factory=datetime.now) + + def to_dict(self) -> dict: + result = { + "role": self.role, + "content": self.content, + "timestamp": self.timestamp.isoformat(), + } + if self.tool_calls: + result["tool_calls"] = self.tool_calls + if self.tool_call_id: + result["tool_call_id"] = self.tool_call_id + return result + + def to_api_format(self) -> dict: + """Convert to OpenAI API message format.""" + result = {"role": self.role, "content": self.content} + if self.tool_calls: + result["tool_calls"] = self.tool_calls + if self.tool_call_id: + result["tool_call_id"] = self.tool_call_id + return result + + +# ============================================================================ +# System Prompt +# ============================================================================ + + +AGENT_SYSTEM_PROMPT = """You are LMStack AI Assistant, an expert in deploying and optimizing Large Language Model inference services. + +## Your Capabilities + +You can help users with: +1. **Model Deployment**: Deploy LLM models to GPU workers +2. **Performance Optimization**: Find optimal configurations for throughput or latency +3. **System Management**: Monitor workers, containers, and deployments +4. **Troubleshooting**: Diagnose and fix deployment issues + +## Available Tools + +You have access to MCP tools that let you: +- Query hardware information (GPUs, memory, workers) +- Deploy and manage models +- Run performance benchmarks +- Query the performance knowledge base +- Manage containers and storage + +## Workflow Guidelines + +### For Deployment Requests: +1. First check available workers with `list_workers` +2. Verify the model exists with `list_models` +3. Check GPU availability and memory +4. Deploy with appropriate configuration +5. Wait for deployment to complete +6. Verify the deployment is running + +### For Performance Tuning: +1. Analyze hardware with `get_gpu_status` +2. Query knowledge base for similar configurations +3. Consider optimization target (throughput vs latency) +4. Test configurations with benchmarks +5. Compare results and recommend best config + +### Engine Selection Guidelines: +- **vLLM**: Best for high throughput, supports tensor parallelism +- **SGLang**: Good for multi-turn conversations, prefix caching +- **Ollama**: Simple deployment, good for smaller models + +### Quantization Notes: +- FP16: Full precision, requires more VRAM +- FP8: Half the memory, requires Hopper+ GPUs (H100, H200) +- AWQ/GPTQ: 4-bit quantization, needs pre-quantized models + +## Response Style + +- **IMPORTANT: Always respond in the EXACT same language and script as the user's message.** + - If the user writes in Traditional Chinese (繁體中文), you MUST respond in Traditional Chinese (繁體中文). + - If the user writes in Simplified Chinese (简体中文), respond in Simplified Chinese. + - If the user writes in English, respond in English. + - Never mix scripts - if the user uses 「」for quotes, use 「」not "". +- Be concise but informative +- Show your reasoning when making decisions +- Report progress during long operations +- Provide clear recommendations with rationale +- Use markdown formatting for readability + +## Important Notes + +- Always verify operations completed successfully +- Check logs if deployments fail +- Consider VRAM constraints when selecting configurations +- Report any errors or issues clearly +""" + + +# ============================================================================ +# Agent Service +# ============================================================================ + + +class AgentService: + """ + AI Agent service that uses MCP for system interaction. + + This agent can: + - Process natural language requests + - Execute operations via MCP tools + - Stream progress updates to the client + - Maintain conversation context + """ + + def __init__( + self, + llm_client: Any = None, + llm_model: str = "gpt-4o", + llm_base_url: str | None = None, + llm_api_key: str | None = None, + mcp_api_url: str | None = None, + mcp_api_token: str | None = None, + max_iterations: int = 20, + ): + """ + Initialize the Agent Service. + + Args: + llm_client: Pre-configured LLM client (OpenAI compatible). + llm_model: Model to use for the agent. + llm_base_url: Base URL for LLM API. + llm_api_key: API key for LLM. + mcp_api_url: LMStack API URL for MCP server. + mcp_api_token: LMStack API token for MCP server. + max_iterations: Maximum tool call iterations. + """ + self.llm_client = llm_client + self.llm_model = llm_model + self.llm_base_url = llm_base_url + self.llm_api_key = llm_api_key + self.max_iterations = max_iterations + + # MCP configuration + self.mcp_api_url = mcp_api_url + self.mcp_api_token = mcp_api_token + + # State + self.mcp_client: MCPClient | None = None + self.conversation: list[ConversationMessage] = [] + self.tools: list[MCPTool] = [] + self._cancelled = False + self._lock = asyncio.Lock() + + async def __aenter__(self) -> "AgentService": + """Initialize the agent.""" + await self.initialize() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Cleanup the agent.""" + await self.cleanup() + + async def initialize(self) -> None: + """Initialize the agent and MCP connection.""" + # Initialize MCP client with LLM config for auto-tuning + self.mcp_client = MCPClient( + api_url=self.mcp_api_url, + api_token=self.mcp_api_token, + llm_base_url=self.llm_base_url, + llm_api_key=self.llm_api_key, + llm_model=self.llm_model, + ) + await self.mcp_client.connect() + + # Load available tools + self.tools = await self.mcp_client.list_tools() + logger.info(f"Agent initialized with {len(self.tools)} tools") + + # Initialize LLM client if not provided + if self.llm_client is None: + from openai import AsyncOpenAI + + self.llm_client = AsyncOpenAI( + base_url=self.llm_base_url, + api_key=self.llm_api_key or "dummy", + ) + + async def cleanup(self) -> None: + """Cleanup resources.""" + if self.mcp_client: + await self.mcp_client.disconnect() + self.mcp_client = None + + def cancel(self) -> None: + """Cancel the current operation.""" + self._cancelled = True + + def reset(self) -> None: + """Reset the agent state for a new conversation.""" + self.conversation = [] + self._cancelled = False + + def _build_tools_schema(self) -> list[dict]: + """Build OpenAI-compatible tools schema from MCP tools.""" + tools = [] + for tool in self.tools: + # Build parameters schema + properties = {} + required = [] + for param in tool.parameters: + prop = { + "type": param.type, + "description": param.description, + } + if param.enum: + prop["enum"] = param.enum + properties[param.name] = prop + if param.required: + required.append(param.name) + + tools.append( + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": properties, + "required": required, + }, + }, + } + ) + return tools + + def _get_page_reference(self, tool_name: str, tool_args: dict) -> dict | None: + """Get page reference for a tool execution.""" + # Map tools to their relevant pages + page_mappings = { + "list_workers": { + "path": "/workers", + "title": "Workers", + "icon": "cluster", + "description": "查看所有 GPU Worker 節點", + }, + "get_gpu_status": { + "path": "/workers", + "title": "Workers", + "icon": "cluster", + "description": "查看 GPU 狀態與記憶體使用", + }, + "list_deployments": { + "path": "/deployments", + "title": "Deployments", + "icon": "rocket", + "description": "查看所有模型部署", + }, + "deploy_model": { + "path": "/deployments", + "title": "Deployments", + "icon": "rocket", + "description": "查看部署狀態", + }, + "list_models": { + "path": "/models", + "title": "Models", + "icon": "database", + "description": "查看所有可用模型", + }, + "add_model": { + "path": "/models", + "title": "Models", + "icon": "database", + "description": "查看新增的模型", + }, + "list_containers": { + "path": "/containers", + "title": "Containers", + "icon": "container", + "description": "查看 Docker 容器", + }, + "get_system_status": { + "path": "/", + "title": "Dashboard", + "icon": "dashboard", + "description": "查看系統總覽", + }, + "list_api_keys": { + "path": "/api-keys", + "title": "API Keys", + "icon": "key", + "description": "管理 API 金鑰", + }, + } + return page_mappings.get(tool_name) + + def _get_action_suggestions( + self, tool_name: str, tool_result: str, tool_args: dict + ) -> list[dict] | None: + """Generate action suggestions based on tool results.""" + import re + + suggestions = [] + + # After listing models, suggest deployment actions and Auto-Tune + if tool_name == "list_models": + # Parse model IDs and names from the result + # Look for patterns like "**ModelName** (ID: N)" + model_matches = re.findall(r"\*\*([^*]+)\*\* \(ID: (\d+)\)", tool_result) + for model_name, model_id in model_matches[:2]: # Limit to 2 deploy suggestions + suggestions.append( + { + "label": f"部署 {model_name}", + "message": f"部署模型 {model_name} (ID: {model_id}) 到 worker 1,使用 vllm", + "icon": "rocket", + "type": "primary", + } + ) + + # Add Auto-Tune suggestion + if model_matches: + suggestions.append( + { + "label": "Auto-Tune 最佳配置", + "message": "為這些模型找出最佳部署配置,請先查詢 GPU 狀態,然後搜尋網路上的最佳配置建議", + "icon": "experiment", + "type": "default", + } + ) + + # After listing deployments, suggest stop actions for running ones + elif tool_name == "list_deployments": + # Look for running deployments + running_matches = re.findall( + r"## ([^\[]+) \[running\][\s\S]*?- \*\*ID:\*\* (\d+)", tool_result + ) + for dep_name, dep_id in running_matches[:2]: + suggestions.append( + { + "label": f"停止 {dep_name.strip()}", + "message": f"停止部署 ID {dep_id}", + "icon": "stop", + "type": "danger", + } + ) + + # After listing workers, suggest checking GPU status + elif tool_name == "list_workers": + suggestions.append( + { + "label": "查看 GPU 詳細狀態", + "message": "顯示所有 GPU 的詳細狀態和記憶體使用情況", + "icon": "monitor", + "type": "default", + } + ) + + # After get_gpu_status, suggest searching for optimal configs + elif tool_name == "get_gpu_status": + # Extract GPU model from result + gpu_model_match = re.search( + r"GPU \d+: (NVIDIA[^,\n]+|RTX[^,\n]+|A\d+[^,\n]+|H\d+[^,\n]+)", tool_result + ) + if gpu_model_match: + gpu_model = gpu_model_match.group(1).strip() + suggestions.append( + { + "label": f"搜尋 {gpu_model} 最佳配置", + "message": f"搜尋網路上 {gpu_model} 部署 LLM 的最佳配置和效能基準", + "icon": "search", + "type": "primary", + } + ) + suggestions.append( + { + "label": "列出可部署的模型", + "message": "列出所有可以部署的模型", + "icon": "unordered-list", + "type": "default", + } + ) + + # After web search, suggest deploying with found config + elif tool_name == "web_search" or tool_name == "search_llm_config": + suggestions.append( + { + "label": "使用此配置部署", + "message": "根據搜尋結果中的最佳配置部署模型", + "icon": "rocket", + "type": "primary", + } + ) + suggestions.append( + { + "label": "列出現有模型", + "message": "列出所有已註冊的模型", + "icon": "unordered-list", + "type": "default", + } + ) + + # After deploy_model, suggest checking status + elif tool_name == "deploy_model": + suggestions.append( + { + "label": "查看部署狀態", + "message": "列出所有部署狀態", + "icon": "check-circle", + "type": "primary", + } + ) + + return suggestions if suggestions else None + + def _build_messages(self, user_message: str) -> list[dict]: + """Build messages array for LLM API call.""" + messages = [{"role": "system", "content": AGENT_SYSTEM_PROMPT}] + + # Add conversation history + for msg in self.conversation: + messages.append(msg.to_api_format()) + + # Add current user message + messages.append({"role": "user", "content": user_message}) + + return messages + + async def chat( + self, + message: str, + on_event: Callable[[AgentEvent], None] | None = None, + ) -> AsyncIterator[AgentEvent]: + """ + Process a user message and stream agent responses. + + Args: + message: The user's message. + on_event: Optional callback for each event. + + Yields: + AgentEvent objects representing the agent's progress. + """ + self._cancelled = False + + # Add user message to conversation + user_msg = ConversationMessage(role="user", content=message) + self.conversation.append(user_msg) + + try: + # Initial thinking event + yield AgentEvent(type=EventType.THINKING, content="Analyzing your request...") + + iteration = 0 + while iteration < self.max_iterations: + if self._cancelled: + yield AgentEvent(type=EventType.CANCELLED, content="Operation cancelled") + return + + iteration += 1 + logger.debug(f"Agent iteration {iteration}") + + # Build messages and call LLM + messages = self._build_messages(message if iteration == 1 else "") + if iteration > 1: + # For continuation, don't add the user message again + messages = messages[:-1] + + tools_schema = self._build_tools_schema() + + try: + # Use streaming for better UX + stream = await self.llm_client.chat.completions.create( + model=self.llm_model, + messages=messages, + tools=tools_schema if tools_schema else None, + tool_choice="auto" if tools_schema else None, + stream=True, + ) + + # Accumulate streaming response + accumulated_content = "" + accumulated_reasoning = "" + accumulated_tool_calls: dict[int, dict] = {} + finish_reason = None + + async for chunk in stream: + if self._cancelled: + yield AgentEvent( + type=EventType.CANCELLED, content="Operation cancelled" + ) + return + + delta = chunk.choices[0].delta if chunk.choices else None + if not delta: + continue + + # Handle reasoning/thinking content (e.g., DeepSeek-R1) + reasoning_content = getattr(delta, "reasoning_content", None) + if reasoning_content: + accumulated_reasoning += reasoning_content + yield AgentEvent( + type=EventType.REASONING, + content=reasoning_content, + ) + + # Handle content streaming + if delta.content: + accumulated_content += delta.content + # Emit streaming content + yield AgentEvent( + type=EventType.MESSAGE, + content=delta.content, + ) + + # Handle tool calls (accumulate deltas) + if delta.tool_calls: + for tc_delta in delta.tool_calls: + idx = tc_delta.index + if idx not in accumulated_tool_calls: + accumulated_tool_calls[idx] = { + "id": "", + "type": "function", + "function": {"name": "", "arguments": ""}, + } + if tc_delta.id: + accumulated_tool_calls[idx]["id"] = tc_delta.id + if tc_delta.function: + if tc_delta.function.name: + accumulated_tool_calls[idx]["function"][ + "name" + ] = tc_delta.function.name + if tc_delta.function.arguments: + accumulated_tool_calls[idx]["function"][ + "arguments" + ] += tc_delta.function.arguments + + # Track finish reason + if chunk.choices and chunk.choices[0].finish_reason: + finish_reason = chunk.choices[0].finish_reason + + except Exception as e: + logger.error(f"LLM API error: {e}") + yield AgentEvent( + type=EventType.ERROR, + content=f"Failed to communicate with LLM: {str(e)}", + ) + return + + # Convert accumulated tool calls to list + tool_calls_list = [ + accumulated_tool_calls[i] for i in sorted(accumulated_tool_calls.keys()) + ] + + # Check for tool calls + if tool_calls_list: + # Add assistant message with tool calls to conversation + self.conversation.append( + ConversationMessage( + role="assistant", + content=accumulated_content, + tool_calls=tool_calls_list, + ) + ) + + # Content was already streamed above, no need to yield again + + # Execute tool calls + for tool_call in tool_calls_list: + if self._cancelled: + yield AgentEvent( + type=EventType.CANCELLED, content="Operation cancelled" + ) + return + + tool_name = tool_call["function"]["name"] + try: + tool_args = json.loads(tool_call["function"]["arguments"]) + except json.JSONDecodeError: + tool_args = {} + + # Emit tool start event + yield AgentEvent( + type=EventType.TOOL_START, + content=f"Executing {tool_name}", + data={ + "tool_name": tool_name, + "arguments": tool_args, + }, + ) + + # Execute the tool via MCP + result = await self.mcp_client.call_tool(tool_name, tool_args) + + # Emit tool result event + if result.success: + yield AgentEvent( + type=EventType.TOOL_RESULT, + content=f"Completed {tool_name}", + data={ + "tool_name": tool_name, + "result": result.result, + "execution_time_ms": result.execution_time_ms, + }, + ) + + # Emit page reference for relevant tools + page_ref = self._get_page_reference(tool_name, tool_args) + if page_ref: + yield AgentEvent( + type=EventType.PAGE_REFERENCE, + data=page_ref, + ) + + # Emit action suggestions based on tool results + suggestions = self._get_action_suggestions( + tool_name, result.result, tool_args + ) + if suggestions: + yield AgentEvent( + type=EventType.ACTION_SUGGESTIONS, + data={"suggestions": suggestions}, + ) + else: + yield AgentEvent( + type=EventType.TOOL_ERROR, + content=f"Failed {tool_name}: {result.error}", + data={ + "tool_name": tool_name, + "error": result.error, + }, + ) + + # Add tool result to conversation + self.conversation.append( + ConversationMessage( + role="tool", + content=( + result.result if result.success else f"Error: {result.error}" + ), + tool_call_id=tool_call["id"], + ) + ) + + # Continue to next iteration to process tool results + continue + + else: + # No tool calls, this is the final response + # Content was already streamed above + self.conversation.append( + ConversationMessage( + role="assistant", + content=accumulated_content, + ) + ) + + # Check if done + if finish_reason == "stop": + yield AgentEvent(type=EventType.DONE) + return + + # Max iterations reached + yield AgentEvent( + type=EventType.ERROR, + content="Maximum iterations reached. Please try a simpler request.", + ) + + except Exception as e: + logger.exception(f"Agent error: {e}") + yield AgentEvent( + type=EventType.ERROR, + content=f"An error occurred: {str(e)}", + ) + + async def chat_simple(self, message: str) -> str: + """ + Simple chat interface that returns the final response. + + Args: + message: The user's message. + + Returns: + The agent's final response. + """ + response_parts = [] + async for event in self.chat(message): + if event.type == EventType.MESSAGE: + response_parts.append(event.content or "") + elif event.type == EventType.ERROR: + return f"Error: {event.content}" + elif event.type == EventType.CANCELLED: + return "Operation cancelled" + + return "\n".join(response_parts) + + def get_conversation_history(self) -> list[dict]: + """Get the conversation history as a list of dicts.""" + return [msg.to_dict() for msg in self.conversation] + + +# ============================================================================ +# Factory Functions +# ============================================================================ + + +async def create_agent( + deployment_id: int | None = None, + provider: str = "system", + api_key: str | None = None, + base_url: str | None = None, + model: str | None = None, + mcp_api_url: str | None = None, + mcp_api_token: str | None = None, +) -> AgentService: + """ + Factory function to create an AgentService with appropriate configuration. + + Args: + deployment_id: LMStack deployment ID to use as LLM backend. + provider: LLM provider ("system", "openai", "custom"). + api_key: API key for the LLM provider. + base_url: Base URL for the LLM API. + model: Model name to use. + mcp_api_url: LMStack API URL for MCP. + mcp_api_token: LMStack API token for MCP. + + Returns: + Configured AgentService instance. + """ + # Determine LLM configuration based on provider + if provider == "system" and deployment_id: + # Use a local LMStack deployment + # Need to look up the deployment to get the endpoint + # For now, assume localhost with standard port + llm_base_url = f"http://localhost:8000/api/deployments/{deployment_id}/v1" + llm_api_key = "dummy" + llm_model = model or "default" + elif provider == "openai": + llm_base_url = "https://api.openai.com/v1" + llm_api_key = api_key + llm_model = model or "gpt-4o" + elif provider == "custom": + llm_base_url = base_url + llm_api_key = api_key or "dummy" + llm_model = model or "default" + else: + raise ValueError(f"Invalid provider: {provider}") + + agent = AgentService( + llm_base_url=llm_base_url, + llm_api_key=llm_api_key, + llm_model=llm_model, + mcp_api_url=mcp_api_url, + mcp_api_token=mcp_api_token, + ) + + await agent.initialize() + return agent diff --git a/backend/app/services/mcp/client.py b/backend/app/services/mcp/client.py new file mode 100644 index 0000000..19429e3 --- /dev/null +++ b/backend/app/services/mcp/client.py @@ -0,0 +1,617 @@ +""" +MCP Client Implementation + +A Python client for communicating with MCP (Model Context Protocol) servers. +Implements JSON-RPC 2.0 over stdio transport. + +This client can connect to any MCP-compliant server and execute tools, +read resources, and manage the connection lifecycle. + +Example: + async with MCPClient() as client: + # List available tools + tools = await client.list_tools() + + # Call a tool + result = await client.call_tool("deploy_model", { + "model_id": 1, + "worker_id": 1, + }) +""" + +import asyncio +import json +import logging +import os +import time +from collections.abc import AsyncIterator +from pathlib import Path + +from .types import ( + JSONRPCRequest, + JSONRPCResponse, + MCPConnectionError, + MCPProtocolError, + MCPResource, + MCPTimeoutError, + MCPTool, + ToolCallResult, + ToolCallStatus, +) + +logger = logging.getLogger(__name__) + + +class MCPClient: + """ + Async MCP Client for communicating with MCP servers. + + This client manages a subprocess running the MCP server and communicates + with it using JSON-RPC 2.0 over stdin/stdout. + + Attributes: + server_path: Path to the MCP server executable + env: Environment variables to pass to the server process + timeout: Default timeout for operations in seconds + """ + + # Default path to the LMStack MCP server + # Path: backend/app/services/mcp/client.py -> 5 parents up = lmstack/ -> mcp-server/dist/index.js + DEFAULT_SERVER_PATH = ( + Path(__file__).parent.parent.parent.parent.parent / "mcp-server" / "dist" / "index.js" + ) + + def __init__( + self, + server_path: str | Path | None = None, + api_url: str | None = None, + api_token: str | None = None, + timeout: float = 60.0, + llm_base_url: str | None = None, + llm_api_key: str | None = None, + llm_model: str | None = None, + ): + """ + Initialize the MCP client. + + Args: + server_path: Path to the MCP server. Defaults to LMStack MCP server. + api_url: LMStack API URL. Defaults to environment variable. + api_token: LMStack API token. Defaults to environment variable. + timeout: Default timeout for operations in seconds. + llm_base_url: LLM base URL to pass to MCP server for auto-tuning. + llm_api_key: LLM API key to pass to MCP server for auto-tuning. + llm_model: LLM model name to pass to MCP server for auto-tuning. + """ + self.server_path = Path(server_path) if server_path else self.DEFAULT_SERVER_PATH + self.timeout = timeout + + # Build environment for the MCP server + self.env = os.environ.copy() + if api_url: + self.env["LMSTACK_API_URL"] = api_url + if api_token: + self.env["LMSTACK_API_TOKEN"] = api_token + + # Pass LLM configuration for auto-tuning agent + if llm_base_url: + self.env["AGENT_LLM_BASE_URL"] = llm_base_url + if llm_api_key: + self.env["AGENT_LLM_API_KEY"] = llm_api_key + if llm_model: + self.env["AGENT_LLM_MODEL"] = llm_model + + # Connection state + self._process: asyncio.subprocess.Process | None = None + self._request_id = 0 + self._pending_requests: dict[int, asyncio.Future] = {} + self._read_task: asyncio.Task | None = None + self._connected = False + self._lock = asyncio.Lock() + + # Cached data + self._tools: list[MCPTool] | None = None + self._resources: list[MCPResource] | None = None + + @property + def connected(self) -> bool: + """Check if the client is connected to the MCP server.""" + return self._connected and self._process is not None + + async def __aenter__(self) -> "MCPClient": + """Async context manager entry.""" + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Async context manager exit.""" + await self.disconnect() + + async def connect(self) -> None: + """ + Connect to the MCP server. + + Starts the MCP server subprocess and initializes the connection. + + Raises: + MCPConnectionError: If connection fails. + """ + if self._connected: + return + + async with self._lock: + if self._connected: + return + + try: + # Validate server path + if not self.server_path.exists(): + raise MCPConnectionError( + f"MCP server not found at {self.server_path}. " + "Please run 'npm run build' in the mcp-server directory." + ) + + logger.info(f"Starting MCP server: {self.server_path}") + + # Start the MCP server process + self._process = await asyncio.create_subprocess_exec( + "node", + str(self.server_path), + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=self.env, + ) + + # Start the reader task + self._read_task = asyncio.create_task(self._read_responses()) + + # Initialize the connection + await self._initialize() + + self._connected = True + logger.info("MCP client connected successfully") + + except Exception as e: + await self._cleanup() + raise MCPConnectionError(f"Failed to connect to MCP server: {e}") from e + + async def disconnect(self) -> None: + """ + Disconnect from the MCP server. + + Cleanly shuts down the subprocess and cleans up resources. + """ + async with self._lock: + await self._cleanup() + + async def _cleanup(self) -> None: + """Clean up resources.""" + self._connected = False + + # Cancel pending requests + for future in self._pending_requests.values(): + if not future.done(): + future.cancel() + self._pending_requests.clear() + + # Cancel read task + if self._read_task: + self._read_task.cancel() + try: + await self._read_task + except asyncio.CancelledError: + pass + self._read_task = None + + # Terminate process + if self._process: + try: + self._process.terminate() + await asyncio.wait_for(self._process.wait(), timeout=5.0) + except asyncio.TimeoutError: + self._process.kill() + await self._process.wait() + except Exception as e: + logger.warning(f"Error terminating MCP server: {e}") + self._process = None + + # Clear cache + self._tools = None + self._resources = None + + async def _initialize(self) -> None: + """Initialize the MCP connection with handshake.""" + # Send initialize request + response = await self._send_request( + "initialize", + { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "lmstack-python-client", + "version": "1.0.0", + }, + }, + ) + + if not response.success: + raise MCPProtocolError(f"Initialize failed: {response.error}") + + # Send initialized notification + await self._send_notification("notifications/initialized", {}) + + logger.debug(f"MCP initialized: {response.result}") + + async def _send_request(self, method: str, params: dict | None = None) -> JSONRPCResponse: + """ + Send a JSON-RPC request and wait for the response. + + Args: + method: The method name to call. + params: Optional parameters for the method. + + Returns: + The JSON-RPC response. + + Raises: + MCPTimeoutError: If the request times out. + MCPConnectionError: If not connected. + """ + if not self._process or not self._process.stdin: + raise MCPConnectionError("Not connected to MCP server") + + self._request_id += 1 + request_id = self._request_id + + request = JSONRPCRequest(method=method, params=params, id=request_id) + + # Create future for response + future: asyncio.Future[JSONRPCResponse] = asyncio.Future() + self._pending_requests[request_id] = future + + try: + # Send request + message = json.dumps(request.to_dict()) + "\n" + self._process.stdin.write(message.encode()) + await self._process.stdin.drain() + + logger.debug(f"Sent request: {method} (id={request_id})") + + # Wait for response with timeout + response = await asyncio.wait_for(future, timeout=self.timeout) + return response + + except asyncio.TimeoutError: + self._pending_requests.pop(request_id, None) + raise MCPTimeoutError(f"Request timed out: {method}") + except Exception: + self._pending_requests.pop(request_id, None) + raise + + async def _send_notification(self, method: str, params: dict | None = None) -> None: + """ + Send a JSON-RPC notification (no response expected). + + Args: + method: The method name. + params: Optional parameters. + """ + if not self._process or not self._process.stdin: + raise MCPConnectionError("Not connected to MCP server") + + request = JSONRPCRequest(method=method, params=params, id=None) + message = json.dumps(request.to_dict()) + "\n" + self._process.stdin.write(message.encode()) + await self._process.stdin.drain() + + async def _read_responses(self) -> None: + """Background task to read responses from the MCP server.""" + if not self._process or not self._process.stdout: + return + + buffer = "" + try: + while True: + chunk = await self._process.stdout.read(4096) + if not chunk: + break + + buffer += chunk.decode() + + # Process complete lines + while "\n" in buffer: + line, buffer = buffer.split("\n", 1) + line = line.strip() + if not line: + continue + + try: + data = json.loads(line) + response = JSONRPCResponse.from_dict(data) + + # Match response to pending request + if response.id is not None: + future = self._pending_requests.pop(response.id, None) + if future and not future.done(): + future.set_result(response) + else: + logger.warning(f"No pending request for id: {response.id}") + else: + # This is a notification from the server + logger.debug(f"Received notification: {data}") + + except json.JSONDecodeError as e: + logger.warning(f"Invalid JSON from server: {e}") + except Exception as e: + logger.error(f"Error processing response: {e}") + + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"Error reading from MCP server: {e}") + + async def list_tools(self, use_cache: bool = True) -> list[MCPTool]: + """ + List all available tools from the MCP server. + + Args: + use_cache: Whether to use cached tool list. + + Returns: + List of available tools. + """ + if use_cache and self._tools is not None: + return self._tools + + response = await self._send_request("tools/list", {}) + if not response.success: + raise MCPProtocolError(f"Failed to list tools: {response.error}") + + tools = [MCPTool.from_dict(t) for t in response.result.get("tools", [])] + self._tools = tools + return tools + + async def list_resources(self, use_cache: bool = True) -> list[MCPResource]: + """ + List all available resources from the MCP server. + + Args: + use_cache: Whether to use cached resource list. + + Returns: + List of available resources. + """ + if use_cache and self._resources is not None: + return self._resources + + response = await self._send_request("resources/list", {}) + if not response.success: + raise MCPProtocolError(f"Failed to list resources: {response.error}") + + resources = [MCPResource.from_dict(r) for r in response.result.get("resources", [])] + self._resources = resources + return resources + + async def read_resource(self, uri: str) -> str: + """ + Read the content of a resource. + + Args: + uri: The resource URI (e.g., "lmstack://workers"). + + Returns: + The resource content as a string. + """ + response = await self._send_request("resources/read", {"uri": uri}) + if not response.success: + raise MCPProtocolError(f"Failed to read resource {uri}: {response.error}") + + contents = response.result.get("contents", []) + if contents: + return contents[0].get("text", "") + return "" + + async def call_tool(self, name: str, arguments: dict | None = None) -> ToolCallResult: + """ + Call a tool on the MCP server. + + Args: + name: The tool name. + arguments: Optional arguments for the tool. + + Returns: + The tool call result. + + Raises: + MCPToolError: If the tool execution fails. + """ + start_time = time.time() + + try: + response = await self._send_request( + "tools/call", + { + "name": name, + "arguments": arguments or {}, + }, + ) + + execution_time = (time.time() - start_time) * 1000 + + if not response.success: + error_msg = ( + response.error.get("message", "Unknown error") + if response.error + else "Unknown error" + ) + return ToolCallResult( + tool_name=name, + status=ToolCallStatus.ERROR, + error=error_msg, + execution_time_ms=execution_time, + ) + + # Parse the result + content = response.result.get("content", []) + result_text = "" + for item in content: + if item.get("type") == "text": + result_text += item.get("text", "") + + # Check if the result indicates an error + is_error = response.result.get("isError", False) + if is_error: + return ToolCallResult( + tool_name=name, + status=ToolCallStatus.ERROR, + error=result_text, + execution_time_ms=execution_time, + ) + + return ToolCallResult( + tool_name=name, + status=ToolCallStatus.SUCCESS, + result=result_text, + execution_time_ms=execution_time, + ) + + except MCPTimeoutError: + return ToolCallResult( + tool_name=name, + status=ToolCallStatus.ERROR, + error="Tool execution timed out", + execution_time_ms=(time.time() - start_time) * 1000, + ) + except Exception as e: + return ToolCallResult( + tool_name=name, + status=ToolCallStatus.ERROR, + error=str(e), + execution_time_ms=(time.time() - start_time) * 1000, + ) + + async def call_tool_streaming( + self, name: str, arguments: dict | None = None + ) -> AsyncIterator[dict]: + """ + Call a tool with streaming progress updates. + + This method yields progress updates while the tool is executing, + useful for long-running operations like deployments or benchmarks. + + Args: + name: The tool name. + arguments: Optional arguments for the tool. + + Yields: + Progress updates with status information. + """ + start_time = time.time() + + yield { + "type": "status", + "status": "executing", + "tool_name": name, + "arguments": arguments, + } + + try: + result = await self.call_tool(name, arguments) + + if result.success: + yield { + "type": "result", + "status": "success", + "tool_name": name, + "result": result.result, + "execution_time_ms": result.execution_time_ms, + } + else: + yield { + "type": "result", + "status": "error", + "tool_name": name, + "error": result.error, + "execution_time_ms": result.execution_time_ms, + } + + except Exception as e: + yield { + "type": "result", + "status": "error", + "tool_name": name, + "error": str(e), + "execution_time_ms": (time.time() - start_time) * 1000, + } + + def get_tool_by_name(self, name: str) -> MCPTool | None: + """ + Get a tool by its name from the cached tool list. + + Args: + name: The tool name. + + Returns: + The tool if found, None otherwise. + """ + if self._tools is None: + return None + for tool in self._tools: + if tool.name == name: + return tool + return None + + +class MCPClientPool: + """ + Pool of MCP clients for concurrent access. + + Useful when you need multiple simultaneous connections to the MCP server. + """ + + def __init__(self, size: int = 3, **client_kwargs): + """ + Initialize the client pool. + + Args: + size: Number of clients in the pool. + **client_kwargs: Arguments to pass to MCPClient constructor. + """ + self.size = size + self.client_kwargs = client_kwargs + self._clients: list[MCPClient] = [] + self._semaphore = asyncio.Semaphore(size) + self._index = 0 + self._lock = asyncio.Lock() + + async def __aenter__(self) -> "MCPClientPool": + """Initialize all clients.""" + for _ in range(self.size): + client = MCPClient(**self.client_kwargs) + await client.connect() + self._clients.append(client) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Disconnect all clients.""" + for client in self._clients: + await client.disconnect() + self._clients.clear() + + async def acquire(self) -> MCPClient: + """ + Acquire a client from the pool. + + Returns: + An available MCP client. + """ + await self._semaphore.acquire() + async with self._lock: + client = self._clients[self._index % len(self._clients)] + self._index += 1 + return client + + def release(self) -> None: + """Release a client back to the pool.""" + self._semaphore.release() diff --git a/backend/app/services/mcp/types.py b/backend/app/services/mcp/types.py new file mode 100644 index 0000000..43cc413 --- /dev/null +++ b/backend/app/services/mcp/types.py @@ -0,0 +1,185 @@ +""" +MCP Type Definitions + +This module contains all type definitions for the MCP client. +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class MCPError(Exception): + """Base exception for MCP-related errors.""" + + pass + + +class MCPConnectionError(MCPError): + """Raised when connection to MCP server fails.""" + + pass + + +class MCPToolError(MCPError): + """Raised when a tool execution fails.""" + + def __init__(self, tool_name: str, message: str, details: dict | None = None): + self.tool_name = tool_name + self.details = details or {} + super().__init__(f"Tool '{tool_name}' failed: {message}") + + +class MCPTimeoutError(MCPError): + """Raised when an MCP operation times out.""" + + pass + + +class MCPProtocolError(MCPError): + """Raised when there's a protocol-level error.""" + + pass + + +class ToolCallStatus(str, Enum): + """Status of a tool call.""" + + PENDING = "pending" + EXECUTING = "executing" + SUCCESS = "success" + ERROR = "error" + CANCELLED = "cancelled" + + +@dataclass +class ToolCallResult: + """Result of a tool call.""" + + tool_name: str + status: ToolCallStatus + result: Any = None + error: str | None = None + execution_time_ms: float = 0.0 + metadata: dict = field(default_factory=dict) + + @property + def success(self) -> bool: + return self.status == ToolCallStatus.SUCCESS + + def to_dict(self) -> dict: + return { + "tool_name": self.tool_name, + "status": self.status.value, + "result": self.result, + "error": self.error, + "execution_time_ms": self.execution_time_ms, + "metadata": self.metadata, + } + + +@dataclass +class MCPResource: + """Represents an MCP resource.""" + + uri: str + name: str + description: str + mime_type: str = "text/plain" + + @classmethod + def from_dict(cls, data: dict) -> "MCPResource": + return cls( + uri=data.get("uri", ""), + name=data.get("name", ""), + description=data.get("description", ""), + mime_type=data.get("mimeType", "text/plain"), + ) + + +@dataclass +class MCPToolParameter: + """Represents a parameter for an MCP tool.""" + + name: str + type: str + description: str + required: bool = False + enum: list[str] | None = None + default: Any = None + + +@dataclass +class MCPTool: + """Represents an MCP tool.""" + + name: str + description: str + parameters: list[MCPToolParameter] = field(default_factory=list) + requires_confirmation: bool = False + + @classmethod + def from_dict(cls, data: dict) -> "MCPTool": + params = [] + input_schema = data.get("inputSchema", {}) + properties = input_schema.get("properties", {}) + required = input_schema.get("required", []) + + for param_name, param_info in properties.items(): + params.append( + MCPToolParameter( + name=param_name, + type=param_info.get("type", "string"), + description=param_info.get("description", ""), + required=param_name in required, + enum=param_info.get("enum"), + default=param_info.get("default"), + ) + ) + + return cls( + name=data.get("name", ""), + description=data.get("description", ""), + parameters=params, + ) + + +@dataclass +class JSONRPCRequest: + """JSON-RPC 2.0 request.""" + + method: str + params: dict | None = None + id: int | str | None = None + jsonrpc: str = "2.0" + + def to_dict(self) -> dict: + result = {"jsonrpc": self.jsonrpc, "method": self.method} + if self.params is not None: + result["params"] = self.params + if self.id is not None: + result["id"] = self.id + return result + + +@dataclass +class JSONRPCResponse: + """JSON-RPC 2.0 response.""" + + id: int | str | None + result: Any = None + error: dict | None = None + jsonrpc: str = "2.0" + + @classmethod + def from_dict(cls, data: dict) -> "JSONRPCResponse": + return cls( + id=data.get("id"), + result=data.get("result"), + error=data.get("error"), + jsonrpc=data.get("jsonrpc", "2.0"), + ) + + @property + def success(self) -> bool: + return self.error is None diff --git a/backend/app/services/tuning_agent.py b/backend/app/services/tuning_agent.py index 311163e..9d4fead 100644 --- a/backend/app/services/tuning_agent.py +++ b/backend/app/services/tuning_agent.py @@ -93,22 +93,22 @@ === DIAGNOSING DEPLOYMENT ISSUES === When wait_for_deployment times out: -1. FIRST call test_deployment_endpoint to check if API is actually ready - - If ready=true: Great! Proceed to run_benchmark - - If ready=false: Continue to step 2 -2. Call get_deployment_logs to check container logs (use tail=100 or more) -3. Look for common patterns in logs: - - "Loading checkpoint shards" or "Loading model weights" - model is loading, keep waiting - - "INFO: Started server process" or "Uvicorn running" - vLLM is ready! - - "CUDA out of memory" - try quantization or fewer GPUs - - "Error" or "Exception" - check the error message -4. Based on logs, decide: - - If model loading: call test_deployment_endpoint every 30s until ready - - If OOM error: stop_deployment and try with quantization - - If other error: stop_deployment and try different engine/config - -DO NOT just give up on timeout - always test endpoint and check logs first! -A 0.6B model should load in 1-2 minutes, larger models (7B+) may take 5-10 minutes. +1. Call test_deployment_endpoint ONCE to check if API is ready + - If ready=true: Proceed to run_benchmark immediately + - If ready=false: Call get_deployment_logs ONCE +2. Based on logs, make a QUICK decision: + - If "Loading model" in logs: Wait ONE more time with wait_for_deployment(timeout_seconds=120) + - If any error: Call stop_deployment and try next config + - If unclear: Call stop_deployment and try next config + +STRICT RULES TO AVOID LOOPS: +- Maximum 2 calls to wait_for_deployment per config +- Maximum 2 calls to test_deployment_endpoint per config +- If deployment not ready after 2 waits, STOP it and move to next config +- Do NOT repeatedly check status - make a decision and move on! +- Small models (< 1B) should load in 60s, if not working after 2 attempts, skip it + +A 0.6B model should load in under 60 seconds. If it takes longer, something is wrong. === HANDLING LOW GPU MEMORY === If deploy_model fails with "GPU memory is low": @@ -416,6 +416,8 @@ def __init__(self, db: AsyncSession, job: TuningJob): self.db = db self.job = job self.created_deployments: list[int] = [] + self.benchmark_results: list[dict] = [] # Track completed benchmarks + self.hardware_checked: bool = False # Track if hardware was checked async def execute(self, tool_name: str, args: dict) -> str: """Execute a tool and return result as string""" @@ -431,6 +433,7 @@ async def execute(self, tool_name: str, args: dict) -> str: async def _tool_get_hardware_info(self, worker_id: int) -> dict: """Get hardware info for a worker""" + self.hardware_checked = True # Mark hardware as checked result = await self.db.execute(select(Worker).where(Worker.id == worker_id)) worker = result.scalar_one_or_none() @@ -707,6 +710,30 @@ async def _tool_wait_for_deployment( "error": "Deployment not found. It may have been deleted.", } + # Update job status to show deployment progress + elapsed = int(time.time() - start_time) + status_map = { + "pending": "Preparing deployment...", + "starting": "Loading model into GPU memory...", + "running": "Model ready!", + "error": "Deployment failed", + "stopped": "Deployment stopped", + } + status_desc = status_map.get(deployment.status, deployment.status) + self.job.status_message = f"Waiting for model: {status_desc} ({elapsed}s)" + self.job.progress = { + "step": self.job.progress.get("step", 0) if self.job.progress else 0, + "total_steps": ( + self.job.progress.get("total_steps", 15) if self.job.progress else 15 + ), + "step_name": "wait_for_deployment", + "step_description": f"Deployment #{deployment_id}: {status_desc}", + "deployment_status": deployment.status, + "deployment_message": deployment.status_message or "", + "elapsed_seconds": elapsed, + } + await self.db.commit() + if deployment.status == DeploymentStatus.RUNNING.value: return { "success": True, @@ -735,9 +762,9 @@ async def _tool_wait_for_deployment( "deployment_id": deployment_id, "error": f"Timeout after {timeout_seconds}s", "action_required": ( - f"1. Call get_deployment_logs({deployment_id}) to check what's happening\n" - f"2. If model is still loading, wait more with wait_for_deployment(timeout_seconds=300)\n" - f"3. If there's an error, call stop_deployment({deployment_id}) and try a different config" + f"1. Call test_deployment_endpoint({deployment_id}) to check if it's actually ready\n" + f"2. If not ready, call stop_deployment({deployment_id}) and try next config\n" + f"DO NOT wait again - move on to the next configuration!" ), } @@ -784,6 +811,18 @@ async def _tool_run_benchmark( output_tokens=output_tokens, ) + # Save successful benchmark results for tracking + if metrics.get("success"): + self.benchmark_results.append( + { + "deployment_id": deployment_id, + "engine": deployment.backend, + "gpu_indexes": deployment.gpu_indexes or [0], + "extra_params": deployment.extra_params or {}, + "metrics": metrics.get("metrics", {}), + } + ) + return metrics async def _tool_stop_deployment(self, deployment_id: int) -> dict: @@ -800,19 +839,42 @@ async def _tool_stop_deployment(self, deployment_id: int) -> dict: # Stop container if running if deployment.container_id: - deployer = DeployerService() - await deployer.stop(deployment_id) - - # Delete deployment record - await self.db.delete(deployment) + try: + deployer = DeployerService() + await deployer.stop(deployment_id) + except Exception as e: + logger.warning(f"Failed to stop container for deployment {deployment_id}: {e}") + + # Always update status to stopped first (in case delete fails) + deployment.status = DeploymentStatus.STOPPED.value + deployment.status_message = "Stopped by tuning agent" await self.db.commit() + # Then try to delete + try: + await self.db.delete(deployment) + await self.db.commit() + except Exception as e: + logger.warning(f"Failed to delete deployment {deployment_id}: {e}") + # Status is already stopped, so it's ok + if deployment_id in self.created_deployments: self.created_deployments.remove(deployment_id) return {"success": True, "message": f"Deployment {deployment_id} stopped and removed"} except Exception as e: logger.exception(f"Failed to stop deployment: {e}") + # Try to at least mark it stopped + try: + result = await self.db.execute( + select(Deployment).where(Deployment.id == deployment_id) + ) + deployment = result.scalar_one_or_none() + if deployment: + deployment.status = DeploymentStatus.STOPPED.value + await self.db.commit() + except Exception: + pass return {"success": False, "error": str(e)} async def _tool_check_deployment_status(self, deployment_id: int) -> dict: @@ -959,6 +1021,25 @@ async def _tool_finish_tuning( self, best_config: dict, reasoning: str, all_results: list | None = None ) -> dict: """Mark tuning as complete and save to knowledge base""" + # Validate that proper steps were completed + if not self.hardware_checked: + return { + "success": False, + "error": "Cannot finish tuning: You must call get_hardware_info first to check the GPU environment.", + "required_action": "Call get_hardware_info(worker_id=...) before finishing.", + } + + if not self.benchmark_results and not all_results: + return { + "success": False, + "error": "Cannot finish tuning: No benchmark results found. You must run at least one benchmark.", + "required_action": "Deploy a model, run run_benchmark(), then call finish_tuning with the results.", + } + + # Use tracked benchmark results if all_results not provided + if not all_results: + all_results = self.benchmark_results + # Update job status self.job.status = TuningJobStatus.COMPLETED.value self.job.status_message = "Auto-tuning completed successfully" @@ -1340,10 +1421,20 @@ async def run_tuning_agent(job_id: int, llm_config: dict | None = None): # Initialize OpenAI client (supports OpenAI-compatible endpoints) client = AsyncOpenAI(api_key=api_key, base_url=base_url or "https://api.openai.com/v1") - # Build initial user message - user_message = f"""Help me find the best deployment configuration for {job.model.name} on {job.worker.name}. I want to optimize for {job.optimization_target}. + # Build initial user message with explicit steps + user_message = f"""Find the optimal deployment configuration for {job.model.name} on {job.worker.name}. +Optimization target: {job.optimization_target} +Model ID: {job.model_id}, Worker ID: {job.worker_id} + +REQUIRED STEPS (you must complete all of these): +1. Call get_hardware_info(worker_id={job.worker_id}) to check GPU specs +2. Call query_knowledge_base() to check historical data +3. Deploy the model with deploy_model() and wait for it +4. Run run_benchmark() to test performance +5. Stop the deployment and optionally test other configurations +6. Call finish_tuning() with best_config and all benchmark results -Model ID: {job.model_id}, Worker ID: {job.worker_id}""" +Start with Step 1: get_hardware_info""" messages = [ {"role": "system", "content": AGENT_SYSTEM_PROMPT}, @@ -1370,8 +1461,8 @@ async def save_log(): job.conversation_log = conversation_log await db.commit() - # Agent loop - max_iterations = 20 + # Agent loop - limit iterations to prevent infinite loops + max_iterations = 15 iteration = 0 while iteration < max_iterations: @@ -1387,11 +1478,20 @@ async def save_log(): # Call LLM logger.info(f"Agent iteration {iteration}, calling LLM with model: {model_name}...") + # Force tool calls if essential steps not completed + # Use "required" to ensure tool is called when needed + if not executor.hardware_checked or ( + not executor.benchmark_results and iteration < 10 + ): + tool_choice = "required" + else: + tool_choice = "auto" + response = await client.chat.completions.create( model=model_name, messages=messages, tools=get_agent_tools(), - tool_choice="auto", + tool_choice=tool_choice, max_tokens=4096, ) @@ -1419,19 +1519,23 @@ async def save_log(): # Check if no tool calls - prompt the agent to take action if not assistant_message.tool_calls: logger.warning(f"Agent responded without tool calls at iteration {iteration}") - # Add a user message to prompt the agent to take action - prompt_message = ( - "You need to call a tool to proceed. Available actions:\n" - "1. list_deployments - Find existing deployments on the worker\n" - "2. stop_deployment - Stop a deployment to free GPU memory\n" - "3. deploy_model - Deploy a model with specific config\n" - "4. test_deployment_endpoint - Check if deployment is ready\n" - "5. get_deployment_logs - Check container logs\n" - "6. run_benchmark - Run performance benchmark\n" - "7. finish_tuning - Complete with recommendation\n" - "8. abort_tuning - Abort if cannot proceed\n" - "Do not respond with just text - you must call a tool." - ) + # Build a context-aware prompt based on current state + if not executor.hardware_checked: + prompt_message = ( + f"You must call get_hardware_info(worker_id={job.worker_id}) first " + "to check the GPU environment before proceeding." + ) + elif not executor.benchmark_results: + prompt_message = ( + "You must run at least one benchmark before finishing. " + f"Call deploy_model(model_id={job.model_id}, worker_id={job.worker_id}, engine='vllm') " + "to deploy the model, then run run_benchmark() after it's ready." + ) + else: + prompt_message = ( + "You have benchmark results. Call finish_tuning() with the best configuration " + "to complete the tuning process." + ) messages.append({"role": "user", "content": prompt_message}) conversation_log.append( { diff --git a/backend/requirements.txt b/backend/requirements.txt index 363d65a..0729e40 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -10,3 +10,4 @@ python-multipart>=0.0.6 python-jose[cryptography]>=3.3.0 email-validator>=2.0.0 psutil>=5.9.0 +optuna>=3.5.0 diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 6bc12a8..a8fdf5d 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -33,13 +33,13 @@ import { } from "@ant-design/icons"; import { AuthProvider, useAuth } from "./contexts/AuthContext"; +import { ChatPanelProvider } from "./contexts/ChatPanelContext"; import { useAppTheme, useResponsive } from "./hooks"; import { Header, Sidebar, MobileSidebar } from "./components/layout"; import { ChatPanel, CHAT_PANEL_STORAGE_KEY, DEFAULT_PANEL_WIDTH, - TUNING_JOB_EVENT_KEY, } from "./components/chat-panel"; import Loading from "./components/Loading"; @@ -287,45 +287,6 @@ function AppLayout() { } }, [chatPanelOpen]); - // Listen for tuning job events to auto-open chat panel - useEffect(() => { - const handleStorageChange = (e: StorageEvent) => { - if (e.key === TUNING_JOB_EVENT_KEY && e.newValue) { - try { - const data = JSON.parse(e.newValue); - if (data.jobId) { - setChatPanelOpen(true); - } - } catch { - // Ignore - } - } - }; - - // Check on mount if there's a pending tuning job - const checkInitial = () => { - const stored = localStorage.getItem(TUNING_JOB_EVENT_KEY); - if (stored) { - try { - const data = JSON.parse(stored); - if ( - data.jobId && - data.timestamp && - Date.now() - data.timestamp < 5000 - ) { - setChatPanelOpen(true); - } - } catch { - // Ignore - } - } - }; - - checkInitial(); - window.addEventListener("storage", handleStorageChange); - return () => window.removeEventListener("storage", handleStorageChange); - }, []); - useEffect(() => { document.body.setAttribute("data-theme", isDark ? "dark" : "light"); }, [isDark]); @@ -387,211 +348,216 @@ function AppLayout() { ]; return ( - + - - {isMobile && ( - setMobileDrawerOpen(false)} - menuItems={menuItems} - selectedKey={location.pathname} - openKeys={openKeys} - onNavigate={handleNavigate} - isDark={isDark} - colors={colors} - /> - )} - - {!isMobile && ( - - )} - - -
setMobileDrawerOpen(true)} - /> - + + {isMobile && ( + setMobileDrawerOpen(false)} + menuItems={menuItems} + selectedKey={location.pathname} + openKeys={openKeys} + onNavigate={handleNavigate} + isDark={isDark} + colors={colors} + /> + )} + + {!isMobile && ( + + )} + + - - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - - - - } - /> - - - - } - /> - - - - } - /> - - - - - {/* Floating chat button */} - {!chatPanelOpen && ( - - - + )} + + {/* Clear button */} + {hasMessages && ( + - - )} - {messages.length > 0 && ( - -