diff --git a/apps/backend/services/orchestrator.py b/apps/backend/services/orchestrator.py index f3d33e3..05bec3c 100755 --- a/apps/backend/services/orchestrator.py +++ b/apps/backend/services/orchestrator.py @@ -6,6 +6,8 @@ from datetime import datetime from typing import Any, Dict, List, Optional +import yaml + from sqlalchemy.orm import Session from ...core.credential_service import CredentialService, StoredCredential @@ -68,6 +70,9 @@ def generate_leaderboard(self, query: LeaderboardQuery, user_id: Optional[int] = plan_metadata = self.planner_agent.create_evaluation_plan( query.query, secure_models ) + plan_metadata = self._apply_user_sample_size( + plan_metadata, query.sample_size + ) plan_metadata = self._attach_credential_references( plan_metadata, stored_credentials @@ -408,6 +413,26 @@ def get_leaderboard_by_criteria( total_models=len(leaderboard_entries) ) + @staticmethod + def _apply_user_sample_size( + plan_metadata: Dict[str, Any], sample_size: int + ) -> Dict[str, Any]: + if plan_metadata.get("config"): + plan_metadata["config"]["sample_size"] = sample_size + + plan_yaml_raw = plan_metadata.get("plan_yaml", "") + if plan_yaml_raw: + plan_dict = yaml.safe_load(plan_yaml_raw) + if plan_dict.get("metadata"): + plan_dict["metadata"]["sample_size"] = sample_size + for ds in plan_dict.get("datasets", []): + ds["sample_size"] = sample_size + plan_metadata["plan_yaml"] = yaml.dump( + plan_dict, default_flow_style=False, allow_unicode=True + ) + + return plan_metadata + def _default_plan_config(self) -> PlanConfig: """Create a safe default plan configuration.""" return PlanConfig(