diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..39f46d7 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,15 @@ +[run] +source = src/llama_prompt_ops +omit = + */tests/* + */site-packages/* + setup.py + +[report] +exclude_lines = + pragma: no cover + def __repr__ + raise NotImplementedError + if __name__ == .__main__.: + pass + raise ImportError diff --git a/frontend/README.md b/frontend/README.md new file mode 100644 index 0000000..d976ead --- /dev/null +++ b/frontend/README.md @@ -0,0 +1,189 @@ +# Llama Prompt Ops - Frontend + +A modern React frontend interface for [llama-prompt-ops](https://github.com/meta-llama/llama-prompt-ops), providing an intuitive web interface for prompt optimization workflows. + +## Features + +- **Prompt Enhancement**: Optimize prompts for better performance with Llama models +- **Prompt Migration**: Migrate prompts between different model architectures +- **Real-time Optimization**: Monitor optimization progress with live updates +- **Dataset Management**: Upload and manage datasets for optimization +- **Configuration Management**: Flexible configuration for different optimization strategies +- **Clean UI**: Modern, accessible interface with Meta's design language + +## Technology Stack + +- **Frontend**: React 18 + TypeScript +- **UI Components**: Radix UI + shadcn/ui +- **Styling**: Tailwind CSS with Meta/Facebook design system +- **Build Tool**: Vite +- **Backend**: FastAPI with llama-prompt-ops integration + +## Quick Start + +### Prerequisites + +- **Node.js 18+** and npm +- **Python 3.8+** (for backend) +- **OpenRouter API Key** (get one at [OpenRouter](https://openrouter.ai/)) + +### Installation + +1. **Clone the repository** + ```bash + git clone https://github.com/meta-llama/llama-prompt-ops.git + cd llama-prompt-ops/frontend + ``` + +2. **Install frontend dependencies** + ```bash + npm install + ``` + +3. **Set up backend environment** + ```bash + cd backend + python -m venv venv + source venv/bin/activate # On Windows: venv\Scripts\activate + pip install -r requirements.txt + ``` + +4. **Configure environment variables** + + Create a `.env` file in the `frontend/backend` directory: + ```bash + # In frontend/backend/.env + OPENROUTER_API_KEY=your_openrouter_api_key_here + OPENAI_API_KEY=your_openai_api_key_here # Optional: for fallback enhance feature + ``` + +### Running the Application + +#### Option 1: Use the Development Script (Recommended) +```bash +# From the frontend directory +chmod +x start-dev.sh +./start-dev.sh +``` + +#### Option 2: Manual Start +```bash +# Terminal 1: Start backend +cd backend +source venv/bin/activate +python -m uvicorn main:app --reload --port 8000 + +# Terminal 2: Start frontend +cd .. +npm run dev +``` + +The application will be available at: +- **Frontend**: http://localhost:8080 +- **Backend API**: http://localhost:8000 + +### First Run + +1. **Upload a dataset**: Click "Manage Dataset" and upload a JSON file with your training data +2. **Configure optimization**: Select your preferred model, metrics, and optimization strategy +3. **Enter your prompt**: Paste your existing prompt in the text area +4. **Click "Optimize"**: Watch the real-time progress and get your optimized prompt! + +## Dataset Format + +Upload JSON files in this format: +```json +[ + { + "question": "Your input query here", + "answer": "Expected response here" + }, + { + "question": "Another input query", + "answer": "Another expected response" + } +] +``` + +## Troubleshooting + +### Common Issues + +**Backend won't start:** +- Ensure you've activated the virtual environment +- Check that all requirements are installed: `pip install -r requirements.txt` +- Verify your API keys are set in the `.env` file + +**Frontend can't connect to backend:** +- Make sure the backend is running on port 8000 +- Check browser console for CORS errors +- Verify the backend URL in the frontend code + +**Optimization fails:** +- Check that you've uploaded a valid dataset +- Verify your OpenRouter API key is correct +- Ensure your dataset has the expected format + +**Port already in use:** +- Kill existing processes: `pkill -f "uvicorn\|vite"` +- Or use different ports in the configuration + +## Development + +### Frontend Development + +```bash +# Start with hot reload +npm run dev + +# Build for production +npm run build + +# Preview production build +npm run preview + +# Lint code +npm run lint +``` + +### Backend Development + +```bash +# Start with auto-reload +uvicorn main:app --reload --port 8000 + +# Run with debug logging +uvicorn main:app --reload --port 8000 --log-level debug +``` + +## Project Structure + +``` +frontend/ +├── backend/ # FastAPI backend +│ ├── main.py # API server +│ ├── requirements.txt # Python dependencies +│ └── uploaded_datasets/ # Dataset storage +├── src/ +│ ├── components/ # React components +│ │ ├── ui/ # Reusable UI components +│ │ ├── ConfigurationPanel.tsx +│ │ ├── PromptInput.tsx +│ │ └── ... +│ ├── context/ # React context +│ ├── hooks/ # Custom hooks +│ └── pages/ # Page components +├── package.json +└── start-dev.sh # Development startup script +``` + +## Contributing + +1. Follow the existing code style and patterns +2. Add tests for new features +3. Update documentation for any changes +4. Ensure the application builds and runs successfully + +## License + +This project is licensed under the same terms as llama-prompt-ops. diff --git a/frontend/backend/README.md b/frontend/backend/README.md new file mode 100644 index 0000000..b8952fa --- /dev/null +++ b/frontend/backend/README.md @@ -0,0 +1,79 @@ +# Llama Prompt Ops Frontend Backend + +This is a FastAPI backend for the llama-prompt-ops frontend interface. It provides API endpoints for optimizing prompts using OpenAI's GPT models and the llama-prompt-ops library. + +## Setup + +1. Create a virtual environment: +```bash +python -m venv venv +source venv/bin/activate # On Windows: venv\Scripts\activate +``` + +2. Install dependencies: +```bash +pip install -r requirements.txt +``` + +3. Set up your environment variables: + - Copy `.env.example` to `.env` (if available) + - Add your OpenAI API key and OpenRouter API key to the `.env` file + +## Running the Server + +Start the FastAPI server with: +```bash +uvicorn main:app --reload --port 8000 +``` + +The API will be available at http://localhost:8000 + +## API Endpoints + +### POST /api/enhance-prompt + +Enhances a prompt using OpenAI's GPT model. + +**Request Body:** +```json +{ + "prompt": "Your prompt text here" +} +``` + +**Response:** +```json +{ + "optimizedPrompt": "Enhanced prompt text" +} +``` + +### POST /api/migrate-prompt + +Optimizes a prompt using the llama-prompt-ops library. + +**Request Body:** +```json +{ + "prompt": "Your prompt text here", + "config": { + "taskModel": "Llama 3.3 70B", + "proposerModel": "Llama 3.1 8B", + "optimizer": "MiPro", + "dataset": "Q&A", + "metrics": "Exact Match", + "useLlamaTips": true + } +} +``` + +**Response:** +```json +{ + "optimizedPrompt": "Optimized prompt text" +} +``` + +## Integration with llama-prompt-ops + +This backend serves as a development interface for the llama-prompt-ops library, providing web API access to prompt optimization features. When this frontend is eventually integrated into the main llama-prompt-ops repository, this backend functionality will be incorporated into the library's core API structure. diff --git a/frontend/backend/config.py b/frontend/backend/config.py new file mode 100644 index 0000000..0bd3620 --- /dev/null +++ b/frontend/backend/config.py @@ -0,0 +1,149 @@ +""" +Configuration mappings and settings for the backend API. +""" + +import os + +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# Available models (these would be configured based on your available models) +MODEL_MAPPING = { + "Llama 3.3 70B": "meta-llama/llama-3.3-70b-instruct", + "Llama 3.1 8B": "meta-llama/llama-3.1-8b-instruct", + "Llama 3.1 70B": "meta-llama/llama-3.1-70b-instruct", + "GPT-4o": "openai/gpt-4o", + "GPT-4o-mini": "openai/gpt-4o-mini", +} + +# Available metrics from llama-prompt-ops +METRIC_MAPPING = { + "exact_match": { + "class": "llama_prompt_ops.core.metrics.ExactMatchMetric", + "params": {"output_field": "answer"}, + }, + "semantic_similarity": { + "class": "llama_prompt_ops.core.metrics.DSPyMetricAdapter", + "params": { + "signature_name": "similarity", + "score_range": (1, 10), + "normalize_to": (0, 1), + }, + }, + "correctness": { + "class": "llama_prompt_ops.core.metrics.DSPyMetricAdapter", + "params": { + "signature_name": "correctness", + "score_range": (1, 10), + "normalize_to": (0, 1), + }, + }, + "json_structured": { + "class": "llama_prompt_ops.core.metrics.StandardJSONMetric", + "params": { + "output_field": "answer", + "evaluation_mode": "selected_fields_comparison", + "strict_json": False, + }, + }, + # Legacy mappings for backward compatibility + "Facility Support": { + "class": "llama_prompt_ops.core.metrics.FacilityMetric", + "params": {"output_field": "answer", "strict_json": False}, + }, + "HotpotQA": { + "class": "llama_prompt_ops.datasets.hotpotqa.HotpotQAMetric", + "params": {"output_field": "answer"}, + }, + "Standard JSON": { + "class": "llama_prompt_ops.core.metrics.StandardJSONMetric", + "params": {"output_field": "answer"}, + }, + "Exact Match": { + "class": "llama_prompt_ops.core.metrics.ExactMatchMetric", + "params": {}, + }, +} + +# Available dataset adapters from llama-prompt-ops +DATASET_ADAPTER_MAPPING = { + "standard_json": { + "adapter_class": "llama_prompt_ops.core.datasets.ConfigurableJSONAdapter", + "description": "Standard JSON format with customizable field mappings", + "example_fields": {"input": "string", "output": "string"}, + "params": {"input_field": "input", "golden_output_field": "output"}, + }, + "hotpotqa": { + "adapter_class": "llama_prompt_ops.datasets.hotpotqa.adapter.HotPotQAAdapter", + "description": "Multi-hop reasoning dataset for question answering", + "example_fields": { + "question": "string", + "answer": "string", + "context": "array", + }, + "params": { + "input_field": "question", + "golden_output_field": "answer", + "context_field": "context", + }, + }, + "facility": { + "adapter_class": "llama_prompt_ops.core.datasets.ConfigurableJSONAdapter", + "description": "Facility support and maintenance dataset with nested field structure", + "example_fields": {"fields": "object", "answer": "string"}, + "params": { + # Nested path for facility dataset + "input_field": ["fields", "input"], + "golden_output_field": "answer", + }, + }, +} + +# Available optimization strategies from llama-prompt-ops +STRATEGY_MAPPING = { + "Basic": "basic", +} + +# System message for OpenRouter operations +ENHANCE_SYSTEM_MESSAGE = """ + You are a highly advanced language model, capable of complex reasoning and problem-solving. + Your goal is to provide accurate and informative responses to the given input, following a structured approach. + Here is the input you'll work with: + + {{USER_INPUT}} + + To accomplish this, follow these steps: + Understand the Task: Carefully read and comprehend the input, identifying the key elements and requirements. + Break Down the Problem: Decompose the task into smaller, manageable sub-problems, using a chain-of-thought (CoT) approach. + Gather Relevant Information: If necessary, use external knowledge sources to gather relevant information and provide provenance for your answers. + Apply Reasoning and Logic: Apply step-by-step reasoning and logical thinking to arrive at a solution, using self-ask prompting to guide your thought process. + Evaluate and Refine: Evaluate your solution, refining it as needed to ensure accuracy and completeness. + Your output must follow these guidelines: + Clear and Concise: Provide clear and concise responses, avoiding ambiguity and jargon. + Well-Structured: Use a well-structured format for your response, including headings and bullet points as needed. + Accurate and Informative: Ensure that your response is accurate and informative, providing relevant details and examples. + Format your final answer inside tags and do not include any of your internal reasoning. + + ...your response... + + Chain of Thought (CoT) Template + To facilitate CoT, use the following template: + Step 1: Identify the key elements and requirements of the task. + Sub-question: What are the essential components of the task? + Answer: [Provide a brief answer] + Step 2: Break down the problem into smaller sub-problems. + Sub-question: How can I decompose the task into manageable parts? + Answer: [Provide a brief answer] + Step 3: Gather relevant information and apply reasoning and logic. + Sub-question: What information do I need to solve the task, and how can I apply logical thinking? + Answer: [Provide a brief answer] + Step 4: Evaluate and refine the solution. + Sub-question: Is my solution accurate and complete, and how can I refine it? + Answer: [Provide a brief answer] + By following this structured approach, you will be able to provide accurate and informative responses to the given input, demonstrating your ability to think critically and solve complex problems.""" + +# Environment settings +OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") +UPLOAD_DIR = os.path.join(os.path.dirname(__file__), "uploaded_datasets") diff --git a/frontend/backend/config_transformer.py b/frontend/backend/config_transformer.py new file mode 100644 index 0000000..969c1e3 --- /dev/null +++ b/frontend/backend/config_transformer.py @@ -0,0 +1,725 @@ +""" +Configuration transformer service for converting onboarding wizard data +to llama-prompt-ops YAML configuration format. +""" + +import os +from pathlib import Path +from typing import Any, Dict, List, Union + +import yaml + + +class ConfigurationTransformer: + """ + Transforms onboarding wizard data into llama-prompt-ops compatible YAML configuration. + """ + + # Mapping from frontend metric IDs to backend metric classes + METRIC_ID_MAPPING = { + "exact_match": { + "class": "llama_prompt_ops.core.metrics.ExactMatchMetric", + "default_params": {}, + }, + "semantic_similarity": { + "class": "llama_prompt_ops.core.metrics.DSPyMetricAdapter", + "default_params": {"signature_name": "similarity"}, + }, + "correctness": { + "class": "llama_prompt_ops.core.metrics.DSPyMetricAdapter", + "default_params": {"signature_name": "correctness"}, + }, + "json_structured": { + "class": "llama_prompt_ops.core.metrics.StandardJSONMetric", + "default_params": {}, + }, + "facility_metric": { + "class": "llama_prompt_ops.core.metrics.FacilityMetric", + "default_params": {"strict_json": False}, + }, + } + + # Fallback mapping for legacy use case-based configuration + USE_CASE_FALLBACK_METRICS = { + "qa": "exact_match", + "rag": "semantic_similarity", + "classification": "exact_match", + "summarization": "semantic_similarity", + "extraction": "json_structured", + "custom": "exact_match", + } + + # Mapping from wizard dataset types to adapter configurations + DATASET_FIELD_MAPPING = { + "qa": { + "adapter_class": "llama_prompt_ops.core.datasets.ConfigurableJSONAdapter", + "input_field": "question", + "golden_output_field": "answer", + }, + "rag": { + "adapter_class": "llama_prompt_ops.core.datasets.RAGJSONAdapter", + "question_field": "query", + "context_field": "context", + "golden_answer_field": "answer", + }, + "classification": { + "adapter_class": "llama_prompt_ops.core.datasets.ConfigurableJSONAdapter", + "input_field": "text", + "golden_output_field": "category", + }, + "summarization": { + "adapter_class": "llama_prompt_ops.core.datasets.ConfigurableJSONAdapter", + "input_field": "text", + "golden_output_field": "summary", + }, + "extraction": { + "adapter_class": "llama_prompt_ops.core.datasets.ConfigurableJSONAdapter", + "input_field": "text", + "golden_output_field": "extracted_data", + }, + "custom": { + "adapter_class": "llama_prompt_ops.core.datasets.ConfigurableJSONAdapter", + # No default fields - will be populated from user mappings + }, + } + + def transform( + self, wizard_data: Dict[str, Any], project_name: str = "generated" + ) -> Dict[str, Any]: + """ + Transform onboarding wizard data into YAML configuration. + + Args: + wizard_data: Data collected from the onboarding wizard + project_name: Name for the generated project + + Returns: + Dictionary representing the YAML configuration + """ + config = {} + + # 1. System Prompt Configuration + config["system_prompt"] = self._transform_system_prompt( + wizard_data.get("prompt", {}) + ) + + # 2. Dataset Configuration + config["dataset"] = self._transform_dataset( + wizard_data.get("dataset", {}), wizard_data.get("useCase") + ) + + # 3. Model Configuration + config["model"] = self._transform_model(wizard_data.get("models", {})) + + # 4. Metric Configuration + config["metric"] = self._transform_metric(wizard_data) + + # 5. Optimization Configuration + config["optimization"] = self._transform_optimization( + wizard_data.get("optimizer", {}) + ) + + return config + + def _transform_system_prompt(self, prompt_data: Dict[str, Any]) -> Dict[str, Any]: + """Transform prompt configuration.""" + config = { + "inputs": prompt_data.get("inputs", ["question"]), + "outputs": prompt_data.get("outputs", ["answer"]), + } + + # Always use file reference for better project structure + config["file"] = "prompts/prompt.txt" + + return config + + def _transform_dataset( + self, dataset_data: Dict[str, Any], use_case: str + ) -> Dict[str, Any]: + """Transform dataset configuration with support for custom field mappings.""" + + # Get base configuration for the use case + base_config = self.DATASET_FIELD_MAPPING.get( + use_case, self.DATASET_FIELD_MAPPING["custom"] + ) + dataset_config = {"adapter_class": base_config["adapter_class"]} + + # Always use standard relative path for project structure + dataset_config["path"] = "data/dataset.json" + + # Set default train/validation splits (0.5/0.2 as requested) + dataset_config["train_size"] = dataset_data.get("trainSize", 50) / 100.0 + dataset_config["validation_size"] = ( + dataset_data.get("validationSize", 20) / 100.0 + ) + + # Handle field mappings based on use case + if use_case == "custom": + # For custom use case, handle flexible field mappings + field_mappings = dataset_data.get("fieldMappings", {}) + + if field_mappings: + # Find the most likely input and output fields for ConfigurableJSONAdapter + input_candidates = [ + "question", + "input", + "query", + "prompt", + "text", + "user_input", + ] + output_candidates = [ + "answer", + "output", + "response", + "label", + "target", + "expected_output", + ] + + input_field = None + output_field = None + + # Find input field by checking common field name patterns + for candidate in input_candidates: + if candidate in field_mappings and field_mappings[candidate]: + input_field = field_mappings[candidate] + break + + # Find output field by checking common field name patterns + for candidate in output_candidates: + if candidate in field_mappings and field_mappings[candidate]: + output_field = field_mappings[candidate] + break + + # Set the detected fields + if input_field: + dataset_config["input_field"] = input_field + if output_field: + dataset_config["golden_output_field"] = output_field + + # Store all custom field mappings for advanced use cases + dataset_config["custom_field_mappings"] = field_mappings + else: + # Fallback to base config if no field mappings + dataset_config.update( + {k: v for k, v in base_config.items() if k != "adapter_class"} + ) + + elif use_case == "rag": + # RAG-specific field mapping + field_mappings = dataset_data.get("fieldMappings", {}) + + if field_mappings: + # Map user's field mappings to RAG adapter expected fields + dataset_config["question_field"] = field_mappings.get( + "query", field_mappings.get("question", "query") + ) + dataset_config["context_field"] = field_mappings.get( + "context", "context" + ) + dataset_config["golden_answer_field"] = field_mappings.get( + "answer", "answer" + ) + else: + # Use base config defaults + dataset_config.update( + {k: v for k, v in base_config.items() if k != "adapter_class"} + ) + + else: + # Standard use cases (qa, classification, etc.) + field_mappings = dataset_data.get("fieldMappings", {}) + + if field_mappings: + # For Q&A and other standard use cases + dataset_config["input_field"] = field_mappings.get( + "question", + field_mappings.get( + "input", base_config.get("input_field", "question") + ), + ) + dataset_config["golden_output_field"] = field_mappings.get( + "answer", + field_mappings.get( + "output", base_config.get("golden_output_field", "answer") + ), + ) + else: + # Use base config defaults + dataset_config.update( + {k: v for k, v in base_config.items() if k != "adapter_class"} + ) + + return dataset_config + + def _transform_model(self, model_data: Dict[str, Any]) -> Dict[str, Any]: + """Transform model configuration with full provider support.""" + models = model_data.get("selected", []) + + if not models: + # Default model configuration (no name field, use task_model/proposer_model) + return { + "task_model": "openrouter/meta-llama/llama-3.3-70b-instruct", + "proposer_model": "openrouter/meta-llama/llama-3.3-70b-instruct", + "api_base": "https://openrouter.ai/api/v1", + "temperature": 0.0, + } + + # Get the primary model configuration + primary_model = models[0] + + # Build full model name with prefix + model_prefix = primary_model.get("model_prefix", "") + model_name = primary_model.get("model_name", "") + full_model_name = f"{model_prefix}{model_name}" if model_prefix else model_name + + # Start with base model configuration (no name field) + model_config = {} + + # Add API base if provided + if primary_model.get("api_base"): + model_config["api_base"] = primary_model["api_base"] + + # Add generation parameters + if "temperature" in primary_model: + model_config["temperature"] = primary_model["temperature"] + if "max_tokens" in primary_model: + model_config["max_tokens"] = primary_model["max_tokens"] + + # Handle multiple models with different roles + target_models = [m for m in models if m.get("role") in ["target", "both"]] + optimizer_models = [m for m in models if m.get("role") in ["optimizer", "both"]] + + if target_models and optimizer_models and len(models) > 1: + # Separate target and optimizer models + target_model = target_models[0] + optimizer_model = optimizer_models[0] + + target_prefix = target_model.get("model_prefix", "") + target_name = target_model.get("model_name", "") + target_full_name = ( + f"{target_prefix}{target_name}" if target_prefix else target_name + ) + + optimizer_prefix = optimizer_model.get("model_prefix", "") + optimizer_name = optimizer_model.get("model_name", "") + optimizer_full_name = ( + f"{optimizer_prefix}{optimizer_name}" + if optimizer_prefix + else optimizer_name + ) + + # Only set task_model and proposer_model (no name field) + model_config.update( + { + "task_model": target_full_name, + "proposer_model": optimizer_full_name, + } + ) + else: + # Single model for both tasks (no name field) + model_config.update( + { + "task_model": full_model_name, + "proposer_model": full_model_name, + } + ) + + return model_config + + def _transform_metric(self, wizard_data: Dict[str, Any]) -> Dict[str, Any]: + """Transform metric configuration using actual selected metrics.""" + + # Get selected metrics from wizard data + selected_metrics = wizard_data.get("metrics", []) + metric_configurations = wizard_data.get("metricConfigurations", {}) + + # If no metrics selected, fall back to use case default + if not selected_metrics: + use_case = wizard_data.get("useCase", "custom") + fallback_metric_id = self.USE_CASE_FALLBACK_METRICS.get( + use_case, "exact_match" + ) + selected_metrics = [fallback_metric_id] + + # Use the first selected metric (for now, could be enhanced for multiple metrics) + primary_metric_id = selected_metrics[0] + + # Get metric configuration from mapping + metric_mapping = self.METRIC_ID_MAPPING.get(primary_metric_id) + if not metric_mapping: + # Fallback to exact match if unknown metric + metric_mapping = self.METRIC_ID_MAPPING["exact_match"] + + # Start with base metric configuration + metric_config = {"class": metric_mapping["class"]} + + # Add default parameters for this metric type + metric_config.update(metric_mapping["default_params"]) + + # Add user-configured parameters if available + if primary_metric_id in metric_configurations: + user_config = metric_configurations[primary_metric_id] + metric_config.update(user_config) + + # Determine output field from field mappings + dataset_data = wizard_data.get("dataset", {}) + field_mappings = dataset_data.get("fieldMappings", {}) + + if field_mappings: + # Find the actual output field from field mappings + output_candidates = [ + "answer", + "output", + "response", + "label", + "target", + "expected_output", + ] + + for candidate in output_candidates: + if candidate in field_mappings and field_mappings[candidate]: + metric_config["output_field"] = ( + candidate # Use the target field name + ) + break + + # Ensure output_field is set (fallback to "answer") + if "output_field" not in metric_config: + metric_config["output_field"] = "answer" + + return metric_config + + def _transform_optimization(self, optimizer_data: Dict[str, Any]) -> Dict[str, Any]: + """Transform optimization configuration using only frontend-controlled parameters.""" + strategy_id = optimizer_data.get("selectedOptimizer", "basic") + custom_params = optimizer_data.get("customParams", {}) + + optimization_config = {"strategy": strategy_id} + + # Add frontend-controlled parameters if provided + if custom_params: + # Only include parameters that are actually controlled by the frontend + frontend_controlled_params = { + "num_candidates": "num_candidates", + "max_bootstrapped_demos": "bootstrap_examples", + "max_labeled_demos": "max_labeled_demos", + "num_threads": "num_threads", + "max_errors": "max_errors", + "seed": "seed", + } + + for frontend_key, backend_key in frontend_controlled_params.items(): + if frontend_key in custom_params: + optimization_config[backend_key] = custom_params[frontend_key] + + return optimization_config + + def _extract_environment_variables( + self, wizard_data: Dict[str, Any] + ) -> Dict[str, str]: + """Extract API keys and other sensitive data for .env file.""" + env_vars = {} + + models = wizard_data.get("models", {}).get("selected", []) + + for model in models: + api_key = model.get("api_key") + provider_id = model.get("provider_id") + + if api_key and api_key.strip(): + # Create environment variable name based on provider + env_var_name = f"{provider_id.upper()}_API_KEY" + env_vars[env_var_name] = api_key + + return env_vars + + def generate_yaml_string( + self, wizard_data: Dict[str, Any], project_name: str = "generated" + ) -> str: + """ + Generate YAML configuration string from wizard data. + + Args: + wizard_data: Data collected from the onboarding wizard + project_name: Name for the generated project + + Returns: + YAML configuration as string + """ + config = self.transform(wizard_data, project_name) + return yaml.dump(config, default_flow_style=False, sort_keys=False) + + def save_config_file( + self, + wizard_data: Dict[str, Any], + output_path: str, + project_name: str = "generated", + ) -> str: + """ + Save YAML configuration file from wizard data. + + Args: + wizard_data: Data collected from the onboarding wizard + output_path: Path where to save the config file + project_name: Name for the generated project + + Returns: + Path to the saved configuration file + """ + config = self.transform(wizard_data, project_name) + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + with open(output_path, "w") as f: + yaml.dump(config, f, default_flow_style=False, sort_keys=False) + + return output_path + + def create_project_structure( + self, wizard_data: Dict[str, Any], base_dir: str, project_name: str + ) -> Dict[str, str]: + """ + Create complete project structure with all necessary files. + + Args: + wizard_data: Data collected from the onboarding wizard + base_dir: Base directory where to create the project + project_name: Name of the project + + Returns: + Dictionary mapping file types to their created paths + """ + project_dir = os.path.join(base_dir, project_name) + created_files = {} + + # Create directory structure + os.makedirs(project_dir, exist_ok=True) + os.makedirs(os.path.join(project_dir, "data"), exist_ok=True) + os.makedirs(os.path.join(project_dir, "prompts"), exist_ok=True) + os.makedirs(os.path.join(project_dir, "results"), exist_ok=True) + + # 1. Create config.yaml + config_path = os.path.join(project_dir, "config.yaml") + self.save_config_file(wizard_data, config_path, project_name) + created_files["config"] = config_path + + # 2. Create prompt file + prompt_path = os.path.join(project_dir, "prompts", "prompt.txt") + prompt_text = wizard_data.get("prompt", {}).get( + "text", "# Add your system prompt here" + ) + with open(prompt_path, "w") as f: + f.write(prompt_text) + created_files["prompt"] = prompt_path + + # 3. Copy uploaded dataset file to standard location + dataset_path = os.path.join(project_dir, "data", "dataset.json") + + # Get the uploaded dataset path from wizard data + uploaded_dataset_path = wizard_data.get("dataset", {}).get("path") + + if uploaded_dataset_path and os.path.exists(uploaded_dataset_path): + # Copy the actual uploaded file + import shutil + + shutil.copy2(uploaded_dataset_path, dataset_path) + else: + # Fallback to placeholder data if no uploaded file found + placeholder_data = self._create_placeholder_dataset( + wizard_data.get("useCase", "custom") + ) + with open(dataset_path, "w") as f: + import json + + json.dump(placeholder_data, f, indent=2) + + created_files["dataset"] = dataset_path + + # 4. Create .env file with actual API keys + env_path = os.path.join(project_dir, ".env") + env_vars = self._extract_environment_variables(wizard_data) + + env_content = "# API Keys\n" + if env_vars: + for var_name, var_value in env_vars.items(): + env_content += f"{var_name}={var_value}\n" + else: + env_content += "# Add your API keys here\n" + env_content += "# OPENROUTER_API_KEY=your_api_key_here\n" + env_content += "# ANTHROPIC_API_KEY=your_api_key_here\n" + + with open(env_path, "w") as f: + f.write(env_content) + created_files["env"] = env_path + + # 5. Create README.md + readme_path = os.path.join(project_dir, "README.md") + readme_content = self._create_readme(project_name, wizard_data) + with open(readme_path, "w") as f: + f.write(readme_content) + created_files["readme"] = readme_path + + return created_files + + def _create_placeholder_dataset(self, use_case: str) -> List[Dict[str, Any]]: + """Create a placeholder dataset based on use case.""" + datasets = { + "qa": [ + { + "question": "What is artificial intelligence?", + "answer": "AI is the simulation of human intelligence in machines.", + }, + { + "question": "How does machine learning work?", + "answer": "Machine learning uses algorithms to learn patterns from data.", + }, + { + "question": "What is deep learning?", + "answer": "Deep learning uses neural networks with multiple layers.", + }, + { + "question": "What are the benefits of AI?", + "answer": "AI can automate tasks, provide insights, and improve efficiency.", + }, + ], + "rag": [ + { + "question": "What is the capital of France?", + "context": "France is a country in Europe. Its capital city is Paris.", + "answer": "Paris", + }, + { + "question": "Who wrote Romeo and Juliet?", + "context": "William Shakespeare was an English playwright and poet. He wrote many famous plays including Romeo and Juliet.", + "answer": "William Shakespeare", + }, + { + "question": "What is photosynthesis?", + "context": "Photosynthesis is the process by which plants use sunlight, water and carbon dioxide to produce oxygen and energy.", + "answer": "The process by which plants convert sunlight into energy", + }, + { + "question": "When was the internet invented?", + "context": "The internet was developed in the late 1960s as ARPANET by DARPA. The World Wide Web was created by Tim Berners-Lee in 1989.", + "answer": "Late 1960s (ARPANET), World Wide Web in 1989", + }, + ], + "classification": [ + { + "text": "I love this product! It works perfectly.", + "category": "positive", + }, + { + "text": "This is terrible quality and broke immediately.", + "category": "negative", + }, + { + "text": "The product is okay, nothing special but functional.", + "category": "neutral", + }, + {"text": "Amazing service and fast delivery!", "category": "positive"}, + ], + "summarization": [ + { + "text": "The meeting covered quarterly sales figures, upcoming product launches, and budget allocations for the next fiscal year. Sales exceeded expectations by 15% and three new products will launch in Q2.", + "summary": "Meeting discussed Q1 sales (15% above target), Q2 product launches, and budget planning.", + }, + { + "text": "Research shows that exercise improves mental health, reduces stress, and increases cognitive function. Regular physical activity releases endorphins and promotes better sleep patterns.", + "summary": "Exercise benefits mental health through endorphin release, stress reduction, and improved cognition and sleep.", + }, + { + "text": "The new software update includes bug fixes, security improvements, and user interface enhancements. Performance has been optimized and several new features added based on user feedback.", + "summary": "Software update includes bug fixes, security improvements, UI enhancements, and performance optimizations.", + }, + { + "text": "Climate change is causing rising sea levels, extreme weather events, and ecosystem disruption. Scientists recommend immediate action to reduce greenhouse gas emissions.", + "summary": "Climate change causes rising seas, extreme weather, and ecosystem damage. Scientists urge immediate emission reductions.", + }, + ], + } + + return datasets.get( + use_case, + [ + {"fields": {"input": "Example input"}, "answer": "Example output"}, + {"fields": {"input": "Another input"}, "answer": "Another output"}, + {"fields": {"input": "Third input"}, "answer": "Third output"}, + {"fields": {"input": "Fourth input"}, "answer": "Fourth output"}, + ], + ) + + def _create_readme(self, project_name: str, wizard_data: Dict[str, Any]) -> str: + """Create README.md content for the project.""" + use_case = wizard_data.get("useCase", "custom") + optimizer = wizard_data.get("optimizer", {}).get("selectedOptimizer", "basic") + + return f"""# {project_name} + +## Project Overview + +This project was generated using the llama-prompt-ops onboarding wizard. + +**Use Case:** {use_case.title()} +**Optimization Strategy:** {optimizer.title()} + +## Project Structure + +``` +{project_name}/ +├── config.yaml # Main configuration file +├── prompts/ +│ └── prompt.txt # System prompt template +├── data/ +│ └── dataset.json # Training dataset +├── results/ # Optimization results +└── .env # Environment variables (API keys) +``` + +## Getting Started + +1. **Set up your API key:** + ```bash + # Edit .env file and add your API key + OPENROUTER_API_KEY=your_actual_api_key_here + ``` + +2. **Install llama-prompt-ops:** + ```bash + pip install llama-prompt-ops + ``` + +3. **Customize your prompt:** + Edit `prompts/prompt.txt` with your actual system prompt. + +4. **Add your dataset:** + Replace the placeholder data in `data/dataset.json` with your actual dataset. + +5. **Run optimization:** + ```bash + prompt-ops migrate --config config.yaml + ``` + +## Configuration + +The `config.yaml` file contains all the settings for your optimization run: + +- **system_prompt**: Path to your prompt file and input/output specifications +- **dataset**: Dataset path and field mappings +- **model**: AI models to use for optimization +- **metric**: Evaluation metric for measuring prompt performance +- **optimization**: Strategy and parameters for optimization + +## Next Steps + +1. Customize the system prompt in `prompts/prompt.txt` +2. Replace placeholder dataset with your actual data +3. Review and adjust configuration parameters +4. Run the optimization and analyze results + +## Support + +For more information, see the [llama-prompt-ops documentation](https://github.com/meta-llama/llama-prompt-ops). +""" diff --git a/frontend/backend/dataset_analyzer.py b/frontend/backend/dataset_analyzer.py new file mode 100644 index 0000000..9ae6415 --- /dev/null +++ b/frontend/backend/dataset_analyzer.py @@ -0,0 +1,491 @@ +""" +Dataset Analysis Service for Dynamic Field Mapping + +This service analyzes uploaded dataset files to detect field structures, +classify field types, and suggest field mappings for different use cases. +""" + +import csv +import json +import logging +from collections import Counter +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import pandas as pd +import yaml + +logger = logging.getLogger(__name__) + + +class FieldInfo: + """Information about a detected field in the dataset.""" + + def __init__( + self, + name: str, + field_type: str, + sample_values: List[Any], + coverage: float = 0.0, # percentage of records that have this field + populated_count: int = 0, # number of records with non-null values + total_count: int = 0, # total number of records analyzed + ): + self.name = name + self.field_type = field_type # 'string', 'array', 'object', 'number', 'boolean' + self.sample_values = sample_values + self.coverage = coverage # 0.0 to 1.0 - what % of records have this field + self.populated_count = populated_count # how many records have this field + self.total_count = total_count # total records analyzed + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "type": self.field_type, + "samples": self.sample_values, + "coverage": self.coverage, + "populated_count": self.populated_count, + "total_count": self.total_count, + } + + +class DatasetAnalyzer: + """Analyzes dataset files and provides field mapping suggestions.""" + + def __init__(self): + pass + + def analyze_file(self, file_path: str, sample_size: int = 10) -> Dict[str, Any]: + """ + Analyze a dataset file and return field information. + + Args: + file_path: Path to the dataset file + sample_size: Number of samples to analyze + + Returns: + Dictionary containing field analysis results + """ + try: + # Load data + data = self._load_data(file_path) + if not data: + return {"error": "Could not load or parse file"} + + # For accurate coverage calculation, we need to analyze the full dataset + # But for performance, we'll analyze a larger sample for coverage + coverage_sample_size = min( + len(data), 100 + ) # Use up to 100 records for coverage + coverage_data = data[:coverage_sample_size] + + # Use smaller sample for field type analysis and sample values + sample_data = data[:sample_size] if len(data) > sample_size else data + + # Analyze fields with full dataset size for accurate coverage + fields = self._analyze_fields(coverage_data, len(data)) + + return { + "total_records": len(data), + "sample_size": len(sample_data), + "fields": [field.to_dict() for field in fields], + "suggestions": {}, # Empty suggestions since we removed classification + "sample_data": sample_data[:3], # Show first 3 records + } + + except Exception as e: + logger.error(f"Error analyzing file {file_path}: {str(e)}") + return {"error": f"Analysis failed: {str(e)}"} + + def _load_data(self, file_path: str) -> List[Dict[str, Any]]: + """Load data from file based on extension.""" + path = Path(file_path) + extension = path.suffix.lower() + + try: + if extension == ".json": + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + return data if isinstance(data, list) else [data] + + elif extension == ".csv": + df = pd.read_csv(file_path) + return df.to_dict("records") + + elif extension in [".yaml", ".yml"]: + with open(file_path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + return data if isinstance(data, list) else [data] + + else: + raise ValueError(f"Unsupported file format: {extension}") + + except Exception as e: + logger.error(f"Error loading file {file_path}: {str(e)}") + return [] + + def _analyze_fields( + self, data: List[Dict[str, Any]], total_dataset_size: int + ) -> List[FieldInfo]: + """Analyze fields in the dataset.""" + if not data: + return [] + + field_info = {} + sample_size = len(data) + + # First pass: collect all field values without tracking coverage + for record in data: + self._extract_fields_recursive(record, field_info, track_coverage=False) + + # Second pass: count how many records have each field (for completeness) + field_completeness = {} + for record in data: + # Get all fields present in this record + record_fields = set() + self._get_record_fields(record, record_fields) + + # Count this record for each field it contains + for field_path in record_fields: + if field_path not in field_completeness: + field_completeness[field_path] = 0 + field_completeness[field_path] += 1 + + # Convert to FieldInfo objects + fields = [] + for field_path, info in field_info.items(): + field_type = self._determine_field_type(info["values"]) + sample_values = self._get_sample_values(info["values"]) + + # Calculate coverage: how many records have this field + populated_count_in_sample = field_completeness.get(field_path, 0) + + # Extrapolate coverage to full dataset + if sample_size > 0: + sample_coverage = populated_count_in_sample / sample_size + estimated_populated_count = int(sample_coverage * total_dataset_size) + coverage = min(sample_coverage, 1.0) # Cap at 100% + else: + estimated_populated_count = 0 + coverage = 0.0 + + fields.append( + FieldInfo( + name=field_path, + field_type=field_type, + sample_values=sample_values, + coverage=coverage, + populated_count=estimated_populated_count, + total_count=total_dataset_size, + ) + ) + + return fields + + def _extract_fields_recursive( + self, + obj: Any, + field_info: Dict[str, Any], + prefix: str = "", + track_coverage: bool = False, + ): + """Recursively extract fields from nested objects for value sampling and type detection.""" + if isinstance(obj, dict): + for key, value in obj.items(): + field_path = f"{prefix}.{key}" if prefix else key + + if field_path not in field_info: + field_info[field_path] = {"values": []} + + if isinstance(value, (dict, list)): + self._extract_fields_recursive( + value, field_info, field_path, track_coverage + ) + else: + # Only add meaningful values to samples + if self._has_meaningful_value(value): + field_info[field_path]["values"].append(value) + + elif isinstance(obj, list) and obj: + # Handle arrays - analyze first few elements + if prefix not in field_info: + field_info[prefix] = {"values": []} + + for i, item in enumerate(obj[:3]): # Sample first 3 array elements + if isinstance(item, dict): + self._extract_fields_recursive( + item, field_info, prefix, track_coverage + ) + else: + # Handle arrays of primitives + field_info[prefix]["values"].extend( + obj[:10] + ) # Sample first 10 items + break + + def _has_meaningful_value(self, value: Any) -> bool: + """Check if a value is meaningful (not None, empty string, or empty collection).""" + if value is None: + return False + if isinstance(value, str) and value.strip() == "": + return False + if isinstance(value, (list, dict)) and len(value) == 0: + return False + return True + + def _get_record_fields(self, obj: Any, field_set: set, prefix: str = ""): + """Get all field paths present in a single record (for completeness calculation).""" + if isinstance(obj, dict): + for key, value in obj.items(): + field_path = f"{prefix}.{key}" if prefix else key + + # Only count this field if it has meaningful value + if self._has_meaningful_value(value): + field_set.add(field_path) + + # Recursively check nested structures + if isinstance(value, (dict, list)): + self._get_record_fields(value, field_set, field_path) + + elif isinstance(obj, list) and obj: + # For arrays, if the array is not empty, count it as present + if prefix: + field_set.add(prefix) + + # Also check elements for nested fields + for item in obj[:3]: # Sample first 3 elements + if isinstance(item, dict): + self._get_record_fields(item, field_set, prefix) + + def _determine_field_type(self, values: List[Any]) -> str: + """Determine the type of a field based on its values.""" + if not values: + return "unknown" + + # Sample values to determine type + sample_values = values[:20] # Look at first 20 values + + type_counts = Counter() + for value in sample_values: + if isinstance(value, str): + type_counts["string"] += 1 + elif isinstance(value, (int, float)): + type_counts["number"] += 1 + elif isinstance(value, bool): + type_counts["boolean"] += 1 + elif isinstance(value, list): + type_counts["array"] += 1 + elif isinstance(value, dict): + type_counts["object"] += 1 + else: + type_counts["unknown"] += 1 + + # Return the most common type + return type_counts.most_common(1)[0][0] if type_counts else "unknown" + + def _get_sample_values(self, values: List[Any], max_samples: int = 5) -> List[Any]: + """Get sample values for display.""" + if not values: + return [] + + # Filter out None values and get diverse samples + filtered_values = [v for v in values if v is not None] + + # Get unique values up to max_samples + unique_values = [] + seen = set() + for value in filtered_values: + if len(unique_values) >= max_samples: + break + + # Convert to string for hashing + value_str = str(value)[:100] # Truncate long values + if value_str not in seen: + seen.add(value_str) + unique_values.append( + value if len(str(value)) <= 100 else str(value)[:100] + "..." + ) + + return unique_values + + def generate_adapter_config( + self, mappings: Dict[str, str], use_case: str + ) -> Dict[str, Any]: + """Generate ConfigurableJSONAdapter configuration from field mappings.""" + config = { + "use_case": use_case, + "mappings": mappings, + "field_mapping": self._get_use_case_field_mapping(use_case, mappings), + } + + if use_case == "qa": + config["input_field"] = mappings.get("question", "question") + config["golden_output_field"] = mappings.get("answer", "answer") + + elif use_case == "rag": + config["question_field"] = mappings.get("query", "query") + config["context_field"] = mappings.get("context", "context") + config["golden_answer_field"] = mappings.get("answer", "answer") + + elif use_case == "custom": + # For custom use case, create a flexible mapping + input_fields = { + k: v + for k, v in mappings.items() + if k not in ["id", "metadata", "answer"] + } + config["input_fields"] = input_fields + config["golden_output_field"] = mappings.get("answer", "answer") + + return config + + def _get_use_case_field_mapping( + self, use_case: str, mappings: Dict[str, str] + ) -> Dict[str, Dict[str, str]]: + """Get field mapping structure for a use case.""" + use_case_mapping = { + "qa": { + "inputs": ["question"], + "outputs": ["answer"], + "metadata": ["id", "metadata"], + }, + "rag": { + "inputs": ["query", "context"], + "outputs": ["answer"], + "metadata": ["id", "metadata"], + }, + "custom": { + "inputs": [], + "outputs": ["answer"], + "metadata": ["id", "metadata"], + }, + } + + field_mapping = use_case_mapping.get(use_case, use_case_mapping["custom"]) + + # Create actual field mapping based on user's mappings + result = {"inputs": {}, "outputs": {}, "metadata": {}} + + for target_field, source_field in mappings.items(): + if target_field in field_mapping["inputs"]: + result["inputs"][target_field] = source_field + elif target_field in field_mapping["outputs"]: + result["outputs"][target_field] = source_field + elif target_field in field_mapping["metadata"]: + result["metadata"][target_field] = source_field + else: + # For custom use case, use smart placement + if use_case == "custom" and target_field not in ["id", "metadata"]: + # For custom, put answer-like fields in outputs, everything else in inputs + if any( + keyword in target_field.lower() + for keyword in ["answer", "response", "output", "result"] + ): + result["outputs"][target_field] = source_field + else: + result["inputs"][target_field] = source_field + else: + result["metadata"][target_field] = source_field + + return result + + def preview_transformation( + self, + file_path: str, + mappings: Dict[str, str], + use_case: str, + sample_size: int = 5, + ) -> Dict[str, Any]: + """Preview how the data will be transformed with given mappings.""" + try: + # Load sample data + data = self._load_data(file_path) + if not data: + return {"error": "Could not load file"} + + sample_data = data[:sample_size] + + # Generate adapter config + adapter_config = self.generate_adapter_config(mappings, use_case) + + # Transform sample data + transformed_data = [] + for record in sample_data: + transformed_record = self._transform_record(record, mappings, use_case) + transformed_data.append(transformed_record) + + return { + "original_data": sample_data, + "transformed_data": transformed_data, + "adapter_config": adapter_config, + } + + except Exception as e: + logger.error(f"Error previewing transformation: {str(e)}") + return {"error": f"Preview failed: {str(e)}"} + + def _transform_record( + self, record: Dict[str, Any], mappings: Dict[str, str], use_case: str + ) -> Dict[str, Any]: + """Transform a single record according to mappings based on use case.""" + transformed = {"inputs": {}, "outputs": {}, "metadata": {}} + + # Define use case field mappings + use_case_mapping = { + "qa": { + "inputs": ["question"], + "outputs": ["answer"], + "metadata": ["id", "metadata"], # everything else goes to metadata + }, + "rag": { + "inputs": ["query", "context"], + "outputs": ["answer"], + "metadata": ["id", "metadata"], + }, + "custom": { + "inputs": [], # For custom, we'll put non-standard fields in inputs + "outputs": ["answer"], # Common output field + "metadata": ["id", "metadata"], + }, + } + + # Get the mapping for this use case + field_mapping = use_case_mapping.get(use_case, use_case_mapping["custom"]) + + # Apply mappings based on use case + for target_field, source_field in mappings.items(): + value = self._get_nested_value(record, source_field) + + if target_field in field_mapping["inputs"]: + transformed["inputs"][target_field] = value + elif target_field in field_mapping["outputs"]: + transformed["outputs"][target_field] = value + elif target_field in field_mapping["metadata"]: + transformed["metadata"][target_field] = value + else: + # For custom use case or unmapped fields, use smart placement + if use_case == "custom" and target_field not in ["id", "metadata"]: + # For custom, put answer-like fields in outputs, everything else in inputs + if any( + keyword in target_field.lower() + for keyword in ["answer", "response", "output", "result"] + ): + transformed["outputs"][target_field] = value + else: + transformed["inputs"][target_field] = value + else: + transformed["metadata"][target_field] = value + + return transformed + + def _get_nested_value(self, obj: Dict[str, Any], field_path: str) -> Any: + """Get value from nested object using dot notation.""" + keys = field_path.split(".") + value = obj + + for key in keys: + if isinstance(value, dict) and key in value: + value = value[key] + else: + return None + + return value diff --git a/frontend/backend/main.py b/frontend/backend/main.py new file mode 100644 index 0000000..38ba2ee --- /dev/null +++ b/frontend/backend/main.py @@ -0,0 +1,223 @@ +""" +FastAPI application setup and configuration. +""" + +import importlib +import logging +import os +import subprocess +import sys +from typing import Any, Dict + +# Import configuration and utilities +from config import ( + DATASET_ADAPTER_MAPPING, + METRIC_MAPPING, + MODEL_MAPPING, + STRATEGY_MAPPING, +) +from dotenv import load_dotenv +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse, PlainTextResponse +from pydantic import BaseModel + +# Import route modules +from routes import datasets, projects, prompts, websockets + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.FileHandler("backend.log"), logging.StreamHandler()], +) +logger = logging.getLogger(__name__) + +# Load environment variables from .env file +load_dotenv() + +# Install required dependencies if missing +required_packages = ["scipy", "llama-prompt-ops==0.0.7"] +try: + for package in required_packages: + try: + # Handle special case for llama-prompt-ops + if "llama-prompt-ops" in package: + module_name = "llama_prompt_ops" + else: + module_name = package.replace("-", "_") + + importlib.import_module(module_name) + print(f"✓ {package} is already installed") + except ImportError: + print(f"Installing {package}...") + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + print(f"✓ {package} installed successfully") +except Exception as e: + print(f"Warning: Failed to install dependencies: {e}") + +# Add llama-prompt-ops to Python path +llama_prompt_ops_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "src") +) +if llama_prompt_ops_path not in sys.path: + sys.path.insert(0, llama_prompt_ops_path) + print(f"Added {llama_prompt_ops_path} to Python path") + +# Try to import core modules +try: + from llama_prompt_ops.core.migrator import PromptMigrator + + print("✓ Successfully imported llama_prompt_ops core modules") + LLAMA_PROMPT_OPS_AVAILABLE = True +except ImportError as e: + print(f"Warning: Could not import llama_prompt_ops: {e}") + print("The /api/migrate-prompt endpoint will fall back to OpenRouter") + LLAMA_PROMPT_OPS_AVAILABLE = False + +# FastAPI Application Setup +app = FastAPI(title="Llama Prompt Ops API") + +# CORS for local development +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["*"], +) + +# Include route modules +app.include_router(datasets.router) +app.include_router(prompts.router) +app.include_router(projects.router) +app.include_router(websockets.router) + + +# Pydantic models for remaining endpoints +class ConfigResponse(BaseModel): + models: Dict[str, str] + metrics: Dict[str, Dict] + dataset_adapters: Dict[str, Dict] + strategies: Dict[str, str] + + +# Remaining endpoints that don't fit in other modules +@app.options("/api/enhance-prompt") +@app.options("/api/migrate-prompt") +@app.options("/api/configurations") +@app.options("/api/datasets/upload") +@app.options("/api/datasets") +@app.options("/api/datasets/{filename}") +@app.options("/api/datasets/analyze/{filename}") +@app.options("/api/datasets/preview-transformation") +@app.options("/api/datasets/save-mapping") +@app.options("/api/quick-start-demo") +@app.options("/api/docs/structure") +@app.options("/docs/{file_path:path}") +async def options_route(): + return {"status": "OK"} + + +@app.get("/api/configurations", response_model=ConfigResponse) +async def get_configurations(): + """Return available configuration options for the frontend.""" + return { + "models": MODEL_MAPPING, + "metrics": METRIC_MAPPING, + "dataset_adapters": DATASET_ADAPTER_MAPPING, + "strategies": STRATEGY_MAPPING, + } + + +@app.get("/docs/{file_path:path}") +async def get_docs_file(file_path: str): + """Serve documentation files from the docs directory.""" + try: + # Construct the path to the docs file + docs_base_path = os.path.join(os.path.dirname(__file__), "..", "..", "docs") + full_path = os.path.join(docs_base_path, file_path) + + # Security check: ensure the path is within the docs directory + real_docs_path = os.path.realpath(docs_base_path) + real_file_path = os.path.realpath(full_path) + if not real_file_path.startswith(real_docs_path): + raise HTTPException(status_code=403, detail="Access denied") + + # Check if file exists + if not os.path.exists(full_path): + raise HTTPException(status_code=404, detail="Documentation file not found") + + # Read and return the file content + with open(full_path, "r", encoding="utf-8") as f: + content = f.read() + + # Return appropriate content type based on file extension + if file_path.endswith(".md"): + return PlainTextResponse(content, media_type="text/markdown") + elif file_path.endswith(".json"): + return PlainTextResponse(content, media_type="application/json") + else: + return PlainTextResponse(content, media_type="text/plain") + + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error reading documentation file: {str(e)}" + ) + + +@app.get("/api/docs/structure") +async def get_docs_structure(): + """Get the structure of the documentation directory.""" + try: + docs_base_path = os.path.join(os.path.dirname(__file__), "..", "..", "docs") + structure = [] + + def scan_directory(path, relative_path=""): + items = [] + if not os.path.exists(path): + return items + + for item in os.listdir(path): + if item.startswith("."): # Skip hidden files + continue + + item_path = os.path.join(path, item) + item_relative_path = ( + os.path.join(relative_path, item) if relative_path else item + ) + + if os.path.isdir(item_path): + # Recursively scan subdirectories + subitems = scan_directory(item_path, item_relative_path) + items.extend(subitems) + elif item.endswith((".md", ".txt")): + # Get file stats + stat = os.stat(item_path) + items.append( + { + "path": item_relative_path.replace(os.sep, "/"), + "name": os.path.splitext(item)[0], + "type": "file", + "size": stat.st_size, + "modified": stat.st_mtime, + } + ) + + return items + + structure = scan_directory(docs_base_path) + + return {"success": True, "structure": structure, "total_files": len(structure)} + + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error scanning docs directory: {str(e)}" + ) + + +# Main +if __name__ == "__main__": + import uvicorn + + uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) diff --git a/frontend/backend/requirements.txt b/frontend/backend/requirements.txt new file mode 100644 index 0000000..82a68a1 --- /dev/null +++ b/frontend/backend/requirements.txt @@ -0,0 +1,19 @@ +# Core API dependencies +fastapi==0.104.0 +uvicorn==0.23.2 +python-dotenv==1.0.0 +httpx==0.25.0 +python-multipart==0.0.20 + +# llama-prompt-ops dependencies +llama-prompt-ops==0.0.7 +scipy>=1.11.0 +importlib-metadata>=6.0.0 +pyyaml>=6.0 + +# Dataset analysis dependencies +pandas>=1.5.0 + +# Testing dependencies +pytest>=7.0.0 +pytest-asyncio>=0.21.0 diff --git a/frontend/backend/routes/__init__.py b/frontend/backend/routes/__init__.py new file mode 100644 index 0000000..a2deb58 --- /dev/null +++ b/frontend/backend/routes/__init__.py @@ -0,0 +1,5 @@ +""" +API route modules for the backend. +""" + +# This file makes the routes directory a Python package diff --git a/frontend/backend/routes/datasets.py b/frontend/backend/routes/datasets.py new file mode 100644 index 0000000..75a6d13 --- /dev/null +++ b/frontend/backend/routes/datasets.py @@ -0,0 +1,245 @@ +""" +Dataset upload, analysis, and management endpoints. +""" + +import json +import logging +import os +from typing import Any, Dict + +from config import UPLOAD_DIR +from dataset_analyzer import DatasetAnalyzer +from fastapi import APIRouter, File, HTTPException, UploadFile +from pydantic import BaseModel +from utils import get_uploaded_datasets + +logger = logging.getLogger(__name__) +router = APIRouter() + +# Initialize dataset analyzer +dataset_analyzer = DatasetAnalyzer() + + +# Pydantic models for this module +class DatasetUploadResponse(BaseModel): + filename: str + path: str + preview: list[Dict[str, Any]] + total_records: int + + +class DatasetListResponse(BaseModel): + datasets: list[Dict[str, Any]] + + +class DatasetAnalysisResponse(BaseModel): + total_records: int + sample_size: int + fields: list[Dict[str, Any]] + suggestions: Dict[str, Any] + sample_data: list[Dict[str, Any]] + error: str = None + + +class FieldMappingRequest(BaseModel): + filename: str + mappings: Dict[str, str] + use_case: str + + +class PreviewTransformationRequest(BaseModel): + filename: str + mappings: Dict[str, str] + use_case: str + + +class PreviewTransformationResponse(BaseModel): + original_data: list[Dict[str, Any]] + transformed_data: list[Dict[str, Any]] + adapter_config: Dict[str, Any] + error: str = None + + +@router.post("/api/datasets/upload", response_model=DatasetUploadResponse) +async def upload_dataset(file: UploadFile = File(...)): + """Upload a dataset file.""" + if not file.filename.endswith(".json"): + raise HTTPException(status_code=400, detail="Only JSON files are supported") + + try: + # Create the directory if it doesn't exist + os.makedirs(UPLOAD_DIR, exist_ok=True) + + # Get list of existing dataset files + existing_files = os.listdir(UPLOAD_DIR) + + # Delete all existing dataset files to enforce one-dataset-at-a-time rule + for existing_file in existing_files: + file_path = os.path.join(UPLOAD_DIR, existing_file) + if os.path.isfile(file_path) and existing_file.endswith(".json"): + os.remove(file_path) + logger.info(f"Deleted existing dataset: {existing_file}") + + # Read and validate JSON + contents = await file.read() + try: + data = json.loads(contents) + except json.JSONDecodeError as e: + logger.error(f"Invalid JSON in uploaded file {file.filename}: {e}") + raise HTTPException(status_code=400, detail=f"Invalid JSON file: {str(e)}") + + # Validate that it's a list of objects + if not isinstance(data, list): + raise HTTPException( + status_code=400, detail="Dataset must be a JSON array of objects" + ) + + if len(data) == 0: + raise HTTPException(status_code=400, detail="Dataset cannot be empty") + + # Validate that each item is an object + for i, item in enumerate(data[:5]): # Check first 5 items + if not isinstance(item, dict): + raise HTTPException( + status_code=400, detail=f"Item {i+1} is not an object" + ) + + logger.info(f"Successfully validated dataset with {len(data)} records") + + # Save to disk + file_path = os.path.join(UPLOAD_DIR, file.filename) + with open(file_path, "wb") as f: + f.write(contents) + + # Get preview + preview = data[:5] if isinstance(data, list) else [] + + return { + "filename": file.filename, + "path": file_path, + "preview": preview, + "total_records": len(data) if isinstance(data, list) else 0, + } + except HTTPException: + # Re-raise HTTPExceptions to preserve status codes + raise + except Exception as e: + logger.error(f"Error uploading dataset: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/api/datasets", response_model=DatasetListResponse) +async def list_datasets(): + """List all uploaded datasets.""" + return {"datasets": get_uploaded_datasets()} + + +@router.delete("/api/datasets/{filename}") +async def delete_dataset(filename: str): + """Delete an uploaded dataset.""" + file_path = os.path.join(UPLOAD_DIR, filename) + + if not os.path.exists(file_path): + raise HTTPException(status_code=404, detail="Dataset not found") + + try: + os.remove(file_path) + return {"message": f"Dataset {filename} deleted successfully"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/api/datasets/analyze/{filename}", response_model=DatasetAnalysisResponse) +async def analyze_dataset(filename: str): + """Analyze a dataset file and return field information with suggested mappings.""" + file_path = os.path.join(UPLOAD_DIR, filename) + + if not os.path.exists(file_path): + raise HTTPException(status_code=404, detail="Dataset not found") + + try: + analysis_result = dataset_analyzer.analyze_file(file_path) + + if "error" in analysis_result: + return DatasetAnalysisResponse( + total_records=0, + sample_size=0, + fields=[], + suggestions={}, + sample_data=[], + error=analysis_result["error"], + ) + + return DatasetAnalysisResponse(**analysis_result) + + except Exception as e: + logger.error(f"Error analyzing dataset {filename}: {str(e)}") + raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}") + + +@router.post( + "/api/datasets/preview-transformation", response_model=PreviewTransformationResponse +) +async def preview_transformation(request: PreviewTransformationRequest): + """Preview how dataset will be transformed with given field mappings.""" + file_path = os.path.join(UPLOAD_DIR, request.filename) + + if not os.path.exists(file_path): + raise HTTPException(status_code=404, detail="Dataset not found") + + try: + preview_result = dataset_analyzer.preview_transformation( + file_path, request.mappings, request.use_case + ) + + if "error" in preview_result: + return PreviewTransformationResponse( + original_data=[], + transformed_data=[], + adapter_config={}, + error=preview_result["error"], + ) + + return PreviewTransformationResponse(**preview_result) + + except Exception as e: + logger.error(f"Error previewing transformation: {str(e)}") + raise HTTPException(status_code=500, detail=f"Preview failed: {str(e)}") + + +@router.post("/api/datasets/save-mapping") +async def save_field_mapping(request: FieldMappingRequest): + """Save field mapping configuration for a dataset.""" + file_path = os.path.join(UPLOAD_DIR, request.filename) + + if not os.path.exists(file_path): + raise HTTPException(status_code=404, detail="Dataset not found") + + try: + # Generate adapter configuration + adapter_config = dataset_analyzer.generate_adapter_config( + request.mappings, request.use_case + ) + + # Save mapping configuration alongside the dataset + mapping_file = os.path.join(UPLOAD_DIR, f"{request.filename}.mapping.json") + with open(mapping_file, "w") as f: + json.dump( + { + "filename": request.filename, + "use_case": request.use_case, + "mappings": request.mappings, + "adapter_config": adapter_config, + }, + f, + indent=2, + ) + + return { + "message": "Field mapping saved successfully", + "adapter_config": adapter_config, + } + + except Exception as e: + logger.error(f"Error saving field mapping: {str(e)}") + raise HTTPException(status_code=500, detail=f"Failed to save mapping: {str(e)}") diff --git a/frontend/backend/routes/projects.py b/frontend/backend/routes/projects.py new file mode 100644 index 0000000..0430ad9 --- /dev/null +++ b/frontend/backend/routes/projects.py @@ -0,0 +1,218 @@ +""" +Project creation and management endpoints. +""" + +import logging +import os +from typing import Any, Dict + +from config import UPLOAD_DIR +from config_transformer import ConfigurationTransformer +from fastapi import APIRouter, HTTPException +from fastapi.responses import FileResponse +from pydantic import BaseModel +from utils import generate_unique_project_name + +logger = logging.getLogger(__name__) +router = APIRouter() + + +# Pydantic models +class QuickStartResponse(BaseModel): + success: bool + dataset: Dict[str, Any] + prompt: str + config: Dict[str, Any] + message: str + + +@router.post("/api/quick-start-demo", response_model=QuickStartResponse) +async def quick_start_demo(): + """Load the facility support analyzer demo with dataset, prompt, and optimal configuration.""" + try: + # Define paths to demo files + demo_dataset_path = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "..", + "use-cases", + "facility-support-analyzer", + "dataset.json", + ) + demo_prompt_path = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "..", + "use-cases", + "facility-support-analyzer", + "facility_prompt_sys.txt", + ) + + # Check if demo files exist + if not os.path.exists(demo_dataset_path): + raise HTTPException(status_code=404, detail="Demo dataset file not found") + if not os.path.exists(demo_prompt_path): + raise HTTPException(status_code=404, detail="Demo prompt file not found") + + # Ensure upload directory exists + os.makedirs(UPLOAD_DIR, exist_ok=True) + + # Copy demo dataset to uploaded datasets directory + demo_filename = "facility_support_demo.json" + destination_path = os.path.join(UPLOAD_DIR, demo_filename) + + # Remove existing demo file if it exists + if os.path.exists(destination_path): + os.remove(destination_path) + + import shutil + + shutil.copy2(demo_dataset_path, destination_path) + + # Load and validate the dataset + import json + + with open(destination_path, "r", encoding="utf-8") as f: + data = json.load(f) + + if not isinstance(data, list): + raise HTTPException( + status_code=400, detail="Demo dataset must be a JSON array" + ) + + # Create preview (first 3 records) + preview = data[:3] if len(data) > 3 else data + + # Load the demo prompt + with open(demo_prompt_path, "r", encoding="utf-8") as f: + demo_prompt = f.read().strip() + + # Define optimal configuration for facility support analyzer + optimal_config = { + "datasetAdapter": "facility", + "metrics": "Exact Match", # Temporarily use simpler metric for debugging + "model": "Llama 3.3 70B", + "proposer": "Llama 3.3 70B", + "strategy": "Basic", + "useLlamaTips": True, + } + + # Create dataset response + dataset_response = { + "filename": demo_filename, + "path": destination_path, + "preview": preview, + "total_records": len(data), + } + + logger.info(f"Quick start demo loaded successfully: {len(data)} records") + + return QuickStartResponse( + success=True, + dataset=dataset_response, + prompt=demo_prompt, + config=optimal_config, + message=f"Demo loaded successfully! Facility Support Analyzer with {len(data)} sample records ready for optimization.", + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error loading quick start demo: {str(e)}") + raise HTTPException(status_code=500, detail=f"Failed to load demo: {str(e)}") + + +@router.post("/generate-config") +async def generate_config(request: dict): + """Generate YAML configuration from onboarding wizard data.""" + try: + wizard_data = request.get("wizardData", {}) + project_name = request.get("projectName", "generated-project") + save_to_disk = request.get("saveToYaml", False) + + transformer = ConfigurationTransformer() + config_dict = transformer.transform(wizard_data, project_name) + config_yaml = transformer.generate_yaml_string(wizard_data, project_name) + + response = {"success": True, "config": config_dict, "yaml": config_yaml} + + # Optionally save YAML file to disk + if save_to_disk: + uploads_dir = UPLOAD_DIR + os.makedirs(uploads_dir, exist_ok=True) + + yaml_filename = f"{project_name}-config.yaml" + yaml_path = os.path.join(uploads_dir, yaml_filename) + + with open(yaml_path, "w") as f: + f.write(config_yaml) + + response["saved_path"] = yaml_path + response["filename"] = yaml_filename + response["message"] = f"Configuration saved as {yaml_filename}" + + return response + except Exception as e: + return {"success": False, "error": str(e)} + + +@router.post("/create-project") +async def create_project(request: dict): + """Create a complete project structure with config, prompt, and dataset files.""" + try: + wizard_data = request.get("wizardData", {}) + requested_project_name = request.get("projectName", "generated-project") + + # Create project in uploads directory + uploads_dir = UPLOAD_DIR + + # Generate unique project name to avoid conflicts + unique_project_name = generate_unique_project_name( + requested_project_name, uploads_dir + ) + logger.info(f"Requested project name: {requested_project_name}") + logger.info(f"Using unique project name: {unique_project_name}") + + # Fix dataset path to point to actual uploaded file + dataset_info = wizard_data.get("dataset", {}) + if "path" in dataset_info: + dataset_filename = dataset_info["path"] + dataset_absolute_path = os.path.join(uploads_dir, dataset_filename) + wizard_data["dataset"]["path"] = dataset_absolute_path + + transformer = ConfigurationTransformer() + created_files = transformer.create_project_structure( + wizard_data, uploads_dir, unique_project_name + ) + + return { + "success": True, + "projectPath": os.path.join(uploads_dir, unique_project_name), + "createdFiles": created_files, + "message": f"Project '{unique_project_name}' created successfully", + "actualProjectName": unique_project_name, + "requestedProjectName": requested_project_name, + } + except Exception as e: + return {"success": False, "error": str(e)} + + +@router.get("/download-config/{project_name}") +async def download_config(project_name: str): + """Download the config.yaml file for a generated project.""" + try: + uploads_dir = UPLOAD_DIR + config_path = os.path.join(uploads_dir, project_name, "config.yaml") + + if not os.path.exists(config_path): + raise HTTPException(status_code=404, detail="Config file not found") + + return FileResponse( + config_path, + media_type="application/x-yaml", + filename=f"{project_name}-config.yaml", + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/frontend/backend/routes/prompts.py b/frontend/backend/routes/prompts.py new file mode 100644 index 0000000..8281874 --- /dev/null +++ b/frontend/backend/routes/prompts.py @@ -0,0 +1,296 @@ +""" +Prompt enhancement and migration endpoints. +""" + +import logging +import os +import traceback +from typing import Any, Dict, Optional + +from config import ( + DATASET_ADAPTER_MAPPING, + ENHANCE_SYSTEM_MESSAGE, + METRIC_MAPPING, + MODEL_MAPPING, + OPENROUTER_API_KEY, + STRATEGY_MAPPING, +) +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel +from utils import ( + create_openrouter_client, + get_uploaded_datasets, + load_class_dynamically, +) + +logger = logging.getLogger(__name__) +router = APIRouter() + +# Check for llama-prompt-ops availability (copy from main.py) +try: + from llama_prompt_ops.core.datasets import ConfigurableJSONAdapter + from llama_prompt_ops.core.metrics import DSPyMetricAdapter + from llama_prompt_ops.core.migrator import PromptMigrator + from llama_prompt_ops.core.model import setup_model + from llama_prompt_ops.core.model_strategies import LlamaStrategy + + LLAMA_PROMPT_OPS_AVAILABLE = True +except ImportError: + LLAMA_PROMPT_OPS_AVAILABLE = False + + +# Pydantic models +class PromptRequest(BaseModel): + prompt: str + config: Optional[Dict[str, Any]] = None + + +class PromptResponse(BaseModel): + optimizedPrompt: str + + +async def enhance_prompt_with_openrouter( + request: PromptRequest, system_message: str, operation_name: str = "processing" +): + """ + Shared function to enhance prompts using OpenRouter. + """ + config = request.config or {} + + # API key precedence: env > client supplied + api_key = OPENROUTER_API_KEY or config.get("openrouterApiKey") + if not api_key: + raise HTTPException( + status_code=400, + detail="OpenRouter API key missing. Supply via UI or set OPENROUTER_API_KEY environment variable.", + ) + + try: + # Create OpenRouter client with the API key + openrouter_client = create_openrouter_client(api_key) + + # Use the model from config or default to Llama 3.3 70B + model = config.get("model", "meta-llama/llama-3.3-70b-instruct") + if model in MODEL_MAPPING: + model = MODEL_MAPPING[model] + + response = openrouter_client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_message}, + {"role": "user", "content": request.prompt}, + ], + temperature=0.7, + ) + + enhanced_prompt = response.choices[0].message.content.strip() + return {"optimizedPrompt": enhanced_prompt} + + except Exception as e: + print(f"Error in {operation_name}: {e}") + traceback.print_exc() + raise HTTPException( + status_code=500, detail=f"Error {operation_name} prompt: {str(e)}" + ) + + +@router.post("/api/enhance-prompt", response_model=PromptResponse) +async def enhance_prompt(request: PromptRequest): + """Enhance prompt using OpenRouter.""" + return await enhance_prompt_with_openrouter( + request, ENHANCE_SYSTEM_MESSAGE, "enhance" + ) + + +@router.post("/api/migrate-prompt", response_model=PromptResponse) +async def migrate_prompt(request: PromptRequest): + """Run llama-prompt-ops optimization based on frontend config.""" + # Check if llama-prompt-ops is available + if not LLAMA_PROMPT_OPS_AVAILABLE: + # Fall back to OpenRouter for prompt migration + return await enhance_prompt_with_openrouter( + request, ENHANCE_SYSTEM_MESSAGE, "fallback_migrate" + ) + + try: + config = request.config or {} + + # API key precedence: env > client supplied + api_key = os.getenv("OPENROUTER_API_KEY") or config.get("openrouterApiKey") + if not api_key: + raise HTTPException( + status_code=400, + detail="OpenRouter API key missing. Supply via UI or set OPENROUTER_API_KEY.", + ) + + # Set the API key in the environment so all components can access it + if api_key and not os.getenv("OPENROUTER_API_KEY"): + os.environ["OPENROUTER_API_KEY"] = api_key + print("Set OPENROUTER_API_KEY from frontend configuration") + + # Get configuration from request or use defaults + task_model_name = MODEL_MAPPING.get( + config.get("model", "Llama 3.1 8B"), "meta-llama/llama-3.1-8b-instruct" + ) + proposer_model_name = MODEL_MAPPING.get( + config.get("proposer", "Llama 3.1 8B"), "meta-llama/llama-3.1-8b-instruct" + ) + optimization_level = STRATEGY_MAPPING.get( + config.get("strategy", "Basic"), "basic" + ) + use_llama_tips = config.get("useLlamaTips", True) + + # Get dataset adapter configuration + dataset_adapter_name = config.get("datasetAdapter", "facility") + dataset_adapter_cfg = DATASET_ADAPTER_MAPPING.get(dataset_adapter_name) + + if not dataset_adapter_cfg: + raise HTTPException( + status_code=400, + detail=f"Dataset adapter '{dataset_adapter_name}' not found.", + ) + + # Get uploaded dataset info + uploaded_datasets = get_uploaded_datasets() + if not uploaded_datasets: + raise HTTPException( + status_code=400, + detail="No dataset uploaded. Please upload a dataset first.", + ) + + # Use the first (and only) uploaded dataset + dataset_info = uploaded_datasets[0] + + # Use selected metric from configuration + metrics_config = config.get("metrics", "Exact Match") + metric_configurations = config.get("metricConfigurations", {}) + + # Handle new format (array of metrics) vs old format (single metric string) + if isinstance(metrics_config, list) and len(metrics_config) > 0: + metric_name = metrics_config[0] + metric_cfg = METRIC_MAPPING.get(metric_name) + if metric_cfg and metric_name in metric_configurations: + user_config = metric_configurations[metric_name] + metric_cfg = metric_cfg.copy() + metric_cfg["params"] = {**metric_cfg["params"], **user_config} + else: + metric_name = metrics_config + metric_cfg = METRIC_MAPPING.get(metric_name) + + if not metric_cfg: + metric_name = "exact_match" + metric_cfg = { + "class": "llama_prompt_ops.core.metrics.ExactMatchMetric", + "params": {}, + } + + # Handle model configurations + model_configurations = config.get("modelConfigurations", []) + target_model_config = None + optimizer_model_config = None + + if model_configurations: + for model_config in model_configurations: + role = model_config.get("role", "both") + if role in ["target", "both"]: + target_model_config = model_config + if role in ["optimizer", "both"]: + optimizer_model_config = model_config + + if model_config.get("api_key"): + provider_id = model_config["provider_id"] + if provider_id == "openrouter": + os.environ["OPENROUTER_API_KEY"] = model_config["api_key"] + elif provider_id == "together": + os.environ["TOGETHER_API_KEY"] = model_config["api_key"] + + if target_model_config: + target_model_name = f"{target_model_config['provider_id']}/{target_model_config['model_name']}" + if optimizer_model_config: + optimizer_model_name = f"{optimizer_model_config['provider_id']}/{optimizer_model_config['model_name']}" + + # Instantiate components + try: + task_model = setup_model( + model_name=f"openrouter/{task_model_name}", + api_key=api_key, + temperature=0.0, + ) + proposer_model = setup_model( + model_name=f"openrouter/{proposer_model_name}", + api_key=api_key, + temperature=0.7, + ) + + metric_cls = load_class_dynamically(metric_cfg["class"]) + metric_params = metric_cfg.get("params", {}) + if issubclass(metric_cls, DSPyMetricAdapter): + metric = metric_cls(model=task_model, **metric_params) + else: + metric = metric_cls(**metric_params) + + # Create dataset adapter + adapter_cls = load_class_dynamically(dataset_adapter_cfg["adapter_class"]) + adapter_params = dataset_adapter_cfg.get("params", {}).copy() + adapter = adapter_cls(dataset_path=dataset_info["path"], **adapter_params) + + strategy = LlamaStrategy( + model_name=task_model_name, + metric=metric, + auto=optimization_level, + task_model=task_model, + prompt_model=proposer_model, + ) + + migrator = PromptMigrator( + strategy=strategy, task_model=task_model, prompt_model=proposer_model + ) + + # Dataset split + trainset, valset, testset = migrator.load_dataset_with_adapter( + adapter, train_size=0.7, validation_size=0.15 + ) + + # Query adapter for actual field names + sample_data = adapter.adapt()[:1] + if sample_data: + input_fields = list(sample_data[0]["inputs"].keys()) + output_fields = list(sample_data[0]["outputs"].keys()) + else: + input_fields = ["question"] + output_fields = ["answer"] + + # Prepare prompt data + prompt_data = { + "text": request.prompt, + "inputs": input_fields, + "outputs": output_fields, + } + + # Execute optimization + optimized_program = migrator.optimize( + prompt_data, + trainset=trainset, + valset=valset, + testset=testset, + use_llama_tips=use_llama_tips, + ) + + # Extract the optimized prompt + optimized_prompt = optimized_program.signature.instructions + return {"optimizedPrompt": optimized_prompt} + + except Exception as component_error: + print(f"Error during llama-prompt-ops component setup: {component_error}") + traceback.print_exc() + print("Falling back to OpenRouter for prompt migration") + return await enhance_prompt_with_openrouter( + request, ENHANCE_SYSTEM_MESSAGE, "fallback_migrate" + ) + + except Exception as exc: + print(f"Unexpected error in migrate_prompt: {exc}") + traceback.print_exc() + raise HTTPException( + status_code=500, detail=f"Error migrating prompt: {str(exc)}" + ) diff --git a/frontend/backend/routes/websockets.py b/frontend/backend/routes/websockets.py new file mode 100644 index 0000000..568d324 --- /dev/null +++ b/frontend/backend/routes/websockets.py @@ -0,0 +1,229 @@ +""" +WebSocket endpoints for real-time optimization streaming. +""" + +import logging +import os + +import yaml +from config import UPLOAD_DIR +from fastapi import APIRouter, WebSocket, WebSocketDisconnect +from utils import OptimizationManager, load_class_dynamically + +logger = logging.getLogger(__name__) +router = APIRouter() + +# Check for llama-prompt-ops availability +try: + from llama_prompt_ops.core.metrics import DSPyMetricAdapter + from llama_prompt_ops.core.migrator import PromptMigrator + from llama_prompt_ops.core.model import setup_model + from llama_prompt_ops.core.model_strategies import LlamaStrategy + + LLAMA_PROMPT_OPS_AVAILABLE = True +except ImportError: + LLAMA_PROMPT_OPS_AVAILABLE = False + + +@router.websocket("/ws/optimize/{project_name}") +async def optimize_with_streaming(websocket: WebSocket, project_name: str): + """WebSocket endpoint for real-time optimization with streaming logs.""" + await websocket.accept() + + # Initialize optimization manager + manager = OptimizationManager(websocket) + + try: + await manager.send_status("Initializing optimization...", "setup") + + # Set up log streaming to capture all output + manager.setup_log_streaming() + + # Find the project directory + uploads_dir = UPLOAD_DIR + project_path = os.path.join(uploads_dir, project_name) + + if not os.path.exists(project_path): + await manager.send_error(f"Project '{project_name}' not found") + return + + config_path = os.path.join(project_path, "config.yaml") + if not os.path.exists(config_path): + await manager.send_error( + f"Config file not found in project '{project_name}'" + ) + return + + await manager.send_status("Loading configuration...", "config") + + # Load configuration + try: + with open(config_path, "r") as f: + config_dict = yaml.safe_load(f) + except Exception as e: + await manager.send_error(f"Failed to load config: {str(e)}") + return + + await manager.send_status("Setting up models and dataset...", "setup") + + # Check if llama-prompt-ops is available + if not LLAMA_PROMPT_OPS_AVAILABLE: + await manager.send_error( + "llama-prompt-ops is not available. Please install it." + ) + return + + # Get API key from config or environment + api_key = os.getenv("OPENROUTER_API_KEY") + if not api_key: + await manager.send_error( + "API key not found. Please set OPENROUTER_API_KEY environment variable." + ) + return + + # Set API key in environment for all components + os.environ["OPENROUTER_API_KEY"] = api_key + + # Extract configuration + model_config = config_dict.get("model", {}) + dataset_config = config_dict.get("dataset", {}) + metric_config = config_dict.get("metric", {}) + optimization_config = config_dict.get("optimization", {}) + + # Setup models + task_model_name = model_config.get( + "task_model", "openrouter/meta-llama/llama-3.3-70b-instruct" + ) + proposer_model_name = model_config.get("proposer_model", task_model_name) + + await manager.send_progress( + "models", 25, f"Setting up task model: {task_model_name}" + ) + + task_model = setup_model( + model_name=task_model_name, + api_key=api_key, + temperature=model_config.get("temperature", 0.0), + ) + + await manager.send_progress( + "models", 50, f"Setting up proposer model: {proposer_model_name}" + ) + + proposer_model = setup_model( + model_name=proposer_model_name, + api_key=api_key, + temperature=0.7, + ) + + # Setup metric + await manager.send_progress("metric", 75, "Setting up evaluation metric...") + + metric_class_path = metric_config.get( + "class", "llama_prompt_ops.core.metrics.ExactMatchMetric" + ) + metric_cls = load_class_dynamically(metric_class_path) + metric_params = {k: v for k, v in metric_config.items() if k != "class"} + + if issubclass(metric_cls, DSPyMetricAdapter): + metric = metric_cls(model=task_model, **metric_params) + else: + metric = metric_cls(**metric_params) + + # Setup dataset adapter + await manager.send_progress("dataset", 85, "Loading dataset...") + + adapter_class_path = dataset_config.get( + "adapter_class", "llama_prompt_ops.core.datasets.ConfigurableJSONAdapter" + ) + adapter_cls = load_class_dynamically(adapter_class_path) + + dataset_path = os.path.join(project_path, "data", "dataset.json") + adapter_params = { + k: v + for k, v in dataset_config.items() + if k not in ["adapter_class", "path"] + } + adapter = adapter_cls(dataset_path=dataset_path, **adapter_params) + + # Setup strategy and migrator + await manager.send_progress( + "setup", 95, "Initializing optimization strategy..." + ) + + strategy = LlamaStrategy( + model_name=task_model_name, + metric=metric, + auto=optimization_config.get("strategy", "basic"), + task_model=task_model, + prompt_model=proposer_model, + ) + + migrator = PromptMigrator( + strategy=strategy, task_model=task_model, prompt_model=proposer_model + ) + + # Load dataset splits + trainset, valset, testset = migrator.load_dataset_with_adapter( + adapter, + train_size=dataset_config.get("train_size", 0.5), + validation_size=dataset_config.get("validation_size", 0.2), + ) + + # Load prompt + prompt_config = config_dict.get("system_prompt", {}) + prompt_file = prompt_config.get("file") + prompt_text = prompt_config.get("text", "") + + if prompt_file and not prompt_text: + prompt_file_path = os.path.join(project_path, prompt_file) + if os.path.exists(prompt_file_path): + with open(prompt_file_path, "r") as f: + prompt_text = f.read() + + prompt_data = { + "text": prompt_text, + "inputs": prompt_config.get("inputs", ["question"]), + "outputs": prompt_config.get("outputs", ["answer"]), + } + + await manager.send_status("Starting optimization...", "optimize") + await manager.send_progress( + "optimize", 0, "Beginning prompt optimization process..." + ) + + # Run optimization - this will stream all logs automatically + optimized_program = migrator.optimize( + prompt_data, + trainset=trainset, + valset=valset, + testset=testset, + use_llama_tips=optimization_config.get("use_llama_tips", True), + ) + + # Extract results + optimized_prompt = optimized_program.signature.instructions + + await manager.send_result( + { + "success": True, + "optimizedPrompt": optimized_prompt, + "originalPrompt": prompt_text, + "projectName": project_name, + "projectPath": project_path, + "message": "Optimization completed successfully!", + } + ) + + except WebSocketDisconnect: + logger.info(f"WebSocket disconnected during optimization of {project_name}") + except Exception as e: + logger.error(f"Error during optimization: {str(e)}") + try: + await manager.send_error(f"Optimization failed: {str(e)}") + except: + # WebSocket might be closed + pass + finally: + # Always clean up log handlers + manager.cleanup_log_streaming() diff --git a/frontend/backend/test_main.py b/frontend/backend/test_main.py new file mode 100644 index 0000000..d4a3d5c --- /dev/null +++ b/frontend/backend/test_main.py @@ -0,0 +1,153 @@ +""" +Basic tests for the llama-prompt-ops frontend backend +""" + +import json +import os +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient +from main import app + +client = TestClient(app) + + +def test_health_check(): + """Test that the server is running""" + response = client.get("/api/configurations") + assert response.status_code == 200 + data = response.json() + assert "models" in data + assert "metrics" in data + assert "dataset_adapters" in data + assert "strategies" in data + + +def test_configurations_endpoint(): + """Test the configurations endpoint returns expected structure""" + response = client.get("/api/configurations") + assert response.status_code == 200 + + data = response.json() + + # Check models + assert isinstance(data["models"], dict) + assert "Llama 3.3 70B" in data["models"] + + # Check metrics + assert isinstance(data["metrics"], dict) + assert "Exact Match" in data["metrics"] + + # Check dataset adapters + assert isinstance(data["dataset_adapters"], dict) + assert "standard_json" in data["dataset_adapters"] + + # Check strategies + assert isinstance(data["strategies"], dict) + assert "Basic" in data["strategies"] + + +@patch("main.openai_api_key", "test-key") +@patch("main.client") +def test_enhance_prompt_success(mock_client): + """Test successful prompt enhancement""" + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Enhanced prompt" + mock_client.chat.completions.create.return_value = mock_response + + response = client.post("/api/enhance-prompt", json={"prompt": "Test prompt"}) + + assert response.status_code == 200 + data = response.json() + assert "optimizedPrompt" in data + assert data["optimizedPrompt"] == "Enhanced prompt" + + +def test_enhance_prompt_no_api_key(): + """Test enhance prompt without API key""" + with patch("main.openai_api_key", None): + response = client.post("/api/enhance-prompt", json={"prompt": "Test prompt"}) + + assert response.status_code == 500 + assert "OPENAI_API_KEY not configured" in response.json()["detail"] + + +def test_dataset_upload_invalid_json(): + """Test dataset upload with invalid JSON""" + response = client.post( + "/api/datasets/upload", + files={"file": ("test.json", "invalid json", "application/json")}, + ) + + assert response.status_code == 400 + assert "Invalid JSON" in response.json()["detail"] + + +def test_dataset_upload_non_array(): + """Test dataset upload with non-array JSON""" + response = client.post( + "/api/datasets/upload", + files={"file": ("test.json", '{"key": "value"}', "application/json")}, + ) + + assert response.status_code == 400 + assert "must be a JSON array" in response.json()["detail"] + + +def test_dataset_upload_empty_array(): + """Test dataset upload with empty array""" + response = client.post( + "/api/datasets/upload", files={"file": ("test.json", "[]", "application/json")} + ) + + assert response.status_code == 400 + assert "cannot be empty" in response.json()["detail"] + + +def test_dataset_upload_success(): + """Test successful dataset upload""" + test_data = [ + {"question": "What is AI?", "answer": "Artificial Intelligence"}, + {"question": "What is ML?", "answer": "Machine Learning"}, + ] + + response = client.post( + "/api/datasets/upload", + files={"file": ("test.json", json.dumps(test_data), "application/json")}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["filename"] == "test.json" + assert data["total_records"] == 2 + assert len(data["preview"]) == 2 + + +def test_list_datasets(): + """Test listing datasets""" + response = client.get("/api/datasets") + assert response.status_code == 200 + data = response.json() + assert "datasets" in data + assert isinstance(data["datasets"], list) + + +def test_options_endpoints(): + """Test CORS preflight requests""" + endpoints = [ + "/api/enhance-prompt", + "/api/migrate-prompt", + "/api/configurations", + "/api/datasets/upload", + "/api/datasets", + ] + + for endpoint in endpoints: + response = client.options(endpoint) + assert response.status_code == 200 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/frontend/backend/utils.py b/frontend/backend/utils.py new file mode 100644 index 0000000..e80953d --- /dev/null +++ b/frontend/backend/utils.py @@ -0,0 +1,200 @@ +""" +Shared utility functions and classes for the backend API. +""" + +import asyncio +import importlib +import json +import logging +import os +import time +from typing import Any, Dict, List + +import openai +from config import OPENROUTER_API_KEY, UPLOAD_DIR +from fastapi import WebSocket + + +def load_class_dynamically(class_path: str): + """Import and return class from dotted path string.""" + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +def create_openrouter_client(api_key: str = None): + """Create OpenRouter client with the provided API key.""" + key_to_use = api_key or OPENROUTER_API_KEY + if not key_to_use: + raise ValueError("OpenRouter API key is required") + + return openai.OpenAI( + api_key=key_to_use, + base_url="https://openrouter.ai/api/v1", + ) + + +def get_uploaded_datasets(): + """Get list of uploaded datasets.""" + datasets = [] + if os.path.exists(UPLOAD_DIR): + for filename in os.listdir(UPLOAD_DIR): + if filename.endswith(".json"): + filepath = os.path.join(UPLOAD_DIR, filename) + try: + with open(filepath, "r") as f: + data = json.load(f) + # Get first few records for preview + preview = data[:3] if isinstance(data, list) else [] + datasets.append( + { + "name": f"Uploaded: {filename}", + "filename": filename, + "path": filepath, + "preview": preview, + "total_records": ( + len(data) if isinstance(data, list) else 0 + ), + } + ) + except Exception as e: + print(f"Error reading dataset {filename}: {e}") + return datasets + + +def generate_unique_project_name(base_name: str, base_dir: str) -> str: + """ + Generate a unique project name by adding incremental suffixes if the project already exists. + + Args: + base_name: The desired project name (e.g., "qa-project-2025-09-15") + base_dir: The directory where projects are created + + Returns: + Unique project name (e.g., "qa-project-2025-09-15-2" if original exists) + """ + project_path = os.path.join(base_dir, base_name) + + # If the base name doesn't exist, use it + if not os.path.exists(project_path): + return base_name + + # Otherwise, find the next available incremental name + counter = 2 + while True: + incremental_name = f"{base_name}-{counter}" + incremental_path = os.path.join(base_dir, incremental_name) + + if not os.path.exists(incremental_path): + return incremental_name + + counter += 1 + + # Safety check to prevent infinite loop (though very unlikely) + if counter > 1000: + # Fallback to timestamp-based naming + timestamp = str(int(time.time())) + return f"{base_name}-{timestamp}" + + +class StreamingLogHandler(logging.Handler): + """Custom log handler that streams log messages to WebSocket clients.""" + + def __init__(self, websocket: WebSocket): + super().__init__() + self.websocket = websocket + self.formatter = logging.Formatter("%(levelname)s - %(name)s - %(message)s") + + def emit(self, record): + """Send log record to WebSocket client.""" + try: + log_entry = self.format(record) + # Create task to send message (non-blocking) only if there's an event loop + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + asyncio.create_task( + self.websocket.send_json( + { + "type": "log", + "message": log_entry, + "level": record.levelname, + "logger": record.name, + "timestamp": record.created, + } + ) + ) + except RuntimeError: + # No event loop running, skip WebSocket logging + pass + except Exception as e: + # Avoid infinite recursion by not logging this error + pass + + +class OptimizationManager: + """Manages the optimization process with real-time streaming.""" + + def __init__(self, websocket: WebSocket): + self.websocket = websocket + self.log_handler = None + + async def send_status(self, message: str, phase: str = None): + """Send status update to client.""" + await self.websocket.send_json( + {"type": "status", "message": message, "phase": phase or "unknown"} + ) + + async def send_progress(self, phase: str, progress: float, message: str): + """Send progress update to client.""" + await self.websocket.send_json( + { + "type": "progress", + "phase": phase, + "progress": progress, + "message": message, + } + ) + + async def send_result(self, result: dict): + """Send final optimization result to client.""" + await self.websocket.send_json({"type": "complete", **result}) + + async def send_error(self, error: str): + """Send error message to client.""" + await self.websocket.send_json({"type": "error", "message": error}) + + def setup_log_streaming(self): + """Set up log handlers to capture all optimization logs.""" + self.log_handler = StreamingLogHandler(self.websocket) + self.log_handler.setLevel(logging.INFO) + + # Add handler to multiple loggers to capture all output + loggers_to_stream = [ + logging.getLogger(), # Root logger + logging.getLogger("prompt_ops"), # llama-prompt-ops logger + logging.getLogger("llama_prompt_ops"), # Alternative logger name + logging.getLogger("dspy"), # DSPy optimization logs + logging.getLogger("LiteLLM"), # LiteLLM API call logs + ] + + for logger in loggers_to_stream: + logger.addHandler(self.log_handler) + + def cleanup_log_streaming(self): + """Clean up log handlers.""" + if self.log_handler: + loggers_to_cleanup = [ + logging.getLogger(), + logging.getLogger("prompt_ops"), + logging.getLogger("llama_prompt_ops"), + logging.getLogger("dspy"), + logging.getLogger("LiteLLM"), + ] + + for logger in loggers_to_cleanup: + try: + logger.removeHandler(self.log_handler) + except ValueError: + # Handler not in logger, ignore + pass diff --git a/frontend/check-deps.js b/frontend/check-deps.js new file mode 100644 index 0000000..8a7ffe4 --- /dev/null +++ b/frontend/check-deps.js @@ -0,0 +1,164 @@ +#!/usr/bin/env node + +/** + * Dependencies checker for llama-prompt-ops frontend + * Checks for required system dependencies and environment setup + */ + +const fs = require('fs'); +const path = require('path'); +const { execSync } = require('child_process'); + +// Colors for output +const colors = { + red: '\x1b[31m', + green: '\x1b[32m', + yellow: '\x1b[33m', + blue: '\x1b[34m', + reset: '\x1b[0m' +}; + +function log(color, message) { + console.log(`${colors[color]}${message}${colors.reset}`); +} + +function checkCommand(command, name) { + try { + execSync(`${command} --version`, { stdio: 'pipe' }); + log('green', `✓ ${name} is available`); + return true; + } catch (error) { + log('red', `✗ ${name} is not available`); + return false; + } +} + +function checkFile(filePath, name) { + if (fs.existsSync(filePath)) { + log('green', `✓ ${name} exists`); + return true; + } else { + log('red', `✗ ${name} missing`); + return false; + } +} + +function checkNodeVersion() { + try { + const version = execSync('node --version', { encoding: 'utf8' }).trim(); + const majorVersion = parseInt(version.slice(1).split('.')[0]); + if (majorVersion >= 18) { + log('green', `✓ Node.js ${version} (>= 18)`); + return true; + } else { + log('red', `✗ Node.js ${version} (< 18)`); + return false; + } + } catch (error) { + log('red', '✗ Node.js not available'); + return false; + } +} + +function checkPythonVersion() { + try { + const version = execSync('python --version', { encoding: 'utf8' }).trim(); + const versionMatch = version.match(/Python (\d+)\.(\d+)/); + if (versionMatch) { + const major = parseInt(versionMatch[1]); + const minor = parseInt(versionMatch[2]); + if (major >= 3 && minor >= 8) { + log('green', `✓ ${version} (>= 3.8)`); + return true; + } else { + log('red', `✗ ${version} (< 3.8)`); + return false; + } + } else { + log('red', '✗ Could not determine Python version'); + return false; + } + } catch (error) { + // Try python3 + try { + const version = execSync('python3 --version', { encoding: 'utf8' }).trim(); + const versionMatch = version.match(/Python (\d+)\.(\d+)/); + if (versionMatch) { + const major = parseInt(versionMatch[1]); + const minor = parseInt(versionMatch[2]); + if (major >= 3 && minor >= 8) { + log('green', `✓ ${version} (>= 3.8)`); + return true; + } else { + log('red', `✗ ${version} (< 3.8)`); + return false; + } + } else { + log('red', '✗ Could not determine Python version'); + return false; + } + } catch (error2) { + log('red', '✗ Python not available'); + return false; + } + } +} + +function main() { + log('blue', '🔍 Checking llama-prompt-ops frontend dependencies...\n'); + + let allChecksPass = true; + + // System dependencies + log('blue', '📋 System Dependencies:'); + allChecksPass &= checkNodeVersion(); + allChecksPass &= checkPythonVersion(); + allChecksPass &= checkCommand('npm', 'npm'); + allChecksPass &= checkCommand('git', 'git'); + console.log(); + + // Project files + log('blue', '📁 Project Files:'); + allChecksPass &= checkFile('package.json', 'package.json'); + allChecksPass &= checkFile('backend/main.py', 'backend/main.py'); + allChecksPass &= checkFile('backend/requirements.txt', 'backend/requirements.txt'); + console.log(); + + // Node modules + log('blue', '📦 Frontend Dependencies:'); + const nodeModulesExists = checkFile('node_modules', 'node_modules'); + if (!nodeModulesExists) { + log('yellow', '💡 Run: npm install'); + allChecksPass = false; + } + console.log(); + + // Python virtual environment + log('blue', '🐍 Backend Dependencies:'); + const venvExists = checkFile('backend/venv', 'Python virtual environment'); + if (!venvExists) { + log('yellow', '💡 Run: cd backend && python -m venv venv'); + allChecksPass = false; + } + console.log(); + + // Environment configuration + log('blue', '⚙️ Environment Configuration:'); + const envExists = checkFile('backend/.env', 'Environment configuration'); + if (!envExists) { + log('yellow', '💡 Create backend/.env with your API keys'); + allChecksPass = false; + } + console.log(); + + // Final result + if (allChecksPass) { + log('green', '✅ All dependencies are ready!'); + log('green', '🚀 Run: ./start-dev.sh or npm run dev'); + } else { + log('red', '❌ Some dependencies are missing.'); + log('yellow', '📖 Check the README.md for setup instructions.'); + } +} + +main(); diff --git a/frontend/components.json b/frontend/components.json new file mode 100644 index 0000000..62e1011 --- /dev/null +++ b/frontend/components.json @@ -0,0 +1,20 @@ +{ + "$schema": "https://ui.shadcn.com/schema.json", + "style": "default", + "rsc": false, + "tsx": true, + "tailwind": { + "config": "tailwind.config.ts", + "css": "src/index.css", + "baseColor": "slate", + "cssVariables": true, + "prefix": "" + }, + "aliases": { + "components": "@/components", + "utils": "@/lib/utils", + "ui": "@/components/ui", + "lib": "@/lib", + "hooks": "@/hooks" + } +} diff --git a/frontend/eslint.config.js b/frontend/eslint.config.js new file mode 100644 index 0000000..e67846f --- /dev/null +++ b/frontend/eslint.config.js @@ -0,0 +1,29 @@ +import js from "@eslint/js"; +import globals from "globals"; +import reactHooks from "eslint-plugin-react-hooks"; +import reactRefresh from "eslint-plugin-react-refresh"; +import tseslint from "typescript-eslint"; + +export default tseslint.config( + { ignores: ["dist"] }, + { + extends: [js.configs.recommended, ...tseslint.configs.recommended], + files: ["**/*.{ts,tsx}"], + languageOptions: { + ecmaVersion: 2020, + globals: globals.browser, + }, + plugins: { + "react-hooks": reactHooks, + "react-refresh": reactRefresh, + }, + rules: { + ...reactHooks.configs.recommended.rules, + "react-refresh/only-export-components": [ + "warn", + { allowConstantExport: true }, + ], + "@typescript-eslint/no-unused-vars": "off", + }, + } +); diff --git a/frontend/index.html b/frontend/index.html new file mode 100644 index 0000000..0adaa52 --- /dev/null +++ b/frontend/index.html @@ -0,0 +1,25 @@ + + + + + + Llama Prompt Ops - Frontend + + + + + + + + + + + + + + + +
+ + + diff --git a/frontend/package.json b/frontend/package.json new file mode 100644 index 0000000..39b7cd0 --- /dev/null +++ b/frontend/package.json @@ -0,0 +1,107 @@ +{ + "name": "llama-prompt-ops-frontend", + "version": "1.0.0", + "description": "Frontend interface for llama-prompt-ops - A comprehensive prompt optimization toolkit", + "private": true, + "type": "module", + "scripts": { + "dev": "vite", + "build": "vite build", + "build:dev": "vite build --mode development", + "lint": "eslint .", + "preview": "vite preview", + "check-deps": "node check-deps.js", + "setup": "npm install && cd backend && python -m venv venv && source venv/bin/activate && pip install -r requirements.txt", + "start": "chmod +x start-dev.sh && ./start-dev.sh", + "backend": "cd backend && source venv/bin/activate && python -m uvicorn main:app --reload --port 8000", + "clean": "rm -rf node_modules backend/venv backend/uploaded_datasets/* && npm cache clean --force" + }, + "dependencies": { + "@hookform/resolvers": "^3.9.0", + "@radix-ui/react-accordion": "^1.2.0", + "@radix-ui/react-alert-dialog": "^1.1.1", + "@radix-ui/react-aspect-ratio": "^1.1.0", + "@radix-ui/react-avatar": "^1.1.0", + "@radix-ui/react-checkbox": "^1.1.1", + "@radix-ui/react-collapsible": "^1.1.0", + "@radix-ui/react-context-menu": "^2.2.1", + "@radix-ui/react-dialog": "^1.1.2", + "@radix-ui/react-dropdown-menu": "^2.1.1", + "@radix-ui/react-hover-card": "^1.1.1", + "@radix-ui/react-label": "^2.1.0", + "@radix-ui/react-menubar": "^1.1.1", + "@radix-ui/react-navigation-menu": "^1.2.0", + "@radix-ui/react-popover": "^1.1.1", + "@radix-ui/react-progress": "^1.1.0", + "@radix-ui/react-radio-group": "^1.2.0", + "@radix-ui/react-scroll-area": "^1.1.0", + "@radix-ui/react-select": "^2.1.1", + "@radix-ui/react-separator": "^1.1.0", + "@radix-ui/react-slider": "^1.2.0", + "@radix-ui/react-slot": "^1.1.0", + "@radix-ui/react-switch": "^1.1.0", + "@radix-ui/react-tabs": "^1.1.0", + "@radix-ui/react-toast": "^1.2.1", + "@radix-ui/react-toggle": "^1.1.0", + "@radix-ui/react-toggle-group": "^1.1.0", + "@radix-ui/react-tooltip": "^1.1.4", + "@tanstack/react-query": "^5.56.2", + "class-variance-authority": "^0.7.1", + "clsx": "^2.1.1", + "cmdk": "^1.0.0", + "date-fns": "^3.6.0", + "embla-carousel-react": "^8.3.0", + "input-otp": "^1.2.4", + "lucide-react": "^0.462.0", + "next-themes": "^0.3.0", + "react": "^18.3.1", + "react-day-picker": "^8.10.1", + "react-dom": "^18.3.1", + "react-hook-form": "^7.53.0", + "react-markdown": "^10.1.0", + "react-resizable-panels": "^2.1.3", + "react-router-dom": "^6.26.2", + "recharts": "^2.12.7", + "rehype-highlight": "^7.0.2", + "rehype-slug": "^6.0.0", + "remark-gfm": "^4.0.1", + "sonner": "^1.5.0", + "tailwind-merge": "^2.5.2", + "tailwindcss-animate": "^1.0.7", + "vaul": "^0.9.3", + "zod": "^3.23.8" + }, + "devDependencies": { + "@eslint/js": "^9.9.0", + "@tailwindcss/typography": "^0.5.15", + "@types/node": "^22.5.5", + "@types/react": "^18.3.3", + "@types/react-dom": "^18.3.0", + "@vitejs/plugin-react-swc": "^3.5.0", + "autoprefixer": "^10.4.20", + "eslint": "^9.9.0", + "eslint-plugin-react-hooks": "^5.1.0-rc.0", + "eslint-plugin-react-refresh": "^0.4.9", + "globals": "^15.9.0", + "postcss": "^8.4.47", + "tailwindcss": "^3.4.11", + "typescript": "^5.5.3", + "typescript-eslint": "^8.0.1", + "vite": "^5.4.1" + }, + "repository": { + "type": "git", + "url": "https://github.com/meta-llama/llama-prompt-ops.git" + }, + "keywords": [ + "llama", + "prompt-optimization", + "ai", + "machine-learning", + "frontend", + "react", + "typescript" + ], + "author": "Meta Llama Team", + "license": "MIT" +} diff --git a/frontend/postcss.config.js b/frontend/postcss.config.js new file mode 100644 index 0000000..2e7af2b --- /dev/null +++ b/frontend/postcss.config.js @@ -0,0 +1,6 @@ +export default { + plugins: { + tailwindcss: {}, + autoprefixer: {}, + }, +} diff --git a/frontend/public/placeholder.svg b/frontend/public/placeholder.svg new file mode 100644 index 0000000..9b13eb6 --- /dev/null +++ b/frontend/public/placeholder.svg @@ -0,0 +1 @@ + diff --git a/frontend/public/robots.txt b/frontend/public/robots.txt new file mode 100644 index 0000000..6018e70 --- /dev/null +++ b/frontend/public/robots.txt @@ -0,0 +1,14 @@ +User-agent: Googlebot +Allow: / + +User-agent: Bingbot +Allow: / + +User-agent: Twitterbot +Allow: / + +User-agent: facebookexternalhit +Allow: / + +User-agent: * +Allow: / diff --git a/frontend/src/App.css b/frontend/src/App.css new file mode 100644 index 0000000..b9d355d --- /dev/null +++ b/frontend/src/App.css @@ -0,0 +1,42 @@ +#root { + max-width: 1280px; + margin: 0 auto; + padding: 2rem; + text-align: center; +} + +.logo { + height: 6em; + padding: 1.5em; + will-change: filter; + transition: filter 300ms; +} +.logo:hover { + filter: drop-shadow(0 0 2em #646cffaa); +} +.logo.react:hover { + filter: drop-shadow(0 0 2em #61dafbaa); +} + +@keyframes logo-spin { + from { + transform: rotate(0deg); + } + to { + transform: rotate(360deg); + } +} + +@media (prefers-reduced-motion: no-preference) { + a:nth-of-type(2) .logo { + animation: logo-spin infinite 20s linear; + } +} + +.card { + padding: 2em; +} + +.read-the-docs { + color: #888; +} diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx new file mode 100644 index 0000000..2b0c1b5 --- /dev/null +++ b/frontend/src/App.tsx @@ -0,0 +1,32 @@ +import { Toaster } from "@/components/ui/toaster"; +import { Toaster as Sonner } from "@/components/ui/sonner"; +import { TooltipProvider } from "@/components/ui/tooltip"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import { BrowserRouter, Routes, Route } from "react-router-dom"; +import { AppProvider } from "./context/AppContext"; +import Index from "./pages/Index"; +import NotFound from "./pages/NotFound"; +import OptimizationGrid from "./pages/OptimizationGrid"; + +const queryClient = new QueryClient(); + +const App = () => ( + + + + + + + + } /> + } /> + {/* ADD ALL CUSTOM ROUTES ABOVE THE CATCH-ALL "*" ROUTE */} + } /> + + + + + +); + +export default App; diff --git a/frontend/src/components/docs/DocsContent.tsx b/frontend/src/components/docs/DocsContent.tsx new file mode 100644 index 0000000..d70d131 --- /dev/null +++ b/frontend/src/components/docs/DocsContent.tsx @@ -0,0 +1,236 @@ +import React, { useState, useEffect } from 'react'; +import ReactMarkdown from 'react-markdown'; +import remarkGfm from 'remark-gfm'; +import rehypeHighlight from 'rehype-highlight'; +import rehypeSlug from 'rehype-slug'; +import { Copy, Check, ExternalLink, Clock, FileText } from 'lucide-react'; +import { Button } from '@/components/ui/button'; +import { DocItem } from './DocsTab'; +import 'highlight.js/styles/github.css'; + +interface DocsContentProps { + doc: DocItem; +} + +export const DocsContent: React.FC = ({ doc }) => { + const [content, setContent] = useState(''); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + const [copySuccess, setCopySuccess] = useState(false); + + useEffect(() => { + const fetchContent = async () => { + setLoading(true); + setError(null); + + try { + // Fetch content from the backend docs endpoint + const response = await fetch(`http://localhost:8000/docs/${doc.path}`); + if (!response.ok) { + throw new Error(`Failed to load ${doc.title}`); + } + const text = await response.text(); + setContent(text); + } catch (err) { + // Fallback content for demo purposes when files aren't available + const fallbackContent = `# ${doc.title} + +This is a demonstration of the documentation system. The actual file \`${doc.path}\` could not be loaded. + +## About This Document + +${doc.description || 'This document provides comprehensive information about the topic.'} + +## Features + +- **Markdown Rendering**: Full support for GitHub Flavored Markdown +- **Syntax Highlighting**: Code blocks with syntax highlighting +- **Responsive Design**: Works great on all screen sizes +- **Search Functionality**: Find what you need quickly + +## Code Example + +\`\`\`javascript +// Example code block +const optimizePrompt = async (prompt, config) => { + const response = await fetch('/api/optimize', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ prompt, config }), + }); + + return response.json(); +}; +\`\`\` + +## Getting Started + +To get started with ${doc.title.toLowerCase()}, follow these steps: + +1. Review the configuration options +2. Prepare your dataset +3. Run the optimization process +4. Analyze the results + +--- + +*This is a demo version. In a real implementation, this content would be loaded from \`docs/${doc.path}\`.*`; + + setContent(fallbackContent); + setError(`Could not load ${doc.path}. Showing demo content instead.`); + } finally { + setLoading(false); + } + }; + + fetchContent(); + }, [doc]); + + const handleCopy = () => { + navigator.clipboard.writeText(content); + setCopySuccess(true); + setTimeout(() => setCopySuccess(false), 2000); + }; + + if (loading) { + return ( +
+
+
+
+
+
+
+
+
+
+
+ ); + } + + if (error) { + return ( +
+
+ +

+ Content Not Available +

+

+ {error} +

+

+ This is a demo version. In a real implementation, the content would be loaded from the actual markdown files. +

+
+
+ ); + } + + return ( +
+ {/* Header */} +
+
+ {doc.icon && } +
+

+ {doc.title} +

+

+ {doc.category} • {doc.path} +

+
+
+ +
+ +
+
+ + {/* Content */} +
+
+ ( +

+ ), + h2: ({node, ...props}) => ( +

+ ), + h3: ({node, ...props}) => ( +

+ ), + p: ({node, ...props}) => ( +

+ ), + a: ({node, ...props}) => ( + + ), + code: ({node, className, children, ...props}) => { + const match = /language-(\w+)/.exec(className || ''); + return match ? ( + + {children} + + ) : ( + + {children} + + ); + }, + pre: ({node, ...props}) => ( +

+              ),
+              blockquote: ({node, ...props}) => (
+                
+ ), + ul: ({node, ...props}) => ( +
    + ), + ol: ({node, ...props}) => ( +
      + ), + li: ({node, ...props}) => ( +
    1. + ), + table: ({node, ...props}) => ( +
      + + + ), + th: ({node, ...props}) => ( +
      + ), + td: ({node, ...props}) => ( + + ), + }} + > + {content} + + + + + ); +}; diff --git a/frontend/src/components/docs/DocsSidebar.tsx b/frontend/src/components/docs/DocsSidebar.tsx new file mode 100644 index 0000000..746f1da --- /dev/null +++ b/frontend/src/components/docs/DocsSidebar.tsx @@ -0,0 +1,138 @@ +import React from 'react'; +import { ChevronDown, ChevronRight, Menu, X } from 'lucide-react'; +import { Button } from '@/components/ui/button'; +import { DocItem } from './DocsTab'; + +interface DocsSidebarProps { + docs: DocItem[]; + categories: string[]; + selectedDoc: DocItem | null; + onSelectDoc: (doc: DocItem) => void; + isOpen: boolean; + onToggle: () => void; +} + +export const DocsSidebar: React.FC = ({ + docs, + categories, + selectedDoc, + onSelectDoc, + isOpen, + onToggle, +}) => { + const [expandedCategories, setExpandedCategories] = React.useState>( + new Set(categories) + ); + + const toggleCategory = (category: string) => { + const newExpanded = new Set(expandedCategories); + if (newExpanded.has(category)) { + newExpanded.delete(category); + } else { + newExpanded.add(category); + } + setExpandedCategories(newExpanded); + }; + + if (!isOpen) { + return ( +
      + +
      + ); + } + + return ( +
      + {/* Header */} +
      +

      Contents

      + +
      + + {/* Navigation */} + + + {/* Footer */} +
      +
      + {docs.length} documents +
      +
      +
      + ); +}; diff --git a/frontend/src/components/docs/DocsTab.tsx b/frontend/src/components/docs/DocsTab.tsx new file mode 100644 index 0000000..3889e39 --- /dev/null +++ b/frontend/src/components/docs/DocsTab.tsx @@ -0,0 +1,227 @@ +import React, { useState, useEffect } from 'react'; +import { Book, Search, FileText, Code, Settings, ChevronRight, ExternalLink } from 'lucide-react'; +import { Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { DocsContent } from './DocsContent'; +import { DocsSidebar } from './DocsSidebar'; + +export interface DocItem { + id: string; + title: string; + path: string; + category: string; + description?: string; + lastModified?: string; + icon?: React.ElementType; +} + +export const DocsTab = () => { + const [selectedDoc, setSelectedDoc] = useState(null); + const [searchQuery, setSearchQuery] = useState(''); + const [sidebarOpen, setSidebarOpen] = useState(true); + + // Load RunLLM widget when component mounts + useEffect(() => { + const loadRunLLMWidget = () => { + // Check if script already exists + if (document.getElementById('runllm-widget-script')) { + return; + } + + const script = document.createElement('script'); + script.type = 'module'; + script.id = 'runllm-widget-script'; + script.src = 'https://widget.runllm.com'; + script.setAttribute('version', 'stable'); + script.setAttribute('crossorigin', 'true'); + script.setAttribute('runllm-keyboard-shortcut', 'Mod+j'); + script.setAttribute('runllm-name', 'llama-prompt-ops Assistant'); + script.setAttribute('runllm-position', 'BOTTOM_RIGHT'); + // RunLLM Assistant ID from https://app.runllm.com/assistant/1149 + script.setAttribute('runllm-assistant-id', '1149'); + script.setAttribute('runllm-theme-color', '#1877f2'); // Facebook blue + script.setAttribute('runllm-floating-button-text', 'Ask AI'); + script.setAttribute('runllm-disclaimer', 'This AI assistant can help you navigate the llama-prompt-ops documentation.'); + script.async = true; + + document.head.appendChild(script); + }; + + loadRunLLMWidget(); + + // Cleanup function to remove the script when component unmounts + return () => { + const script = document.getElementById('runllm-widget-script'); + if (script) { + document.head.removeChild(script); + } + }; + }, []); + + // Sample docs structure - in a real app, this would come from an API + const docsStructure: DocItem[] = [ + { + id: 'getting-started', + title: 'Getting Started', + path: 'README.md', + category: 'Basics', + description: 'Learn the fundamentals of llama-prompt-ops', + icon: Book + }, + { + id: 'metrics-guide', + title: 'Metric Selection Guide', + path: 'metric_selection_guide.md', + category: 'Guides', + description: 'Choose the right metrics for your optimization', + icon: Settings + }, + { + id: 'dataset-adapter', + title: 'Dataset Adapter Guide', + path: 'dataset_adapter_selection_guide.md', + category: 'Guides', + description: 'Configure dataset adapters for different data formats', + icon: FileText + }, + { + id: 'intermediate-guide', + title: 'Facility YAML Configuration', + path: 'intermediate/readme.md', + category: 'Intermediate', + description: 'Advanced YAML configuration options for facility management tasks', + icon: Code + }, + { + id: 'inference-providers', + title: 'Inference Providers', + path: 'inference_providers.md', + category: 'Advanced', + description: 'Configure and use different inference providers', + icon: Code + } + ]; + + const filteredDocs = docsStructure.filter(doc => + doc.title.toLowerCase().includes(searchQuery.toLowerCase()) || + doc.description?.toLowerCase().includes(searchQuery.toLowerCase()) + ); + + const categories = Array.from(new Set(docsStructure.map(doc => doc.category))); + + return ( +
      + {/* Header */} +
      +

      + Documentation +

      +

      + Comprehensive guides, API references, and examples to help you get the most out of llama-prompt-ops. + + 💬 Ask AI for help (Cmd+J) + +

      +
      + + {/* Search Bar */} +
      +
      + + setSearchQuery(e.target.value)} + className="pl-10 h-12 text-lg border-facebook-border focus:border-facebook-blue focus:ring-facebook-blue/20" + /> +
      +
      + + {/* Main Content Area */} +
      + {/* Sidebar */} +
      + setSidebarOpen(!sidebarOpen)} + /> +
      + + {/* Content Area */} +
      + {selectedDoc ? ( + + ) : ( + + )} +
      +
      +
      + ); +}; + +// Overview component showing doc categories when no specific doc is selected +const DocsOverview = ({ docs, onSelectDoc }: { docs: DocItem[], onSelectDoc: (doc: DocItem) => void }) => { + const categories = Array.from(new Set(docs.map(doc => doc.category))); + + return ( +
      +
      + +

      + Welcome to the Documentation +

      +

      + Select a document from the sidebar or browse by category below. +

      +
      + +
      + {categories.map(category => { + const categoryDocs = docs.filter(doc => doc.category === category); + return ( +
      +

      + {category === 'Basics' && } + {category === 'Guides' && } + {category === 'Intermediate' && } + {category === 'Advanced' && } + {category} +

      +
      + {categoryDocs.map(doc => ( + + ))} +
      +
      + ); + })} +
      +
      + ); +}; diff --git a/frontend/src/components/layout/MainContent.tsx b/frontend/src/components/layout/MainContent.tsx new file mode 100644 index 0000000..82b27d2 --- /dev/null +++ b/frontend/src/components/layout/MainContent.tsx @@ -0,0 +1,103 @@ +import React, { useContext } from 'react'; +import { AppContext } from '../../context/AppContext'; +import { PromptInput } from '../optimization/PromptInput'; +import { DocsTab } from '../docs/DocsTab'; +import { Button } from '@/components/ui/button'; +import { Badge } from '@/components/ui/badge'; +import { Lock } from 'lucide-react'; + +export const MainContent = () => { + const { activeMode, setActiveMode, isModeLocked } = useContext(AppContext)!; + + // If in docs mode, show only the docs content + if (activeMode === 'docs') { + return ( +
      + +
      + ); + } + + return ( +
      +
      + {/* Hero Section - Centered */} +
      +

      + Optimize your +
      + + prompt + +

      +
      + + {/* Mode Toggle - Only Migrate and Enhance */} +
      +
      + {/* Container using CSS Grid for equal button widths */} +
      + {/* Sliding indicator with Facebook blue gradient */} +
      + + {/* Lock icon when mode is locked */} + {isModeLocked && ( +
      + +
      + )} + + + + +
      +
      +
      + + {/* Prompt Input - Elevated and Centered */} +
      + +
      +
      +
      + ); +}; diff --git a/frontend/src/components/layout/Sidebar.tsx b/frontend/src/components/layout/Sidebar.tsx new file mode 100644 index 0000000..9dae447 --- /dev/null +++ b/frontend/src/components/layout/Sidebar.tsx @@ -0,0 +1,92 @@ +import React, { useState, useContext } from 'react'; +import { Play, FileText, Book, Github } from 'lucide-react'; +import { Link } from 'react-router-dom'; +import { AppContext } from '../../context/AppContext'; + +export const Sidebar = () => { + const { activeMode, setActiveMode } = useContext(AppContext)!; + + const navItems = [ + { id: 'playground', label: 'Playground', icon: Play, path: '/', mode: 'migrate' }, + { id: 'docs', label: 'Docs', icon: Book, mode: 'docs' }, + { id: 'github', label: 'GitHub', icon: Github, path: 'https://github.com/meta-llama/llama-prompt-ops', external: true }, + ]; + + const handleNavClick = (item: any) => { + if (item.mode) { + setActiveMode(item.mode); + } + }; + + return ( + + ); +}; diff --git a/frontend/src/components/onboarding/FieldMappingInterface.tsx b/frontend/src/components/onboarding/FieldMappingInterface.tsx new file mode 100644 index 0000000..f14758a --- /dev/null +++ b/frontend/src/components/onboarding/FieldMappingInterface.tsx @@ -0,0 +1,563 @@ +import React, { useState, useEffect } from "react"; +import { cn } from "@/lib/utils"; +import { + ArrowRight, + CheckCircle, + AlertCircle, + Eye, + RefreshCw, + FileText, + Database, + Hash, + List, + Settings, +} from "lucide-react"; + +interface FieldInfo { + name: string; + type: string; + samples: any[]; + coverage: number; + populated_count: number; + total_count: number; +} + +interface FieldMappingInterfaceProps { + filename: string; + useCase: string; + onMappingUpdate: (mappings: Record) => void; + className?: string; + existingMappings?: Record; +} + +interface DatasetAnalysis { + total_records: number; + sample_size: number; + fields: FieldInfo[]; + suggestions: Record; + sample_data: any[]; + error?: string; +} + +interface PreviewData { + original_data: any[]; + transformed_data: any[]; + adapter_config: any; + error?: string; +} + +const USE_CASE_REQUIREMENTS = { + qa: { + required: ["question", "answer"], + optional: ["id", "metadata"], + description: "Question-Answer format", + }, + rag: { + required: ["context", "query", "answer"], + optional: ["id", "metadata"], + description: "RAG (Retrieval-Augmented Generation) format", + }, + custom: { + required: [], + optional: [], + description: "Custom configuration", + }, +}; + +const getFieldIcon = (fieldType: string) => { + switch (fieldType) { + case "string": + return ; + case "array": + return ; + case "object": + return ; + case "number": + return ; + default: + return ; + } +}; + + + +// Individual custom field mapping component to prevent input focus loss +const CustomFieldMapping: React.FC<{ + targetField: string; + sourceField: string; + availableFields: FieldInfo[]; + onTargetFieldChange: (oldField: string, newField: string) => void; + onSourceFieldChange: (field: string, value: string) => void; + onRemove: () => void; +}> = ({ + targetField, + sourceField, + availableFields, + onTargetFieldChange, + onSourceFieldChange, + onRemove, +}) => { + const [localTargetField, setLocalTargetField] = useState(targetField); + + // Update local state when prop changes + useEffect(() => { + setLocalTargetField(targetField); + }, [targetField]); + + const handleTargetFieldBlur = () => { + if (localTargetField !== targetField) { + onTargetFieldChange(targetField, localTargetField); + } + }; + + const handleKeyPress = (e: React.KeyboardEvent) => { + if (e.key === "Enter") { + handleTargetFieldBlur(); + } + }; + + return ( +
      +
      + setLocalTargetField(e.target.value)} + onBlur={handleTargetFieldBlur} + onKeyPress={handleKeyPress} + className="flex-1 p-2 border border-gray-300 rounded-md focus:ring-2 focus:ring-facebook-blue focus:border-transparent" + /> + +
      + + +
      + ); +}; + +export const FieldMappingInterface: React.FC = ({ + filename, + useCase, + onMappingUpdate, + className, + existingMappings = {}, +}) => { + const [analysis, setAnalysis] = useState(null); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + const [mappings, setMappings] = + useState>(existingMappings); + const [showPreview, setShowPreview] = useState(false); + const [previewData, setPreviewData] = useState(null); + const [previewLoading, setPreviewLoading] = useState(false); + + const requirements = + USE_CASE_REQUIREMENTS[useCase as keyof typeof USE_CASE_REQUIREMENTS]; + + useEffect(() => { + analyzeDataset(); + }, [filename]); + + useEffect(() => { + setMappings(existingMappings); + }, [existingMappings]); + + const analyzeDataset = async () => { + try { + setLoading(true); + setError(null); + + const response = await fetch( + `http://localhost:8000/api/datasets/analyze/${filename}`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + } + ); + + if (!response.ok) { + throw new Error("Failed to analyze dataset"); + } + + const data: DatasetAnalysis = await response.json(); + + if (data.error) { + setError(data.error); + return; + } + + setAnalysis(data); + + // No automatic mappings since we removed suggested mappings + // Just use existing mappings + setMappings(existingMappings); + } catch (err) { + setError( + err instanceof Error ? err.message : "Failed to analyze dataset" + ); + } finally { + setLoading(false); + } + }; + + const handleMappingChange = (targetField: string, sourceField: string) => { + const newMappings = { + ...mappings, + [targetField]: sourceField, + }; + setMappings(newMappings); + onMappingUpdate(newMappings); + }; + + const handlePreview = async () => { + if (!canPreview()) return; + + try { + setPreviewLoading(true); + + const response = await fetch( + "http://localhost:8000/api/datasets/preview-transformation", + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + filename, + mappings, + use_case: useCase, + }), + } + ); + + if (!response.ok) { + throw new Error("Failed to preview transformation"); + } + + const data: PreviewData = await response.json(); + setPreviewData(data); + setShowPreview(true); + } catch (err) { + setError( + err instanceof Error ? err.message : "Failed to preview transformation" + ); + } finally { + setPreviewLoading(false); + } + }; + + const canPreview = () => { + if (useCase === "custom") { + // For custom use cases, allow preview if at least one mapping is defined + return ( + Object.keys(mappings).length > 0 && + Object.values(mappings).some((value) => value !== "") + ); + } + return requirements.required.every((field) => mappings[field]); + }; + + + if (loading) { + return ( +
      +
      + +

      Analyzing dataset...

      +
      +
      + ); + } + + if (error) { + return ( +
      +
      + +

      Analysis Error

      +
      +

      {error}

      + +
      + ); + } + + if (!analysis) { + return ( +
      +

      No analysis data available

      +
      + ); + } + + return ( +
      + {/* Header */} +
      +

      Field Mapping

      +

      + To evaluate your dataset correctly, map your dataset's fields to the required fields below. + Check the Completeness percentages to ensure your selected fields have sufficient data coverage. +

      + +
      + Dataset: {filename} + + {analysis.total_records} records + + {analysis.fields.length} fields detected +
      +
      + + {/* Field Mapping */} +
      + {/* Required/Custom Fields */} +
      +

      + {useCase === "custom" ? "Custom Field Mappings" : "Required Fields"} +

      + + {useCase === "custom" ? ( +
      +

      + Create custom field mappings for your dataset. Add as many + mappings as needed for your use case. +

      + + {/* Custom field mapping inputs */} +
      + {Object.entries(mappings).map( + ([targetField, sourceField], index) => ( + { + const newMappings = { ...mappings }; + delete newMappings[oldField]; + if (newField) { + newMappings[newField] = sourceField; + } + setMappings(newMappings); + onMappingUpdate(newMappings); + }} + onSourceFieldChange={(field, value) => { + handleMappingChange(field, value); + }} + onRemove={() => { + const newMappings = { ...mappings }; + delete newMappings[targetField]; + setMappings(newMappings); + onMappingUpdate(newMappings); + }} + /> + ) + )} + + {/* Add new mapping button */} + +
      +
      + ) : ( +
      + {requirements.required.map((requiredField) => ( +
      +
      + + {requiredField} + + + Required + +
      + + + + {mappings[requiredField] && ( +
      + + Mapped to: {mappings[requiredField]} +
      + )} +
      + ))} +
      + )} +
      + +
      +

      + Detected Fields +

      +
      + {analysis.fields.map((field, index) => ( +
      +
      +
      + {getFieldIcon(field.type)} + + {field.name} + + + {field.type} + +
      + {/* Field Completeness Information */} +
      +
      + Completeness: + = 0.9 ? "text-green-600" : + field.coverage >= 0.7 ? "text-yellow-600" : "text-red-600" + )} + > + {Math.round(field.coverage * 100)}% + +
      +
      +
      + + {field.samples.length > 0 && ( +
      +
      Sample values:
      + {field.samples.slice(0, 2).map((sample, i) => ( +
      + {typeof sample === "string" + ? `"${sample}"` + : JSON.stringify(sample)} +
      + ))} +
      + )} +
      + ))} +
      +
      +
      + + {/* Preview Section */} + {showPreview && previewData && ( +
      +

      + Preview Transformation +

      + +
      +
      +

      + Original Data +

      +
      +
      +                  {JSON.stringify(previewData.original_data[0], null, 2)}
      +                
      +
      +
      + +
      +

      + Transformed Data +

      +
      +
      +                  {JSON.stringify(previewData.transformed_data[0], null, 2)}
      +                
      +
      +
      +
      +
      + )} + + {/* Preview Section - Keep preview functionality but remove navigation buttons */} + {canPreview() && ( +
      + +
      + )} +
      + ); +}; diff --git a/frontend/src/components/onboarding/MetricsSelector.tsx b/frontend/src/components/onboarding/MetricsSelector.tsx new file mode 100644 index 0000000..abbd9d1 --- /dev/null +++ b/frontend/src/components/onboarding/MetricsSelector.tsx @@ -0,0 +1,850 @@ +import React, { useState, useEffect } from "react"; +import { cn } from "@/lib/utils"; +import { + Check, + Info, + Settings, + Zap, + Target, + Brain, + FileText, + BarChart3, + AlertCircle, + HelpCircle, +} from "lucide-react"; + +// Helper components for complex parameter types +const ArrayInput: React.FC<{ + value: string[]; + onChange: (value: string[]) => void; + arrayType: "string" | "number"; + placeholder: string; +}> = ({ value, onChange, arrayType, placeholder }) => { + const [inputValue, setInputValue] = useState(value.join(", ")); + + const handleChange = (e: React.ChangeEvent) => { + const newValue = e.target.value; + setInputValue(newValue); + + // Parse comma-separated values + const items = newValue + .split(",") + .map((item) => item.trim()) + .filter((item) => item.length > 0); + + if (arrayType === "number") { + const numbers = items + .map((item) => parseFloat(item)) + .filter((num) => !isNaN(num)); + onChange(numbers.map(String)); + } else { + onChange(items); + } + }; + + return ( + + ); +}; + +const FieldMappingInput: React.FC<{ + value: Record; + onChange: (value: Record) => void; + placeholder: string; +}> = ({ value, onChange, placeholder }) => { + const [inputValue, setInputValue] = useState( + Object.entries(value) + .map(([key, val]) => `${key}: ${val}`) + .join(", ") + ); + + const handleChange = (e: React.ChangeEvent) => { + const newValue = e.target.value; + setInputValue(newValue); + + // Parse field mappings (field_name: weight) + const mappings: Record = {}; + const pairs = newValue + .split(",") + .map((pair) => pair.trim()) + .filter((pair) => pair.length > 0); + + pairs.forEach((pair) => { + const [key, value] = pair.split(":").map((item) => item.trim()); + if (key && value) { + const numValue = parseFloat(value); + if (!isNaN(numValue)) { + mappings[key] = numValue; + } + } + }); + + onChange(mappings); + }; + + return ( + + ); +}; + +interface MetricConfig { + id: string; + name: string; + description: string; + type: "exact" | "semantic" | "structured" | "custom"; + icon: React.ReactNode; + useCases: string[]; + dataRequirements: string[]; + parameters?: { + [key: string]: { + type: + | "boolean" + | "number" + | "string" + | "select" + | "array" + | "object" + | "fieldMapping"; + default: any; + description: string; + options?: string[]; + arrayType?: "string" | "number"; + objectSchema?: { + [key: string]: { + type: "string" | "number" | "boolean"; + required?: boolean; + }; + }; + }; + }; + examples: { + input: string; + output: string; + score: string; + }[]; + pros: string[]; + cons: string[]; + recommendedFor: string[]; +} + +interface MetricsSelectorProps { + useCase: string; + fieldMappings: Record; + selectedMetrics: string[]; + onMetricsChange: ( + metrics: string[], + configurations: Record + ) => void; + className?: string; +} + +const AVAILABLE_METRICS: MetricConfig[] = [ + { + id: "exact_match", + name: "Exact Match", + description: + "Compares predictions to ground truth using exact string matching", + type: "exact", + icon: , + useCases: ["qa", "rag", "custom"], + dataRequirements: ["Clear, unambiguous answers", "Consistent formatting"], + parameters: { + case_sensitive: { + type: "boolean", + default: true, + description: "Whether to perform case-sensitive matching", + }, + strip_whitespace: { + type: "boolean", + default: true, + description: "Whether to strip whitespace before comparing", + }, + }, + examples: [ + { + input: "What is the capital of France?", + output: "Paris", + score: "1.0 (Perfect match)", + }, + { + input: "What is the capital of France?", + output: "paris", + score: "0.0 (Case mismatch)", + }, + ], + pros: [ + "Fast and deterministic", + "Easy to understand", + "No API costs", + "Perfect for factual answers", + ], + cons: [ + "Too strict for complex answers", + "Misses semantically correct variations", + "Sensitive to formatting differences", + ], + recommendedFor: [ + "Factual Q&A", + "Classification tasks", + "Exact value extraction", + ], + }, + { + id: "semantic_similarity", + name: "Semantic Similarity", + description: + "Uses AI to evaluate semantic similarity between prediction and ground truth", + type: "semantic", + icon: , + useCases: ["qa", "rag", "custom"], + dataRequirements: [ + "Natural language answers", + "Conceptual understanding needed", + ], + parameters: { + score_range: { + type: "select", + default: "1-10", + options: ["1-5", "1-10", "0-1"], + description: "Score range for evaluation", + }, + normalize_to: { + type: "select", + default: "0-1", + options: ["0-1", "1-10"], + description: "Range to normalize scores to", + }, + }, + examples: [ + { + input: "What is the capital of France?", + output: "The capital city of France is Paris", + score: "0.95 (High semantic similarity)", + }, + { + input: "What is the capital of France?", + output: "It is located in Europe", + score: "0.3 (Low semantic similarity)", + }, + ], + pros: [ + "Understands meaning, not just words", + "Handles paraphrasing well", + "Good for complex answers", + "More human-like evaluation", + ], + cons: [ + "Requires AI model calls", + "Can be inconsistent", + "Slower than exact match", + "API costs involved", + ], + recommendedFor: ["Complex Q&A", "Summary evaluation", "Creative tasks"], + }, + { + id: "correctness", + name: "Correctness Evaluation", + description: + "AI-powered evaluation focusing on factual correctness rather than similarity", + type: "semantic", + icon: , + useCases: ["qa", "rag", "custom"], + dataRequirements: ["Factual content", "Clear ground truth"], + parameters: { + score_range: { + type: "select", + default: "1-10", + options: ["1-5", "1-10", "0-1"], + description: "Score range for evaluation", + }, + }, + examples: [ + { + input: "When was World War II?", + output: "1939-1945", + score: "1.0 (Factually correct)", + }, + { + input: "When was World War II?", + output: "It was a major global conflict", + score: "0.2 (Correct but incomplete)", + }, + ], + pros: [ + "Focuses on factual accuracy", + "Good for knowledge tasks", + "Handles different phrasings", + "Evaluates completeness", + ], + cons: [ + "Requires AI model calls", + "May miss nuanced correctness", + "API costs involved", + ], + recommendedFor: [ + "Factual Q&A", + "Knowledge retrieval", + "Educational content", + ], + }, + { + id: "json_structured", + name: "Structured JSON", + description: + "Evaluates structured JSON responses by comparing specific fields", + type: "structured", + icon: , + useCases: ["qa", "rag", "custom"], + dataRequirements: ["JSON-formatted responses", "Structured data fields"], + parameters: { + evaluation_mode: { + type: "select", + default: "selected_fields_comparison", + options: ["selected_fields_comparison", "full_json_comparison"], + description: "How to evaluate the JSON structure", + }, + strict_json: { + type: "boolean", + default: false, + description: "Whether to require strict JSON parsing", + }, + output_fields: { + type: "array", + default: [], + arrayType: "string", + description: + "List of fields to evaluate (leave empty to evaluate all fields)", + }, + required_fields: { + type: "array", + default: [], + arrayType: "string", + description: "Fields that must be present for a valid prediction", + }, + field_weights: { + type: "fieldMapping", + default: {}, + description: + "Weight for each field in the evaluation (field_name: weight)", + }, + output_field: { + type: "string", + default: "answer", + description: "Name of the field containing the ground truth output", + }, + }, + examples: [ + { + input: "Categorize this request", + output: '{"category": "urgent", "sentiment": "negative"}', + score: "1.0 (All fields match)", + }, + { + input: "Categorize this request", + output: '{"category": "urgent", "sentiment": "positive"}', + score: "0.5 (Partial field match)", + }, + ], + pros: [ + "Perfect for structured outputs", + "Field-level granular scoring", + "Fast evaluation", + "Configurable field weights", + ], + cons: [ + "Only works with JSON", + "Requires structured ground truth", + "Less flexible for natural language", + ], + recommendedFor: [ + "Classification tasks", + "Structured data extraction", + "API response evaluation", + ], + }, + { + id: "facility_metric", + name: "Facility Categorization", + description: + "Specialized metric for facility support requests with urgency, sentiment, and categories", + type: "custom", + icon: , + useCases: ["qa", "rag", "custom"], + dataRequirements: [ + "JSON responses with urgency, sentiment, and categories fields", + "Facility support request data", + ], + parameters: { + output_field: { + type: "string", + default: "answer", + description: "Name of the field containing the ground truth output", + }, + strict_json: { + type: "boolean", + default: false, + description: + "Whether to require strict JSON parsing (no code block extraction)", + }, + }, + examples: [ + { + input: "Urgent HVAC issue in Building A", + output: + '{"urgency": "high", "sentiment": "frustrated", "categories": {"hvac": true, "maintenance": true}}', + score: "1.0 (Perfect categorization)", + }, + { + input: "Request for new office supplies", + output: + '{"urgency": "low", "sentiment": "neutral", "categories": {"supplies": true, "administrative": true}}', + score: "1.0 (Correct low-priority request)", + }, + ], + pros: [ + "Domain-specific for facility management", + "Evaluates multiple dimensions (urgency, sentiment, categories)", + "Handles boolean category mappings", + "Fast, deterministic scoring", + ], + cons: [ + "Only works with facility-specific JSON format", + "Limited to urgency/sentiment/categories structure", + "Not suitable for other domains", + ], + recommendedFor: [ + "Facility support automation", + "Maintenance request categorization", + "Support ticket prioritization", + ], + }, +]; + +const METRIC_TYPE_INFO = { + exact: { + name: "Exact Matching", + description: "Fast, deterministic comparison", + color: "bg-green-100 text-green-800", + }, + semantic: { + name: "AI-Powered", + description: "Understands meaning and context", + color: "bg-blue-100 text-blue-800", + }, + structured: { + name: "Structured Data", + description: "Perfect for JSON and structured outputs", + color: "bg-purple-100 text-purple-800", + }, + custom: { + name: "Custom Logic", + description: "Domain-specific evaluation", + color: "bg-orange-100 text-orange-800", + }, +}; + +export const MetricsSelector: React.FC = ({ + useCase, + fieldMappings, + selectedMetrics, + onMetricsChange, + className, +}) => { + const [configurations, setConfigurations] = useState>({}); + const [expandedMetric, setExpandedMetric] = useState(null); + const [showAdvanced, setShowAdvanced] = useState(false); + + // Filter metrics based on use case + const availableMetrics = AVAILABLE_METRICS.filter( + (metric) => metric.useCases.includes(useCase) || useCase === "custom" + ); + + // Get recommended metrics based on use case and field mappings + const getRecommendedMetrics = () => { + const hasStructuredFields = Object.keys(fieldMappings).length > 2; + const hasJsonLikeFields = Object.keys(fieldMappings).some( + (key) => + key.includes("category") || + key.includes("sentiment") || + key.includes("urgency") + ); + + const hasFacilityFields = + Object.keys(fieldMappings).some( + (key) => + key.includes("urgency") && + key.includes("sentiment") && + key.includes("categories") + ) || + (Object.keys(fieldMappings).length >= 3 && hasJsonLikeFields); + + if (useCase === "qa") { + return hasJsonLikeFields + ? ["semantic_similarity", "json_structured"] + : ["exact_match", "semantic_similarity"]; + } else if (useCase === "rag") { + return ["correctness", "semantic_similarity"]; + } else { + // Custom use case + if (hasFacilityFields) { + return ["facility_metric", "json_structured"]; + } + return hasStructuredFields ? ["json_structured"] : ["exact_match"]; + } + }; + + const recommendedMetrics = getRecommendedMetrics(); + + const handleMetricToggle = (metricId: string) => { + const newSelectedMetrics = selectedMetrics.includes(metricId) + ? selectedMetrics.filter((id) => id !== metricId) + : [...selectedMetrics, metricId]; + + onMetricsChange(newSelectedMetrics, configurations); + }; + + const handleParameterChange = ( + metricId: string, + paramName: string, + value: any + ) => { + const newConfigurations = { + ...configurations, + [metricId]: { + ...configurations[metricId], + [paramName]: value, + }, + }; + setConfigurations(newConfigurations); + onMetricsChange(selectedMetrics, newConfigurations); + }; + + const MetricCard = ({ metric }: { metric: MetricConfig }) => { + const isSelected = selectedMetrics.includes(metric.id); + const isRecommended = recommendedMetrics.includes(metric.id); + const isExpanded = expandedMetric === metric.id; + const typeInfo = METRIC_TYPE_INFO[metric.type]; + + return ( +
      + {/* Header */} +
      +
      + + +
      +
      + {metric.icon} +

      + {metric.name} +

      + {isRecommended && ( + + Recommended + + )} + + {typeInfo.name} + +
      +

      {metric.description}

      +
      +
      + + +
      + + {/* Expanded Details */} + {isExpanded && ( +
      + {/* Examples */} +
      +

      Examples

      +
      + {metric.examples.map((example, index) => ( +
      +
      + Input: {example.input} +
      +
      + Output: {example.output} +
      +
      + Score: {example.score} +
      +
      + ))} +
      +
      + + {/* Pros and Cons */} +
      +
      +

      + ✓ Advantages +

      +
        + {metric.pros.map((pro, index) => ( +
      • • {pro}
      • + ))} +
      +
      +
      +

      ⚠ Limitations

      +
        + {metric.cons.map((con, index) => ( +
      • • {con}
      • + ))} +
      +
      +
      + + {/* Data Requirements */} +
      +

      + Data Requirements +

      +
        + {metric.dataRequirements.map((req, index) => ( +
      • • {req}
      • + ))} +
      +
      +
      + )} + + {/* Parameters Configuration */} + {isSelected && metric.parameters && ( +
      +

      + + Configuration +

      +
      + {Object.entries(metric.parameters).map(([paramName, param]) => ( +
      + +

      + {param.description} +

      + + {param.type === "boolean" ? ( + + ) : param.type === "select" ? ( + + ) : param.type === "array" ? ( + + handleParameterChange(metric.id, paramName, value) + } + arrayType={param.arrayType || "string"} + placeholder={`Enter ${ + param.arrayType || "string" + } values (comma-separated)`} + /> + ) : param.type === "fieldMapping" ? ( + + handleParameterChange(metric.id, paramName, value) + } + placeholder="field_name: weight (e.g., urgency: 2.0, sentiment: 1.5)" + /> + ) : ( + + handleParameterChange( + metric.id, + paramName, + param.type === "number" + ? parseFloat(e.target.value) + : e.target.value + ) + } + className="w-full p-2 border border-gray-300 rounded-md focus:ring-2 focus:ring-facebook-blue focus:border-transparent" + /> + )} +
      + ))} +
      +
      + )} +
      + ); + }; + + return ( +
      + {/* Header */} +
      +

      + Success Metrics +

      +

      + Choose how to evaluate your optimized prompt's performance +

      +
      + + {/* Smart Recommendations */} + {recommendedMetrics.length > 0 && ( +
      +
      + +

      + Smart Recommendations +

      +
      +

      + Based on your {useCase} use case and field mappings, we recommend: +

      +
      + {recommendedMetrics.map((metricId) => { + const metric = AVAILABLE_METRICS.find((m) => m.id === metricId); + return metric ? ( + + ) : null; + })} +
      +
      + )} + + {/* Metrics Grid */} +
      + {availableMetrics.map((metric) => ( + + ))} +
      + + {/* Selected Summary */} + {selectedMetrics.length > 0 && ( +
      +

      + Selected Metrics Summary +

      +
      + {selectedMetrics.map((metricId) => { + const metric = AVAILABLE_METRICS.find((m) => m.id === metricId); + return metric ? ( +
      + + {metric.name} + + - {metric.description} + +
      + ) : null; + })} +
      +
      + )} +
      + ); +}; diff --git a/frontend/src/components/onboarding/ModelProviderSelector.tsx b/frontend/src/components/onboarding/ModelProviderSelector.tsx new file mode 100644 index 0000000..dd0a3b6 --- /dev/null +++ b/frontend/src/components/onboarding/ModelProviderSelector.tsx @@ -0,0 +1,2073 @@ +import React, { useState, useEffect } from "react"; +import { cn } from "@/lib/utils"; +import { + Check, + Settings, + Zap, + Shield, + Globe, + Server, + Eye, + EyeOff, + AlertCircle, + CheckCircle, + Loader2, + Plus, + Trash2, + HelpCircle, + ExternalLink, + Cpu, + Cloud, + DollarSign, + Clock, + Brain, + Target, + ChevronDown, + Split, + Merge, +} from "lucide-react"; + +interface ProviderConfig { + id: string; + name: string; + description: string; + icon: React.ReactNode; + category: "cloud" | "local" | "enterprise"; + pricing: "free" | "paid" | "usage"; + setup_difficulty: "easy" | "medium" | "hard"; + api_base: string; + model_prefix: string; + popular_models: string[]; + pros: string[]; + cons: string[]; + docs_url: string; + requires_signup: boolean; +} + +interface ModelConfig { + id: string; // Unique identifier for this configuration + provider_id: string; + model_name: string; + role: "target" | "optimizer" | "both"; + api_key?: string; + api_base?: string; + temperature: number; + max_tokens: number; + // Custom provider fields + custom_provider_name?: string; + model_prefix?: string; + auth_method?: "api_key" | "bearer_token" | "custom_headers"; + custom_headers?: Record; +} + +interface ModelProviderSelectorProps { + useCase: string; + fieldMappings: Record; + onConfigurationChange: (configs: ModelConfig[]) => void; +} + +const PROVIDER_CONFIGS: ProviderConfig[] = [ + { + id: "openrouter", + name: "OpenRouter", + description: "Access 200+ models from multiple providers through one API", + icon: , + category: "cloud", + pricing: "usage", + setup_difficulty: "easy", + api_base: "https://openrouter.ai/api/v1", + model_prefix: "openrouter/", + popular_models: [ + "meta-llama/llama-3.1-8b-instruct", + "meta-llama/llama-3.3-70b-instruct", + "anthropic/claude-3.5-sonnet", + "openai/gpt-4o", + ], + pros: ["Huge model selection", "Competitive pricing", "Easy setup"], + cons: ["Usage-based pricing"], + docs_url: "https://openrouter.ai/docs", + requires_signup: true, + }, + + { + id: "vllm", + name: "vLLM (Local)", + description: "Run models locally with fast inference engine", + icon: , + category: "local", + pricing: "free", + setup_difficulty: "medium", + api_base: "http://localhost:8000/v1", + model_prefix: "hosted_vllm/", + popular_models: [ + "meta-llama/Llama-3.1-8B-Instruct", + "microsoft/DialoGPT-medium", + "google/flan-t5-large", + ], + pros: ["Local Inference", "Data privacy", "Full control"], + cons: ["Requires setup", "Hardware dependent", "Local only"], + docs_url: "https://docs.vllm.ai/", + requires_signup: false, + }, + { + id: "nvidia_nim", + name: "NVIDIA NIM", + description: "Optimized containers for NVIDIA GPUs", + icon: , + category: "enterprise", + pricing: "free", + setup_difficulty: "hard", + api_base: "http://localhost:8000/v1", + model_prefix: "openai/", + popular_models: [ + "meta/llama-3.1-8b-instruct", + "microsoft/phi-3-mini-4k-instruct", + "mistralai/mixtral-8x7b-instruct-v0.1", + ], + pros: ["GPU optimized", "Enterprise grade", "High performance"], + cons: ["Requires NVIDIA NIMS", "Complex setup", "Enterprise focused"], + docs_url: "https://docs.nvidia.com/nim/", + requires_signup: true, + }, + { + id: "custom", + name: "Custom Provider", + description: + "Configure your own API endpoint (LiteLLM, Azure AI Studio, etc.)", + icon: , + category: "enterprise", + pricing: "usage", + setup_difficulty: "medium", + api_base: "", + model_prefix: "", + popular_models: [ + "your-model-name", + "azure_ai/command-r-plus", + "azure_ai/mistral-large-latest", + "custom/your-model", + ], + pros: ["Full control", "Any provider", "Custom endpoints"], + cons: ["Requires configuration", "Manual setup", "Provider dependent"], + docs_url: "https://docs.litellm.ai/docs/providers", + requires_signup: false, + }, +]; + +const ROLE_DESCRIPTIONS = { + target: + "🎯 Target Model - The model you're optimizing FOR (where your prompt will be deployed in production)", + optimizer: + "🧠 Optimizer Model - The AI that generates improved prompt variations during optimization", + both: "🔄 Dual Role - Single model handles both optimization and deployment", +}; + +export const ModelProviderSelector: React.FC = ({ + useCase, + fieldMappings, + onConfigurationChange, +}) => { + const [selectedProviders, setSelectedProviders] = useState([]); + const [configurations, setConfigurations] = useState([]); + const [showAdvanced, setShowAdvanced] = useState(false); + const [testingConnections, setTestingConnections] = useState< + Record + >({}); + const [connectionStatus, setConnectionStatus] = useState< + Record + >({}); + const [showApiKeys, setShowApiKeys] = useState>({}); + + // Smart defaults based on use case + useEffect(() => { + if (useCase && selectedProviders.length === 0) { + // Auto-select OpenRouter for all use cases + setSelectedProviders(["openrouter"]); + // Still show advanced options for custom use case + if (useCase === "custom") { + setShowAdvanced(true); + } + } + }, [useCase, selectedProviders.length]); + + // Initialize configurations when providers change + useEffect(() => { + setConfigurations((prevConfigs) => { + // Keep existing configurations for selected providers + const existingConfigs = prevConfigs.filter((config) => + selectedProviders.includes(config.provider_id) + ); + + // Add default configurations for newly selected providers + const newConfigs: ModelConfig[] = [...existingConfigs]; + + selectedProviders.forEach((providerId) => { + const provider = PROVIDER_CONFIGS.find((p) => p.id === providerId); + if (!provider) return; + + // Check if we already have any configuration for this provider + const hasExistingConfig = existingConfigs.some( + (config) => config.provider_id === providerId + ); + + // Only add default configuration if none exists for this provider + if (!hasExistingConfig) { + // Determine the best default role for this provider + const globalTarget = newConfigs.find( + (c) => c.role === "target" || c.role === "both" + ); + const globalOptimizer = newConfigs.find( + (c) => c.role === "optimizer" || c.role === "both" + ); + const globalBoth = newConfigs.find((c) => c.role === "both"); + + let defaultRole: "target" | "optimizer" | "both" = "both"; + + // If this is the first provider and no other configs exist, use "both" + if (newConfigs.length === 0) { + defaultRole = "both"; + } + // If someone already has "both", don't add anything (let user manually add if needed) + else if (globalBoth) { + return; // Skip adding default config + } + // If we need a target and this provider is suitable + else if ( + !globalTarget && + (provider.category === "local" || + provider.category === "enterprise") + ) { + defaultRole = "target"; + } + // If we need an optimizer and this provider is suitable + else if (!globalOptimizer && provider.category === "cloud") { + defaultRole = "optimizer"; + } + // Otherwise, let user choose manually + else { + return; // Skip adding default config + } + + newConfigs.push({ + id: `${providerId}-${Date.now()}`, + provider_id: providerId, + model_name: provider.popular_models[0], + role: defaultRole, + api_base: provider.api_base, + api_key: "", + temperature: 0.0, + max_tokens: 4096, + // Custom provider fields + custom_provider_name: providerId === "custom" ? "" : undefined, + model_prefix: providerId === "custom" ? "" : provider.model_prefix, + auth_method: providerId === "custom" ? "api_key" : undefined, + custom_headers: providerId === "custom" ? {} : undefined, + }); + } + }); + + return newConfigs; + }); + }, [selectedProviders]); + + // Notify parent of configuration changes + useEffect(() => { + onConfigurationChange(configurations); + }, [configurations]); // Remove onConfigurationChange from dependencies to prevent infinite loop + + const handleProviderToggle = (providerId: string) => { + setSelectedProviders((prev) => { + if (prev.includes(providerId)) { + // Remove provider + return prev.filter((id) => id !== providerId); + } else { + // Add provider, but limit to 2 providers max + if (prev.length >= 2) { + alert( + "Maximum 2 providers allowed. You only need one for both roles, or two for separate target/optimizer models." + ); + return prev; + } + return [...prev, providerId]; + } + }); + }; + + const handleConfigChange = ( + configId: string, + field: keyof ModelConfig, + value: any + ) => { + setConfigurations((prevConfigs) => { + const newConfigs = prevConfigs.map((config) => + config.id === configId ? { ...config, [field]: value } : config + ); + return newConfigs; + }); + }; + + const addConfiguration = ( + providerId: string, + role: "target" | "optimizer" + ) => { + const provider = PROVIDER_CONFIGS.find((p) => p.id === providerId); + if (!provider) return; + + // Check if this role is already filled globally + const existingTargetConfig = configurations.find( + (c) => c.role === "target" || c.role === "both" + ); + const existingOptimizerConfig = configurations.find( + (c) => c.role === "optimizer" || c.role === "both" + ); + + if (role === "target" && existingTargetConfig) { + alert( + "Target model is already configured. Remove the existing target configuration first." + ); + return; + } + + if (role === "optimizer" && existingOptimizerConfig) { + alert( + "Optimizer model is already configured. Remove the existing optimizer configuration first." + ); + return; + } + + const newConfig: ModelConfig = { + id: `${providerId}-${role}-${Date.now()}`, + provider_id: providerId, + model_name: provider.popular_models[0], + role: role, + api_base: provider.api_base, + api_key: "", + temperature: 0.0, + max_tokens: 4096, + // Custom provider fields + custom_provider_name: providerId === "custom" ? "" : undefined, + model_prefix: providerId === "custom" ? "" : provider.model_prefix, + auth_method: providerId === "custom" ? "api_key" : undefined, + custom_headers: providerId === "custom" ? {} : undefined, + }; + + setConfigurations((prevConfigs) => [...prevConfigs, newConfig]); + }; + + const removeConfiguration = (configId: string) => { + setConfigurations((prevConfigs) => + prevConfigs.filter((config) => config.id !== configId) + ); + }; + + // Split a "both" configuration into separate target and optimizer configs + const splitConfiguration = (configId: string) => { + const config = configurations.find((c) => c.id === configId); + if (!config || config.role !== "both") return; + + const provider = PROVIDER_CONFIGS.find((p) => p.id === config.provider_id); + if (!provider) return; + + // Create target config + const targetConfig: ModelConfig = { + ...config, + id: `${config.provider_id}-target-${Date.now()}`, + role: "target", + }; + + // Create optimizer config with potentially different model + const optimizerConfig: ModelConfig = { + ...config, + id: `${config.provider_id}-optimizer-${Date.now() + 1}`, + role: "optimizer", + // Default to a more powerful model for optimizer if available + model_name: + provider.popular_models.find( + (model) => + model.includes("claude-3.5") || + model.includes("gpt-4") || + model.includes("70b") + ) || config.model_name, + }; + + setConfigurations((prevConfigs) => + prevConfigs + .map((c) => (c.id === configId ? targetConfig : c)) + .concat([optimizerConfig]) + ); + }; + + // Merge separate target and optimizer configs from same provider into "both" + const mergeConfigurations = (providerId: string) => { + const providerConfigs = configurations.filter( + (c) => c.provider_id === providerId + ); + const targetConfig = providerConfigs.find((c) => c.role === "target"); + const optimizerConfig = providerConfigs.find((c) => c.role === "optimizer"); + + if (!targetConfig || !optimizerConfig) return; + + // Create merged config based on target config + const mergedConfig: ModelConfig = { + ...targetConfig, + role: "both", + }; + + setConfigurations((prevConfigs) => + prevConfigs + .filter((c) => c.id !== targetConfig.id && c.id !== optimizerConfig.id) + .concat([mergedConfig]) + ); + }; + + // Change role of a configuration + const changeConfigRole = ( + configId: string, + newRole: "target" | "optimizer" | "both" + ) => { + // Check if the new role is already taken + const { hasTarget, hasOptimizer } = getRoleStatus(); + + if (newRole === "target" && hasTarget) { + const existingTarget = configurations.find( + (c) => (c.role === "target" || c.role === "both") && c.id !== configId + ); + if (existingTarget) { + alert( + "Target role is already assigned. Remove the existing target configuration first." + ); + return; + } + } + + if (newRole === "optimizer" && hasOptimizer) { + const existingOptimizer = configurations.find( + (c) => + (c.role === "optimizer" || c.role === "both") && c.id !== configId + ); + if (existingOptimizer) { + alert( + "Optimizer role is already assigned. Remove the existing optimizer configuration first." + ); + return; + } + } + + if (newRole === "both" && (hasTarget || hasOptimizer)) { + const existing = configurations.find( + (c) => + (c.role === "target" || + c.role === "optimizer" || + c.role === "both") && + c.id !== configId + ); + if (existing) { + alert( + "Cannot set to 'both' when other roles are already assigned. Remove other configurations first." + ); + return; + } + } + + setConfigurations((prevConfigs) => + prevConfigs.map((config) => + config.id === configId ? { ...config, role: newRole } : config + ) + ); + }; + + // Role Badge Component + const RoleBadge = ({ + role, + configId, + className = "", + }: { + role: "target" | "optimizer" | "both"; + configId: string; + className?: string; + }) => { + const [showDropdown, setShowDropdown] = useState(false); + const dropdownRef = React.useRef(null); + + // Close dropdown when clicking outside + useEffect(() => { + const handleClickOutside = (event: MouseEvent) => { + if ( + dropdownRef.current && + !dropdownRef.current.contains(event.target as Node) + ) { + setShowDropdown(false); + } + }; + + if (showDropdown) { + document.addEventListener("mousedown", handleClickOutside); + return () => + document.removeEventListener("mousedown", handleClickOutside); + } + }, [showDropdown]); + + const getRoleColor = (role: string) => { + switch (role) { + case "target": + return "bg-green-100 text-green-800 border-green-200 hover:bg-green-200"; + case "optimizer": + return "bg-purple-100 text-purple-800 border-purple-200 hover:bg-purple-200"; + case "both": + return "bg-blue-100 text-blue-800 border-blue-200 hover:bg-blue-200"; + default: + return "bg-gray-100 text-gray-800 border-gray-200 hover:bg-gray-200"; + } + }; + + const getRoleIcon = (role: string) => { + switch (role) { + case "target": + return ; + case "optimizer": + return ; + case "both": + return ( +
      + + +
      + ); + default: + return null; + } + }; + + const getRoleLabel = (role: string) => { + switch (role) { + case "target": + return "Target"; + case "optimizer": + return "Optimizer"; + case "both": + return "Target + Optimizer"; + default: + return role; + } + }; + + const availableRoles = ["target", "optimizer", "both"].filter((r) => { + if (r === role) return false; + + const { hasTarget, hasOptimizer } = getRoleStatus(); + if (r === "target" && hasTarget) return false; + if (r === "optimizer" && hasOptimizer) return false; + if (r === "both" && (hasTarget || hasOptimizer)) return false; + + return true; + }); + + return ( +
      + + + {showDropdown && availableRoles.length > 0 && ( +
      +
      + {availableRoles.map((availableRole) => ( + + ))} +
      +
      + )} +
      + ); + }; + + // Helper functions for role management + const getRoleStatus = () => { + const hasTarget = configurations.some( + (c) => c.role === "target" || c.role === "both" + ); + const hasOptimizer = configurations.some( + (c) => c.role === "optimizer" || c.role === "both" + ); + const hasBoth = configurations.some((c) => c.role === "both"); + + return { hasTarget, hasOptimizer, hasBoth }; + }; + + const canAddRole = (providerId: string, role: "target" | "optimizer") => { + const { hasTarget, hasOptimizer, hasBoth } = getRoleStatus(); + const providerConfigs = configurations.filter( + (c) => c.provider_id === providerId + ); + const providerHasRole = providerConfigs.some( + (c) => c.role === role || c.role === "both" + ); + + // If any config has "both" role, can't add anything else + if (hasBoth) return false; + + // If this provider already has this role, can't add another + if (providerHasRole) return false; + + // If role is globally filled, can't add another + if (role === "target" && hasTarget) return false; + if (role === "optimizer" && hasOptimizer) return false; + + return true; + }; + + const getAvailableRoles = (providerId: string) => { + const canAddTarget = canAddRole(providerId, "target"); + const canAddOptimizer = canAddRole(providerId, "optimizer"); + const { hasTarget, hasOptimizer, hasBoth } = getRoleStatus(); + const providerConfigs = configurations.filter( + (c) => c.provider_id === providerId + ); + + // If this provider has no configs and no global "both" exists, can suggest "both" + const canAddBoth = + providerConfigs.length === 0 && !hasBoth && !hasTarget && !hasOptimizer; + + return { canAddTarget, canAddOptimizer, canAddBoth }; + }; + + const testConnection = async (config: ModelConfig) => { + const configKey = `${config.provider_id}-${config.role}-${config.id}`; + setTestingConnections((prev) => ({ ...prev, [configKey]: true })); + + try { + // Basic validation first + const hasApiKey = config.api_key && config.api_key.length > 0; + const provider = PROVIDER_CONFIGS.find( + (p) => p.id === config.provider_id + ); + + // Check if authentication is required + const requiresAuth = + provider?.requires_signup || config.provider_id === "custom"; + + if ( + requiresAuth && + !hasApiKey && + config.auth_method !== "custom_headers" + ) { + throw new Error("Authentication required - please provide an API key"); + } + + // For custom headers, check if headers are provided + if ( + config.auth_method === "custom_headers" && + (!config.custom_headers || + Object.keys(config.custom_headers).length === 0) + ) { + throw new Error("Custom headers required"); + } + + // For custom providers, check if required fields are filled + if (config.provider_id === "custom") { + if ( + !config.api_base || + !config.model_name || + !config.custom_provider_name + ) { + throw new Error( + "Please fill in all required fields for custom provider" + ); + } + } + + // Perform actual API test + await performActualAPITest(config); + + setConnectionStatus((prev) => ({ ...prev, [configKey]: "success" })); + } catch (error) { + console.error("Connection test failed:", error); + setConnectionStatus((prev) => ({ + ...prev, + [configKey]: "error", + })); + + // Show user-friendly error message + const errorMessage = + error instanceof Error ? error.message : "Connection test failed"; + // Don't duplicate "Connection test failed" if the error message is already user-friendly + const isUserFriendlyError = + error instanceof Error && + (error.message.includes("API key") || + error.message.includes("Access denied") || + error.message.includes("Invalid request") || + error.message.includes("Provider server error") || + error.message.includes("Connection failed") || + error.message.includes("Network error") || + error.message.includes("timeout")); + + const displayMessage = isUserFriendlyError + ? errorMessage + : `Connection test failed: ${errorMessage}`; + alert(displayMessage); + } finally { + setTestingConnections((prev) => ({ ...prev, [configKey]: false })); + } + }; + + const performActualAPITest = async (config: ModelConfig) => { + const provider = PROVIDER_CONFIGS.find((p) => p.id === config.provider_id); + let apiUrl = config.api_base || provider?.api_base || ""; + + // Ensure URL ends with proper path + if (!apiUrl.endsWith("/")) { + apiUrl += "/"; + } + + // For better API key validation, we'll make a small completion request instead of just checking models + // This actually validates the API key works for the intended purpose + const testUrl = `${apiUrl}chat/completions`; + + // Prepare headers + const headers: Record = { + "Content-Type": "application/json", + }; + + // Add authentication + if (config.auth_method === "custom_headers" && config.custom_headers) { + Object.assign(headers, config.custom_headers); + } else if (config.auth_method === "bearer_token" && config.api_key) { + headers["Authorization"] = `Bearer ${config.api_key}`; + } else if (config.api_key) { + // Default to Authorization header for most providers + headers["Authorization"] = `Bearer ${config.api_key}`; + } + + // Add provider-specific headers + if (config.provider_id === "openrouter") { + headers["HTTP-Referer"] = window.location.origin; + headers["X-Title"] = "Llama Prompt Ops"; + } + + // Prepare test request body - minimal completion request + const testBody = { + model: config.model_name, + messages: [ + { + role: "user", + content: "test", + }, + ], + max_tokens: 1, + temperature: 0, + stream: false, + }; + + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), 15000); // 15 second timeout + + try { + console.log( + `Testing connection to ${config.provider_id} with URL: ${testUrl}` + ); + console.log("Headers:", headers); + console.log("Body:", testBody); + + const response = await fetch(testUrl, { + method: "POST", + headers, + body: JSON.stringify(testBody), + signal: controller.signal, + }); + + clearTimeout(timeoutId); + + console.log(`Response status: ${response.status}`); + console.log( + "Response headers:", + Object.fromEntries(response.headers.entries()) + ); + + if (!response.ok) { + // Get the response text for better error messages + let errorText = ""; + try { + const errorData = await response.json(); + errorText = + errorData.error?.message || + errorData.message || + JSON.stringify(errorData); + } catch { + errorText = await response.text(); + } + + console.log("Error response:", errorText); + + if (response.status === 401) { + // Make 401 errors more user-friendly + let friendlyMessage = "Incorrect or invalid API key"; + if (errorText.toLowerCase().includes("no auth")) { + friendlyMessage = + "API key is missing or invalid - please check your key"; + } else if (errorText.toLowerCase().includes("unauthorized")) { + friendlyMessage = "API key is incorrect or doesn't have access"; + } else if (errorText.toLowerCase().includes("expired")) { + friendlyMessage = "API key has expired - please generate a new one"; + } else if ( + errorText.toLowerCase().includes("quota") || + errorText.toLowerCase().includes("limit") + ) { + friendlyMessage = "API quota exceeded or rate limit reached"; + } + throw new Error(friendlyMessage); + } else if (response.status === 403) { + let friendlyMessage = "Access denied"; + if ( + errorText.toLowerCase().includes("quota") || + errorText.toLowerCase().includes("limit") + ) { + friendlyMessage = "API quota exceeded or rate limit reached"; + } else if (errorText.toLowerCase().includes("permission")) { + friendlyMessage = "API key doesn't have required permissions"; + } + throw new Error(friendlyMessage); + } else if (response.status === 404) { + throw new Error("API endpoint not found - check your base URL"); + } else if (response.status === 422) { + let friendlyMessage = "Invalid request"; + if (errorText.toLowerCase().includes("model")) { + friendlyMessage = "Model name is invalid or not available"; + } else if (errorText.toLowerCase().includes("parameter")) { + friendlyMessage = "Invalid request parameters"; + } + throw new Error(friendlyMessage); + } else if (response.status >= 500) { + throw new Error("Provider server error - please try again later"); + } else { + throw new Error( + `Connection failed (${response.status}) - please check your configuration` + ); + } + } + + // Try to parse response to ensure it's valid + const data = await response.json(); + console.log("Successful response:", data); + + // Verify it looks like a completion response + if (!data || (!data.choices && !data.id && !data.object)) { + console.warn("Unexpected API response format:", data); + // Still consider it successful if we got a 200 response + } + + return data; + } catch (error) { + clearTimeout(timeoutId); + console.error("API test error:", error); + + if (error instanceof Error) { + if (error.name === "AbortError") { + throw new Error( + "Connection timeout - check your internet connection and API endpoint" + ); + } else if ( + error.message.includes("Failed to fetch") || + error.message.includes("NetworkError") || + error.message.includes("fetch") + ) { + throw new Error( + "Network error - check your internet connection and API endpoint URL" + ); + } else if ( + error.message.includes("CORS") || + error.message.includes("cors") + ) { + throw new Error( + "CORS error - this API may not support browser requests" + ); + } + } + + throw error; + } + }; + + const getRecommendedSetup = () => { + switch (useCase) { + case "rag": + return "For RAG: Use a larger model like Llama 3.3 70B as Optimizer (better prompt generation) and Llama 3.1 8B as Target (cost-effective deployment)."; + case "qa": + return "For Q&A: Use Claude or GPT-4 as Optimizer (advanced reasoning) and your target Llama model for deployment. Custom providers like Azure AI Studio work great for production."; + case "custom": + return "For custom workflows: Consider cloud models for optimization (more capable) and local models for target deployment (cost + privacy). Use Custom Provider for Azure AI Studio, LiteLLM, or other specialized endpoints."; + default: + return "Start with OpenRouter for both roles, then consider separating for cost optimization. Add Custom Provider if you have existing Azure/enterprise endpoints."; + } + }; + + return ( +
      + {/* Header */} +
      +

      + Choose Your +
      + + AI Models + +

      +

      + Select inference providers and configure models for your optimization +

      +
      + + {/* Dual Model Explanation */} +
      +
      +
      +
      + +
      +
      +
      +

      + 🎯 Dual Model Optimization +

      +

      + Llama Prompt Ops uses two AI models working together to optimize + your prompts: +

      +
      +
      +
      + + + Target Model + +
      +

      + The model you're optimizing FOR - where your prompt will be + deployed in production +

      +
      +
      +
      + + + Optimizer Model + +
      +

      + The AI that generates improved prompt variations during + optimization +

      +
      +
      +
      +
      +
      + + {/* Recommendation */} + {/*
      +
      + +
      +

      + Smart Recommendation +

      +

      {getRecommendedSetup()}

      +
      +
      +
      */} + + {/* Provider Selection */} +
      +

      + 1. Select Inference Providers +

      + +
      + {PROVIDER_CONFIGS.map((provider) => ( +
      handleProviderToggle(provider.id)} + > +
      +
      +
      + {provider.icon} +
      +
      +

      + {provider.name} +

      +

      + {provider.description} +

      +
      +
      + + {selectedProviders.includes(provider.id) && ( + + )} +
      + + {/* Quick stats */} +
      +
      + {provider.category === "cloud" ? ( + + ) : provider.category === "local" ? ( + + ) : ( + + )} + {provider.category} +
      +
      + + {provider.pricing} +
      +
      + + + {provider.setup_difficulty} setup + +
      +
      + + {/* Pros/Cons */} +
      +
      + Pros: +
        + {provider.pros.slice(0, 2).map((pro, idx) => ( +
      • • {pro}
      • + ))} +
      +
      +
      + + Considerations: + +
        + {provider.cons.slice(0, 2).map((con, idx) => ( +
      • • {con}
      • + ))} +
      +
      +
      + + {/* Documentation link */} + +
      + ))} +
      +
      + + {/* Configuration */} + {selectedProviders.length > 0 && ( +
      +
      +

      + 2. Configure Models +

      + +
      + + {/* Interactive Role Status Overview */} +
      +
      +

      + Model Role Assignment +

      +
      + Click role badges below to change assignments +
      +
      +
      + {(() => { + const { hasTarget, hasOptimizer, hasBoth } = getRoleStatus(); + const targetConfig = configurations.find( + (c) => c.role === "target" || c.role === "both" + ); + const optimizerConfig = configurations.find( + (c) => c.role === "optimizer" || c.role === "both" + ); + + return ( + <> + {/* Target Status */} +
      +
      + + + Target Model + + {hasTarget && ( + + )} +
      + {targetConfig ? ( +
      +
      + + {targetConfig.role === "both" && ( + + Also optimizer + + )} +
      +
      +
      + {targetConfig.provider_id === "custom" + ? targetConfig.custom_provider_name + : PROVIDER_CONFIGS.find( + (p) => p.id === targetConfig.provider_id + )?.name} +
      +
      + {targetConfig.model_name} +
      +
      +
      + ) : ( +
      +

      + Not configured +

      + +
      + )} +
      + + {/* Optimizer Status */} +
      +
      + + + Optimizer Model + + {hasOptimizer && ( + + )} +
      + {optimizerConfig ? ( +
      +
      + + {optimizerConfig.role === "both" && ( + + Also target + + )} +
      +
      +
      + {optimizerConfig.provider_id === "custom" + ? optimizerConfig.custom_provider_name + : PROVIDER_CONFIGS.find( + (p) => p.id === optimizerConfig.provider_id + )?.name} +
      +
      + {optimizerConfig.model_name} +
      +
      +
      + ) : ( +
      +

      + Not configured +

      + +
      + )} +
      + + {/* Overall Status */} +
      +
      + + + Setup Status + +
      + {hasTarget && hasOptimizer ? ( +
      +

      + ✅ Ready for optimization +

      +

      + Using separate models for each role +

      +
      + ) : hasBoth ? ( +
      +

      + ✅ Ready for optimization +

      +

      + Using single model for both roles +

      +
      + ) : ( +
      +

      + ⚠️ Configure {!hasTarget ? "target" : ""}{" "} + {!hasTarget && !hasOptimizer ? " & " : ""}{" "} + {!hasOptimizer ? "optimizer" : ""} model + {!hasTarget && !hasOptimizer ? "s" : ""} +

      +

      + Select a provider above to get started +

      +
      + )} +
      + + ); + })()} +
      +
      + +
      + {selectedProviders.map((providerId) => { + const provider = PROVIDER_CONFIGS.find( + (p) => p.id === providerId + ); + const providerConfigs = configurations.filter( + (config) => config.provider_id === providerId + ); + + if (!provider) return null; + + return ( +
      + {/* Provider Header */} +
      +
      + {provider.icon} +
      +

      + {provider.name} +

      +

      + {providerConfigs.length} configuration + {providerConfigs.length !== 1 ? "s" : ""} +

      +
      +
      + + {/* Add Configuration Buttons */} +
      + {(() => { + const { canAddTarget, canAddOptimizer, canAddBoth } = + getAvailableRoles(providerId); + const { hasTarget, hasOptimizer, hasBoth } = + getRoleStatus(); + + return ( + <> + {/* Add Both button (only if no configs exist anywhere) */} + {canAddBoth && ( + + )} + + {/* Add Target button */} + {canAddTarget && ( + + )} + + {/* Add Optimizer button */} + {canAddOptimizer && ( + + )} + + {/* Status indicator when no buttons available */} + {!canAddTarget && + !canAddOptimizer && + !canAddBoth && ( +
      + + Roles complete +
      + )} + + ); + })()} +
      +
      + + {/* Split/Merge Controls */} + {(() => { + const hasBothConfig = providerConfigs.some( + (c) => c.role === "both" + ); + const hasTargetAndOptimizer = + providerConfigs.some((c) => c.role === "target") && + providerConfigs.some((c) => c.role === "optimizer"); + + if (hasBothConfig) { + return ( +
      +
      +
      + + + Using one model for both roles + +
      + +
      +
      + ); + } else if (hasTargetAndOptimizer) { + return ( +
      +
      +
      + + + Using separate models for each role + +
      + +
      +
      + ); + } + return null; + })()} + + {/* Configurations */} +
      + {providerConfigs.map((config) => { + const configKey = `${config.provider_id}-${config.role}-${config.id}`; + + return ( +
      +
      +
      + +
      + Model Configuration +
      +
      + +
      + {/* Connection status */} + {testingConnections[configKey] ? ( + + ) : connectionStatus[configKey] === "success" ? ( + + ) : connectionStatus[configKey] === "error" ? ( + + ) : null} + + + + {providerConfigs.length > 1 && ( + + )} +
      +
      + +
      + {/* Custom Provider Name (for custom providers) */} + {config.provider_id === "custom" && ( +
      + + + handleConfigChange( + config.id, + "custom_provider_name", + e.target.value + ) + } + placeholder="e.g., Azure AI Studio, My Custom API" + className="w-full p-3 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-facebook-blue" + /> +
      + )} + + {/* API Base URL (for custom providers and vLLM) */} + {(config.provider_id === "custom" || + config.provider_id === "vllm") && ( +
      + + + handleConfigChange( + config.id, + "api_base", + e.target.value + ) + } + placeholder={ + config.provider_id === "vllm" + ? "e.g., http://localhost:8000" + : "e.g., https://your-endpoint.eastus2.inference.ai.azure.com/" + } + className="w-full p-3 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-facebook-blue" + /> +
      + )} + + {/* Model Prefix (for custom providers) */} + {config.provider_id === "custom" && ( +
      + + + handleConfigChange( + config.id, + "model_prefix", + e.target.value + ) + } + placeholder="e.g., azure_ai/, custom/" + className="w-full p-3 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-facebook-blue" + /> +

      + Leave empty for direct model names +

      +
      + )} + + {/* Authentication Method (for custom providers) */} + {config.provider_id === "custom" && ( +
      + + +
      + )} + + {/* Model Selection */} +
      + + {config.provider_id === "custom" ? ( + + handleConfigChange( + config.id, + "model_name", + e.target.value + ) + } + placeholder="e.g., command-r-plus, mistral-large-latest" + className="w-full p-3 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-facebook-blue" + /> + ) : ( + + )} + {config.provider_id === "custom" && ( +

      + Final model will be:{" "} + {config.model_prefix || ""} + {config.model_name} +

      + )} +
      + + {/* Authentication */} + {(provider?.requires_signup || + config.provider_id === "custom") && ( +
      + + {config.provider_id === "custom" && + config.auth_method === "custom_headers" ? ( +
      +

      + Configure custom headers for + authentication +

      +