diff --git a/Cargo.toml b/Cargo.toml index a0c0f9d4..882c270a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "torc-server", "torc-slurm-job-runner", "torc-dash", + "torc-mcp-server", ] resolver = "2" @@ -96,6 +97,10 @@ hyper-tls = "0.5" hyper-openssl = "0.9" openssl = "0.10" +# MCP server +rmcp = { version = "0.1", features = ["server", "macros", "transport-io"] } +schemars = "1.0" + [package] name = "torc" version.workspace = true @@ -159,6 +164,7 @@ client = [ "dep:signal-hook", "dep:libc", "dep:nvml-wrapper", + "dep:sha2", "config", ] tui = [ diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index 1477c8ce..68976d86 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -24,6 +24,7 @@ - [Parallelization Strategies](./explanation/parallelization.md) - [Workflow Actions](./explanation/workflow-actions.md) - [Slurm Workflows](./explanation/slurm-workflows.md) + - [Automatic Failure Recovery](./explanation/automatic-recovery.md) - [Design](./explanation/design/README.md) - [Server API Handler](./explanation/design/server.md) - [Central Database](./explanation/design/database.md) @@ -75,6 +76,8 @@ - [Map Python functions across workers](./tutorials/map_python_function_across_workers.md) - [Filtering CLI Output with Nushell](./tutorials/filtering-with-nushell.md) - [Custom HPC Profile](./tutorials/custom-hpc-profile.md) + - [MCP Server with Claude Code](./tutorials/mcp-server.md) + - [Automatic Failure Recovery](./tutorials/automatic-recovery.md) --- diff --git a/docs/src/explanation/README.md b/docs/src/explanation/README.md index 507ef945..366aa433 100644 --- a/docs/src/explanation/README.md +++ b/docs/src/explanation/README.md @@ -16,3 +16,4 @@ This section provides understanding-oriented discussions of Torc's key concepts - Ready queue optimization for large workflows - Parallelization strategies and job allocation approaches - Workflow actions for automation and dynamic resource allocation +- AI-assisted recovery for diagnosing and fixing job failures diff --git a/docs/src/explanation/automatic-recovery.md b/docs/src/explanation/automatic-recovery.md new file mode 100644 index 00000000..86b168fa --- /dev/null +++ b/docs/src/explanation/automatic-recovery.md @@ -0,0 +1,196 @@ +# Automatic Failure Recovery + +This document explains how Torc's automatic failure recovery system works, its design principles, and when to use automatic vs manual recovery. + +## Overview + +Torc provides **automatic failure recovery** through the `torc watch --auto-recover` command. When jobs fail, the system: + +1. Diagnoses the failure cause (OOM, timeout, or unknown) +2. Applies heuristics to adjust resource requirements +3. Resets failed jobs and submits new Slurm allocations +4. Resumes monitoring until completion or max retries + +This deterministic approach handles the majority of HPC failures without human intervention. + +## Design Principles + +### Why Deterministic Recovery? + +Most HPC job failures fall into predictable categories: + +| Failure Type | Frequency | Solution | +|--------------|-----------|----------| +| Out of Memory | ~60% | Increase memory allocation | +| Timeout | ~25% | Increase runtime limit | +| Transient errors | ~10% | Simple retry | +| Code bugs | ~5% | Manual intervention | + +For 85-90% of failures, the solution is mechanical: increase resources and retry. This doesn't require AI judgment—simple heuristics work well. + +### Recovery Architecture + +```mermaid +flowchart LR + A[torc watch
polling] --> B{Workflow
complete?} + B -->|No| A + B -->|Yes, with failures| C[Diagnose failures
check resources] + C --> D[Apply heuristics
adjust resources] + D --> E[Submit new
allocations] + E --> A + B -->|Yes, success| F[Exit 0] +``` + +### Failure Detection + +Torc tracks resource usage during job execution: +- Memory usage (RSS and peak) +- CPU utilization +- Execution time + +This data is analyzed to determine failure causes: + +**OOM Detection:** +- Peak memory exceeds specified limit +- Exit code 137 (SIGKILL from OOM killer) +- Flag: `likely_oom: true` + +**Timeout Detection:** +- Execution time within 10% of runtime limit +- Job was killed (not graceful exit) +- Flag: `likely_timeout: true` + +### Recovery Heuristics + +Default multipliers applied to failed jobs: + +| Failure | Default Multiplier | Configurable | +|---------|-------------------|--------------| +| OOM | 1.5x memory | `--memory-multiplier` | +| Timeout | 1.5x runtime | `--runtime-multiplier` | + +Example: A job with 8g memory that fails with OOM gets 12g on retry. + +### Slurm Scheduler Regeneration + +After adjusting resources, the system regenerates Slurm schedulers: + +1. Finds all pending jobs (uninitialized, ready, blocked) +2. Groups by resource requirements +3. Calculates minimum allocations needed +4. Creates new schedulers with appropriate walltimes +5. Submits allocations to Slurm + +This is handled by `torc slurm regenerate --submit`. + +## Configuration + +### Command-Line Options + +```bash +torc watch \ + --auto-recover \ # Enable automatic recovery + --max-retries 3 \ # Maximum recovery attempts + --memory-multiplier 1.5 \ # Memory increase factor for OOM + --runtime-multiplier 1.5 \ # Runtime increase factor for timeout + --poll-interval 60 \ # Seconds between status checks + --output-dir output \ # Directory for job output files + --show-job-counts # Display job counts during polling (optional) +``` + +### Retry Limits + +The `--max-retries` option prevents infinite retry loops. After exceeding this limit, the system exits with an error, indicating manual intervention is needed. + +Default: 3 retries + +## When to Use Manual Recovery + +Automatic recovery works well for resource-related failures, but some situations require manual intervention: + +### Use Manual Recovery When: + +1. **Jobs keep failing after max retries** + - The heuristics aren't solving the problem + - Need to investigate root cause + +2. **Unknown failure modes** + - Exit codes that don't indicate OOM/timeout + - Application-specific errors + +3. **Code bugs** + - Jobs fail consistently with same error + - No resource issue detected + +4. **Cost optimization** + - Want to analyze actual usage before increasing + - Need to decide whether job is worth more resources + +### MCP Server for Manual Recovery + +The Torc MCP server provides tools for AI-assisted investigation: + +| Tool | Purpose | +|------|---------| +| `get_workflow_status` | Get overall workflow status | +| `list_failed_jobs` | List failed jobs with error info | +| `get_job_logs` | Read stdout/stderr logs | +| `check_resource_utilization` | Detailed resource analysis | +| `update_job_resources` | Manually adjust resources | +| `restart_jobs` | Reset and restart jobs | +| `resubmit_workflow` | Regenerate Slurm schedulers | + +## Comparison + +| Feature | Automatic | Manual/AI-Assisted | +|---------|-----------|-------------------| +| Human involvement | None | Interactive | +| Speed | Fast | Depends on human | +| Handles OOM/timeout | Yes | Yes | +| Handles unknown errors | Retry only | Full investigation | +| Cost optimization | Basic | Can be sophisticated | +| Use case | Production workflows | Debugging, optimization | + +## Implementation Details + +### The Watch Command + +```bash +torc watch --auto-recover +``` + +Main loop: +1. Poll `is_workflow_complete` API +2. Print status updates +3. On completion, check for failures +4. If failures and auto-recover enabled: + - Run `torc reports check-resource-utilization --include-failed` + - Parse results for `likely_oom` and `likely_timeout` flags + - Update resource requirements via API + - Run `torc workflows reset-status --failed-only --restart` + - Run `torc slurm regenerate --submit` + - Increment retry counter + - Resume polling +5. Exit 0 on success, exit 1 on max retries exceeded + +### The Regenerate Command + +```bash +torc slurm regenerate --submit +``` + +1. Query jobs with status uninitialized/ready/blocked +2. Group by resource requirements +3. For each group: + - Find best partition using HPC profile + - Calculate jobs per node + - Determine number of allocations needed + - Create scheduler config +4. Update jobs with new scheduler reference +5. Submit allocations via sbatch + +## See Also + +- [Automatic Failure Recovery Tutorial](../tutorials/automatic-recovery.md) - Step-by-step guide +- [MCP Server Tutorial](../tutorials/mcp-server.md) - Setting up AI-assisted tools +- [Resource Monitoring](../how-to/resource-monitoring.md) - Understanding resource tracking diff --git a/docs/src/tutorials/README.md b/docs/src/tutorials/README.md index 4d835900..4fdc73e8 100644 --- a/docs/src/tutorials/README.md +++ b/docs/src/tutorials/README.md @@ -16,6 +16,8 @@ This section contains learning-oriented lessons to help you get started with Tor 10. [Map Python Functions](./map_python_function_across_workers.md) - Distribute Python functions across workers 11. [Filtering CLI Output with Nushell](./filtering-with-nushell.md) - Filter jobs, results, and user data with readable queries 12. [Custom HPC Profile](./custom-hpc-profile.md) - Create an HPC profile for unsupported clusters +13. [MCP Server with Claude Code](./mcp-server.md) - Enable Claude to interact with your workflows +14. [Automatic Failure Recovery](./automatic-recovery.md) - Autonomous workflow monitoring with `torc watch` Start with the Configuration Files tutorial to set up your environment, then try the Dashboard Deployment tutorial if you want to use the web interface. diff --git a/docs/src/tutorials/ai-failure-recovery.md b/docs/src/tutorials/ai-failure-recovery.md new file mode 100644 index 00000000..3bd9d1dd --- /dev/null +++ b/docs/src/tutorials/ai-failure-recovery.md @@ -0,0 +1 @@ +# Automatic Failure Recovery diff --git a/docs/src/tutorials/automatic-recovery.md b/docs/src/tutorials/automatic-recovery.md new file mode 100644 index 00000000..0653fc12 --- /dev/null +++ b/docs/src/tutorials/automatic-recovery.md @@ -0,0 +1,264 @@ +# Tutorial: Automatic Failure Recovery + +This tutorial shows how to use `torc watch` with automatic recovery to handle workflow failures without manual intervention. + +## Learning Objectives + +By the end of this tutorial, you will: + +- Understand automatic vs manual recovery options +- Know how to configure automatic recovery heuristics +- Monitor workflows with automatic failure handling + +## Prerequisites + +- Torc installed with the client feature +- A running Torc server +- Workflows submitted to Slurm + +## Automatic Recovery + +The `torc watch` command can automatically recover from common failures: + +```bash +torc watch 42 --auto-recover +``` + +This will: +1. Poll the workflow until completion +2. On failure, diagnose the cause (OOM, timeout, etc.) +3. Adjust resource requirements based on heuristics +4. Reset failed jobs and submit new Slurm allocations +5. Resume monitoring +6. Repeat until success or max retries exceeded + +### Recovery Heuristics + +| Failure Type | Detection | Default Action | +|--------------|-----------|----------------| +| Out of Memory | Peak memory > limit, exit code 137 | Increase memory by 1.5x | +| Timeout | Execution time near limit | Increase runtime by 1.5x | +| Unknown | Other exit codes | Retry without changes | + +### Configuration Options + +```bash +torc watch 42 --auto-recover \ + --max-retries 5 \ # Maximum recovery attempts (default: 3) + --memory-multiplier 2.0 \ # Memory increase factor (default: 1.5) + --runtime-multiplier 2.0 \ # Runtime increase factor (default: 1.5) + --poll-interval 120 \ # Seconds between status checks (default: 60) + --output-dir /scratch/output \ + --show-job-counts # Display per-status job counts (optional, adds server load) +``` + +## Example: Complete Workflow + +### 1. Submit a Workflow + +```bash +torc submit-slurm --account myproject workflow.yaml +``` + +Output: +``` +Created workflow 42 with 100 jobs +Submitted to Slurm with 10 allocations +``` + +### 2. Start Watching with Auto-Recovery + +```bash +torc watch 42 --auto-recover --max-retries 3 --show-job-counts +``` + +> **Note:** The `--show-job-counts` flag is optional. Without it, the command polls +> silently until completion, which reduces server load for large workflows. + +Output: +``` +Watching workflow 42 (poll interval: 60s, auto-recover enabled, max retries: 3, job counts enabled) + completed=0, running=10, pending=0, failed=0, blocked=90 + completed=25, running=10, pending=0, failed=0, blocked=65 + ... + completed=95, running=0, pending=0, failed=5, blocked=0 +Workflow 42 is complete + +Workflow completed with failures: + - Failed: 5 + - Canceled: 0 + - Terminated: 0 + - Completed: 95 + +Attempting automatic recovery (attempt 1/3) + +Diagnosing failures... +Applying recovery heuristics... + Job 107 (train_model_7): OOM detected, increasing memory 8g -> 12g + Job 112 (train_model_12): OOM detected, increasing memory 8g -> 12g + Job 123 (train_model_23): OOM detected, increasing memory 8g -> 12g + Job 131 (train_model_31): OOM detected, increasing memory 8g -> 12g + Job 145 (train_model_45): OOM detected, increasing memory 8g -> 12g + Applied fixes: 5 OOM, 0 timeout + +Resetting failed jobs... +Regenerating Slurm schedulers and submitting... + +Recovery initiated. Resuming monitoring... + +Watching workflow 42 (poll interval: 60s, auto-recover enabled, max retries: 3, job counts enabled) + completed=95, running=5, pending=0, failed=0, blocked=0 + ... +Workflow 42 is complete + +✓ Workflow completed successfully (100 jobs) +``` + +### 3. If Max Retries Exceeded + +If failures persist after max retries: + +``` +Max retries (3) exceeded. Manual intervention required. +Use the Torc MCP server with your AI assistant to investigate. +``` + +At this point, you can use the MCP server with an AI assistant to investigate the root cause. + +## Manual Recovery (Without --auto-recover) + +Without the `--auto-recover` flag, `torc watch` simply monitors and reports: + +```bash +torc watch 42 +``` + +On failure, it exits with instructions: + +``` +Workflow completed with failures: + - Failed: 5 + - Completed: 95 + +Auto-recovery disabled. To enable, use --auto-recover flag. +Or use the Torc MCP server with your AI assistant for manual recovery. +``` + +## When to Use Each Approach + +### Use Automatic Recovery (`--auto-recover`) when: +- Running standard compute jobs with predictable failure modes +- You want hands-off operation +- OOM and timeout are the main failure types +- You have HPC allocation budget for retries + +### Use Manual/AI-Assisted Recovery when: +- Failures have complex or unknown causes +- You need to investigate before retrying +- Resource increases aren't solving the problem +- You want to understand why jobs are failing + +## Best Practices + +### 1. Start with Conservative Resources + +Set initial resource requests lower and let auto-recovery increase them: +- Jobs that succeed keep their original allocation +- Only failing jobs get increased resources +- Avoids wasting HPC resources on over-provisioned jobs + +### 2. Set Reasonable Max Retries + +```bash +--max-retries 3 # Good for most workflows +``` + +Too many retries can waste allocation time on jobs that will never succeed. + +### 3. Use Appropriate Multipliers + +For memory-bound jobs: +```bash +--memory-multiplier 2.0 # Double on OOM +``` + +For time-sensitive jobs where you want larger increases: +```bash +--runtime-multiplier 2.0 # Double runtime on timeout +``` + +### 4. Monitor Long-Running Workflows + +**Always run `torc watch` inside tmux or screen** for long-running workflows. HPC workflows can run for hours or days, and you don't want to lose your monitoring session if: + +- Your SSH connection drops +- Your laptop goes to sleep +- You need to disconnect and reconnect later + +Using [tmux](https://github.com/tmux/tmux/wiki) (recommended): + +```bash +# Start a new tmux session +tmux new -s torc-watch + +# Run the watch command +torc watch 42 --auto-recover --poll-interval 300 --show-job-counts + +# Detach from session: press Ctrl+b, then d +# Reattach later: tmux attach -t torc-watch +``` + +Using screen: +```bash +screen -S torc-watch +torc watch 42 --auto-recover --poll-interval 300 --show-job-counts +# Detach: Ctrl+a, then d +# Reattach: screen -r torc-watch +``` + +For very large workflows, omit `--show-job-counts` to reduce server load. + +### 5. Check Resource Utilization Afterward + +After completion, review actual usage: +```bash +torc reports check-resource-utilization 42 +``` + +This helps tune future job specifications. + +## Troubleshooting + +### Jobs Keep Failing After Recovery + +If jobs fail repeatedly with the same error: +1. Check if the error is resource-related (OOM/timeout) +2. Review job logs: `torc jobs logs ` +3. Check if there's a code bug +4. Use MCP server with AI assistant to investigate + +### No Slurm Schedulers Generated + +If `slurm regenerate` fails: +1. Ensure workflow was created with `--account` option +2. Check HPC profile is detected: `torc hpc detect` +3. Specify profile explicitly: `--profile kestrel` + +### Resource Limits Too High + +If jobs are requesting more resources than partitions allow: +1. Check partition limits: `torc hpc partitions ` +2. Use smaller multipliers +3. Consider splitting jobs into smaller pieces + +## Summary + +The `torc watch --auto-recover` command provides: + +- **Automatic OOM handling**: Detects memory issues and increases allocations +- **Automatic timeout handling**: Detects slow jobs and increases runtime +- **Configurable heuristics**: Adjust multipliers for your workload +- **Retry limits**: Prevent infinite retry loops +- **Graceful degradation**: Falls back to manual recovery when needed + +For most HPC workflows, automatic recovery handles 80-90% of transient failures without human intervention. diff --git a/docs/src/tutorials/mcp-server.md b/docs/src/tutorials/mcp-server.md new file mode 100644 index 00000000..9f2543a6 --- /dev/null +++ b/docs/src/tutorials/mcp-server.md @@ -0,0 +1,358 @@ +# Tutorial: Using the MCP Server + +This tutorial shows how to use the Torc MCP (Model Context Protocol) server to enable AI assistants to interact with your Torc workflows directly. + +## Learning Objectives + +By the end of this tutorial, you will: + +- Understand what the MCP server provides +- Know how to configure your AI assistant to use the Torc MCP server +- Be able to inspect and manage your workflows using natural language + +## Prerequisites + +- Torc installed +- Torc server running +- One of the following AI assistants: + - [Claude Code](https://claude.ai/code) (terminal) + - [VS Code](https://code.visualstudio.com/) with GitHub Copilot (IDE) + +## What is the MCP Server? + +The Model Context Protocol (MCP) is an open standard for connecting AI assistants to external tools and data sources. The `torc-mcp-server` binary exposes Torc's workflow management capabilities as MCP tools. + +**Available Tools:** + +| Tool | Description | +|------|-------------| +| `get_workflow_status` | Get workflow info with job counts by status | +| `get_job_details` | Get detailed job info including resource requirements | +| `get_job_logs` | Read stdout/stderr from job log files | +| `list_failed_jobs` | List all failed jobs in a workflow | +| `list_jobs_by_status` | Filter jobs by status | +| `check_resource_utilization` | Analyze resource usage and detect OOM/timeout issues | +| `update_job_resources` | Modify job resource requirements | +| `restart_jobs` | Reset and restart failed jobs | +| `resubmit_workflow` | Regenerate Slurm schedulers and submit new allocations | +| `cancel_jobs` | Cancel specific jobs | +| `create_workflow_from_spec` | Create a workflow from JSON specification | + +## Configuration + +Choose the setup that matches your environment: + +- **[Claude Code](#claude-code)** - Terminal-based AI assistant +- **[VS Code + Copilot](#vs-code--github-copilot)** - IDE with GitHub Copilot Chat +- **[VS Code + Copilot on HPC](#vs-code-remote-ssh-for-hpc)** - Remote development on HPC clusters + +--- + +## Claude Code + +Claude Code supports MCP configuration at three scopes: + +| Scope | File | Use Case | +|-------|------|----------| +| **Project** | `.mcp.json` in project root | Team-shared configuration (commit to git) | +| **Local** | `.mcp.json` with `--scope local` | Personal project settings (gitignored) | +| **User** | `~/.claude.json` | Cross-project personal tools | + +### Using the CLI (Recommended) + +```bash +# Add the Torc MCP server to your project +claude mcp add torc \ + --scope project \ + -e TORC_API_URL=http://localhost:8080/torc-service/v1 \ + -e TORC_OUTPUT_DIR=/path/to/your/output \ + -- /path/to/torc-mcp-server +``` + +### Manual Configuration + +Create or edit `.mcp.json` in your project root: + +```json +{ + "mcpServers": { + "torc": { + "command": "/path/to/torc-mcp-server", + "env": { + "TORC_API_URL": "http://localhost:8080/torc-service/v1", + "TORC_OUTPUT_DIR": "/path/to/your/output" + } + } + } +} +``` + +Replace `/path/to/torc-mcp-server` with the actual path to your built binary. + +### Environment Variables + +| Variable | Description | Default | +|----------|-------------|---------| +| `TORC_API_URL` | Torc server URL | `http://localhost:8080/torc-service/v1` | +| `TORC_OUTPUT_DIR` | Directory containing job logs | `output` | +| `TORC_USERNAME` | Username for authentication (optional) | — | +| `TORC_PASSWORD` | Password for authentication (optional) | — | + +You can also use environment variable expansion in `.mcp.json`: + +```json +{ + "mcpServers": { + "torc": { + "command": "/path/to/torc-mcp-server", + "env": { + "TORC_API_URL": "${TORC_API_URL:-http://localhost:8080/torc-service/v1}", + "TORC_OUTPUT_DIR": "${TORC_OUTPUT_DIR:-./output}" + } + } + } +} +``` + +### Verify Configuration + +Check that the MCP server is configured correctly: + +```bash +# List configured MCP servers +claude mcp list + +# Get details about the torc server +claude mcp get torc +``` + +Start a new Claude Code session in your project directory. Claude will automatically have access to the Torc tools. + +### Example Conversation + +``` +You: What's the status of workflow 15? + +Claude: I'll check the workflow status for you. + +Workflow 15 "hyperparameter_sweep" has 18 jobs: +- Completed: 12 +- Running: 3 +- Failed: 2 +- Blocked: 1 + +Would you like me to investigate the failed jobs? + +You: Yes, show me what went wrong + +Claude: Looking at the failed jobs... + +Job 7 "train_lr0.01_batch32" failed with: +- Return code: 137 (killed by OOM) +- Peak memory: 15.8 GB +- Requested memory: 8 GB + +Job 12 "train_lr0.001_batch64" failed with: +- Return code: 1 +- Error: FileNotFoundError: checkpoint.pt not found + +Recommendations: +1. Job 7 needs more memory - I can restart it with 20 GB +2. Job 12 has a missing dependency - check if upstream jobs completed + +You: Restart job 7 with 20GB memory + +Claude: Done. I've: +1. Updated job 7's memory requirement from 8g to 20g +2. Reset failed jobs and restarted the workflow + +Job 7 is now ready to run again. +``` + +### Troubleshooting Claude Code + +**Claude doesn't see the tools:** +- Verify the MCP server is configured: `claude mcp list` +- Check the config file is valid JSON: `cat .mcp.json | jq .` +- Check that the path to `torc-mcp-server` is correct and the binary exists +- Start a new Claude Code session (MCP servers are loaded at startup) + +**Remove the MCP server:** +```bash +claude mcp remove torc +``` + +--- + +## VS Code + GitHub Copilot + +VS Code with GitHub Copilot Chat supports MCP servers for enhanced AI-assisted workflow management. + +### Prerequisites + +- VS Code 1.99 or later +- GitHub Copilot extension installed +- GitHub Copilot subscription (Business, Enterprise, Pro, or Pro+) + +### Configuration + +Create `.vscode/mcp.json` in your project root: + +```json +{ + "servers": { + "torc": { + "command": "/path/to/torc-mcp-server", + "env": { + "TORC_API_URL": "http://localhost:8080/torc-service/v1", + "TORC_OUTPUT_DIR": "./output" + } + } + } +} +``` + +### Verify Setup + +1. Open the Command Palette (`Ctrl+Shift+P` / `Cmd+Shift+P`) +2. Run "MCP: List Servers" +3. Verify "torc" appears in the list + +### Usage + +In Copilot Chat, use **Agent Mode** (`@workspace` or the agent icon) to access MCP tools: + +> "What's the status of workflow 42?" + +> "Show me the failed jobs and their error logs" + +--- + +## VS Code Remote SSH for HPC + +For users running Torc on HPC clusters, VS Code's Remote SSH extension allows you to use Copilot Chat with the MCP server running directly on the cluster. + +### Architecture + +``` +┌─────────────────────┐ ┌─────────────────────────────────────┐ +│ Local Machine │ SSH │ HPC Cluster │ +│ │◄───────►│ │ +│ VS Code │ │ torc-mcp-server ◄──► torc-server │ +│ (Copilot Chat) │ │ ▲ │ +│ │ │ │ │ +└─────────────────────┘ │ .vscode/mcp.json │ + └─────────────────────────────────────┘ +``` + +The MCP server runs on the HPC, communicates with the Torc server on the HPC, and VS Code proxies requests through SSH. No ports need to be exposed to your local machine. + +### Step 1: Build `torc-mcp-server` on the HPC + +```bash +# On the HPC (via SSH or login node) +cd /path/to/torc +cargo build --release -p torc-mcp-server +``` + +### Step 2: Configure MCP in your project + +Create `.vscode/mcp.json` in your project directory **on the HPC**: + +```json +{ + "servers": { + "torc": { + "command": "/path/on/hpc/torc/target/release/torc-mcp-server", + "env": { + "TORC_API_URL": "http://localhost:8080/torc-service/v1", + "TORC_OUTPUT_DIR": "./output" + } + } + } +} +``` + +> **Important:** MCP servers configured in workspace settings (`.vscode/mcp.json`) run on the remote host, not your local machine. + +### Step 3: Connect and use + +1. Install the [Remote - SSH](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-ssh) extension +2. Connect to the HPC: `Remote-SSH: Connect to Host...` +3. Open your project folder on the HPC +4. Open Copilot Chat and use Agent Mode + +### HPC-Specific Tips + +- **Module systems:** If your HPC uses modules, you may need to set `PATH` in the env to include required dependencies +- **Shared filesystems:** Place `.vscode/mcp.json` in a project directory on a shared filesystem accessible from compute nodes +- **Firewalls:** The MCP server only needs to reach the Torc server on the HPC's internal network + +--- + +## Interact with Workflows + +Once configured, you can ask your AI assistant to help manage workflows using natural language: + +**Check workflow status:** +> "What's the status of workflow 42?" + +**Investigate failures:** +> "List all failed jobs in workflow 42 and show me the error logs" + +**Take action:** +> "Restart the failed jobs in workflow 42 with doubled memory" + +**Create workflows:** +> "Create a workflow with 10 parallel jobs that each run `python process.py index`" + +--- + +## How It Works + +The MCP server: + +1. **Receives tool calls** from the AI assistant via stdio +2. **Translates them** to Torc REST API calls +3. **Returns results** in a format the assistant can understand + +The server is stateless—it simply proxies requests to your running Torc server. All workflow state remains in Torc's database. + +## Security Considerations + +- The MCP server has full access to your Torc server +- Consider using authentication (`TORC_USERNAME`/`TORC_PASSWORD`) if your Torc server is exposed +- The server can modify workflows (restart, cancel, update resources) +- Review proposed actions before they execute + +## Troubleshooting + +### "Failed to connect to server" +- Ensure your Torc server is running +- Check that `TORC_API_URL` is correct +- Verify network connectivity + +### "Permission denied" or "Authentication failed" +- Set `TORC_USERNAME` and `TORC_PASSWORD` if your server requires auth +- Check that the credentials are correct + +### Logs not found +- Ensure `TORC_OUTPUT_DIR` points to your job output directory +- Check that jobs have actually run (logs are created at runtime) + +## What You Learned + +In this tutorial, you learned: + +- ✅ What the Torc MCP server provides +- ✅ How to configure Claude Code to use it +- ✅ How to configure VS Code + GitHub Copilot to use it +- ✅ How to set up MCP on HPC clusters via Remote SSH +- ✅ How to interact with workflows using natural language +- ✅ Security considerations for production use + +## Next Steps + +- [Automatic Failure Recovery](./automatic-recovery.md) - Use `torc watch` for automatic failure recovery +- [Automatic Recovery Explained](../explanation/automatic-recovery.md) - Understand the recovery architecture +- [Configuration Files](./configuration.md) - Set up Torc configuration diff --git a/src/cli.rs b/src/cli.rs index 562daa7d..0f1ef5e2 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -142,6 +142,54 @@ pub enum Commands { #[arg(long, default_value = "false")] skip_checks: bool, }, + /// Watch a workflow and automatically recover from failures + /// + /// Monitors a workflow until completion. With --auto-recover, automatically + /// diagnoses failures, adjusts resource requirements, and resubmits jobs. + /// + /// Recovery heuristics: + /// - OOM (out of memory): Increase memory by --memory-multiplier (default 1.5x) + /// - Timeout: Increase runtime by --runtime-multiplier (default 1.5x) + /// - Other failures: Retry without changes (transient errors) + /// + /// Without --auto-recover, reports failures and exits for manual intervention + /// or AI-assisted recovery via the MCP server. + Watch { + /// Workflow ID to watch + #[arg()] + workflow_id: i64, + + /// Poll interval in seconds + #[arg(short, long, default_value = "60")] + poll_interval: u64, + + /// Enable automatic failure recovery + #[arg(long)] + auto_recover: bool, + + /// Maximum number of recovery attempts (default: 3) + #[arg(long, default_value = "3")] + max_retries: u32, + + /// Memory multiplier for OOM failures (default: 1.5 = 50% increase) + #[arg(long, default_value = "1.5")] + memory_multiplier: f64, + + /// Runtime multiplier for timeout failures (default: 1.5 = 50% increase) + #[arg(long, default_value = "1.5")] + runtime_multiplier: f64, + + /// Output directory for job files + #[arg(short, long, default_value = "output")] + output_dir: PathBuf, + + /// Show job counts by status during polling + /// + /// WARNING: This option queries all jobs on each poll, which can cause high + /// server load for large workflows. Only use for debugging or small workflows. + #[arg(long)] + show_job_counts: bool, + }, /// Workflow management commands Workflows { #[command(subcommand)] diff --git a/src/client.rs b/src/client.rs index ae53e318..2c8e3bdc 100644 --- a/src/client.rs +++ b/src/client.rs @@ -18,6 +18,7 @@ pub mod log_paths; pub mod parameter_expansion; pub mod resource_monitor; pub mod utils; +pub mod watch; pub mod workflow_graph; pub mod workflow_manager; pub mod workflow_spec; diff --git a/src/client/commands.rs b/src/client/commands.rs index 26d90cce..057ba15f 100644 --- a/src/client/commands.rs +++ b/src/client/commands.rs @@ -13,6 +13,7 @@ pub mod scheduled_compute_nodes; pub mod slurm; pub mod table_format; pub mod user_data; +pub mod watch; pub mod workflows; use std::env; diff --git a/src/client/commands/reports.rs b/src/client/commands/reports.rs index 582651f5..751069eb 100644 --- a/src/client/commands/reports.rs +++ b/src/client/commands/reports.rs @@ -68,6 +68,9 @@ pub enum ReportCommands { /// Show all jobs (default: only show jobs that exceeded requirements) #[arg(short, long)] all: bool, + /// Include failed jobs in the analysis (for recovery diagnostics) + #[arg(long)] + include_failed: bool, }, /// Generate a comprehensive JSON report of job results including all log file paths Results { @@ -89,8 +92,16 @@ pub fn handle_report_commands(config: &Configuration, command: &ReportCommands, workflow_id, run_id, all, + include_failed, } => { - check_resource_utilization(config, *workflow_id, *run_id, *all, format); + check_resource_utilization( + config, + *workflow_id, + *run_id, + *all, + *include_failed, + format, + ); } ReportCommands::Results { workflow_id, @@ -107,6 +118,7 @@ fn check_resource_utilization( workflow_id: Option, run_id: Option, show_all: bool, + include_failed: bool, format: &str, ) { // Get or select workflow ID @@ -122,21 +134,51 @@ fn check_resource_utilization( }, }; - // Fetch results for the workflow using pagination + // Fetch completed results for the workflow using pagination let mut params = pagination::ResultListParams::new().with_status(models::JobStatus::Completed); if let Some(rid) = run_id { params = params.with_run_id(rid); } - let results = match pagination::paginate_results(config, wf_id, params) { + let completed_results = match pagination::paginate_results(config, wf_id, params) { Ok(results) => results, Err(e) => { - print_error("fetching results", &e); + print_error("fetching completed results", &e); std::process::exit(1); } }; + // Fetch failed results if requested + let failed_results = if include_failed { + let mut failed_params = + pagination::ResultListParams::new().with_status(models::JobStatus::Failed); + if let Some(rid) = run_id { + failed_params = failed_params.with_run_id(rid); + } + match pagination::paginate_results(config, wf_id, failed_params) { + Ok(results) => results, + Err(e) => { + print_error("fetching failed results", &e); + std::process::exit(1); + } + } + } else { + Vec::new() + }; + + // Combine results + let mut results = completed_results; + results.extend(failed_results); + if results.is_empty() { - println!("No completed job results found for workflow {}", wf_id); + let msg = if include_failed { + format!( + "No completed or failed job results found for workflow {}", + wf_id + ) + } else { + format!("No completed job results found for workflow {}", wf_id) + }; + println!("{}", msg); std::process::exit(0); } @@ -175,9 +217,11 @@ fn check_resource_utilization( // Analyze each result let mut rows = Vec::new(); let mut over_util_count = 0; + let mut failed_jobs_info: Vec = Vec::new(); for result in &results { let job_id = result.job_id; + let is_failed = result.status == models::JobStatus::Failed; // Get job and its resource requirements let job = match job_map.get(&job_id) { @@ -209,6 +253,52 @@ fn check_resource_utilization( let job_name = job.name.clone(); + // Track failed jobs separately with their resource info + if is_failed { + let mut failed_info = serde_json::json!({ + "job_id": job_id, + "job_name": job_name.clone(), + "return_code": result.return_code, + "exec_time_minutes": result.exec_time_minutes, + "configured_memory": resource_req.memory.clone(), + "configured_runtime": resource_req.runtime.clone(), + "configured_cpus": resource_req.num_cpus, + }); + + // Add resource usage if available + if let Some(peak_mem) = result.peak_memory_bytes { + failed_info["peak_memory_bytes"] = serde_json::json!(peak_mem); + failed_info["peak_memory_formatted"] = + serde_json::json!(format_memory_bytes(peak_mem)); + + // Check if it's an OOM issue + let specified_memory_bytes = parse_memory_string(&resource_req.memory); + if peak_mem > specified_memory_bytes { + failed_info["likely_oom"] = serde_json::json!(true); + let over_pct = + ((peak_mem as f64 / specified_memory_bytes as f64) - 1.0) * 100.0; + failed_info["memory_over_utilization"] = + serde_json::json!(format!("+{:.1}%", over_pct)); + } + } + + // Check if runtime exceeded + let exec_time_seconds = result.exec_time_minutes * 60.0; + if let Ok(specified_runtime_seconds) = duration_string_to_seconds(&resource_req.runtime) + { + let specified_runtime_seconds = specified_runtime_seconds as f64; + if exec_time_seconds > specified_runtime_seconds * 0.9 { + // If job ran for > 90% of its runtime, it might be a timeout + failed_info["likely_timeout"] = serde_json::json!(true); + let pct_of_runtime = (exec_time_seconds / specified_runtime_seconds) * 100.0; + failed_info["runtime_utilization"] = + serde_json::json!(format!("{:.1}%", pct_of_runtime)); + } + } + + failed_jobs_info.push(failed_info); + } + // Check memory over-utilization if let Some(peak_memory_bytes) = result.peak_memory_bytes { let specified_memory_bytes = parse_memory_string(&resource_req.memory); @@ -305,7 +395,7 @@ fn check_resource_utilization( // Output results match format { "json" => { - let json_output = serde_json::json!({ + let mut json_output = serde_json::json!({ "workflow_id": wf_id, "run_id": run_id, "total_results": results.len(), @@ -321,6 +411,13 @@ fn check_resource_utilization( }) }).collect::>(), }); + + // Add failed jobs section if there are any + if !failed_jobs_info.is_empty() { + json_output["failed_jobs_count"] = serde_json::json!(failed_jobs_info.len()); + json_output["failed_jobs"] = serde_json::json!(failed_jobs_info); + } + println!("{}", serde_json::to_string_pretty(&json_output).unwrap()); } "table" | _ => { diff --git a/src/client/commands/slurm.rs b/src/client/commands/slurm.rs index 9708d8ab..3f942239 100644 --- a/src/client/commands/slurm.rs +++ b/src/client/commands/slurm.rs @@ -317,6 +317,43 @@ pub enum SlurmCommands { #[arg(long)] force: bool, }, + /// Regenerate Slurm schedulers for an existing workflow based on pending jobs + /// + /// Analyzes jobs that are uninitialized, ready, or blocked and generates new + /// Slurm schedulers to run them. Uses existing scheduler configurations as + /// defaults for account, partition, and other settings. + /// + /// This is useful for recovery after job failures: update job resources, + /// reset failed jobs, then regenerate schedulers to submit new allocations. + Regenerate { + /// Workflow ID + #[arg()] + workflow_id: i64, + + /// Slurm account to use (defaults to account from existing schedulers) + #[arg(long)] + account: Option, + + /// HPC profile to use (if not specified, tries to detect current system) + #[arg(long)] + profile: Option, + + /// Bundle all nodes into a single Slurm allocation per scheduler + #[arg(long)] + single_allocation: bool, + + /// Submit the generated allocations immediately + #[arg(long)] + submit: bool, + + /// Output directory for job output files (used when submitting) + #[arg(short, long, default_value = "output")] + output_dir: PathBuf, + + /// Poll interval in seconds (used when submitting) + #[arg(short, long, default_value = "60")] + poll_interval: i32, + }, } /// Convert seconds to Slurm walltime format (HH:MM:SS or D-HH:MM:SS) @@ -1080,6 +1117,27 @@ pub fn handle_slurm_commands(config: &Configuration, command: &SlurmCommands, fo format, ); } + SlurmCommands::Regenerate { + workflow_id, + account, + profile: profile_name, + single_allocation, + submit, + output_dir, + poll_interval, + } => { + handle_regenerate( + config, + *workflow_id, + account.as_deref(), + profile_name.as_deref(), + *single_allocation, + *submit, + output_dir, + *poll_interval, + format, + ); + } } } @@ -2442,3 +2500,469 @@ fn handle_generate( } } } + +/// Result of regenerating schedulers for an existing workflow +#[derive(Debug, Serialize, Deserialize)] +pub struct RegenerateResult { + pub workflow_id: i64, + pub pending_jobs: usize, + pub schedulers_created: Vec, + pub total_allocations: i64, + pub warnings: Vec, + pub submitted: bool, +} + +/// Information about a created scheduler +#[derive(Debug, Serialize, Deserialize)] +pub struct SchedulerInfo { + pub id: i64, + pub name: String, + pub account: String, + pub partition: Option, + pub walltime: String, + pub nodes: i64, + pub num_allocations: i64, + pub job_count: usize, +} + +/// Handle the regenerate command - regenerates Slurm schedulers for pending jobs +fn handle_regenerate( + config: &Configuration, + workflow_id: i64, + account: Option<&str>, + profile_name: Option<&str>, + single_allocation: bool, + submit: bool, + output_dir: &PathBuf, + poll_interval: i32, + format: &str, +) { + // Load HPC config and registry + let torc_config = TorcConfig::load().unwrap_or_default(); + let registry = create_registry_with_config_public(&torc_config.client.hpc); + + // Get the HPC profile + let profile = if let Some(n) = profile_name { + registry.get(n) + } else { + registry.detect() + }; + + let profile = match profile { + Some(p) => p, + None => { + if profile_name.is_some() { + eprintln!("Unknown HPC profile: {}", profile_name.unwrap()); + } else { + eprintln!("No HPC profile specified and no system detected."); + eprintln!("Use --profile or run on an HPC system."); + } + std::process::exit(1); + } + }; + + // Fetch pending jobs (uninitialized, ready, blocked) + let pending_statuses = [ + models::JobStatus::Uninitialized, + models::JobStatus::Ready, + models::JobStatus::Blocked, + ]; + let mut pending_jobs: Vec = Vec::new(); + + for status in &pending_statuses { + match default_api::list_jobs( + config, + workflow_id, + Some(status.clone()), + None, // needs_file_id + None, // upstream_job_id + Some(0), + Some(10000), + None, + None, + None, + ) { + Ok(response) => { + pending_jobs.extend(response.items.unwrap_or_default()); + } + Err(e) => { + print_error(&format!("listing {:?} jobs", status), &e); + std::process::exit(1); + } + } + } + + if pending_jobs.is_empty() { + if format == "json" { + println!( + "{}", + serde_json::to_string_pretty(&RegenerateResult { + workflow_id, + pending_jobs: 0, + schedulers_created: Vec::new(), + total_allocations: 0, + warnings: vec!["No pending jobs found".to_string()], + submitted: false, + }) + .unwrap() + ); + } else { + println!( + "No pending jobs (uninitialized, ready, or blocked) found in workflow {}", + workflow_id + ); + } + return; + } + + // Fetch all resource requirements for the workflow + let resource_requirements = match default_api::list_resource_requirements( + config, + workflow_id, + None, // job_id + Some(0), + Some(10000), + None, // sort_by + None, // reverse_sort + None, // name + None, // memory + None, // num_cpus + None, // num_gpus + None, // num_nodes + None, // runtime + ) { + Ok(response) => response.items.unwrap_or_default(), + Err(e) => { + print_error("listing resource requirements", &e); + std::process::exit(1); + } + }; + + // Build a map of resource requirement ID -> model + let rr_map: HashMap = resource_requirements + .iter() + .filter_map(|rr| rr.id.map(|id| (id, rr))) + .collect(); + + // Get existing schedulers to use as defaults + let existing_schedulers = match default_api::list_slurm_schedulers( + config, + workflow_id, + Some(0), + Some(100), + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) { + Ok(response) => response.items.unwrap_or_default(), + Err(e) => { + print_error("listing existing schedulers", &e); + std::process::exit(1); + } + }; + + // Determine account to use + let account_to_use = account + .map(|s| s.to_string()) + .or_else(|| existing_schedulers.first().map(|s| s.account.clone())) + .unwrap_or_else(|| { + eprintln!("No account specified and no existing schedulers found."); + eprintln!("Use --account to specify a Slurm account."); + std::process::exit(1); + }); + + // Group jobs by resource requirements + let mut jobs_by_rr: HashMap> = HashMap::new(); + let mut warnings: Vec = Vec::new(); + + for job in &pending_jobs { + if let Some(rr_id) = job.resource_requirements_id { + jobs_by_rr.entry(rr_id).or_default().push(job); + } else { + warnings.push(format!( + "Job '{}' (ID: {}) has no resource requirements, skipping", + job.name, + job.id.unwrap_or(-1) + )); + } + } + + if jobs_by_rr.is_empty() { + if format == "json" { + println!( + "{}", + serde_json::to_string_pretty(&RegenerateResult { + workflow_id, + pending_jobs: pending_jobs.len(), + schedulers_created: Vec::new(), + total_allocations: 0, + warnings, + submitted: false, + }) + .unwrap() + ); + } else { + println!("No pending jobs with resource requirements found"); + for warning in &warnings { + println!(" Warning: {}", warning); + } + } + return; + } + + // Generate schedulers for each resource requirement group + let mut schedulers_created: Vec = Vec::new(); + let mut total_allocations: i64 = 0; + let timestamp = Utc::now().format("%Y%m%d_%H%M%S"); + + for (rr_id, jobs) in &jobs_by_rr { + let rr = match rr_map.get(rr_id) { + Some(rr) => *rr, + None => { + warnings.push(format!( + "Resource requirements ID {} not found, skipping {} job(s)", + rr_id, + jobs.len() + )); + continue; + } + }; + + // Parse resource requirements + let memory_mb = match parse_memory_mb(&rr.memory) { + Ok(m) => m, + Err(e) => { + warnings.push(format!( + "Failed to parse memory '{}' for RR {}: {}", + rr.memory, rr_id, e + )); + continue; + } + }; + + let runtime_secs = match duration_string_to_seconds(&rr.runtime) { + Ok(s) => s as u64, + Err(e) => { + warnings.push(format!( + "Failed to parse runtime '{}' for RR {}: {}", + rr.runtime, rr_id, e + )); + continue; + } + }; + + let gpus = if rr.num_gpus > 0 { + Some(rr.num_gpus as u32) + } else { + None + }; + + // Find best partition + let partition = match profile.find_best_partition( + rr.num_cpus as u32, + memory_mb, + runtime_secs, + gpus, + ) { + Some(p) => p, + None => { + warnings.push(format!( + "No partition found for resource requirements '{}' (CPUs: {}, Memory: {}, Runtime: {}, GPUs: {:?})", + rr.name, rr.num_cpus, rr.memory, rr.runtime, gpus + )); + continue; + } + }; + + // Calculate jobs per node and total nodes needed + let jobs_per_node_by_cpu = partition.cpus_per_node / rr.num_cpus as u32; + let jobs_per_node_by_mem = (partition.memory_mb / memory_mb) as u32; + let jobs_per_node_by_gpu = match (gpus, partition.gpus_per_node) { + (Some(job_gpus), Some(node_gpus)) if job_gpus > 0 => node_gpus / job_gpus, + _ => u32::MAX, + }; + let jobs_per_node = std::cmp::max( + 1, + std::cmp::min( + jobs_per_node_by_cpu, + std::cmp::min(jobs_per_node_by_mem, jobs_per_node_by_gpu), + ), + ); + + let nodes_per_job = rr.num_nodes as u32; + let total_nodes_needed = + ((jobs.len() as u32 + jobs_per_node - 1) / jobs_per_node) * nodes_per_job; + let total_nodes_needed = std::cmp::max(1, total_nodes_needed) as i64; + + // Allocation strategy + let (nodes_per_alloc, num_allocations) = if single_allocation { + (total_nodes_needed, 1i64) + } else { + (1i64, total_nodes_needed) + }; + + // Create scheduler name with timestamp to avoid conflicts + let scheduler_name = format!("{}_regen_{}", rr.name, timestamp); + + // Create the scheduler in the database + let scheduler = models::SlurmSchedulerModel { + id: None, + workflow_id, + name: Some(scheduler_name.clone()), + account: account_to_use.clone(), + partition: if partition.requires_explicit_request { + Some(partition.name.clone()) + } else { + None + }, + mem: Some(rr.memory.clone()), + walltime: secs_to_walltime(partition.max_walltime_secs), + nodes: nodes_per_alloc, + gres: gpus.map(|g| format!("gpu:{}", g)), + ntasks_per_node: None, + qos: partition.default_qos.clone(), + tmp: None, + extra: None, + }; + + let created_scheduler = match default_api::create_slurm_scheduler(config, scheduler) { + Ok(s) => s, + Err(e) => { + print_error("creating scheduler", &e); + std::process::exit(1); + } + }; + + let scheduler_id = created_scheduler.id.unwrap_or(-1); + + schedulers_created.push(SchedulerInfo { + id: scheduler_id, + name: scheduler_name.clone(), + account: account_to_use.clone(), + partition: created_scheduler.partition.clone(), + walltime: created_scheduler.walltime.clone(), + nodes: nodes_per_alloc, + num_allocations, + job_count: jobs.len(), + }); + + total_allocations += num_allocations; + + // Update jobs to reference this scheduler + for job in jobs { + if let Some(job_id) = job.id { + let mut updated_job = (*job).clone(); + updated_job.scheduler_id = Some(scheduler_id); + if let Err(e) = default_api::update_job(config, job_id, updated_job) { + warnings.push(format!( + "Failed to update job {} with scheduler: {}", + job_id, e + )); + } + } + } + } + + // Submit allocations if requested + let mut submitted = false; + if submit && !schedulers_created.is_empty() { + // Create output directory + if let Err(e) = std::fs::create_dir_all(output_dir) { + eprintln!("Error creating output directory: {}", e); + std::process::exit(1); + } + + for scheduler_info in &schedulers_created { + let start_one_worker_per_node = scheduler_info.nodes > 1; + + match schedule_slurm_nodes( + config, + workflow_id, + scheduler_info.id, + scheduler_info.num_allocations as i32, + "worker", + output_dir.to_str().unwrap_or("output"), + poll_interval, + None, // max_parallel_jobs + start_one_worker_per_node, + false, // keep_submission_scripts + ) { + Ok(()) => { + info!( + "Submitted {} allocation(s) for scheduler '{}'", + scheduler_info.num_allocations, scheduler_info.name + ); + } + Err(e) => { + eprintln!( + "Error submitting allocations for scheduler '{}': {}", + scheduler_info.name, e + ); + std::process::exit(1); + } + } + } + submitted = true; + } + + // Output results + let result = RegenerateResult { + workflow_id, + pending_jobs: pending_jobs.len(), + schedulers_created, + total_allocations, + warnings, + submitted, + }; + + if format == "json" { + println!("{}", serde_json::to_string_pretty(&result).unwrap()); + } else { + println!("Regenerated Slurm schedulers for workflow {}", workflow_id); + println!(); + println!("Summary:"); + println!(" Pending jobs: {}", result.pending_jobs); + println!(" Schedulers created: {}", result.schedulers_created.len()); + println!(" Total allocations: {}", result.total_allocations); + println!( + " Profile used: {} ({})", + profile.display_name, profile.name + ); + + if !result.schedulers_created.is_empty() { + println!(); + println!("Schedulers:"); + for sched in &result.schedulers_created { + println!( + " - {} (ID: {}): {} job(s), {} allocation(s) × {} node(s)", + sched.name, sched.id, sched.job_count, sched.num_allocations, sched.nodes + ); + } + } + + if !result.warnings.is_empty() { + println!(); + println!("Warnings:"); + for warning in &result.warnings { + println!(" - {}", warning); + } + } + + if result.submitted { + println!(); + println!("Allocations submitted successfully."); + } else if !result.schedulers_created.is_empty() { + println!(); + println!("To submit the allocations, run:"); + println!(" torc slurm regenerate {} --submit", workflow_id); + } + } +} diff --git a/src/client/commands/watch.rs b/src/client/commands/watch.rs new file mode 100644 index 00000000..ac6fb105 --- /dev/null +++ b/src/client/commands/watch.rs @@ -0,0 +1,615 @@ +//! Watch command for monitoring workflows with automatic failure recovery + +use std::collections::HashMap; +use std::path::PathBuf; +use std::process::Command; +use std::time::Duration; + +use crate::client::apis::configuration::Configuration; +use crate::client::apis::default_api; +use crate::time_utils::duration_string_to_seconds; + +/// Arguments for the watch command +pub struct WatchArgs { + pub workflow_id: i64, + pub poll_interval: u64, + pub auto_recover: bool, + pub max_retries: u32, + pub memory_multiplier: f64, + pub runtime_multiplier: f64, + pub output_dir: PathBuf, + pub show_job_counts: bool, +} + +/// Parse memory string (e.g., "8g", "512m", "1024k") to bytes +pub fn parse_memory_bytes(mem: &str) -> Option { + let mem = mem.trim().to_lowercase(); + let (num_str, multiplier) = if mem.ends_with("gb") { + (mem.trim_end_matches("gb"), 1024u64 * 1024 * 1024) + } else if mem.ends_with("g") { + (mem.trim_end_matches("g"), 1024u64 * 1024 * 1024) + } else if mem.ends_with("mb") { + (mem.trim_end_matches("mb"), 1024u64 * 1024) + } else if mem.ends_with("m") { + (mem.trim_end_matches("m"), 1024u64 * 1024) + } else if mem.ends_with("kb") { + (mem.trim_end_matches("kb"), 1024u64) + } else if mem.ends_with("k") { + (mem.trim_end_matches("k"), 1024u64) + } else { + (mem.as_str(), 1u64) + }; + num_str + .parse::() + .ok() + .map(|n| (n * multiplier as f64) as u64) +} + +/// Format bytes to memory string (e.g., "12g", "512m") +pub fn format_memory_bytes_short(bytes: u64) -> String { + if bytes >= 1024 * 1024 * 1024 { + format!("{}g", bytes / (1024 * 1024 * 1024)) + } else if bytes >= 1024 * 1024 { + format!("{}m", bytes / (1024 * 1024)) + } else if bytes >= 1024 { + format!("{}k", bytes / 1024) + } else { + format!("{}b", bytes) + } +} + +/// Format seconds to ISO8601 duration (e.g., "PT2H30M") +pub fn format_duration_iso8601(secs: u64) -> String { + let hours = secs / 3600; + let mins = (secs % 3600) / 60; + if hours > 0 && mins > 0 { + format!("PT{}H{}M", hours, mins) + } else if hours > 0 { + format!("PT{}H", hours) + } else { + format!("PT{}M", mins.max(1)) + } +} + +/// Get job counts by status for a workflow +fn get_job_counts( + config: &Configuration, + workflow_id: i64, +) -> Result, String> { + let jobs_response = default_api::list_jobs( + config, + workflow_id, + None, // status filter + None, // needs_file_id + None, // upstream_job_id + None, // offset + Some(10000), // limit + None, // sort_by + None, // reverse_sort + None, // include_relationships + ) + .map_err(|e| format!("Failed to list jobs: {}", e))?; + + let jobs = jobs_response.items.unwrap_or_default(); + let mut counts = HashMap::new(); + + for job in &jobs { + if let Some(status) = &job.status { + let status_str = format!("{:?}", status); + *counts.entry(status_str).or_insert(0) += 1; + } + } + + Ok(counts) +} + +/// Poll until workflow is complete, optionally printing status updates +fn poll_until_complete( + config: &Configuration, + workflow_id: i64, + poll_interval: u64, + show_job_counts: bool, +) -> Result, String> { + loop { + // Check if workflow is complete + match default_api::is_workflow_complete(config, workflow_id) { + Ok(response) => { + if response.is_complete { + eprintln!("Workflow {} is complete", workflow_id); + break; + } + } + Err(e) => { + return Err(format!("Error checking workflow status: {}", e)); + } + } + + // Print current status if requested + if show_job_counts { + match get_job_counts(config, workflow_id) { + Ok(counts) => { + let completed = counts.get("Completed").unwrap_or(&0); + let running = counts.get("Running").unwrap_or(&0); + let pending = counts.get("Pending").unwrap_or(&0); + let failed = counts.get("Failed").unwrap_or(&0); + let blocked = counts.get("Blocked").unwrap_or(&0); + eprintln!( + " completed={}, running={}, pending={}, failed={}, blocked={}", + completed, running, pending, failed, blocked + ); + } + Err(e) => { + eprintln!("Error getting job counts: {}", e); + } + } + } + + std::thread::sleep(Duration::from_secs(poll_interval)); + } + + get_job_counts(config, workflow_id) +} + +/// Diagnose failures and return job IDs that need resource adjustments +fn diagnose_failures(workflow_id: i64, output_dir: &PathBuf) -> Result { + // Run check-resource-utilization command + let output = Command::new("torc") + .args([ + "-f", + "json", + "reports", + "check-resource-utilization", + &workflow_id.to_string(), + "--include-failed", + "-o", + output_dir.to_str().unwrap_or("output"), + ]) + .output() + .map_err(|e| format!("Failed to run check-resource-utilization: {}", e))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(format!("check-resource-utilization failed: {}", stderr)); + } + + let stdout = String::from_utf8_lossy(&output.stdout); + serde_json::from_str(&stdout) + .map_err(|e| format!("Failed to parse resource utilization output: {}", e)) +} + +/// Get Slurm log information for failed jobs +fn get_slurm_log_info(workflow_id: i64, output_dir: &PathBuf) -> Result { + // Run reports results command to get log paths + let output = Command::new("torc") + .args([ + "-f", + "json", + "reports", + "results", + &workflow_id.to_string(), + "-o", + output_dir.to_str().unwrap_or("output"), + ]) + .output() + .map_err(|e| format!("Failed to run reports results: {}", e))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(format!("reports results failed: {}", stderr)); + } + + let stdout = String::from_utf8_lossy(&output.stdout); + serde_json::from_str(&stdout) + .map_err(|e| format!("Failed to parse reports results output: {}", e)) +} + +/// Correlate failed jobs with their Slurm allocation logs +fn correlate_slurm_logs( + diagnosis: &serde_json::Value, + slurm_info: &serde_json::Value, +) -> HashMap { + let mut log_map = HashMap::new(); + + // Build map from job_id to slurm log paths + if let Some(jobs) = slurm_info.get("jobs").and_then(|v| v.as_array()) { + for job in jobs { + if let Some(job_id) = job.get("job_id").and_then(|v| v.as_i64()) { + let slurm_stdout = job + .get("slurm_stdout") + .and_then(|v| v.as_str()) + .map(String::from); + let slurm_stderr = job + .get("slurm_stderr") + .and_then(|v| v.as_str()) + .map(String::from); + let slurm_job_id = job + .get("slurm_job_id") + .and_then(|v| v.as_str()) + .map(String::from); + + if slurm_stdout.is_some() || slurm_stderr.is_some() { + log_map.insert( + job_id, + SlurmLogInfo { + slurm_job_id, + slurm_stdout, + slurm_stderr, + }, + ); + } + } + } + } + + // Filter to only failed jobs + let mut failed_log_map = HashMap::new(); + if let Some(failed_jobs) = diagnosis.get("failed_jobs").and_then(|v| v.as_array()) { + for job_info in failed_jobs { + if let Some(job_id) = job_info.get("job_id").and_then(|v| v.as_i64()) { + if let Some(log_info) = log_map.remove(&job_id) { + failed_log_map.insert(job_id, log_info); + } + } + } + } + + failed_log_map +} + +/// Information about Slurm logs for a job +#[derive(Debug)] +pub struct SlurmLogInfo { + pub slurm_job_id: Option, + pub slurm_stdout: Option, + pub slurm_stderr: Option, +} + +/// Apply recovery heuristics and update job resources +fn apply_recovery_heuristics( + config: &Configuration, + workflow_id: i64, + diagnosis: &serde_json::Value, + memory_multiplier: f64, + runtime_multiplier: f64, + output_dir: &PathBuf, +) -> Result<(usize, usize, usize), String> { + let mut oom_fixed = 0; + let mut timeout_fixed = 0; + let mut other_failures = 0; + + // Try to get Slurm log info for correlation + let slurm_log_map = match get_slurm_log_info(workflow_id, output_dir) { + Ok(slurm_info) => { + let log_map = correlate_slurm_logs(diagnosis, &slurm_info); + if !log_map.is_empty() { + eprintln!(" Found Slurm logs for {} failed job(s)", log_map.len()); + } + log_map + } + Err(e) => { + log::debug!("Could not get Slurm log info: {}", e); + HashMap::new() + } + }; + + // Get failed jobs info from diagnosis + let failed_jobs = diagnosis + .get("failed_jobs") + .and_then(|v| v.as_array()) + .cloned() + .unwrap_or_default(); + + for job_info in &failed_jobs { + let job_id = job_info.get("job_id").and_then(|v| v.as_i64()).unwrap_or(0); + let likely_oom = job_info + .get("likely_oom") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + let likely_timeout = job_info + .get("likely_timeout") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + if job_id == 0 { + continue; + } + + // Log Slurm info if available + if let Some(slurm_info) = slurm_log_map.get(&job_id) { + if let Some(slurm_job_id) = &slurm_info.slurm_job_id { + log::debug!(" Job {} ran in Slurm allocation {}", job_id, slurm_job_id); + } + } + + // Get current job to find resource requirements + let job = match default_api::get_job(config, job_id) { + Ok(j) => j, + Err(e) => { + eprintln!(" Warning: couldn't get job {}: {}", job_id, e); + continue; + } + }; + + let rr_id = match job.resource_requirements_id { + Some(id) => id, + None => { + eprintln!(" Warning: job {} has no resource requirements", job_id); + other_failures += 1; + continue; + } + }; + + // Get current resource requirements + let rr = match default_api::get_resource_requirements(config, rr_id) { + Ok(r) => r, + Err(e) => { + eprintln!( + " Warning: couldn't get resource requirements {}: {}", + rr_id, e + ); + continue; + } + }; + + let mut updated = false; + let mut new_rr = rr.clone(); + + // Apply OOM heuristic + if likely_oom { + if let Some(current_bytes) = parse_memory_bytes(&rr.memory) { + let new_bytes = (current_bytes as f64 * memory_multiplier) as u64; + let new_memory = format_memory_bytes_short(new_bytes); + eprintln!( + " Job {} ({}): OOM detected, increasing memory {} -> {}", + job_id, job.name, rr.memory, new_memory + ); + new_rr.memory = new_memory; + updated = true; + oom_fixed += 1; + } + } + + // Apply timeout heuristic + if likely_timeout { + // Use duration_string_to_seconds from time_utils + if let Ok(current_secs) = duration_string_to_seconds(&rr.runtime) { + let new_secs = (current_secs as f64 * runtime_multiplier) as u64; + let new_runtime = format_duration_iso8601(new_secs); + eprintln!( + " Job {} ({}): Timeout detected, increasing runtime {} -> {}", + job_id, job.name, rr.runtime, new_runtime + ); + new_rr.runtime = new_runtime; + updated = true; + timeout_fixed += 1; + } + } + + // Update resource requirements if changed + if updated { + if let Err(e) = default_api::update_resource_requirements(config, rr_id, new_rr) { + eprintln!( + " Warning: failed to update resource requirements for job {}: {}", + job_id, e + ); + } + } else if !likely_oom && !likely_timeout { + // Unknown failure - will retry without changes + other_failures += 1; + } + } + + Ok((oom_fixed, timeout_fixed, other_failures)) +} + +/// Reset failed jobs and restart workflow +fn reset_failed_jobs(workflow_id: i64) -> Result<(), String> { + let output = Command::new("torc") + .args([ + "workflows", + "reset-status", + &workflow_id.to_string(), + "--failed-only", + "--restart", + "--no-prompts", + ]) + .output() + .map_err(|e| format!("Failed to run reset-status: {}", e))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(format!("reset-status failed: {}", stderr)); + } + + Ok(()) +} + +/// Regenerate Slurm schedulers and submit allocations +fn regenerate_and_submit(workflow_id: i64, output_dir: &PathBuf) -> Result<(), String> { + let output = Command::new("torc") + .args([ + "slurm", + "regenerate", + &workflow_id.to_string(), + "--submit", + "-o", + output_dir.to_str().unwrap_or("output"), + ]) + .output() + .map_err(|e| format!("Failed to run slurm regenerate: {}", e))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(format!("slurm regenerate failed: {}", stderr)); + } + + Ok(()) +} + +/// Run the watch command +pub fn run_watch(config: &Configuration, args: &WatchArgs) { + let mut retry_count = 0u32; + + eprintln!( + "Watching workflow {} (poll interval: {}s{}{})", + args.workflow_id, + args.poll_interval, + if args.auto_recover { + format!(", auto-recover enabled, max retries: {}", args.max_retries) + } else { + String::new() + }, + if args.show_job_counts { + ", job counts enabled" + } else { + "" + } + ); + + if !args.show_job_counts { + eprintln!(" (use --show-job-counts to display per-status counts during polling)"); + } + + loop { + // Poll until workflow is complete + let counts = match poll_until_complete( + config, + args.workflow_id, + args.poll_interval, + args.show_job_counts, + ) { + Ok(c) => c, + Err(e) => { + eprintln!("Error: {}", e); + std::process::exit(1); + } + }; + + let completed = *counts.get("Completed").unwrap_or(&0); + let failed = *counts.get("Failed").unwrap_or(&0); + let canceled = *counts.get("Canceled").unwrap_or(&0); + let terminated = *counts.get("Terminated").unwrap_or(&0); + + let needs_recovery = failed > 0 || canceled > 0 || terminated > 0; + + if !needs_recovery { + eprintln!("\n✓ Workflow completed successfully ({} jobs)", completed); + break; + } + + eprintln!("\nWorkflow completed with failures:"); + eprintln!(" - Failed: {}", failed); + eprintln!(" - Canceled: {}", canceled); + eprintln!(" - Terminated: {}", terminated); + eprintln!(" - Completed: {}", completed); + + // Check if we should attempt recovery + if !args.auto_recover { + eprintln!("\nAuto-recovery disabled. To enable, use --auto-recover flag."); + eprintln!("Or use the Torc MCP server with your AI assistant for manual recovery."); + std::process::exit(1); + } + + if retry_count >= args.max_retries { + eprintln!( + "\nMax retries ({}) exceeded. Manual intervention required.", + args.max_retries + ); + eprintln!("Use the Torc MCP server with your AI assistant to investigate."); + std::process::exit(1); + } + + retry_count += 1; + eprintln!( + "\nAttempting automatic recovery (attempt {}/{})", + retry_count, args.max_retries + ); + + // Step 1: Diagnose failures + eprintln!("\nDiagnosing failures..."); + let diagnosis = match diagnose_failures(args.workflow_id, &args.output_dir) { + Ok(d) => d, + Err(e) => { + eprintln!("Warning: Could not diagnose failures: {}", e); + eprintln!("Attempting retry without resource adjustments..."); + serde_json::json!({"failed_jobs": []}) + } + }; + + // Step 2: Apply heuristics to adjust resources + eprintln!("\nApplying recovery heuristics..."); + match apply_recovery_heuristics( + config, + args.workflow_id, + &diagnosis, + args.memory_multiplier, + args.runtime_multiplier, + &args.output_dir, + ) { + Ok((oom, timeout, other)) => { + if oom > 0 || timeout > 0 { + eprintln!(" Applied fixes: {} OOM, {} timeout", oom, timeout); + } + if other > 0 { + eprintln!(" {} job(s) with unknown failure cause (will retry)", other); + } + } + Err(e) => { + eprintln!("Warning: Error applying heuristics: {}", e); + } + } + + // Step 3: Reset failed jobs + eprintln!("\nResetting failed jobs..."); + if let Err(e) = reset_failed_jobs(args.workflow_id) { + eprintln!("Error resetting jobs: {}", e); + std::process::exit(1); + } + + // Step 4: Regenerate Slurm schedulers and submit + eprintln!("Regenerating Slurm schedulers and submitting..."); + if let Err(e) = regenerate_and_submit(args.workflow_id, &args.output_dir) { + eprintln!("Error regenerating schedulers: {}", e); + std::process::exit(1); + } + + eprintln!("\nRecovery initiated. Resuming monitoring...\n"); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_memory_bytes() { + assert_eq!(parse_memory_bytes("1g"), Some(1024 * 1024 * 1024)); + assert_eq!(parse_memory_bytes("2gb"), Some(2 * 1024 * 1024 * 1024)); + assert_eq!(parse_memory_bytes("512m"), Some(512 * 1024 * 1024)); + assert_eq!(parse_memory_bytes("512mb"), Some(512 * 1024 * 1024)); + assert_eq!(parse_memory_bytes("1024k"), Some(1024 * 1024)); + assert_eq!(parse_memory_bytes("1024kb"), Some(1024 * 1024)); + assert_eq!(parse_memory_bytes("1024"), Some(1024)); + assert_eq!(parse_memory_bytes("invalid"), None); + } + + #[test] + fn test_format_memory_bytes_short() { + assert_eq!(format_memory_bytes_short(1024 * 1024 * 1024), "1g"); + assert_eq!(format_memory_bytes_short(2 * 1024 * 1024 * 1024), "2g"); + assert_eq!(format_memory_bytes_short(512 * 1024 * 1024), "512m"); + assert_eq!(format_memory_bytes_short(1024 * 1024), "1m"); + assert_eq!(format_memory_bytes_short(1024), "1k"); + assert_eq!(format_memory_bytes_short(512), "512b"); + } + + #[test] + fn test_format_duration_iso8601() { + assert_eq!(format_duration_iso8601(3600), "PT1H"); + assert_eq!(format_duration_iso8601(7200), "PT2H"); + assert_eq!(format_duration_iso8601(5400), "PT1H30M"); + assert_eq!(format_duration_iso8601(1800), "PT30M"); + assert_eq!(format_duration_iso8601(60), "PT1M"); + assert_eq!(format_duration_iso8601(30), "PT1M"); // rounds up to minimum 1 minute + } +} diff --git a/src/client/watch/audit.rs b/src/client/watch/audit.rs new file mode 100644 index 00000000..855c1d2a --- /dev/null +++ b/src/client/watch/audit.rs @@ -0,0 +1,150 @@ +//! Audit logging for watch command actions. + +use std::fs::{File, OpenOptions}; +use std::io::{BufWriter, Write}; +use std::path::Path; + +use chrono::Utc; +use log::warn; +use serde::Serialize; + +use super::claude_client::Diagnosis; +use super::recovery::RecoveryAction; + +/// Audit logger for recording all watch command actions. +pub struct AuditLogger { + writer: BufWriter, +} + +impl AuditLogger { + /// Create a new audit logger writing to the specified file. + pub fn new(path: &Path) -> Result { + // Create parent directories if needed + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent) + .map_err(|e| format!("Failed to create audit log directory: {}", e))?; + } + + let file = OpenOptions::new() + .create(true) + .append(true) + .open(path) + .map_err(|e| format!("Failed to open audit log: {}", e))?; + + Ok(Self { + writer: BufWriter::new(file), + }) + } + + /// Log a diagnosis event. + pub fn log_diagnosis(&mut self, job_id: i64, job_name: &str, diagnosis: &Diagnosis) { + let entry = AuditEntry { + timestamp: Utc::now().to_rfc3339(), + event_type: "diagnosis".to_string(), + job_id, + job_name: job_name.to_string(), + summary: Some(diagnosis.summary.clone()), + root_cause: diagnosis.root_cause.clone(), + recommended_action: diagnosis.recommended_action.clone(), + confidence: Some(diagnosis.confidence), + action_taken: None, + action_success: None, + notes: diagnosis.notes.clone(), + }; + + self.write_entry(&entry); + } + + /// Log a recovery action event. + pub fn log_recovery( + &mut self, + job_id: i64, + job_name: &str, + action: &RecoveryAction, + success: bool, + ) { + let entry = AuditEntry { + timestamp: Utc::now().to_rfc3339(), + event_type: "recovery".to_string(), + job_id, + job_name: job_name.to_string(), + summary: None, + root_cause: None, + recommended_action: None, + confidence: None, + action_taken: Some(action.clone()), + action_success: Some(success), + notes: None, + }; + + self.write_entry(&entry); + } + + /// Log an arbitrary event. + #[allow(dead_code)] + pub fn log_event( + &mut self, + event_type: &str, + job_id: i64, + job_name: &str, + notes: Option, + ) { + let entry = AuditEntry { + timestamp: Utc::now().to_rfc3339(), + event_type: event_type.to_string(), + job_id, + job_name: job_name.to_string(), + summary: None, + root_cause: None, + recommended_action: None, + confidence: None, + action_taken: None, + action_success: None, + notes, + }; + + self.write_entry(&entry); + } + + fn write_entry(&mut self, entry: &AuditEntry) { + let json = match serde_json::to_string(entry) { + Ok(j) => j, + Err(e) => { + warn!("Failed to serialize audit entry: {}", e); + return; + } + }; + + if let Err(e) = writeln!(self.writer, "{}", json) { + warn!("Failed to write audit entry: {}", e); + } + + // Flush after each entry to ensure logs are written + if let Err(e) = self.writer.flush() { + warn!("Failed to flush audit log: {}", e); + } + } +} + +/// A single audit log entry (JSON lines format). +#[derive(Debug, Serialize)] +struct AuditEntry { + timestamp: String, + event_type: String, + job_id: i64, + job_name: String, + #[serde(skip_serializing_if = "Option::is_none")] + summary: Option, + #[serde(skip_serializing_if = "Option::is_none")] + root_cause: Option, + #[serde(skip_serializing_if = "Option::is_none")] + recommended_action: Option, + #[serde(skip_serializing_if = "Option::is_none")] + confidence: Option, + #[serde(skip_serializing_if = "Option::is_none")] + action_taken: Option, + #[serde(skip_serializing_if = "Option::is_none")] + action_success: Option, + #[serde(skip_serializing_if = "Option::is_none")] + notes: Option, +} diff --git a/src/client/watch/claude_client.rs b/src/client/watch/claude_client.rs new file mode 100644 index 00000000..6eb50357 --- /dev/null +++ b/src/client/watch/claude_client.rs @@ -0,0 +1,320 @@ +//! Claude API client for failure diagnosis. + +use log::debug; +use serde::{Deserialize, Serialize}; + +use crate::client::apis::configuration::Configuration; +use crate::models::JobModel; + +use super::recovery::RecoveryAction; + +const CLAUDE_API_URL: &str = "https://api.anthropic.com/v1/messages"; +const ANTHROPIC_VERSION: &str = "2023-06-01"; + +/// Diagnosis result from Claude. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Diagnosis { + /// Summary of the failure + pub summary: String, + /// Root cause analysis + pub root_cause: Option, + /// Recommended recovery action + pub recommended_action: Option, + /// Confidence level (0.0 - 1.0) + pub confidence: f64, + /// Additional notes or suggestions + pub notes: Option, +} + +/// Claude API client for diagnosing job failures. +pub struct ClaudeClient { + api_key: String, + model: String, + client: reqwest::blocking::Client, +} + +impl ClaudeClient { + /// Create a new Claude client. + pub fn new(api_key: String, model: String) -> Self { + Self { + api_key, + model, + client: reqwest::blocking::Client::new(), + } + } + + /// Diagnose a job failure using Claude. + pub fn diagnose_failure( + &self, + config: &Configuration, + workflow_id: i64, + job: &JobModel, + stdout: &str, + stderr: &str, + ) -> Result { + let job_id = job.id.ok_or("Job has no ID")?; + let job_name = job.name.clone(); + let command = job.command.clone(); + + // Build the prompt + let prompt = + self.build_diagnosis_prompt(workflow_id, job_id, &job_name, &command, stdout, stderr); + + // Get tools definition + let tools = self.get_tools_definition(); + + // Make API request + let request_body = serde_json::json!({ + "model": self.model, + "max_tokens": 4096, + "tools": tools, + "messages": [ + { + "role": "user", + "content": prompt + } + ], + "system": self.get_system_prompt() + }); + + debug!("Sending request to Claude API"); + let response = self + .client + .post(CLAUDE_API_URL) + .header("Content-Type", "application/json") + .header("x-api-key", &self.api_key) + .header("anthropic-version", ANTHROPIC_VERSION) + .json(&request_body) + .send() + .map_err(|e| format!("Failed to send request to Claude API: {}", e))?; + + if !response.status().is_success() { + let status = response.status(); + let body = response + .text() + .unwrap_or_else(|_| "unknown error".to_string()); + return Err(format!("Claude API error ({}): {}", status, body)); + } + + let response_body: ClaudeResponse = response + .json() + .map_err(|e| format!("Failed to parse Claude API response: {}", e))?; + + // Parse the response + self.parse_response(&response_body, config, workflow_id) + } + + fn get_system_prompt(&self) -> &'static str { + r#"You are an expert HPC workflow failure diagnostician. Your job is to analyze job failures from workflow orchestration systems and recommend recovery actions. + +When analyzing a failure: +1. Look for common error patterns (OOM, timeout, missing files, permission errors, CUDA errors, etc.) +2. Consider the job's resource requirements vs actual usage +3. Recommend specific, actionable recovery steps + +Available recovery actions: +- restart: Restart the job with no changes (for transient failures) +- restart_with_resources: Restart with modified resource requirements (memory, CPUs, runtime) +- cancel: Cancel the job and its dependents (for unrecoverable failures) +- skip: Mark as completed and continue (for optional jobs) + +Always provide: +1. A clear summary of what went wrong +2. Root cause analysis when possible +3. A specific recovery action with parameters +4. Confidence level in your diagnosis + +Use the diagnose_failure tool to report your findings."# + } + + fn build_diagnosis_prompt( + &self, + workflow_id: i64, + job_id: i64, + job_name: &str, + command: &str, + stdout: &str, + stderr: &str, + ) -> String { + format!( + r#"Please diagnose the following job failure and recommend a recovery action. + +## Job Information +- Workflow ID: {} +- Job ID: {} +- Job Name: {} +- Command: {} + +## Standard Output (last 10KB) +``` +{} +``` + +## Standard Error (last 10KB) +``` +{} +``` + +Analyze this failure and use the diagnose_failure tool to report your findings."#, + workflow_id, job_id, job_name, command, stdout, stderr + ) + } + + fn get_tools_definition(&self) -> serde_json::Value { + serde_json::json!([ + { + "name": "diagnose_failure", + "description": "Report the diagnosis of a job failure with recommended recovery action", + "input_schema": { + "type": "object", + "properties": { + "summary": { + "type": "string", + "description": "Brief summary of what went wrong (1-2 sentences)" + }, + "root_cause": { + "type": "string", + "description": "Detailed root cause analysis" + }, + "action_type": { + "type": "string", + "enum": ["restart", "restart_with_resources", "cancel", "skip", "none"], + "description": "Type of recovery action to take" + }, + "new_memory": { + "type": "string", + "description": "New memory requirement (e.g., '8g', '16g') for restart_with_resources" + }, + "new_runtime": { + "type": "string", + "description": "New runtime limit (e.g., 'PT2H', 'PT4H') for restart_with_resources" + }, + "new_num_cpus": { + "type": "integer", + "description": "New CPU count for restart_with_resources" + }, + "confidence": { + "type": "number", + "description": "Confidence in the diagnosis (0.0 to 1.0)" + }, + "notes": { + "type": "string", + "description": "Additional notes or suggestions" + } + }, + "required": ["summary", "action_type", "confidence"] + } + } + ]) + } + + fn parse_response( + &self, + response: &ClaudeResponse, + _config: &Configuration, + _workflow_id: i64, + ) -> Result { + // Find tool use in response + for content in &response.content { + if content.content_type == "tool_use" + && content.name.as_deref() == Some("diagnose_failure") + { + let input = content.input.as_ref().ok_or("Tool use has no input")?; + + let summary = input + .get("summary") + .and_then(|v| v.as_str()) + .ok_or("Missing summary in diagnosis")? + .to_string(); + + let root_cause = input + .get("root_cause") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let action_type = input + .get("action_type") + .and_then(|v| v.as_str()) + .unwrap_or("none"); + + let confidence = input + .get("confidence") + .and_then(|v| v.as_f64()) + .unwrap_or(0.5); + + let notes = input + .get("notes") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let recommended_action = match action_type { + "restart" => Some(RecoveryAction::Restart), + "restart_with_resources" => { + let new_memory = input + .get("new_memory") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let new_runtime = input + .get("new_runtime") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let new_num_cpus = input.get("new_num_cpus").and_then(|v| v.as_i64()); + + Some(RecoveryAction::RestartWithResources { + memory: new_memory, + runtime: new_runtime, + num_cpus: new_num_cpus, + }) + } + "cancel" => Some(RecoveryAction::Cancel), + "skip" => Some(RecoveryAction::Skip), + _ => None, + }; + + return Ok(Diagnosis { + summary, + root_cause, + recommended_action, + confidence, + notes, + }); + } + } + + // If no tool use found, try to extract from text + for content in &response.content { + if content.content_type == "text" { + if let Some(text) = &content.text { + return Ok(Diagnosis { + summary: text.chars().take(200).collect(), + root_cause: None, + recommended_action: None, + confidence: 0.3, + notes: Some("Could not parse structured response from Claude".to_string()), + }); + } + } + } + + Err("No diagnosis found in Claude response".to_string()) + } +} + +/// Claude API response structure. +#[derive(Debug, Deserialize)] +struct ClaudeResponse { + content: Vec, + #[allow(dead_code)] + model: String, + #[allow(dead_code)] + stop_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct ContentBlock { + #[serde(rename = "type")] + content_type: String, + text: Option, + name: Option, + input: Option, +} diff --git a/src/client/watch/failure_cache.rs b/src/client/watch/failure_cache.rs new file mode 100644 index 00000000..59342674 --- /dev/null +++ b/src/client/watch/failure_cache.rs @@ -0,0 +1,320 @@ +//! Failure pattern cache using SQLite. +//! +//! Caches failure diagnoses to avoid repeated API calls for similar failures. + +use std::path::Path; + +use log::{debug, warn}; +use rusqlite::{Connection, params}; +use sha2::{Digest, Sha256}; + +use super::claude_client::Diagnosis; + +/// Cache for storing failure patterns and their diagnoses. +pub struct FailureCache { + conn: Connection, +} + +impl FailureCache { + /// Open or create a failure cache database. + pub fn open(path: &Path) -> Result { + let conn = + Connection::open(path).map_err(|e| format!("Failed to open cache database: {}", e))?; + + // Create tables if they don't exist + conn.execute( + r#" + CREATE TABLE IF NOT EXISTS failure_patterns ( + id INTEGER PRIMARY KEY, + job_name_pattern TEXT NOT NULL, + error_signature TEXT NOT NULL, + diagnosis_json TEXT NOT NULL, + success_count INTEGER DEFAULT 0, + failure_count INTEGER DEFAULT 0, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + last_used_at TEXT NOT NULL DEFAULT (datetime('now')), + UNIQUE(job_name_pattern, error_signature) + ) + "#, + [], + ) + .map_err(|e| format!("Failed to create cache table: {}", e))?; + + // Create index for faster lookups + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_failure_patterns_lookup ON failure_patterns(job_name_pattern, error_signature)", + [], + ) + .map_err(|e| format!("Failed to create cache index: {}", e))?; + + Ok(Self { conn }) + } + + /// Compute an error signature from stderr content. + /// + /// This normalizes the error output by: + /// 1. Extracting lines containing error keywords + /// 2. Removing timestamps and PIDs + /// 3. Hashing the result + pub fn compute_signature(stderr: &str) -> String { + let error_keywords = [ + "error", + "Error", + "ERROR", + "exception", + "Exception", + "EXCEPTION", + "failed", + "Failed", + "FAILED", + "oom", + "OOM", + "Out of memory", + "killed", + "Killed", + "KILLED", + "timeout", + "Timeout", + "TIMEOUT", + "cuda", + "CUDA", + "segfault", + "Segmentation fault", + "permission denied", + "Permission denied", + "not found", + "No such file", + ]; + + let mut error_lines: Vec = Vec::new(); + + for line in stderr.lines() { + let line_lower = line.to_lowercase(); + if error_keywords + .iter() + .any(|kw| line_lower.contains(&kw.to_lowercase())) + { + // Normalize the line: remove timestamps, PIDs, paths + let normalized = normalize_error_line(line); + if !normalized.is_empty() { + error_lines.push(normalized); + } + } + } + + // If no error lines found, hash the last 20 lines + if error_lines.is_empty() { + error_lines = stderr + .lines() + .rev() + .take(20) + .map(|l| normalize_error_line(l)) + .filter(|l| !l.is_empty()) + .collect(); + error_lines.reverse(); + } + + // Compute hash + let mut hasher = Sha256::new(); + for line in &error_lines { + hasher.update(line.as_bytes()); + hasher.update(b"\n"); + } + let result = hasher.finalize(); + format!("{:x}", result) + } + + /// Look up a cached diagnosis for a failure pattern. + pub fn lookup( + &self, + job_name: &str, + error_signature: &str, + ) -> Result, String> { + // Extract job name pattern (remove numeric suffixes like _001, _42, etc.) + let job_pattern = extract_job_pattern(job_name); + + let result: Result<(String, i64, i64), _> = self.conn.query_row( + r#" + SELECT diagnosis_json, success_count, failure_count + FROM failure_patterns + WHERE job_name_pattern = ?1 AND error_signature = ?2 + "#, + params![job_pattern, error_signature], + |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)), + ); + + match result { + Ok((json, success_count, failure_count)) => { + // Update last_used_at + let _ = self.conn.execute( + "UPDATE failure_patterns SET last_used_at = datetime('now') WHERE job_name_pattern = ?1 AND error_signature = ?2", + params![job_pattern, error_signature], + ); + + // Only use cache if success rate is reasonable + let total = success_count + failure_count; + if total > 0 && failure_count as f64 / total as f64 > 0.5 { + debug!( + "Cache hit for {} but success rate too low ({}/{}), skipping", + job_name, success_count, total + ); + return Ok(None); + } + + let diagnosis: Diagnosis = serde_json::from_str(&json) + .map_err(|e| format!("Failed to parse cached diagnosis: {}", e))?; + + debug!( + "Cache hit for {} (success: {}, failure: {})", + job_name, success_count, failure_count + ); + Ok(Some(diagnosis)) + } + Err(rusqlite::Error::QueryReturnedNoRows) => { + debug!("Cache miss for {}", job_name); + Ok(None) + } + Err(e) => Err(format!("Cache lookup failed: {}", e)), + } + } + + /// Store a diagnosis in the cache. + pub fn store( + &mut self, + job_name: &str, + error_signature: &str, + diagnosis: &Diagnosis, + ) -> Result<(), String> { + let job_pattern = extract_job_pattern(job_name); + let json = serde_json::to_string(diagnosis) + .map_err(|e| format!("Failed to serialize diagnosis: {}", e))?; + + self.conn + .execute( + r#" + INSERT INTO failure_patterns (job_name_pattern, error_signature, diagnosis_json) + VALUES (?1, ?2, ?3) + ON CONFLICT(job_name_pattern, error_signature) DO UPDATE SET + diagnosis_json = ?3, + last_used_at = datetime('now') + "#, + params![job_pattern, error_signature, json], + ) + .map_err(|e| format!("Failed to store diagnosis: {}", e))?; + + debug!("Cached diagnosis for {}", job_name); + Ok(()) + } + + /// Record a successful recovery using a cached diagnosis. + pub fn record_success(&self, job_name: &str, error_signature: &str) { + let job_pattern = extract_job_pattern(job_name); + if let Err(e) = self.conn.execute( + "UPDATE failure_patterns SET success_count = success_count + 1 WHERE job_name_pattern = ?1 AND error_signature = ?2", + params![job_pattern, error_signature], + ) { + warn!("Failed to record cache success: {}", e); + } + } + + /// Record a failed recovery attempt. + pub fn record_failure(&self, job_name: &str, error_signature: &str) { + let job_pattern = extract_job_pattern(job_name); + if let Err(e) = self.conn.execute( + "UPDATE failure_patterns SET failure_count = failure_count + 1 WHERE job_name_pattern = ?1 AND error_signature = ?2", + params![job_pattern, error_signature], + ) { + warn!("Failed to record cache failure: {}", e); + } + } + + /// Get cache statistics. + #[allow(dead_code)] + pub fn stats(&self) -> Result { + let total_entries: i64 = self + .conn + .query_row("SELECT COUNT(*) FROM failure_patterns", [], |row| { + row.get(0) + }) + .map_err(|e| format!("Failed to get cache stats: {}", e))?; + + let total_successes: i64 = self + .conn + .query_row( + "SELECT COALESCE(SUM(success_count), 0) FROM failure_patterns", + [], + |row| row.get(0), + ) + .map_err(|e| format!("Failed to get success count: {}", e))?; + + let total_failures: i64 = self + .conn + .query_row( + "SELECT COALESCE(SUM(failure_count), 0) FROM failure_patterns", + [], + |row| row.get(0), + ) + .map_err(|e| format!("Failed to get failure count: {}", e))?; + + Ok(CacheStats { + total_entries: total_entries as usize, + total_successes: total_successes as usize, + total_failures: total_failures as usize, + }) + } +} + +/// Cache statistics. +#[derive(Debug)] +pub struct CacheStats { + pub total_entries: usize, + pub total_successes: usize, + pub total_failures: usize, +} + +/// Normalize an error line by removing timestamps, PIDs, and paths. +fn normalize_error_line(line: &str) -> String { + let mut result = line.to_string(); + + // Remove timestamps (various formats) + // ISO: 2024-01-15T10:30:45 + // Common: [2024-01-15 10:30:45], Jan 15 10:30:45 + let timestamp_patterns = [ + r"\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}:\d{2}(\.\d+)?", + r"\[\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\]", + r"[A-Z][a-z]{2} \d{1,2} \d{2}:\d{2}:\d{2}", + ]; + for pattern in timestamp_patterns { + if let Ok(re) = regex::Regex::new(pattern) { + result = re.replace_all(&result, "[TIME]").to_string(); + } + } + + // Remove PIDs + if let Ok(re) = regex::Regex::new(r"\bpid[=: ]?\d+\b|\bPID[=: ]?\d+\b|\[\d+\]") { + result = re.replace_all(&result, "[PID]").to_string(); + } + + // Remove absolute paths but keep the filename + if let Ok(re) = regex::Regex::new(r"/[^\s:]+/([^\s/:]+)") { + result = re.replace_all(&result, "[PATH]/$1").to_string(); + } + + // Remove memory addresses + if let Ok(re) = regex::Regex::new(r"0x[0-9a-fA-F]+") { + result = re.replace_all(&result, "[ADDR]").to_string(); + } + + // Trim whitespace + result.trim().to_string() +} + +/// Extract a job name pattern by removing numeric suffixes. +fn extract_job_pattern(job_name: &str) -> String { + // Remove common numeric suffixes like _001, _42, -1, etc. + if let Ok(re) = regex::Regex::new(r"[_-]\d+$") { + re.replace(job_name, "[N]").to_string() + } else { + job_name.to_string() + } +} diff --git a/src/client/watch/mod.rs b/src/client/watch/mod.rs new file mode 100644 index 00000000..48d45bb5 --- /dev/null +++ b/src/client/watch/mod.rs @@ -0,0 +1,17 @@ +//! Watch module for AI-powered workflow monitoring and failure recovery. +//! +//! This module provides the `torc watch` command functionality, which monitors +//! a running workflow for failures and uses Claude to diagnose issues and +//! automatically apply recovery actions. + +pub mod audit; +pub mod claude_client; +pub mod failure_cache; +pub mod recovery; +pub mod watcher; + +pub use audit::AuditLogger; +pub use claude_client::ClaudeClient; +pub use failure_cache::FailureCache; +pub use recovery::RecoveryAction; +pub use watcher::{WatchConfig, Watcher}; diff --git a/src/client/watch/recovery.rs b/src/client/watch/recovery.rs new file mode 100644 index 00000000..611df23e --- /dev/null +++ b/src/client/watch/recovery.rs @@ -0,0 +1,196 @@ +//! Recovery action execution. + +use log::{debug, info}; +use serde::{Deserialize, Serialize}; + +use crate::client::apis::configuration::Configuration; +use crate::client::apis::default_api; +use crate::models::{JobStatus, ResourceRequirementsModel}; + +/// Recovery actions that can be taken for failed jobs. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum RecoveryAction { + /// Restart the job with no changes (for transient failures) + Restart, + + /// Restart with modified resource requirements + RestartWithResources { + /// New memory requirement (e.g., "8g") + memory: Option, + /// New runtime limit (e.g., "PT2H") + runtime: Option, + /// New CPU count + num_cpus: Option, + }, + + /// Cancel the job and its dependents + Cancel, + + /// Mark as completed and continue (skip the job) + Skip, +} + +/// Execute a recovery action for a job. +pub fn execute_recovery( + config: &Configuration, + job_id: i64, + action: &RecoveryAction, +) -> Result<(), String> { + match action { + RecoveryAction::Restart => restart_job(config, job_id), + RecoveryAction::RestartWithResources { + memory, + runtime, + num_cpus, + } => restart_with_resources(config, job_id, memory, runtime, num_cpus), + RecoveryAction::Cancel => cancel_job(config, job_id), + RecoveryAction::Skip => skip_job(config, job_id), + } +} + +/// Restart a job by resetting its status to uninitialized. +fn restart_job(config: &Configuration, job_id: i64) -> Result<(), String> { + info!("Restarting job {}", job_id); + + // Reset job status to uninitialized + default_api::manage_status_change( + config, + job_id, + JobStatus::Uninitialized, + 0, // run_id + None, + ) + .map_err(|e| format!("Failed to reset job status: {}", e))?; + + // Get workflow ID to reinitialize + let job = + default_api::get_job(config, job_id).map_err(|e| format!("Failed to get job: {}", e))?; + + // Reinitialize the workflow to process the reset job + default_api::initialize_jobs(config, job.workflow_id, None, None, None) + .map_err(|e| format!("Failed to reinitialize jobs: {}", e))?; + + debug!("Job {} reset and workflow reinitialized", job_id); + Ok(()) +} + +/// Restart a job with updated resource requirements. +fn restart_with_resources( + config: &Configuration, + job_id: i64, + memory: &Option, + runtime: &Option, + num_cpus: &Option, +) -> Result<(), String> { + info!( + "Restarting job {} with updated resources (memory: {:?}, runtime: {:?}, cpus: {:?})", + job_id, memory, runtime, num_cpus + ); + + // Get the job to find resource requirements + let job = + default_api::get_job(config, job_id).map_err(|e| format!("Failed to get job: {}", e))?; + + // Update resource requirements if the job has them + if let Some(req_id) = job.resource_requirements_id { + let mut reqs = default_api::get_resource_requirements(config, req_id) + .map_err(|e| format!("Failed to get resource requirements: {}", e))?; + + // Update fields if provided + if let Some(mem) = memory { + info!("Updating memory from {} to {}", reqs.memory, mem); + reqs.memory = mem.clone(); + } + if let Some(rt) = runtime { + info!("Updating runtime from {} to {}", reqs.runtime, rt); + reqs.runtime = rt.clone(); + } + if let Some(cpus) = num_cpus { + info!("Updating num_cpus from {} to {}", reqs.num_cpus, cpus); + reqs.num_cpus = *cpus; + } + + // Update the resource requirements + default_api::update_resource_requirements( + config, + req_id, + ResourceRequirementsModel { + id: reqs.id, + workflow_id: reqs.workflow_id, + name: reqs.name, + num_cpus: reqs.num_cpus, + num_gpus: reqs.num_gpus, + num_nodes: reqs.num_nodes, + memory: reqs.memory, + runtime: reqs.runtime, + }, + ) + .map_err(|e| format!("Failed to update resource requirements: {}", e))?; + } else { + debug!("Job {} has no resource requirements to update", job_id); + } + + // Now restart the job + restart_job(config, job_id) +} + +/// Cancel a job. +fn cancel_job(config: &Configuration, job_id: i64) -> Result<(), String> { + info!("Canceling job {}", job_id); + + default_api::manage_status_change( + config, + job_id, + JobStatus::Canceled, + 0, // run_id + None, + ) + .map_err(|e| format!("Failed to cancel job: {}", e))?; + + debug!("Job {} canceled", job_id); + Ok(()) +} + +/// Skip a job by marking it as completed. +fn skip_job(config: &Configuration, job_id: i64) -> Result<(), String> { + info!("Skipping job {} (marking as completed)", job_id); + + // Get the job to find workflow_id + let job = + default_api::get_job(config, job_id).map_err(|e| format!("Failed to get job: {}", e))?; + + // Get the workflow status to find run_id + let workflow_status = default_api::get_workflow_status(config, job.workflow_id) + .map_err(|e| format!("Failed to get workflow status: {}", e))?; + + // Create a minimal result for the skipped job + let result = crate::models::ResultModel { + id: None, + job_id, + workflow_id: job.workflow_id, + run_id: workflow_status.run_id, + compute_node_id: 0, // No compute node for skipped job + return_code: 0, // Success + exec_time_minutes: 0.0, + completion_time: chrono::Utc::now().to_rfc3339(), + peak_memory_bytes: None, + avg_memory_bytes: None, + peak_cpu_percent: None, + avg_cpu_percent: None, + status: JobStatus::Completed, + }; + + // Use complete_job to properly handle completion + default_api::complete_job( + config, + job_id, + JobStatus::Completed, + workflow_status.run_id, + result, + ) + .map_err(|e| format!("Failed to complete job: {}", e))?; + + debug!("Job {} marked as completed (skipped)", job_id); + Ok(()) +} diff --git a/src/client/watch/watcher.rs b/src/client/watch/watcher.rs new file mode 100644 index 00000000..5c109cb8 --- /dev/null +++ b/src/client/watch/watcher.rs @@ -0,0 +1,401 @@ +//! Main watch loop for monitoring workflows and recovering from failures. + +use std::collections::HashMap; +use std::path::PathBuf; +use std::time::{Duration, Instant}; + +use log::{debug, error, info, warn}; + +use crate::client::apis::configuration::Configuration; +use crate::client::apis::default_api; +use crate::client::log_paths; +use crate::models::JobStatus; + +use super::audit::AuditLogger; +use super::claude_client::ClaudeClient; +use super::failure_cache::FailureCache; +use super::recovery::execute_recovery; + +/// Configuration for the watch command. +#[derive(Debug, Clone)] +pub struct WatchConfig { + /// Poll interval in seconds + pub poll_interval: u64, + /// Output directory for job logs + pub output_dir: PathBuf, + /// Maximum recovery attempts per job + pub max_retries: u32, + /// Cooldown period between retries in seconds + pub retry_cooldown: u64, + /// Whether to only diagnose (not auto-recover) + pub diagnose_only: bool, + /// Claude model to use + pub model: String, + /// Path to failure pattern cache database + pub cache_path: Option, + /// Rate limit: max API calls per minute + pub rate_limit_per_minute: u32, + /// Path to audit log file + pub audit_log_path: Option, +} + +/// Tracks retry state for a job. +#[derive(Debug, Clone)] +struct JobRetryState { + retry_count: u32, + last_retry: Instant, + last_failure_signature: Option, +} + +/// The main watcher that monitors workflows for failures. +pub struct Watcher { + config: Configuration, + workflow_id: i64, + watch_config: WatchConfig, + claude_client: ClaudeClient, + failure_cache: Option, + audit_logger: Option, + retry_states: HashMap, + api_calls_this_minute: u32, + minute_start: Instant, +} + +impl Watcher { + /// Create a new Watcher instance. + pub fn new( + config: Configuration, + workflow_id: i64, + watch_config: WatchConfig, + api_key: String, + ) -> Result { + let claude_client = ClaudeClient::new(api_key, watch_config.model.clone()); + + let failure_cache = if let Some(ref path) = watch_config.cache_path { + match FailureCache::open(path) { + Ok(cache) => Some(cache), + Err(e) => { + warn!("Failed to open failure cache at {:?}: {}", path, e); + None + } + } + } else { + None + }; + + let audit_logger = if let Some(ref path) = watch_config.audit_log_path { + match AuditLogger::new(path) { + Ok(logger) => Some(logger), + Err(e) => { + warn!("Failed to create audit logger at {:?}: {}", path, e); + None + } + } + } else { + None + }; + + Ok(Self { + config, + workflow_id, + watch_config, + claude_client, + failure_cache, + audit_logger, + retry_states: HashMap::new(), + api_calls_this_minute: 0, + minute_start: Instant::now(), + }) + } + + /// Run the main watch loop. + pub fn run(&mut self) -> Result<(), String> { + info!("Starting watch for workflow {}", self.workflow_id); + info!("Poll interval: {}s", self.watch_config.poll_interval); + info!("Max retries per job: {}", self.watch_config.max_retries); + info!("Diagnose only: {}", self.watch_config.diagnose_only); + + loop { + // Check if workflow is complete + if self.is_workflow_complete()? { + info!("Workflow {} is complete", self.workflow_id); + break; + } + + // Get failed jobs + let failed_jobs = self.get_failed_jobs()?; + if !failed_jobs.is_empty() { + info!("Found {} failed jobs", failed_jobs.len()); + for job in failed_jobs { + if let Err(e) = self.handle_failed_job(job) { + error!("Error handling failed job: {}", e); + } + } + } else { + debug!("No failed jobs found"); + } + + // Sleep for poll interval + std::thread::sleep(Duration::from_secs(self.watch_config.poll_interval)); + } + + Ok(()) + } + + /// Check if the workflow is complete (all jobs finished or no more work to do). + fn is_workflow_complete(&self) -> Result { + let response = default_api::is_workflow_complete(&self.config, self.workflow_id) + .map_err(|e| format!("Failed to check workflow completion: {}", e))?; + + Ok(response.is_complete) + } + + /// Get all failed jobs in the workflow. + fn get_failed_jobs(&self) -> Result, String> { + let response = default_api::list_jobs( + &self.config, + self.workflow_id, + Some(JobStatus::Failed), + None, // needs_file_id + None, // upstream_job_id + None, // offset + Some(10000), // limit + None, // sort_by + None, // reverse_sort + None, // include_relationships + ) + .map_err(|e| format!("Failed to list failed jobs: {}", e))?; + + Ok(response.items.unwrap_or_default()) + } + + /// Handle a single failed job. + fn handle_failed_job(&mut self, job: crate::models::JobModel) -> Result<(), String> { + let job_id = job.id.ok_or("Job has no ID")?; + let job_name = job.name.clone(); + let max_retries = self.watch_config.max_retries; + let retry_cooldown = self.watch_config.retry_cooldown; + let diagnose_only = self.watch_config.diagnose_only; + + // Ensure retry state exists with initial values + self.retry_states.entry(job_id).or_insert(JobRetryState { + retry_count: 0, + last_retry: Instant::now() - Duration::from_secs(retry_cooldown + 1), + last_failure_signature: None, + }); + + // Check retry count (read state, drop borrow) + { + let state = self.retry_states.get(&job_id).unwrap(); + if state.retry_count >= max_retries { + debug!( + "Job {} has exceeded max retries ({})", + job_name, max_retries + ); + return Ok(()); + } + + // Check cooldown + let elapsed = state.last_retry.elapsed(); + if elapsed < Duration::from_secs(retry_cooldown) { + debug!( + "Job {} is in cooldown ({:.0}s remaining)", + job_name, + retry_cooldown as f64 - elapsed.as_secs_f64() + ); + return Ok(()); + } + } + + info!("Analyzing failed job: {} (ID: {})", job_name, job_id); + + // Get job logs + let (stdout, stderr) = self.get_job_logs(job_id)?; + + // Compute error signature for cache lookup + let error_signature = FailureCache::compute_signature(&stderr); + + // Check if same failure as last time (avoid retry loops) + { + let state = self.retry_states.get(&job_id).unwrap(); + if state.last_failure_signature.as_ref() == Some(&error_signature) { + debug!( + "Job {} failed with same error signature, skipping", + job_name + ); + return Ok(()); + } + } + + // Check failure cache + let cached_diagnosis = self + .failure_cache + .as_ref() + .and_then(|cache| cache.lookup(&job_name, &error_signature).ok().flatten()); + + let diagnosis = if let Some(cached) = cached_diagnosis { + info!("Found cached diagnosis for failure pattern"); + cached + } else { + // Rate limit check + if !self.check_rate_limit() { + warn!("Rate limit exceeded, skipping Claude API call"); + return Ok(()); + } + + // Call Claude API for diagnosis + info!("Requesting diagnosis from Claude..."); + let diagnosis = self.claude_client.diagnose_failure( + &self.config, + self.workflow_id, + &job, + &stdout, + &stderr, + )?; + + // Cache the diagnosis + if let Some(ref mut cache) = self.failure_cache { + if let Err(e) = cache.store(&job_name, &error_signature, &diagnosis) { + warn!("Failed to cache diagnosis: {}", e); + } + } + + diagnosis + }; + + // Log the diagnosis + info!("Diagnosis: {}", diagnosis.summary); + if let Some(ref action) = diagnosis.recommended_action { + info!("Recommended action: {:?}", action); + } + + // Log to audit + if let Some(ref mut audit) = self.audit_logger { + audit.log_diagnosis(job_id, &job_name, &diagnosis); + } + + // Execute recovery if not in diagnose-only mode + if !diagnose_only { + if let Some(action) = diagnosis.recommended_action { + info!("Executing recovery action: {:?}", action); + match execute_recovery(&self.config, job_id, &action) { + Ok(()) => { + info!("Recovery action executed successfully"); + + // Update retry state + if let Some(state) = self.retry_states.get_mut(&job_id) { + state.retry_count += 1; + state.last_retry = Instant::now(); + state.last_failure_signature = Some(error_signature.clone()); + } + + // Update cache success count + if let Some(ref mut cache) = self.failure_cache { + cache.record_success(&job_name, &error_signature); + } + + if let Some(ref mut audit) = self.audit_logger { + audit.log_recovery(job_id, &job_name, &action, true); + } + } + Err(e) => { + error!("Recovery action failed: {}", e); + + // Update retry state + if let Some(state) = self.retry_states.get_mut(&job_id) { + state.retry_count += 1; + state.last_retry = Instant::now(); + } + + // Update cache failure count + if let Some(ref mut cache) = self.failure_cache { + cache.record_failure(&job_name, &error_signature); + } + + if let Some(ref mut audit) = self.audit_logger { + audit.log_recovery(job_id, &job_name, &action, false); + } + } + } + } else { + info!("No recovery action recommended"); + } + } + + Ok(()) + } + + /// Get stdout and stderr logs for a job. + fn get_job_logs(&self, job_id: i64) -> Result<(String, String), String> { + // Get the latest result to find the run_id + let results = default_api::list_results( + &self.config, + self.workflow_id, + Some(job_id), + None, // run_id + None, // offset + Some(1), // limit + None, // sort_by + None, // reverse_sort + None, // return_code + None, // status + None, // all_runs + ) + .map_err(|e| format!("Failed to get job results: {}", e))?; + + let run_id = results + .items + .and_then(|items| items.into_iter().next()) + .map(|r| r.run_id) + .unwrap_or(1); + + let stdout_path = log_paths::get_job_stdout_path( + &self.watch_config.output_dir, + self.workflow_id, + job_id, + run_id, + ); + let stderr_path = log_paths::get_job_stderr_path( + &self.watch_config.output_dir, + self.workflow_id, + job_id, + run_id, + ); + + let stdout = std::fs::read_to_string(&stdout_path).unwrap_or_else(|_| String::new()); + let stderr = std::fs::read_to_string(&stderr_path).unwrap_or_else(|_| String::new()); + + // Truncate logs if too long (keep last 10KB) + let max_len = 10 * 1024; + let stdout = if stdout.len() > max_len { + format!("...[truncated]...\n{}", &stdout[stdout.len() - max_len..]) + } else { + stdout + }; + let stderr = if stderr.len() > max_len { + format!("...[truncated]...\n{}", &stderr[stderr.len() - max_len..]) + } else { + stderr + }; + + Ok((stdout, stderr)) + } + + /// Check and update rate limit. Returns true if API call is allowed. + fn check_rate_limit(&mut self) -> bool { + let now = Instant::now(); + let elapsed = now.duration_since(self.minute_start); + + // Reset counter every minute + if elapsed >= Duration::from_secs(60) { + self.api_calls_this_minute = 0; + self.minute_start = now; + } + + if self.api_calls_this_minute >= self.watch_config.rate_limit_per_minute { + return false; + } + + self.api_calls_this_minute += 1; + true + } +} diff --git a/src/config/client.rs b/src/config/client.rs index 46b79be7..016e1b8d 100644 --- a/src/config/client.rs +++ b/src/config/client.rs @@ -28,6 +28,9 @@ pub struct ClientConfig { /// HPC profile configuration pub hpc: ClientHpcConfig, + + /// Watch command configuration + pub watch: ClientWatchConfig, } impl Default for ClientConfig { @@ -40,6 +43,7 @@ impl Default for ClientConfig { run: ClientRunConfig::default(), slurm: ClientSlurmConfig::default(), hpc: ClientHpcConfig::default(), + watch: ClientWatchConfig::default(), } } } @@ -104,6 +108,50 @@ impl Default for ClientSlurmConfig { } } +/// Configuration for the `torc watch` command +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct ClientWatchConfig { + /// Poll interval in seconds for checking workflow status + pub poll_interval: u64, + + /// Maximum recovery attempts per job + pub max_retries: u32, + + /// Cooldown period between retries in seconds + pub retry_cooldown: u64, + + /// Claude model to use for diagnosis + pub model: String, + + /// Path to failure pattern cache database + pub cache_path: Option, + + /// Rate limit: max API calls per minute + pub rate_limit_per_minute: u32, + + /// Path to audit log file + pub audit_log_path: Option, + + /// Anthropic API key (fallback if ANTHROPIC_API_KEY env var not set) + pub api_key: Option, +} + +impl Default for ClientWatchConfig { + fn default() -> Self { + Self { + poll_interval: 30, + max_retries: 3, + retry_cooldown: 60, + model: "claude-sonnet-4-20250514".to_string(), + cache_path: None, + rate_limit_per_minute: 10, + audit_log_path: None, + api_key: None, + } + } +} + /// Configuration for HPC profiles #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(default)] diff --git a/src/main.rs b/src/main.rs index 4823e573..31fefb3d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,7 @@ use torc::client::commands::results::handle_result_commands; use torc::client::commands::scheduled_compute_nodes::handle_scheduled_compute_node_commands; use torc::client::commands::slurm::handle_slurm_commands; use torc::client::commands::user_data::handle_user_data_commands; +use torc::client::commands::watch::{WatchArgs, run_watch}; use torc::client::commands::workflows::handle_workflow_commands; use torc::client::config::TorcConfig; use torc::client::workflow_manager::WorkflowManager; @@ -441,6 +442,28 @@ fn main() { } } } + Commands::Watch { + workflow_id, + poll_interval, + auto_recover, + max_retries, + memory_multiplier, + runtime_multiplier, + output_dir, + show_job_counts, + } => { + let args = WatchArgs { + workflow_id: *workflow_id, + poll_interval: *poll_interval, + auto_recover: *auto_recover, + max_retries: *max_retries, + memory_multiplier: *memory_multiplier, + runtime_multiplier: *runtime_multiplier, + output_dir: output_dir.clone(), + show_job_counts: *show_job_counts, + }; + run_watch(&config, &args); + } Commands::Workflows { command } => { handle_workflow_commands(&config, command, &format); } diff --git a/src/server/api/workflows.rs b/src/server/api/workflows.rs index bc468aa1..bb642351 100644 --- a/src/server/api/workflows.rs +++ b/src/server/api/workflows.rs @@ -1337,7 +1337,7 @@ where ); let offset_val = offset.unwrap_or(0); - let limit_val = limit.unwrap_or(100000).min(100000); + let limit_val = limit.unwrap_or(10000).min(10000); // Query job_depends_on table with JOIN to get job names let dependencies = match sqlx::query_as!( @@ -1397,7 +1397,7 @@ where models::ListJobDependenciesResponse { items: Some(dependencies), offset: offset_val, - max_limit: 100000, + max_limit: 10000, count: current_count, total_count, has_more, @@ -1422,7 +1422,7 @@ where ); let offset_val = offset.unwrap_or(0); - let limit_val = limit.unwrap_or(100000).min(100000); + let limit_val = limit.unwrap_or(10000).min(10000); // Query job_input_file and job_output_file tables with JOINs // UNION the input and output relationships @@ -1499,7 +1499,7 @@ where models::ListJobFileRelationshipsResponse { items: Some(relationships), offset: offset_val, - max_limit: 100000, + max_limit: 10000, count: current_count, total_count, has_more, @@ -1524,7 +1524,7 @@ where ); let offset_val = offset.unwrap_or(0); - let limit_val = limit.unwrap_or(100000).min(100000); + let limit_val = limit.unwrap_or(10000).min(10000); // Query job_input_user_data and job_output_user_data tables with JOINs let relationships = match sqlx::query_as!( @@ -1599,7 +1599,7 @@ where models::ListJobUserDataRelationshipsResponse { items: Some(relationships), offset: offset_val, - max_limit: 100000, + max_limit: 10000, count: current_count, total_count, has_more, diff --git a/tests/test_slurm_regenerate.rs b/tests/test_slurm_regenerate.rs new file mode 100644 index 00000000..45cb917a --- /dev/null +++ b/tests/test_slurm_regenerate.rs @@ -0,0 +1,773 @@ +//! Tests for the `torc slurm regenerate` command. +//! +//! These tests simulate failure recovery scenarios where we need to regenerate +//! Slurm schedulers for pending jobs (uninitialized, ready, blocked) after +//! some jobs have completed or failed. + +mod common; + +use common::{ServerProcess, run_cli_with_json, start_server}; +use rstest::rstest; +use std::collections::HashMap; +use torc::client::{Configuration, default_api}; +use torc::models; + +/// Create a workflow with jobs in various states for testing regenerate. +/// Returns (workflow_id, job_ids_by_status) +fn create_workflow_with_job_states( + config: &Configuration, + name: &str, + job_configs: &[(String, models::JobStatus)], +) -> (i64, HashMap) { + // Create workflow + let user = "test_user".to_string(); + let workflow = models::WorkflowModel::new(name.to_string(), user); + let created_workflow = + default_api::create_workflow(config, workflow).expect("Failed to create workflow"); + let workflow_id = created_workflow.id.unwrap(); + + // Create resource requirements (using "test_rr" since "default" is reserved) + let mut rr = models::ResourceRequirementsModel::new(workflow_id, "test_rr".to_string()); + rr.num_cpus = 4; + rr.num_gpus = 0; + rr.num_nodes = 1; + rr.memory = "8g".to_string(); + rr.runtime = "PT1H".to_string(); + let rr = default_api::create_resource_requirements(config, rr) + .expect("Failed to create resource requirements"); + let rr_id = rr.id.unwrap(); + + // Create jobs + let mut job_ids = HashMap::new(); + for (job_name, _status) in job_configs { + let mut job = models::JobModel::new( + workflow_id, + job_name.clone(), + format!("echo '{}'", job_name), + ); + job.resource_requirements_id = Some(rr_id); + let created_job = default_api::create_job(config, job).expect("Failed to create job"); + job_ids.insert(job_name.clone(), created_job.id.unwrap()); + } + + // Initialize jobs - after this, jobs without dependencies will be "ready", + // jobs with dependencies will be "blocked" + default_api::initialize_jobs(config, workflow_id, None, None, None) + .expect("Failed to initialize jobs"); + + (workflow_id, job_ids) +} + +/// Create a multi-stage workflow with dependencies. +/// Stage 1: preprocess (no dependencies) -> Stage 2: work jobs (depend on preprocess) -> Stage 3: postprocess (depends on all work) +fn create_multi_stage_workflow( + config: &Configuration, + name: &str, + num_work_jobs: usize, +) -> (i64, HashMap) { + let user = "test_user".to_string(); + let workflow = models::WorkflowModel::new(name.to_string(), user); + let created_workflow = + default_api::create_workflow(config, workflow).expect("Failed to create workflow"); + let workflow_id = created_workflow.id.unwrap(); + + // Create resource requirements (using "test_rr" since "default" is reserved) + let mut rr = models::ResourceRequirementsModel::new(workflow_id, "test_rr".to_string()); + rr.num_cpus = 4; + rr.num_gpus = 0; + rr.num_nodes = 1; + rr.memory = "8g".to_string(); + rr.runtime = "PT1H".to_string(); + let rr = default_api::create_resource_requirements(config, rr) + .expect("Failed to create resource requirements"); + let rr_id = rr.id.unwrap(); + + // Create files for dependencies + let prep_output = default_api::create_file( + config, + models::FileModel::new( + workflow_id, + "prep_output".to_string(), + "/tmp/prep.out".to_string(), + ), + ) + .expect("Failed to create file"); + + let work_outputs: Vec<_> = (0..num_work_jobs) + .map(|i| { + default_api::create_file( + config, + models::FileModel::new( + workflow_id, + format!("work_output_{}", i), + format!("/tmp/work_{}.out", i), + ), + ) + .expect("Failed to create file") + }) + .collect(); + + // Stage 1: preprocess + let mut preprocess = models::JobModel::new( + workflow_id, + "preprocess".to_string(), + "echo preprocess".to_string(), + ); + preprocess.resource_requirements_id = Some(rr_id); + preprocess.output_file_ids = Some(vec![prep_output.id.unwrap()]); + let preprocess = + default_api::create_job(config, preprocess).expect("Failed to create preprocess job"); + + // Stage 2: work jobs (depend on preprocess via file) + let mut work_jobs = Vec::new(); + for i in 0..num_work_jobs { + let mut work = models::JobModel::new( + workflow_id, + format!("work_{}", i), + format!("echo work_{}", i), + ); + work.resource_requirements_id = Some(rr_id); + work.input_file_ids = Some(vec![prep_output.id.unwrap()]); + work.output_file_ids = Some(vec![work_outputs[i].id.unwrap()]); + let work = default_api::create_job(config, work).expect("Failed to create work job"); + work_jobs.push(work); + } + + // Stage 3: postprocess (depends on all work jobs via files) + let mut postprocess = models::JobModel::new( + workflow_id, + "postprocess".to_string(), + "echo postprocess".to_string(), + ); + postprocess.resource_requirements_id = Some(rr_id); + postprocess.input_file_ids = Some(work_outputs.iter().map(|f| f.id.unwrap()).collect()); + let postprocess = + default_api::create_job(config, postprocess).expect("Failed to create postprocess job"); + + // Initialize workflow + default_api::initialize_jobs(config, workflow_id, None, None, None) + .expect("Failed to initialize jobs"); + + // Build job_ids map + let mut job_ids = HashMap::new(); + job_ids.insert("preprocess".to_string(), preprocess.id.unwrap()); + for (i, work) in work_jobs.iter().enumerate() { + job_ids.insert(format!("work_{}", i), work.id.unwrap()); + } + job_ids.insert("postprocess".to_string(), postprocess.id.unwrap()); + + (workflow_id, job_ids) +} + +/// Create a workflow with jobs having different resource requirements. +fn create_workflow_with_varied_resources( + config: &Configuration, + name: &str, +) -> (i64, HashMap) { + let user = "test_user".to_string(); + let workflow = models::WorkflowModel::new(name.to_string(), user); + let created_workflow = + default_api::create_workflow(config, workflow).expect("Failed to create workflow"); + let workflow_id = created_workflow.id.unwrap(); + + // Create small resource requirements (many jobs per node) + let mut rr_small = models::ResourceRequirementsModel::new(workflow_id, "small".to_string()); + rr_small.num_cpus = 4; + rr_small.num_gpus = 0; + rr_small.num_nodes = 1; + rr_small.memory = "8g".to_string(); + rr_small.runtime = "PT1H".to_string(); + let rr_small = default_api::create_resource_requirements(config, rr_small) + .expect("Failed to create small resource requirements"); + + // Create large resource requirements (one job per node) + let mut rr_large = models::ResourceRequirementsModel::new(workflow_id, "large".to_string()); + rr_large.num_cpus = 64; + rr_large.num_gpus = 0; + rr_large.num_nodes = 1; + rr_large.memory = "120g".to_string(); + rr_large.runtime = "PT4H".to_string(); + let rr_large = default_api::create_resource_requirements(config, rr_large) + .expect("Failed to create large resource requirements"); + + // Create GPU resource requirements + let mut rr_gpu = models::ResourceRequirementsModel::new(workflow_id, "gpu".to_string()); + rr_gpu.num_cpus = 32; + rr_gpu.num_gpus = 2; + rr_gpu.num_nodes = 1; + rr_gpu.memory = "64g".to_string(); + rr_gpu.runtime = "PT2H".to_string(); + let rr_gpu = default_api::create_resource_requirements(config, rr_gpu) + .expect("Failed to create GPU resource requirements"); + + // Create jobs with different resource requirements + let mut job_ids = HashMap::new(); + + // Small jobs + for i in 0..5 { + let mut job = models::JobModel::new( + workflow_id, + format!("small_job_{}", i), + format!("echo small_{}", i), + ); + job.resource_requirements_id = Some(rr_small.id.unwrap()); + let created = default_api::create_job(config, job).expect("Failed to create small job"); + job_ids.insert(format!("small_job_{}", i), created.id.unwrap()); + } + + // Large jobs + for i in 0..3 { + let mut job = models::JobModel::new( + workflow_id, + format!("large_job_{}", i), + format!("echo large_{}", i), + ); + job.resource_requirements_id = Some(rr_large.id.unwrap()); + let created = default_api::create_job(config, job).expect("Failed to create large job"); + job_ids.insert(format!("large_job_{}", i), created.id.unwrap()); + } + + // GPU jobs + for i in 0..2 { + let mut job = models::JobModel::new( + workflow_id, + format!("gpu_job_{}", i), + format!("echo gpu_{}", i), + ); + job.resource_requirements_id = Some(rr_gpu.id.unwrap()); + let created = default_api::create_job(config, job).expect("Failed to create GPU job"); + job_ids.insert(format!("gpu_job_{}", i), created.id.unwrap()); + } + + // Initialize workflow + default_api::initialize_jobs(config, workflow_id, None, None, None) + .expect("Failed to initialize jobs"); + + (workflow_id, job_ids) +} + +/// Helper to get the number of schedulers for a workflow +fn get_scheduler_count(config: &Configuration, workflow_id: i64) -> usize { + let response = default_api::list_slurm_schedulers( + config, + workflow_id, + Some(0), + Some(100), + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + .expect("Failed to list schedulers"); + response.items.unwrap_or_default().len() +} + +// ============== Basic Regenerate Tests ============== + +/// Test regenerate with all jobs in ready state (basic case) +#[rstest] +fn test_regenerate_all_jobs_ready(start_server: &ServerProcess) { + let config = &start_server.config; + + // Create workflow with 5 ready jobs + let job_configs: Vec<(String, models::JobStatus)> = (0..5) + .map(|i| (format!("job_{}", i), models::JobStatus::Ready)) + .collect(); + + let (workflow_id, _job_ids) = + create_workflow_with_job_states(config, "test_regenerate_all_ready", &job_configs); + + // Verify no schedulers exist yet + assert_eq!(get_scheduler_count(config, workflow_id), 0); + + // Run regenerate command with kestrel profile + let args = [ + "slurm", + "regenerate", + &workflow_id.to_string(), + "--account", + "test_account", + "--profile", + "kestrel", + ]; + + let result = run_cli_with_json(&args, start_server); + assert!(result.is_ok(), "Regenerate command failed: {:?}", result); + + let json = result.unwrap(); + assert!(json.get("pending_jobs").is_some()); + assert_eq!(json.get("pending_jobs").unwrap().as_i64().unwrap(), 5); + + // Verify schedulers were created + assert!(get_scheduler_count(config, workflow_id) > 0); +} + +/// Test regenerate with no pending jobs (empty workflow) +#[rstest] +fn test_regenerate_no_pending_jobs(start_server: &ServerProcess) { + let config = &start_server.config; + + // Create workflow with no jobs (empty workflow) + let user = "test_user".to_string(); + let workflow = models::WorkflowModel::new("test_regenerate_no_pending".to_string(), user); + let created_workflow = + default_api::create_workflow(config, workflow).expect("Failed to create workflow"); + let workflow_id = created_workflow.id.unwrap(); + + // Run regenerate command + let args = [ + "slurm", + "regenerate", + &workflow_id.to_string(), + "--account", + "test_account", + "--profile", + "kestrel", + ]; + + let result = run_cli_with_json(&args, start_server); + assert!(result.is_ok(), "Regenerate command failed: {:?}", result); + + let json = result.unwrap(); + // Should report 0 pending jobs + assert_eq!(json.get("pending_jobs").unwrap().as_i64().unwrap(), 0); + assert!(json.get("warnings").is_some()); + + // No schedulers should be created + assert_eq!(get_scheduler_count(config, workflow_id), 0); +} + +/// Test regenerate with many ready jobs (simple case with different job counts) +#[rstest] +fn test_regenerate_multiple_ready_jobs(start_server: &ServerProcess) { + let config = &start_server.config; + + // Create workflow with many ready jobs + let job_configs: Vec<(String, models::JobStatus)> = (0..10) + .map(|i| (format!("job_{}", i), models::JobStatus::Ready)) + .collect(); + + let (workflow_id, _job_ids) = + create_workflow_with_job_states(config, "test_regenerate_multiple", &job_configs); + + // Run regenerate command + let args = [ + "slurm", + "regenerate", + &workflow_id.to_string(), + "--account", + "test_account", + "--profile", + "kestrel", + ]; + + let result = run_cli_with_json(&args, start_server); + assert!(result.is_ok(), "Regenerate command failed: {:?}", result); + + let json = result.unwrap(); + // Should count all 10 ready jobs + assert_eq!(json.get("pending_jobs").unwrap().as_i64().unwrap(), 10); + + // Schedulers should be created + assert!(get_scheduler_count(config, workflow_id) > 0); +} + +// ============== Multi-Stage Workflow Tests ============== + +/// Test regenerate with blocked jobs from multi-stage workflow +/// Jobs with unmet dependencies are blocked after initialization +#[rstest] +fn test_regenerate_with_blocked_jobs(start_server: &ServerProcess) { + let config = &start_server.config; + + // Create multi-stage workflow - work and postprocess jobs will be blocked + // because preprocess hasn't completed + let (workflow_id, _job_ids) = create_multi_stage_workflow(config, "test_blocked_jobs", 5); + + // Run regenerate command - should count blocked jobs as pending + let args = [ + "slurm", + "regenerate", + &workflow_id.to_string(), + "--account", + "test_account", + "--profile", + "kestrel", + ]; + + let result = run_cli_with_json(&args, start_server); + assert!(result.is_ok(), "Regenerate command failed: {:?}", result); + + let json = result.unwrap(); + // Should have: 1 ready (preprocess) + 5 blocked (work) + 1 blocked (postprocess) = 7 + let pending_jobs = json.get("pending_jobs").unwrap().as_i64().unwrap(); + assert_eq!( + pending_jobs, 7, + "Expected 7 pending jobs (1 ready + 6 blocked), got {}", + pending_jobs + ); + + // Schedulers should be created + assert!(get_scheduler_count(config, workflow_id) > 0); +} + +/// Test regenerate counts both ready and blocked jobs correctly +#[rstest] +fn test_regenerate_counts_all_pending_statuses(start_server: &ServerProcess) { + let config = &start_server.config; + + // Create a larger multi-stage workflow to verify counting + let (workflow_id, _job_ids) = create_multi_stage_workflow(config, "test_pending_count", 10); + + // Run regenerate command + let args = [ + "slurm", + "regenerate", + &workflow_id.to_string(), + "--account", + "test_account", + "--profile", + "kestrel", + ]; + + let result = run_cli_with_json(&args, start_server); + assert!(result.is_ok(), "Regenerate command failed: {:?}", result); + + let json = result.unwrap(); + // Should have: 1 ready (preprocess) + 10 blocked (work) + 1 blocked (postprocess) = 12 + let pending_jobs = json.get("pending_jobs").unwrap().as_i64().unwrap(); + assert_eq!( + pending_jobs, 12, + "Expected 12 pending jobs, got {}", + pending_jobs + ); +} + +// ============== Resource Requirement Tests ============== + +/// Test regenerate creates separate schedulers for different resource requirements +#[rstest] +fn test_regenerate_varied_resources(start_server: &ServerProcess) { + let config = &start_server.config; + + // Create workflow with varied resource requirements + let (workflow_id, _job_ids) = + create_workflow_with_varied_resources(config, "test_varied_resources"); + + // Run regenerate command + let args = [ + "slurm", + "regenerate", + &workflow_id.to_string(), + "--account", + "test_account", + "--profile", + "kestrel", + ]; + + let result = run_cli_with_json(&args, start_server); + assert!(result.is_ok(), "Regenerate command failed: {:?}", result); + + let json = result.unwrap(); + // Should have all 10 jobs pending + assert_eq!(json.get("pending_jobs").unwrap().as_i64().unwrap(), 10); + + // Check schedulers_created array + let schedulers_created = json.get("schedulers_created").unwrap().as_array().unwrap(); + // Should have created multiple schedulers for different resource types + assert!( + schedulers_created.len() >= 2, + "Expected at least 2 schedulers for varied resources, got {}", + schedulers_created.len() + ); +} + +/// Test regenerate with single allocation mode +#[rstest] +fn test_regenerate_single_allocation(start_server: &ServerProcess) { + let config = &start_server.config; + + // Create workflow with many jobs + let job_configs: Vec<(String, models::JobStatus)> = (0..20) + .map(|i| (format!("job_{}", i), models::JobStatus::Ready)) + .collect(); + + let (workflow_id, _job_ids) = + create_workflow_with_job_states(config, "test_single_allocation", &job_configs); + + // Run regenerate command with --single-allocation + let args = [ + "slurm", + "regenerate", + &workflow_id.to_string(), + "--account", + "test_account", + "--profile", + "kestrel", + "--single-allocation", + ]; + + let result = run_cli_with_json(&args, start_server); + assert!(result.is_ok(), "Regenerate command failed: {:?}", result); + + let json = result.unwrap(); + assert_eq!(json.get("pending_jobs").unwrap().as_i64().unwrap(), 20); + + // In single allocation mode, should create fewer, larger allocations + let total_allocations = json.get("total_allocations").unwrap().as_i64().unwrap(); + let schedulers = json.get("schedulers_created").unwrap().as_array().unwrap(); + + // With single allocation, we expect 1 scheduler with 1 (larger) allocation + assert_eq!( + schedulers.len(), + 1, + "Single allocation should create 1 scheduler" + ); + assert_eq!( + total_allocations, 1, + "Single allocation mode should create 1 allocation" + ); +} + +// ============== Existing Scheduler Tests ============== + +/// Test regenerate uses existing scheduler's account as default +#[rstest] +fn test_regenerate_uses_existing_account(start_server: &ServerProcess) { + let config = &start_server.config; + + // Create workflow + let job_configs: Vec<(String, models::JobStatus)> = (0..3) + .map(|i| (format!("job_{}", i), models::JobStatus::Ready)) + .collect(); + + let (workflow_id, _job_ids) = + create_workflow_with_job_states(config, "test_existing_account", &job_configs); + + // Create an existing scheduler with a specific account + let scheduler = models::SlurmSchedulerModel { + id: None, + workflow_id, + name: Some("existing_scheduler".to_string()), + account: "existing_project_account".to_string(), + partition: None, + mem: Some("8g".to_string()), + walltime: "01:00:00".to_string(), + nodes: 1, + gres: None, + ntasks_per_node: None, + qos: None, + tmp: None, + extra: None, + }; + default_api::create_slurm_scheduler(config, scheduler).expect("Failed to create scheduler"); + + // Run regenerate without specifying account (should use existing) + let args = [ + "slurm", + "regenerate", + &workflow_id.to_string(), + "--profile", + "kestrel", + ]; + + let result = run_cli_with_json(&args, start_server); + assert!(result.is_ok(), "Regenerate command failed: {:?}", result); + + // Verify new scheduler uses the existing account + let response = default_api::list_slurm_schedulers( + config, + workflow_id, + Some(0), + Some(100), + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + .expect("Failed to list schedulers"); + + let schedulers = response.items.unwrap_or_default(); + // Should have 2 schedulers now (original + regenerated) + assert!(schedulers.len() >= 2); + + // Find the regenerated scheduler (has "regen" in name) + let regen_scheduler = schedulers.iter().find(|s| { + s.name + .as_ref() + .map(|n| n.contains("regen")) + .unwrap_or(false) + }); + assert!( + regen_scheduler.is_some(), + "Should have regenerated scheduler" + ); + assert_eq!(regen_scheduler.unwrap().account, "existing_project_account"); +} + +// ============== Edge Case Tests ============== + +/// Test regenerate with jobs missing resource requirements +#[rstest] +fn test_regenerate_missing_resource_requirements(start_server: &ServerProcess) { + let config = &start_server.config; + + // Create workflow + let user = "test_user".to_string(); + let workflow = models::WorkflowModel::new("test_missing_rr".to_string(), user); + let created_workflow = + default_api::create_workflow(config, workflow).expect("Failed to create workflow"); + let workflow_id = created_workflow.id.unwrap(); + + // Create job WITHOUT resource requirements + let job = models::JobModel::new( + workflow_id, + "job_no_rr".to_string(), + "echo test".to_string(), + ); + default_api::create_job(config, job).expect("Failed to create job"); + + // Initialize + default_api::initialize_jobs(config, workflow_id, None, None, None) + .expect("Failed to initialize"); + + // Run regenerate command + let args = [ + "slurm", + "regenerate", + &workflow_id.to_string(), + "--account", + "test_account", + "--profile", + "kestrel", + ]; + + let result = run_cli_with_json(&args, start_server); + assert!(result.is_ok(), "Regenerate command failed: {:?}", result); + + let json = result.unwrap(); + // Should have a warning about missing resource requirements + let warnings = json.get("warnings").unwrap().as_array().unwrap(); + assert!( + !warnings.is_empty(), + "Expected warnings about missing resource requirements" + ); +} + +/// Test regenerate with non-existent workflow ID +/// The command should return a result with 0 pending jobs (graceful handling) +#[rstest] +fn test_regenerate_nonexistent_workflow(start_server: &ServerProcess) { + let args = [ + "slurm", + "regenerate", + "999999", // Non-existent workflow ID + "--account", + "test_account", + "--profile", + "kestrel", + ]; + + let result = run_cli_with_json(&args, start_server); + // Command should succeed but with 0 pending jobs + // (graceful handling for workflows with no pending jobs) + assert!(result.is_ok(), "Command should succeed gracefully"); + + let json = result.unwrap(); + // Should report 0 pending jobs for non-existent workflow + assert_eq!(json.get("pending_jobs").unwrap().as_i64().unwrap(), 0); +} + +/// Test regenerate with blocked jobs (should include them in pending count) +#[rstest] +fn test_regenerate_includes_blocked_jobs(start_server: &ServerProcess) { + let config = &start_server.config; + + // Create multi-stage workflow (work jobs will be blocked initially) + let (workflow_id, _job_ids) = create_multi_stage_workflow(config, "test_includes_blocked", 5); + + // Don't complete preprocess - work jobs should remain blocked + + // Run regenerate command + let args = [ + "slurm", + "regenerate", + &workflow_id.to_string(), + "--account", + "test_account", + "--profile", + "kestrel", + ]; + + let result = run_cli_with_json(&args, start_server); + assert!(result.is_ok(), "Regenerate command failed: {:?}", result); + + let json = result.unwrap(); + // Should count: 1 ready (preprocess) + 5 blocked (work) + 1 blocked (postprocess) = 7 + let pending_jobs = json.get("pending_jobs").unwrap().as_i64().unwrap(); + assert_eq!( + pending_jobs, 7, + "Expected 7 pending jobs (1 ready + 6 blocked), got {}", + pending_jobs + ); +} + +// ============== Output Format Tests ============== + +/// Test regenerate JSON output structure +#[rstest] +fn test_regenerate_json_output_structure(start_server: &ServerProcess) { + let config = &start_server.config; + + let job_configs: Vec<(String, models::JobStatus)> = (0..5) + .map(|i| (format!("job_{}", i), models::JobStatus::Ready)) + .collect(); + + let (workflow_id, _job_ids) = + create_workflow_with_job_states(config, "test_json_output", &job_configs); + + let args = [ + "slurm", + "regenerate", + &workflow_id.to_string(), + "--account", + "test_account", + "--profile", + "kestrel", + ]; + + let result = run_cli_with_json(&args, start_server); + assert!(result.is_ok(), "Regenerate command failed: {:?}", result); + + let json = result.unwrap(); + + // Verify required fields exist + assert!(json.get("workflow_id").is_some()); + assert!(json.get("pending_jobs").is_some()); + assert!(json.get("schedulers_created").is_some()); + assert!(json.get("total_allocations").is_some()); + assert!(json.get("warnings").is_some()); + assert!(json.get("submitted").is_some()); + + // Verify types + assert!(json.get("workflow_id").unwrap().is_i64()); + assert!(json.get("pending_jobs").unwrap().is_i64()); + assert!(json.get("schedulers_created").unwrap().is_array()); + assert!(json.get("total_allocations").unwrap().is_i64()); + assert!(json.get("warnings").unwrap().is_array()); + assert!(json.get("submitted").unwrap().is_boolean()); +} diff --git a/torc-mcp-server/Cargo.toml b/torc-mcp-server/Cargo.toml new file mode 100644 index 00000000..ce8ad46a --- /dev/null +++ b/torc-mcp-server/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "torc-mcp-server" +version.workspace = true +authors.workspace = true +license.workspace = true +edition.workspace = true +description = "MCP server for Torc workflow orchestration" + +[[bin]] +name = "torc-mcp-server" +path = "src/main.rs" + +[dependencies] +torc = { path = "..", features = ["client"] } +clap.workspace = true +rmcp.workspace = true +schemars.workspace = true +tokio.workspace = true +serde.workspace = true +serde_json.workspace = true +anyhow.workspace = true +tracing.workspace = true +tracing-subscriber.workspace = true +tempfile = "3" diff --git a/torc-mcp-server/src/lib.rs b/torc-mcp-server/src/lib.rs new file mode 100644 index 00000000..d07947cd --- /dev/null +++ b/torc-mcp-server/src/lib.rs @@ -0,0 +1,8 @@ +//! MCP server for Torc workflow orchestration. +//! +//! This crate provides an MCP (Model Context Protocol) server that exposes +//! Torc's workflow and job management capabilities as tools for AI assistants +//! like Claude. + +pub mod server; +pub mod tools; diff --git a/torc-mcp-server/src/main.rs b/torc-mcp-server/src/main.rs new file mode 100644 index 00000000..739e78a8 --- /dev/null +++ b/torc-mcp-server/src/main.rs @@ -0,0 +1,104 @@ +//! Torc MCP Server binary. +//! +//! This binary provides an MCP (Model Context Protocol) server that exposes +//! Torc's workflow and job management capabilities as tools for AI assistants. +//! +//! # Usage +//! +//! ```bash +//! # Run with default settings (connects to localhost:8080) +//! torc-mcp-server +//! +//! # Run with custom API URL +//! TORC_API_URL=http://server:8080/torc-service/v1 torc-mcp-server +//! +//! # Run with custom output directory for logs +//! torc-mcp-server --output-dir /path/to/output +//! ``` + +use anyhow::Result; +use clap::Parser; +use rmcp::{ServiceExt, transport::io::stdio}; +use std::path::PathBuf; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +use torc_mcp_server::server::TorcMcpServer; + +/// MCP server for Torc workflow orchestration. +/// +/// This server exposes Torc's workflow and job management capabilities +/// as tools for AI assistants via the Model Context Protocol (MCP). +#[derive(Parser, Debug)] +#[command(name = "torc-mcp-server")] +#[command(version, about, long_about = None)] +struct Args { + /// Torc API URL + #[arg( + long, + env = "TORC_API_URL", + default_value = "http://localhost:8080/torc-service/v1" + )] + api_url: String, + + /// Output directory for job logs + #[arg(long, env = "TORC_OUTPUT_DIR", default_value = "output")] + output_dir: PathBuf, + + /// Username for API authentication + #[arg(long, env = "TORC_USERNAME")] + username: Option, + + /// Password for API authentication + #[arg(long, env = "TORC_PASSWORD")] + password: Option, +} + +fn main() -> Result<()> { + // Parse CLI arguments (handles -h/--help before async runtime) + let args = Args::parse(); + + // Initialize logging to stderr (stdout is used for MCP protocol) + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "torc_mcp_server=info".into()), + ) + .with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr)) + .init(); + + tracing::info!("Starting Torc MCP Server"); + tracing::info!("API URL: {}", args.api_url); + tracing::info!("Output directory: {}", args.output_dir.display()); + + // Create the server BEFORE entering the async runtime. + // This is important because TorcMcpServer::new() creates a reqwest::blocking::Client + // which spawns its own tokio runtime. Creating it inside block_on would cause + // nested runtime issues. + let server = if args.username.is_some() { + TorcMcpServer::with_auth(args.api_url, args.output_dir, args.username, args.password) + } else { + TorcMcpServer::new(args.api_url, args.output_dir) + }; + + // Build runtime and run the async portion + // Use multi-threaded runtime to properly support spawn_blocking for the + // blocking reqwest client calls + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) + .enable_all() + .build()?; + + runtime.block_on(async_main(server)) +} + +async fn async_main(server: TorcMcpServer) -> Result<()> { + // Serve over stdio transport + let service = server.serve(stdio()).await?; + + tracing::info!("MCP server running"); + + // Wait for the service to complete + service.waiting().await?; + + Ok(()) +} diff --git a/torc-mcp-server/src/server.rs b/torc-mcp-server/src/server.rs new file mode 100644 index 00000000..00a0a1a1 --- /dev/null +++ b/torc-mcp-server/src/server.rs @@ -0,0 +1,395 @@ +//! MCP server implementation for Torc. + +use rmcp::{ + Error as McpError, ServerHandler, + model::{CallToolResult, Implementation, ProtocolVersion, ServerCapabilities, ServerInfo}, + schemars, tool, +}; +use serde::Deserialize; +use std::path::PathBuf; + +use torc::client::apis::configuration::Configuration; + +use crate::tools; + +/// MCP server that exposes Torc workflow operations as tools. +#[derive(Debug, Clone)] +pub struct TorcMcpServer { + config: Configuration, + output_dir: PathBuf, +} + +impl TorcMcpServer { + /// Create a new TorcMcpServer with the given API URL and output directory. + pub fn new(api_url: String, output_dir: PathBuf) -> Self { + let mut config = Configuration::new(); + config.base_path = api_url; + + Self { config, output_dir } + } + + /// Create a new TorcMcpServer with authentication. + pub fn with_auth( + api_url: String, + output_dir: PathBuf, + username: Option, + password: Option, + ) -> Self { + let mut config = Configuration::new(); + config.base_path = api_url; + + if let (Some(user), Some(pass)) = (username, password) { + config.basic_auth = Some((user, Some(pass))); + } + + Self { config, output_dir } + } +} + +// Tool parameter types + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct WorkflowIdParam { + #[schemars(description = "The workflow ID")] + pub workflow_id: i64, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct JobIdParam { + #[schemars(description = "The job ID")] + pub job_id: i64, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct GetJobLogsParams { + #[schemars(description = "The workflow ID")] + pub workflow_id: i64, + #[schemars(description = "The job ID")] + pub job_id: i64, + #[schemars(description = "The run ID (1 for first run, increments on restart)")] + pub run_id: i64, + #[schemars(description = "Log type: 'stdout' or 'stderr'")] + pub log_type: String, + #[schemars( + description = "Number of lines to return from the end (optional, returns all if not specified)" + )] + pub tail_lines: Option, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct ListJobsByStatusParams { + #[schemars(description = "The workflow ID")] + pub workflow_id: i64, + #[schemars( + description = "Job status to filter by: 'uninitialized', 'blocked', 'ready', 'pending', 'running', 'completed', 'failed', 'canceled', 'terminated', 'disabled'" + )] + pub status: String, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct UpdateJobResourcesParams { + #[schemars(description = "The job ID")] + pub job_id: i64, + #[schemars(description = "Number of CPUs (optional)")] + pub num_cpus: Option, + #[schemars(description = "Memory requirement, e.g., '4g', '512m' (optional)")] + pub memory: Option, + #[schemars( + description = "Runtime in ISO8601 duration format, e.g., 'PT30M', 'PT2H' (optional)" + )] + pub runtime: Option, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct CancelJobsParams { + #[schemars(description = "List of job IDs to cancel")] + pub job_ids: Vec, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct CreateWorkflowParams { + #[schemars(description = "Workflow specification as JSON string")] + pub spec_json: String, + #[schemars(description = "User that owns the workflow (optional, defaults to current user)")] + pub user: Option, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct CheckResourceUtilizationParams { + #[schemars(description = "The workflow ID")] + pub workflow_id: i64, + #[schemars( + description = "Include failed jobs in the analysis (recommended for recovery diagnostics)" + )] + pub include_failed: Option, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct ResetAndRestartWorkflowParams { + #[schemars(description = "The workflow ID")] + pub workflow_id: i64, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct ResubmitWorkflowParams { + #[schemars(description = "The workflow ID")] + pub workflow_id: i64, + #[schemars( + description = "Slurm account to use (defaults to account from existing schedulers)" + )] + pub account: Option, + #[schemars(description = "HPC profile to use (auto-detected if not specified)")] + pub profile: Option, + #[schemars(description = "Preview what would be submitted without actually submitting")] + pub dry_run: Option, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct RestartJobsParams { + #[schemars(description = "The workflow ID")] + pub workflow_id: i64, + #[schemars(description = "Only restart failed jobs (default: true)")] + pub failed_only: Option, + #[schemars( + description = "Specific job IDs to restart (optional, restarts all failed if not specified)" + )] + pub job_ids: Option>, +} + +// Tool implementations using #[tool(tool_box)] + +#[tool(tool_box)] +impl TorcMcpServer { + /// Get the status of a workflow including job counts by status. + #[tool( + description = "Get workflow status summary with job counts by status (completed, failed, running, etc.)" + )] + async fn get_workflow_status( + &self, + #[tool(aggr)] params: WorkflowIdParam, + ) -> Result { + let config = self.config.clone(); + let workflow_id = params.workflow_id; + tokio::task::spawn_blocking(move || tools::get_workflow_status(&config, workflow_id)) + .await + .map_err(|e| McpError::internal_error(format!("Task join error: {}", e), None))? + } + + /// Get detailed information about a specific job. + #[tool( + description = "Get detailed job information including command, status, resource requirements, and latest result" + )] + async fn get_job_details( + &self, + #[tool(aggr)] params: JobIdParam, + ) -> Result { + let config = self.config.clone(); + let job_id = params.job_id; + tokio::task::spawn_blocking(move || tools::get_job_details(&config, job_id)) + .await + .map_err(|e| McpError::internal_error(format!("Task join error: {}", e), None))? + } + + /// Read job stdout or stderr logs. + #[tool( + description = "Read job execution logs (stdout or stderr). Optionally return only the last N lines." + )] + async fn get_job_logs( + &self, + #[tool(aggr)] params: GetJobLogsParams, + ) -> Result { + let output_dir = self.output_dir.clone(); + let workflow_id = params.workflow_id; + let job_id = params.job_id; + let run_id = params.run_id; + let log_type = params.log_type; + let tail_lines = params.tail_lines; + tokio::task::spawn_blocking(move || { + tools::get_job_logs( + &output_dir, + workflow_id, + job_id, + run_id, + &log_type, + tail_lines, + ) + }) + .await + .map_err(|e| McpError::internal_error(format!("Task join error: {}", e), None))? + } + + /// List all failed jobs in a workflow. + #[tool( + description = "List all jobs with 'failed' status in a workflow, including their error information" + )] + async fn list_failed_jobs( + &self, + #[tool(aggr)] params: WorkflowIdParam, + ) -> Result { + let config = self.config.clone(); + let workflow_id = params.workflow_id; + tokio::task::spawn_blocking(move || tools::list_failed_jobs(&config, workflow_id)) + .await + .map_err(|e| McpError::internal_error(format!("Task join error: {}", e), None))? + } + + /// List jobs filtered by status. + #[tool( + description = "List jobs in a workflow filtered by status (uninitialized, blocked, ready, pending, running, completed, failed, canceled, terminated, disabled)" + )] + async fn list_jobs_by_status( + &self, + #[tool(aggr)] params: ListJobsByStatusParams, + ) -> Result { + let config = self.config.clone(); + let workflow_id = params.workflow_id; + let status = params.status; + tokio::task::spawn_blocking(move || { + tools::list_jobs_by_status(&config, workflow_id, &status) + }) + .await + .map_err(|e| McpError::internal_error(format!("Task join error: {}", e), None))? + } + + /// Check resource utilization for a workflow. + #[tool( + description = "Check resource utilization and identify jobs that exceeded their limits (memory, CPU, runtime). Use --include-failed to analyze failed jobs for recovery diagnostics." + )] + async fn check_resource_utilization( + &self, + #[tool(aggr)] params: CheckResourceUtilizationParams, + ) -> Result { + let workflow_id = params.workflow_id; + let include_failed = params.include_failed.unwrap_or(true); + tokio::task::spawn_blocking(move || { + tools::check_resource_utilization(workflow_id, include_failed) + }) + .await + .map_err(|e| McpError::internal_error(format!("Task join error: {}", e), None))? + } + + /// Reset failed jobs and restart the workflow. + #[tool( + description = "Reset all failed jobs in a workflow and restart it. This resets job statuses to uninitialized and re-initializes the workflow. Use after updating resource requirements for failed jobs." + )] + async fn reset_and_restart_workflow( + &self, + #[tool(aggr)] params: ResetAndRestartWorkflowParams, + ) -> Result { + let workflow_id = params.workflow_id; + tokio::task::spawn_blocking(move || tools::reset_and_restart_workflow(workflow_id)) + .await + .map_err(|e| McpError::internal_error(format!("Task join error: {}", e), None))? + } + + /// Update resource requirements for a job. + #[tool( + description = "Update a job's resource requirements (CPU, memory, runtime). Use before restarting a job that failed due to resource constraints." + )] + async fn update_job_resources( + &self, + #[tool(aggr)] params: UpdateJobResourcesParams, + ) -> Result { + let config = self.config.clone(); + let job_id = params.job_id; + let num_cpus = params.num_cpus; + let memory = params.memory; + let runtime = params.runtime; + tokio::task::spawn_blocking(move || { + tools::update_job_resources(&config, job_id, num_cpus, memory, runtime) + }) + .await + .map_err(|e| McpError::internal_error(format!("Task join error: {}", e), None))? + } + + /// Cancel specific jobs. + #[tool(description = "Cancel one or more jobs. Jobs must be in a cancellable state.")] + async fn cancel_jobs( + &self, + #[tool(aggr)] params: CancelJobsParams, + ) -> Result { + let config = self.config.clone(); + let job_ids = params.job_ids; + tokio::task::spawn_blocking(move || tools::cancel_jobs(&config, &job_ids)) + .await + .map_err(|e| McpError::internal_error(format!("Task join error: {}", e), None))? + } + + /// Create a workflow from a specification. + #[tool( + description = "Create a new workflow from a JSON specification. Returns the new workflow ID." + )] + async fn create_workflow_from_spec( + &self, + #[tool(aggr)] params: CreateWorkflowParams, + ) -> Result { + let config = self.config.clone(); + let spec_json = params.spec_json; + let user = params + .user + .unwrap_or_else(|| std::env::var("USER").unwrap_or_else(|_| "unknown".to_string())); + tokio::task::spawn_blocking(move || { + tools::create_workflow_from_spec(&config, &spec_json, &user) + }) + .await + .map_err(|e| McpError::internal_error(format!("Task join error: {}", e), None))? + } + + /// Restart jobs in a workflow. + #[tool( + description = "Restart jobs in a workflow. By default restarts all failed jobs. Can specify specific job IDs." + )] + async fn restart_jobs( + &self, + #[tool(aggr)] params: RestartJobsParams, + ) -> Result { + let workflow_id = params.workflow_id; + let failed_only = params.failed_only; + let job_ids = params.job_ids; + tokio::task::spawn_blocking(move || tools::restart_jobs(workflow_id, failed_only, job_ids)) + .await + .map_err(|e| McpError::internal_error(format!("Task join error: {}", e), None))? + } + + /// Resubmit a workflow by regenerating Slurm schedulers and submitting allocations. + #[tool( + description = "Regenerate Slurm schedulers for pending jobs and submit allocations. Use after resetting failed jobs to get new Slurm allocations. Analyzes job resource requirements and calculates the minimum allocations needed." + )] + async fn resubmit_workflow( + &self, + #[tool(aggr)] params: ResubmitWorkflowParams, + ) -> Result { + let output_dir = self.output_dir.clone(); + let workflow_id = params.workflow_id; + let account = params.account; + let profile = params.profile; + let dry_run = params.dry_run.unwrap_or(false); + tokio::task::spawn_blocking(move || { + tools::resubmit_workflow(&output_dir, workflow_id, account, profile, dry_run) + }) + .await + .map_err(|e| McpError::internal_error(format!("Task join error: {}", e), None))? + } +} + +#[tool(tool_box)] +impl ServerHandler for TorcMcpServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + protocol_version: ProtocolVersion::V_2024_11_05, + capabilities: ServerCapabilities::builder().enable_tools().build(), + server_info: Implementation::from_build_env(), + instructions: Some( + "Torc MCP Server - Manage computational workflows. \ + Use get_workflow_status to check workflow progress, \ + list_failed_jobs to find failures, \ + get_job_logs to diagnose issues, \ + check_resource_utilization to identify resource problems, \ + update_job_resources to fix resource limits, \ + restart_jobs to reset and restart failed jobs, \ + and resubmit_workflow to regenerate Slurm schedulers and submit new allocations." + .to_string(), + ), + } + } +} diff --git a/torc-mcp-server/src/tools.rs b/torc-mcp-server/src/tools.rs new file mode 100644 index 00000000..53559fb6 --- /dev/null +++ b/torc-mcp-server/src/tools.rs @@ -0,0 +1,614 @@ +//! Tool implementations for the Torc MCP server. + +use rmcp::{Error as McpError, model::CallToolResult}; +use std::fs; +use std::path::PathBuf; +use std::process::Command; + +use torc::client::apis::configuration::Configuration; +use torc::client::apis::default_api; +use torc::client::log_paths; +use torc::models::{JobStatus, ResourceRequirementsModel}; + +/// Maximum number of jobs to retrieve in a single request. +/// Results may be truncated if the workflow has more jobs than this limit. +const MAX_JOBS_LIMIT: i64 = 10000; + +/// Helper to create an internal error +fn internal_error(msg: String) -> McpError { + McpError::internal_error(msg, None) +} + +/// Helper to create an invalid params error +fn invalid_params(msg: &str) -> McpError { + McpError::invalid_request(msg.to_string(), None) +} + +/// Parse status string to JobStatus enum +fn parse_status(status: &str) -> Option { + match status.to_lowercase().as_str() { + "uninitialized" => Some(JobStatus::Uninitialized), + "blocked" => Some(JobStatus::Blocked), + "ready" => Some(JobStatus::Ready), + "pending" => Some(JobStatus::Pending), + "running" => Some(JobStatus::Running), + "completed" => Some(JobStatus::Completed), + "failed" => Some(JobStatus::Failed), + "canceled" => Some(JobStatus::Canceled), + "terminated" => Some(JobStatus::Terminated), + "disabled" => Some(JobStatus::Disabled), + _ => None, + } +} + +/// Get workflow status with job counts. +pub fn get_workflow_status( + config: &Configuration, + workflow_id: i64, +) -> Result { + // Get workflow info + let workflow = default_api::get_workflow(config, workflow_id) + .map_err(|e| internal_error(format!("Failed to get workflow: {}", e)))?; + + // Get job counts by status - get all jobs + let jobs_response = default_api::list_jobs( + config, + workflow_id, + None, // status filter + None, // needs_file_id + None, // upstream_job_id + None, // offset + Some(MAX_JOBS_LIMIT), // limit + None, // sort_by + None, // reverse_sort + None, // include_relationships + ) + .map_err(|e| internal_error(format!("Failed to list jobs: {}", e)))?; + + let jobs = jobs_response.items.unwrap_or_default(); + let truncated = jobs.len() as i64 >= MAX_JOBS_LIMIT; + + // Count jobs by status + let mut status_counts = std::collections::HashMap::new(); + for job in &jobs { + if let Some(status) = &job.status { + let status_str = format!("{:?}", status); + *status_counts.entry(status_str).or_insert(0) += 1; + } + } + + let mut result = serde_json::json!({ + "workflow_id": workflow.id, + "name": workflow.name, + "user": workflow.user, + "description": workflow.description, + "total_jobs": jobs.len(), + "job_counts_by_status": status_counts, + }); + + if truncated { + result["warning"] = serde_json::json!(format!( + "Results truncated at {} jobs. Workflow may have more jobs.", + MAX_JOBS_LIMIT + )); + } + + Ok(CallToolResult::success(vec![rmcp::model::Content::text( + serde_json::to_string_pretty(&result).unwrap_or_default(), + )])) +} + +/// Get detailed job information. +pub fn get_job_details(config: &Configuration, job_id: i64) -> Result { + let job = default_api::get_job(config, job_id) + .map_err(|e| internal_error(format!("Failed to get job: {}", e)))?; + + // Get resource requirements if available + let resource_reqs = if let Some(req_id) = job.resource_requirements_id { + default_api::get_resource_requirements(config, req_id).ok() + } else { + None + }; + + // Get latest result if job has run + let result = default_api::list_results( + config, + job.workflow_id, + Some(job_id), + None, // run_id + None, // offset + Some(1), // limit - just get latest + None, // sort_by + None, // reverse_sort + None, // return_code + None, // status + None, // all_runs + ) + .ok() + .and_then(|r| r.items) + .and_then(|items| items.into_iter().next()); + + let response = serde_json::json!({ + "job_id": job.id, + "workflow_id": job.workflow_id, + "name": job.name, + "command": job.command, + "status": format!("{:?}", job.status), + "invocation_script": job.invocation_script, + "supports_termination": job.supports_termination, + "cancel_on_blocking_job_failure": job.cancel_on_blocking_job_failure, + "depends_on_job_ids": job.depends_on_job_ids, + "resource_requirements": resource_reqs.map(|r| serde_json::json!({ + "id": r.id, + "num_cpus": r.num_cpus, + "num_gpus": r.num_gpus, + "memory": r.memory, + "runtime": r.runtime, + })), + "latest_result": result.map(|r| serde_json::json!({ + "run_id": r.run_id, + "return_code": r.return_code, + "exec_time_minutes": r.exec_time_minutes, + "completion_time": r.completion_time, + "peak_memory_bytes": r.peak_memory_bytes, + "avg_cpu_percent": r.avg_cpu_percent, + })), + }); + + Ok(CallToolResult::success(vec![rmcp::model::Content::text( + serde_json::to_string_pretty(&response).unwrap_or_default(), + )])) +} + +/// Read job logs. +pub fn get_job_logs( + output_dir: &PathBuf, + workflow_id: i64, + job_id: i64, + run_id: i64, + log_type: &str, + tail_lines: Option, +) -> Result { + let log_path = match log_type.to_lowercase().as_str() { + "stdout" => log_paths::get_job_stdout_path(output_dir, workflow_id, job_id, run_id), + "stderr" => log_paths::get_job_stderr_path(output_dir, workflow_id, job_id, run_id), + _ => return Err(invalid_params("log_type must be 'stdout' or 'stderr'")), + }; + + let content = fs::read_to_string(&log_path) + .map_err(|e| internal_error(format!("Failed to read log file {}: {}", log_path, e)))?; + + let output = if let Some(n) = tail_lines { + let lines: Vec<&str> = content.lines().collect(); + let start = lines.len().saturating_sub(n); + lines[start..].join("\n") + } else { + content + }; + + Ok(CallToolResult::success(vec![rmcp::model::Content::text( + output, + )])) +} + +/// List failed jobs in a workflow. +pub fn list_failed_jobs( + config: &Configuration, + workflow_id: i64, +) -> Result { + let jobs_response = default_api::list_jobs( + config, + workflow_id, + Some(JobStatus::Failed), + None, // needs_file_id + None, // upstream_job_id + None, // offset + Some(MAX_JOBS_LIMIT), // limit + None, // sort_by + None, // reverse_sort + None, // include_relationships + ) + .map_err(|e| internal_error(format!("Failed to list jobs: {}", e)))?; + + let jobs = jobs_response.items.unwrap_or_default(); + let truncated = jobs.len() as i64 >= MAX_JOBS_LIMIT; + + let failed_jobs: Vec = jobs + .iter() + .map(|job| { + serde_json::json!({ + "job_id": job.id, + "name": job.name, + "command": job.command, + }) + }) + .collect(); + + let mut result = serde_json::json!({ + "workflow_id": workflow_id, + "failed_job_count": failed_jobs.len(), + "failed_jobs": failed_jobs, + }); + + if truncated { + result["warning"] = serde_json::json!(format!( + "Results truncated at {} jobs. There may be more failed jobs.", + MAX_JOBS_LIMIT + )); + } + + Ok(CallToolResult::success(vec![rmcp::model::Content::text( + serde_json::to_string_pretty(&result).unwrap_or_default(), + )])) +} + +/// List jobs by status. +pub fn list_jobs_by_status( + config: &Configuration, + workflow_id: i64, + status: &str, +) -> Result { + let status_enum = parse_status(status).ok_or_else(|| invalid_params("Invalid status value"))?; + + let jobs_response = default_api::list_jobs( + config, + workflow_id, + Some(status_enum), + None, // needs_file_id + None, // upstream_job_id + None, // offset + Some(MAX_JOBS_LIMIT), // limit + None, // sort_by + None, // reverse_sort + None, // include_relationships + ) + .map_err(|e| internal_error(format!("Failed to list jobs: {}", e)))?; + + let jobs = jobs_response.items.unwrap_or_default(); + let truncated = jobs.len() as i64 >= MAX_JOBS_LIMIT; + + let job_list: Vec = jobs + .iter() + .map(|job| { + serde_json::json!({ + "job_id": job.id, + "name": job.name, + "command": job.command, + }) + }) + .collect(); + + let mut result = serde_json::json!({ + "workflow_id": workflow_id, + "status": status, + "count": job_list.len(), + "jobs": job_list, + }); + + if truncated { + result["warning"] = serde_json::json!(format!( + "Results truncated at {} jobs. There may be more jobs with status '{}'.", + MAX_JOBS_LIMIT, status + )); + } + + Ok(CallToolResult::success(vec![rmcp::model::Content::text( + serde_json::to_string_pretty(&result).unwrap_or_default(), + )])) +} + +/// Check resource utilization for a workflow by running the CLI command. +pub fn check_resource_utilization( + workflow_id: i64, + include_failed: bool, +) -> Result { + let mut cmd = Command::new("torc"); + cmd.args(["-f", "json", "reports", "check-resource-utilization"]); + cmd.arg(workflow_id.to_string()); + + if include_failed { + cmd.arg("--include-failed"); + } + + let output = cmd + .output() + .map_err(|e| internal_error(format!("Failed to execute torc command: {}", e)))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(internal_error(format!( + "torc command failed: {}", + stderr.trim() + ))); + } + + let stdout = String::from_utf8_lossy(&output.stdout); + Ok(CallToolResult::success(vec![rmcp::model::Content::text( + stdout.to_string(), + )])) +} + +/// Reset failed jobs and restart a workflow by running the CLI command. +pub fn reset_and_restart_workflow(workflow_id: i64) -> Result { + let output = Command::new("torc") + .args([ + "-f", + "json", + "workflows", + "reset-status", + &workflow_id.to_string(), + "--failed-only", + "--restart", + "--no-prompts", + ]) + .output() + .map_err(|e| internal_error(format!("Failed to execute torc command: {}", e)))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(internal_error(format!( + "torc command failed: {}", + stderr.trim() + ))); + } + + let stdout = String::from_utf8_lossy(&output.stdout); + Ok(CallToolResult::success(vec![rmcp::model::Content::text( + stdout.to_string(), + )])) +} + +/// Update job resource requirements. +pub fn update_job_resources( + config: &Configuration, + job_id: i64, + num_cpus: Option, + memory: Option, + runtime: Option, +) -> Result { + // Get the job to find its resource requirements ID + let job = default_api::get_job(config, job_id) + .map_err(|e| internal_error(format!("Failed to get job: {}", e)))?; + + let req_id = job + .resource_requirements_id + .ok_or_else(|| invalid_params("Job does not have resource requirements to update"))?; + + // Get current requirements + let mut reqs = default_api::get_resource_requirements(config, req_id) + .map_err(|e| internal_error(format!("Failed to get resource requirements: {}", e)))?; + + // Update fields if provided + if let Some(cpus) = num_cpus { + reqs.num_cpus = cpus; + } + if let Some(mem) = memory { + reqs.memory = mem; + } + if let Some(rt) = runtime { + reqs.runtime = rt; + } + + // Update the resource requirements + let updated = default_api::update_resource_requirements( + config, + req_id, + ResourceRequirementsModel { + id: reqs.id, + workflow_id: reqs.workflow_id, + name: reqs.name.clone(), + num_cpus: reqs.num_cpus, + num_gpus: reqs.num_gpus, + num_nodes: reqs.num_nodes, + memory: reqs.memory.clone(), + runtime: reqs.runtime.clone(), + }, + ) + .map_err(|e| internal_error(format!("Failed to update resource requirements: {}", e)))?; + + let result = serde_json::json!({ + "success": true, + "job_id": job_id, + "resource_requirements_id": req_id, + "updated": { + "num_cpus": updated.num_cpus, + "num_gpus": updated.num_gpus, + "memory": updated.memory, + "runtime": updated.runtime, + }, + }); + + Ok(CallToolResult::success(vec![rmcp::model::Content::text( + serde_json::to_string_pretty(&result).unwrap_or_default(), + )])) +} + +/// Cancel jobs. +pub fn cancel_jobs(config: &Configuration, job_ids: &[i64]) -> Result { + let mut canceled = Vec::new(); + let mut errors = Vec::new(); + + for job_id in job_ids { + match default_api::manage_status_change( + config, + *job_id, + JobStatus::Canceled, + 0, // run_id + None, + ) { + Ok(_) => canceled.push(*job_id), + Err(e) => errors.push(serde_json::json!({ + "job_id": job_id, + "error": format!("{}", e), + })), + } + } + + let result = serde_json::json!({ + "canceled_jobs": canceled, + "canceled_count": canceled.len(), + "errors": errors, + }); + + Ok(CallToolResult::success(vec![rmcp::model::Content::text( + serde_json::to_string_pretty(&result).unwrap_or_default(), + )])) +} + +/// Create a workflow from a JSON specification. +pub fn create_workflow_from_spec( + config: &Configuration, + spec_json: &str, + user: &str, +) -> Result { + use std::io::Write; + + // Parse the spec to get the name for the result message + let spec: serde_json::Value = serde_json::from_str(spec_json) + .map_err(|e| invalid_params(&format!("Invalid workflow spec JSON: {}", e)))?; + + let name = spec + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or("unnamed"); + + // Write spec to a temp file + let mut temp_file = tempfile::NamedTempFile::new() + .map_err(|e| internal_error(format!("Failed to create temp file: {}", e)))?; + + temp_file + .write_all(spec_json.as_bytes()) + .map_err(|e| internal_error(format!("Failed to write spec to temp file: {}", e)))?; + + let temp_path = temp_file.path(); + + // Create the workflow using the existing function + let workflow_id = torc::client::workflow_spec::WorkflowSpec::create_workflow_from_spec( + config, temp_path, user, false, false, + ) + .map_err(|e| internal_error(format!("Failed to create workflow: {}", e)))?; + + let result = serde_json::json!({ + "success": true, + "workflow_id": workflow_id, + "message": format!("Created workflow '{}' with ID {}", name, workflow_id), + }); + + Ok(CallToolResult::success(vec![rmcp::model::Content::text( + serde_json::to_string_pretty(&result).unwrap_or_default(), + )])) +} + +/// Restart jobs in a workflow. +pub fn restart_jobs( + workflow_id: i64, + failed_only: Option, + job_ids: Option>, +) -> Result { + let mut cmd = Command::new("torc"); + cmd.args(["-f", "json", "workflows", "reset-status"]); + cmd.arg(workflow_id.to_string()); + + if failed_only.unwrap_or(true) { + cmd.arg("--failed-only"); + } + + cmd.args(["--restart", "--no-prompts"]); + + if let Some(ids) = &job_ids { + for id in ids { + cmd.args(["--job-id", &id.to_string()]); + } + } + + let output = cmd + .output() + .map_err(|e| internal_error(format!("Failed to run torc command: {}", e)))?; + + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + + if !output.status.success() { + return Ok(CallToolResult::success(vec![rmcp::model::Content::text( + format!( + "{{\"success\": false, \"error\": \"Command failed\", \"stderr\": {:?}}}", + stderr.trim() + ), + )])); + } + + // The CLI outputs JSON, pass it through + let content = if stdout.trim().is_empty() { + serde_json::json!({ + "success": true, + "message": "Jobs restarted successfully" + }) + .to_string() + } else { + stdout.to_string() + }; + + Ok(CallToolResult::success(vec![rmcp::model::Content::text( + content, + )])) +} + +/// Resubmit a workflow by regenerating Slurm schedulers and submitting allocations. +pub fn resubmit_workflow( + output_dir: &PathBuf, + workflow_id: i64, + account: Option, + profile: Option, + dry_run: bool, +) -> Result { + let mut cmd = Command::new("torc"); + cmd.args(["-f", "json", "slurm", "regenerate"]); + cmd.arg(workflow_id.to_string()); + + if let Some(acct) = &account { + cmd.args(["--account", acct]); + } + + if let Some(prof) = &profile { + cmd.args(["--profile", prof]); + } + + cmd.args(["--output-dir", output_dir.to_str().unwrap_or("output")]); + + if !dry_run { + cmd.arg("--submit"); + } + + let output = cmd + .output() + .map_err(|e| internal_error(format!("Failed to run torc command: {}", e)))?; + + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + + if !output.status.success() { + return Ok(CallToolResult::success(vec![rmcp::model::Content::text( + format!( + "{{\"success\": false, \"error\": \"Command failed\", \"stderr\": {:?}}}", + stderr.trim() + ), + )])); + } + + // The CLI outputs JSON, pass it through + let content = if stdout.trim().is_empty() { + serde_json::json!({ + "success": true, + "workflow_id": workflow_id, + "dry_run": dry_run, + "message": if dry_run { "Preview complete" } else { "Allocations submitted" } + }) + .to_string() + } else { + stdout.to_string() + }; + + Ok(CallToolResult::success(vec![rmcp::model::Content::text( + content, + )])) +}