diff --git a/stratevo/ai_strategy/dna_generator.py b/stratevo/ai_strategy/dna_generator.py new file mode 100644 index 00000000..3c87b469 --- /dev/null +++ b/stratevo/ai_strategy/dna_generator.py @@ -0,0 +1,177 @@ +""" +DNA Generator +============= +Natural language → StrategyDNA dict, with validation against known fields and ranges. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import re +from dataclasses import fields as dataclass_fields +from typing import Optional + +from stratevo.llm.registry import auto_detect_provider, get_provider +from stratevo.llm.base import LLMProvider +from stratevo.ai_strategy.prompt_templates import ( + build_dna_system_prompt, + build_dna_user_prompt, +) +from stratevo.evolution.models import StrategyDNA, _PARAM_RANGES + +logger = logging.getLogger("stratevo.ai_strategy") + + +def _extract_json(raw: str) -> str: + """Extract JSON from LLM response, stripping markdown fences.""" + # Try ```json ... ``` blocks + match = re.search(r"```json\s*\n(.*?)```", raw, re.DOTALL) + if match: + return match.group(1).strip() + # Try generic code blocks + match = re.search(r"```\s*\n(.*?)```", raw, re.DOTALL) + if match: + return match.group(1).strip() + # Try to find a JSON object directly + match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", raw, re.DOTALL) + if match: + return match.group(0).strip() + return raw.strip() + + +def validate_dna_dict(dna_dict: dict) -> tuple[dict, list[str]]: + """Validate and clamp a DNA dict to valid fields and ranges. + + Returns (cleaned_dict, list_of_warnings). + """ + valid_fields = {f.name for f in dataclass_fields(StrategyDNA)} + valid_fields.discard("custom_weights") + warnings: list[str] = [] + cleaned: dict = {} + + for key, value in dna_dict.items(): + if key == "custom_weights": + continue + if key not in valid_fields: + warnings.append(f"Unknown field '{key}' ignored") + continue + + # Type coercion and range clamping + if key in _PARAM_RANGES: + lo, hi, is_int = _PARAM_RANGES[key] + try: + if is_int: + value = int(round(float(value))) + else: + value = float(value) + except (ValueError, TypeError): + warnings.append(f"Invalid type for '{key}': {value!r}, skipped") + continue + if value < lo: + warnings.append(f"'{key}' clamped from {value} to min {lo}") + value = int(lo) if is_int else lo + elif value > hi: + warnings.append(f"'{key}' clamped from {value} to max {hi}") + value = int(hi) if is_int else hi + else: + # Field exists in StrategyDNA but not in _PARAM_RANGES — accept as-is + pass + + cleaned[key] = value + + return cleaned, warnings + + +class DNAGenerator: + """Generate StrategyDNA dict from natural language descriptions.""" + + def __init__( + self, + provider: Optional[LLMProvider] = None, + provider_name: Optional[str] = None, + ): + self._provider = provider + self._provider_name = provider_name + + def _get_provider(self) -> LLMProvider: + if self._provider: + return self._provider + if self._provider_name: + return get_provider(self._provider_name) + provider = auto_detect_provider() + if provider is None: + raise RuntimeError( + "No LLM provider available. Set an API key env var " + "(OPENAI_API_KEY, ANTHROPIC_API_KEY, DEEPSEEK_API_KEY, etc.) " + "or start a local Ollama instance." + ) + return provider + + async def generate_async( + self, + description: str, + max_retries: int = 2, + ) -> dict: + """Generate StrategyDNA dict from a natural language description. + + Returns: + {"dna": dict, "valid": bool, "warnings": list[str], "errors": list[str]} + """ + provider = self._get_provider() + system = build_dna_system_prompt() + user = build_dna_user_prompt(description) + + messages = [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ] + + errors: list[str] = [] + for attempt in range(max_retries + 1): + try: + raw = await provider.chat(messages, temperature=0.3) + json_str = _extract_json(raw) + dna_dict = json.loads(json_str) + + if not isinstance(dna_dict, dict): + raise ValueError(f"Expected JSON object, got {type(dna_dict).__name__}") + + cleaned, warnings = validate_dna_dict(dna_dict) + + # Verify it creates a valid StrategyDNA + StrategyDNA.from_dict(cleaned) + + return { + "dna": cleaned, + "valid": True, + "warnings": warnings, + "errors": [], + } + except json.JSONDecodeError as e: + errors.append(f"JSON parse error: {e}") + messages.append({"role": "assistant", "content": raw}) + messages.append({ + "role": "user", + "content": f"Invalid JSON: {e}. Output ONLY a valid JSON object.", + }) + except Exception as e: + errors.append(str(e)) + if "raw" in dir(): + messages.append({"role": "assistant", "content": raw}) + messages.append({ + "role": "user", + "content": f"Error: {e}. Fix and output ONLY a valid JSON object.", + }) + + return { + "dna": {}, + "valid": False, + "warnings": [], + "errors": errors, + } + + def generate(self, description: str, max_retries: int = 2) -> dict: + """Sync wrapper for generate_async.""" + return asyncio.run(self.generate_async(description, max_retries)) diff --git a/stratevo/ai_strategy/prompt_templates.py b/stratevo/ai_strategy/prompt_templates.py index 8de2778e..4a6cf45d 100644 --- a/stratevo/ai_strategy/prompt_templates.py +++ b/stratevo/ai_strategy/prompt_templates.py @@ -184,6 +184,135 @@ def _format_results(results: dict) -> str: return "\n".join(lines) +def build_dna_system_prompt() -> str: + """Build system prompt for StrategyDNA generation from natural language.""" + return """You are StratEvo DNA Architect, an expert at translating natural language trading strategy descriptions into StrategyDNA parameter configurations. + +## Your Task +Given a natural language description of a trading strategy, output a JSON object containing ONLY the StrategyDNA fields that should differ from their defaults. The JSON must be valid and parseable. + +## Output Format +Output ONLY a valid JSON object. No markdown fences, no explanations before or after. +Example: {"hold_days": 10, "stop_loss_pct": 5.0, "w_momentum": 0.3, "w_trend": 0.4} + +## StrategyDNA Fields Reference + +### Selection Thresholds +- min_score (int, default=6, range 4-8): Minimum composite score to select a stock. Higher = stricter filtering. +- rsi_buy_threshold (float, default=35.0, range 10-50): RSI level below which a stock is considered oversold (buy signal). +- rsi_sell_threshold (float, default=75.0, range 55-95): RSI level above which a stock is considered overbought (sell signal). +- r2_min (float, default=0.5, range 0.1-0.95): Minimum R-squared for trend linearity confirmation. +- slope_min (float, default=0.5, range 0.1-3.0): Minimum daily price slope (%) for trend confirmation. +- volume_ratio_min (float, default=1.2, range 0.5-5.0): Minimum volume ratio vs average to confirm interest. + +### Execution Parameters +- hold_days (int, default=3, range 2-60): Number of days to hold a position before exiting. +- stop_loss_pct (float, default=2.0, range 0.5-25.0): Maximum loss (%) before forced exit. +- take_profit_pct (float, default=20.0, range 3.0-200.0): Profit target (%) to take gains. +- max_positions (int, default=2, range 1-10): Maximum number of concurrent positions. + +### Golden Dip Parameters +- dip_threshold_pct (float, default=10.0, range 3-30): Pullback (%) from recent high to qualify as a dip buy. +- r2_trend_min (float, default=0.6, range 0.2-0.95): Minimum R-squared to confirm a bull trend before dip buying. + +### Position & Risk Management +- position_sizing_power (float, default=0.0, range 0-3.0): Score-based allocation power. 0=equal weight, 2=quadratic (high-score stocks get more capital). +- trailing_stop_enabled (float, default=0.0, range 0-1): >0.5 enables trailing stop. Activated after price rises trailing_activation_pct, then trails at trailing_stop_pct below peak. +- trailing_activation_pct (float, default=5.0, range 2-30): % gain to activate trailing stop. +- trailing_stop_pct (float, default=3.0, range 1-15): % below peak to trigger trailing stop exit. +- scale_in_tranches (int, default=1, range 1-3): Split entry into N tranches over consecutive days. 1=all-in. +- stop_loss_per_day (float, default=0.0, range 0-3.0): Per-day stop loss rate. Effective SL = stop_loss_per_day * hold_days. 0=use fixed stop_loss_pct. +- market_regime_sensitivity (float, default=0.0, range 0-1.0): Reduce exposure when market breadth is weak. 0=disabled. +- min_breadth_to_trade (float, default=0.0, range 0-0.8): Minimum market breadth to trade. 0=always trade, 0.5=only trade in bull. +- sector_max_pct (float, default=1.0, range 0.2-1.0): Max fraction of portfolio in one sector. 1.0=no limit. +- profit_target_scaling (float, default=0.0, range 0-1.0): Scale take_profit by signal strength. 0=fixed target. +- time_decay_exit (float, default=0.0, range 0-1.0): After hold_days, tighten stop loss progressively. 0=off. +- weekly_trend_confirm (float, default=0.0, range 0-1.0): >0.5 requires weekly uptrend (5MA>20MA) confirmation. +- kelly_fraction (float, default=0.0, range 0-1.0): Dynamic position sizing via Kelly criterion. 0=disabled, 1.0=full half-Kelly. +- factor_momentum (float, default=0.0, range 0-1.0): Boost weights of recently-profitable factors. 0=static weights. +- capital_utilization (float, default=1.0, range 0.3-1.0): Fraction of capital to deploy. 1.0=100%, 0.5=50% deployed + 50% cash. +- trend_exit_enabled (float, default=0.0, range 0-1.0): >0.5 enables early exit when 5-day MA crosses below 10-day MA. + +### Bear Market Regime Overrides +- bear_min_score_add (float, default=0.0, range 0-3.0): Add to min_score in bear markets (tighter filtering). +- bear_stop_loss_mult (float, default=1.0, range 0.3-1.0): Tighten stop loss in bear (multiplier on stop_loss_pct). +- bear_hold_days_mult (float, default=1.0, range 0.3-1.0): Shorten hold period in bear (multiplier on hold_days). +- bear_take_profit_mult (float, default=1.0, range 0.3-1.0): Lower targets in bear (multiplier on take_profit_pct). + +### Factor Weights (all float, default varies, range 0-1.0, auto-normalized to sum=1) + +Technical Core: +- w_momentum (default=0.1): RSI + slope-based momentum signal. +- w_mean_reversion (default=0.1): RSI oversold mean-reversion signal. +- w_volume (default=0.1): Volume ratio vs average — confirms institutional interest. +- w_trend (default=0.1): R-squared + MA alignment — trend quality. +- w_pattern (default=0.1): Candlestick pattern recognition. +- w_macd (default=0.1): MACD golden/death cross signal. +- w_bollinger (default=0.1): Bollinger Band position (buy near lower, sell near upper). +- w_kdj (default=0.1): KDJ golden cross — popular in Asian markets. +- w_obv (default=0.1): On-Balance Volume trend — price-volume confirmation. +- w_support (default=0.05): Proximity to support/resistance levels. +- w_volume_profile (default=0.05): Volume profile shape analysis. + +Technical Extended: +- w_atr (default=0.0): Average True Range — volatility measure. +- w_adx (default=0.0): Average Directional Index — trend strength. +- w_roc (default=0.0): Rate of Change — price momentum. +- w_williams_r (default=0.0): Williams %R — overbought/oversold oscillator. +- w_cci (default=0.0): Commodity Channel Index — deviation from statistical mean. +- w_mfi (default=0.0): Money Flow Index — volume-weighted RSI. +- w_vwap (default=0.0): Volume Weighted Average Price distance. +- w_donchian (default=0.0): Donchian Channel breakout signal. +- w_ichimoku (default=0.0): Ichimoku cloud position — trend/support system. +- w_elder_ray (default=0.0): Elder Ray bull/bear power indicator. + +Rolling Statistics: +- w_beta (default=0.0): Price regression slope (beta). +- w_r_squared (default=0.0): Trend linearity via R-squared. +- w_residual (default=0.0): Regression residual — mean reversion signal. +- w_quantile_upper (default=0.0): Distance to 80th percentile. +- w_quantile_lower (default=0.0): Distance to 20th percentile. +- w_aroon (default=0.0): Days since high/low — trend indicator. +- w_price_volume_corr (default=0.0): Price-volume correlation. + +Fundamental: +- w_pe (default=0.0): Price-to-Earnings valuation score. +- w_pb (default=0.0): Price-to-Book value score. +- w_roe (default=0.0): Return on Equity. +- w_revenue_growth (default=0.0): Revenue growth rate. +- w_revenue_yoy (default=0.0): Revenue year-over-year growth. +- w_revenue_qoq (default=0.0): Revenue quarter-over-quarter growth. +- w_profit_yoy (default=0.0): Net profit year-over-year growth. +- w_profit_qoq (default=0.0): Net profit quarter-over-quarter growth. +- w_ps (default=0.0): Price-to-Sales ratio. +- w_peg (default=0.0): PEG ratio (PE / earnings growth). +- w_gross_margin (default=0.0): Gross margin quality. +- w_debt_ratio (default=0.0): Debt-to-asset ratio. +- w_cashflow (default=0.0): Operating cashflow quality. + +## Rules +1. Output ONLY changed fields — fields that differ from their defaults. +2. All values must respect the valid ranges listed above. +3. Integer fields (min_score, hold_days, max_positions, scale_in_tranches) must be integers. +4. Factor weights are auto-normalized; you only need to set relative magnitudes. +5. If the user describes a momentum strategy, increase w_momentum, w_macd, w_roc, w_adx etc. +6. If the user describes a value/fundamental strategy, increase w_pe, w_pb, w_roe etc. +7. If the user describes a mean-reversion strategy, increase w_mean_reversion, w_bollinger, w_williams_r etc. +8. If the user wants conservative risk, set tighter stop_loss_pct, lower max_positions, enable trailing stops. +9. If the user wants aggressive growth, use wider stops, more positions, longer hold periods. +10. Think carefully about which parameters match the user's intent. +""" + + +def build_dna_user_prompt(description: str) -> str: + """Build user prompt for DNA generation.""" + return ( + f"Translate the following trading strategy description into a StrategyDNA JSON object. " + f"Only include fields that should differ from defaults.\n\n" + f"Strategy description: {description}" + ) + + def build_copilot_system_prompt() -> str: """System prompt for the StratEvo Copilot chat mode.""" return """You are StratEvo Copilot, an AI financial analysis assistant. diff --git a/stratevo/cli/commands/strategy.py b/stratevo/cli/commands/strategy.py index 58810eb9..92fd7ec0 100644 --- a/stratevo/cli/commands/strategy.py +++ b/stratevo/cli/commands/strategy.py @@ -370,6 +370,7 @@ def cmd_evolve(args): pareto=getattr(args, "pareto", False), pareto_complexity=getattr(args, "pareto_complexity", False), held_out_ratio=getattr(args, "held_out", 0.2), + seed_dna_path=getattr(args, "seed_dna", None), ) try: diff --git a/stratevo/cli/main.py b/stratevo/cli/main.py index fd6715e0..08d8c1b5 100644 --- a/stratevo/cli/main.py +++ b/stratevo/cli/main.py @@ -714,7 +714,9 @@ def build_parser() -> argparse.ArgumentParser: p_gen.add_argument("--market", default="us_stock", choices=["us_stock", "crypto", "cn_stock"], help="Target market") p_gen.add_argument("--risk", default="medium", choices=["low", "medium", "high"], help="Risk profile") p_gen.add_argument("--provider", default=None, help="LLM provider (openai, anthropic, deepseek, ollama, ...)") - p_gen.add_argument("--output", "-o", default=None, help="Save generated code to file") + p_gen.add_argument("--output", "-o", default=None, help="Save generated code/DNA to file") + p_gen.add_argument("--output-format", default="code", choices=["code", "dna"], + help="Output format: 'code' for StrategyPlugin Python code, 'dna' for StrategyDNA JSON (default: code)") p_opt = sub.add_parser("optimize-strategy", help="AI-optimize an existing strategy") p_opt.add_argument("strategy_file", help="Path to strategy .py file") @@ -815,6 +817,8 @@ def build_parser() -> argparse.ArgumentParser: p_evo.add_argument("--held-out", type=float, default=0.2, help="Held-out validation ratio (0.0-1.0). Last N%% of data reserved for " "final validation after evolution. Set to 0 to disable. (default: 0.2)") + p_evo.add_argument("--seed-dna", default=None, + help="Path to a StrategyDNA JSON file to use as initial seed instead of defaults") # compare-markets p_cmp = sub.add_parser("compare-markets", help="Compare strategy performance across different markets") @@ -1084,9 +1088,38 @@ def _handle_paper(args): def _cmd_generate_strategy(args): """Handle: stratevo generate-strategy""" - from stratevo.ai_strategy.strategy_generator import StrategyGenerator import asyncio + output_format = getattr(args, "output_format", "code") + + if output_format == "dna": + from stratevo.ai_strategy.dna_generator import DNAGenerator + import json as _json + + if not args.description: + print(" Usage: stratevo generate-strategy \"description\" --output-format dna [-o file.json]") + return + + gen = DNAGenerator(provider_name=args.provider) + print(f" 🤖 Generating StrategyDNA for: {args.description}") + result = gen.generate(args.description) + + if result["valid"]: + dna_json = _json.dumps(result["dna"], indent=2, ensure_ascii=False) + print(f" ✅ Generated StrategyDNA ({len(result['dna'])} fields)\n") + print(dna_json) + if result["warnings"]: + print(f"\n ⚠ Warnings: {result['warnings']}") + if args.output: + with open(args.output, "w", encoding="utf-8") as f: + f.write(dna_json) + print(f"\n Saved to {args.output}") + else: + print(f" ❌ DNA generation failed: {result['errors']}") + return + + from stratevo.ai_strategy.strategy_generator import StrategyGenerator + gen = StrategyGenerator( provider_name=args.provider, market=args.market, diff --git a/stratevo/evolution/auto_evolve.py b/stratevo/evolution/auto_evolve.py index de91862b..544a97ea 100644 --- a/stratevo/evolution/auto_evolve.py +++ b/stratevo/evolution/auto_evolve.py @@ -381,6 +381,7 @@ def __init__( logic_bias_strength: float = 3.0, max_stocks: int = 500, held_out_ratio: float = 0.2, + seed_dna_path: Optional[str] = None, ): self.data_dir = data_dir self.population_size = population_size @@ -390,6 +391,7 @@ def __init__( self.market = market self.max_stocks = max_stocks self.held_out_ratio = held_out_ratio + self.seed_dna_path = seed_dna_path self.rng = random.Random(seed) os.makedirs(results_dir, exist_ok=True) @@ -1960,6 +1962,19 @@ def evolve(self, generations: int = 100, save_interval: int = 10) -> List[Evolut start_gen = self._load_start_gen() if parents: logger.info(f"Resuming from generation {start_gen} with {len(parents)} elite strategies") + elif self.seed_dna_path: + # Load seed DNA from file + try: + with open(self.seed_dna_path, "r", encoding="utf-8") as f: + seed_data = json.load(f) + parents = [StrategyDNA.from_dict(seed_data)] + start_gen = 0 + logger.info(f"Starting with seed DNA from {self.seed_dna_path}") + except Exception as e: + logger.warning(f"Failed to load seed DNA from {self.seed_dna_path}: {e}") + parents = [StrategyDNA()] + start_gen = 0 + logger.info("Falling back to default strategy DNA") else: parents = [StrategyDNA()] # default seed start_gen = 0 diff --git a/tests/snapshots/evolve.txt b/tests/snapshots/evolve.txt index e0184572..dfe2060a 100644 --- a/tests/snapshots/evolve.txt +++ b/tests/snapshots/evolve.txt @@ -6,7 +6,7 @@ usage: stratevo evolve [-h] [--market {crypto,a-share,cn,us}] [--quick] [--seed SEED] [--save-interval SAVE_INTERVAL] [--max-stocks MAX_STOCKS] [--download] [--symbols SYMBOLS] [--pareto] [--pareto-complexity] - [--held-out HELD_OUT] + [--held-out HELD_OUT] [--seed-dna SEED_DNA] options: -h, --help show this help message and exit @@ -41,3 +41,5 @@ options: --held-out HELD_OUT Held-out validation ratio (0.0-1.0). Last N% of data reserved for final validation after evolution. Set to 0 to disable. (default: 0.2) + --seed-dna SEED_DNA Path to a StrategyDNA JSON file to use as initial seed + instead of defaults diff --git a/tests/test_dna_generator.py b/tests/test_dna_generator.py new file mode 100644 index 00000000..ac8acaa5 --- /dev/null +++ b/tests/test_dna_generator.py @@ -0,0 +1,368 @@ +"""Tests for DNA Generator — NL → StrategyDNA pipeline.""" + +import asyncio +import json +import os +import sys +import tempfile +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from stratevo.ai_strategy.dna_generator import ( + DNAGenerator, + _extract_json, + validate_dna_dict, +) +from stratevo.ai_strategy.prompt_templates import ( + build_dna_system_prompt, + build_dna_user_prompt, +) +from stratevo.evolution.models import StrategyDNA, _PARAM_RANGES + + +# ── Helpers ────────────────────────────────────────────────────── + +def _mock_provider(response: str): + """Create a mock LLM provider that returns `response` from chat().""" + provider = MagicMock() + provider.chat = AsyncMock(return_value=response) + return provider + + +# ══════════════════════════════════════════════════════════════════ +# 1. Prompt template tests +# ══════════════════════════════════════════════════════════════════ + +class TestDNAPromptTemplates: + def test_system_prompt_contains_all_weight_fields(self): + prompt = build_dna_system_prompt() + for key in ( + "w_momentum", "w_macd", "w_bollinger", "w_pe", "w_atr", + "w_beta", "w_revenue_yoy", "w_ps", "w_gross_margin", + ): + assert key in prompt, f"Missing factor {key} in system prompt" + + def test_system_prompt_contains_threshold_fields(self): + prompt = build_dna_system_prompt() + for key in ("min_score", "rsi_buy_threshold", "stop_loss_pct", "hold_days"): + assert key in prompt + + def test_system_prompt_contains_risk_fields(self): + prompt = build_dna_system_prompt() + for key in ( + "trailing_stop_enabled", "kelly_fraction", "capital_utilization", + "market_regime_sensitivity", "bear_stop_loss_mult", + ): + assert key in prompt + + def test_system_prompt_contains_ranges(self): + prompt = build_dna_system_prompt() + assert "range 4-8" in prompt # min_score + assert "range 0.5-25.0" in prompt # stop_loss_pct + + def test_user_prompt_includes_description(self): + desc = "momentum strategy with tight stops" + prompt = build_dna_user_prompt(desc) + assert desc in prompt + + def test_user_prompt_mentions_defaults(self): + prompt = build_dna_user_prompt("test") + assert "differ from defaults" in prompt + + +# ══════════════════════════════════════════════════════════════════ +# 2. JSON extraction tests +# ══════════════════════════════════════════════════════════════════ + +class TestExtractJSON: + def test_extract_from_json_fence(self): + raw = 'Here:\n```json\n{"hold_days": 10}\n```\nDone.' + assert json.loads(_extract_json(raw)) == {"hold_days": 10} + + def test_extract_from_generic_fence(self): + raw = '```\n{"w_momentum": 0.3}\n```' + assert json.loads(_extract_json(raw)) == {"w_momentum": 0.3} + + def test_extract_bare_json(self): + raw = '{"stop_loss_pct": 5.0, "hold_days": 7}' + assert json.loads(_extract_json(raw)) == {"stop_loss_pct": 5.0, "hold_days": 7} + + def test_extract_json_with_surrounding_text(self): + raw = 'Based on your description, here is the DNA: {"hold_days": 5} Hope this helps!' + assert json.loads(_extract_json(raw)) == {"hold_days": 5} + + +# ══════════════════════════════════════════════════════════════════ +# 3. Validation tests +# ══════════════════════════════════════════════════════════════════ + +class TestValidateDNADict: + def test_valid_dict_passes(self): + d = {"hold_days": 10, "stop_loss_pct": 5.0, "w_momentum": 0.3} + cleaned, warnings = validate_dna_dict(d) + assert cleaned == d + assert warnings == [] + + def test_unknown_field_ignored(self): + d = {"hold_days": 5, "unknown_field": 42} + cleaned, warnings = validate_dna_dict(d) + assert "unknown_field" not in cleaned + assert any("Unknown field" in w for w in warnings) + + def test_value_clamped_to_min(self): + d = {"min_score": 1} # min is 4 + cleaned, warnings = validate_dna_dict(d) + assert cleaned["min_score"] == 4 + assert any("clamped" in w for w in warnings) + + def test_value_clamped_to_max(self): + d = {"max_positions": 99} # max is 10 + cleaned, warnings = validate_dna_dict(d) + assert cleaned["max_positions"] == 10 + + def test_float_converted_for_int_field(self): + d = {"hold_days": 5.7} + cleaned, _ = validate_dna_dict(d) + assert cleaned["hold_days"] == 6 + assert isinstance(cleaned["hold_days"], int) + + def test_invalid_type_skipped(self): + d = {"hold_days": "not_a_number"} + cleaned, warnings = validate_dna_dict(d) + assert "hold_days" not in cleaned + assert any("Invalid type" in w for w in warnings) + + def test_custom_weights_ignored(self): + d = {"custom_weights": {"my_factor": 0.5}} + cleaned, _ = validate_dna_dict(d) + assert "custom_weights" not in cleaned + + def test_all_weight_fields_accepted(self): + d = {f"w_{name}": 0.1 for name in [ + "momentum", "mean_reversion", "volume", "trend", "pattern", + "macd", "bollinger", "kdj", "obv", "support", "volume_profile", + "pe", "pb", "roe", "revenue_growth", + ]} + cleaned, warnings = validate_dna_dict(d) + assert len(cleaned) == 15 + assert warnings == [] + + def test_empty_dict(self): + cleaned, warnings = validate_dna_dict({}) + assert cleaned == {} + assert warnings == [] + + +# ══════════════════════════════════════════════════════════════════ +# 4. DNAGenerator tests (with mocked LLM) +# ══════════════════════════════════════════════════════════════════ + +class TestDNAGenerator: + def test_generate_valid_response(self): + mock_response = json.dumps({"hold_days": 10, "w_momentum": 0.4}) + gen = DNAGenerator(provider=_mock_provider(mock_response)) + result = gen.generate("momentum strategy") + assert result["valid"] is True + assert result["dna"]["hold_days"] == 10 + assert result["dna"]["w_momentum"] == 0.4 + + def test_generate_with_json_fence(self): + mock_response = '```json\n{"stop_loss_pct": 3.0, "max_positions": 5}\n```' + gen = DNAGenerator(provider=_mock_provider(mock_response)) + result = gen.generate("conservative strategy") + assert result["valid"] is True + assert result["dna"]["stop_loss_pct"] == 3.0 + assert result["dna"]["max_positions"] == 5 + + def test_generate_clamps_out_of_range(self): + mock_response = json.dumps({"hold_days": 999, "stop_loss_pct": -5}) + gen = DNAGenerator(provider=_mock_provider(mock_response)) + result = gen.generate("test") + assert result["valid"] is True + assert result["dna"]["hold_days"] == 60 # max + assert result["dna"]["stop_loss_pct"] == 0.5 # min + assert len(result["warnings"]) >= 2 + + @pytest.mark.asyncio + async def test_generate_async(self): + mock_response = json.dumps({"w_trend": 0.5}) + gen = DNAGenerator(provider=_mock_provider(mock_response)) + result = await gen.generate_async("trend following") + assert result["valid"] is True + assert result["dna"]["w_trend"] == 0.5 + + def test_generate_invalid_json_retries(self): + provider = MagicMock() + # First call returns invalid JSON, second returns valid + provider.chat = AsyncMock(side_effect=[ + "This is not JSON at all {{{", + json.dumps({"hold_days": 7}), + ]) + gen = DNAGenerator(provider=provider) + result = gen.generate("test") + assert result["valid"] is True + assert result["dna"]["hold_days"] == 7 + + def test_generate_all_retries_fail(self): + provider = MagicMock() + provider.chat = AsyncMock(return_value="totally not json ][}{") + gen = DNAGenerator(provider=provider) + result = gen.generate("test", max_retries=1) + assert result["valid"] is False + assert len(result["errors"]) > 0 + + def test_generate_creates_valid_strategy_dna(self): + mock_response = json.dumps({ + "hold_days": 10, "stop_loss_pct": 5.0, + "w_momentum": 0.3, "w_trend": 0.4, "w_bollinger": 0.2, + "trailing_stop_enabled": 0.8, + }) + gen = DNAGenerator(provider=_mock_provider(mock_response)) + result = gen.generate("test") + assert result["valid"] is True + dna = StrategyDNA.from_dict(result["dna"]) + assert dna.hold_days == 10 + assert dna.trailing_stop_enabled == 0.8 + + def test_no_provider_raises(self): + gen = DNAGenerator() + with patch("stratevo.ai_strategy.dna_generator.auto_detect_provider", return_value=None): + with pytest.raises(RuntimeError, match="No LLM provider"): + gen.generate("test") + + +# ══════════════════════════════════════════════════════════════════ +# 5. CLI argument tests +# ══════════════════════════════════════════════════════════════════ + +class TestCLIArgs: + def test_generate_strategy_has_output_format_arg(self): + """Verify the generate-strategy parser accepts --output-format.""" + from stratevo.cli.main import build_parser + parser = build_parser() + args = parser.parse_args(["generate-strategy", "test", "--output-format", "dna"]) + assert args.output_format == "dna" + + def test_generate_strategy_default_format_is_code(self): + from stratevo.cli.main import build_parser + parser = build_parser() + args = parser.parse_args(["generate-strategy", "test"]) + assert args.output_format == "code" + + def test_evolve_has_seed_dna_arg(self): + """Verify the evolve parser accepts --seed-dna.""" + from stratevo.cli.main import build_parser + parser = build_parser() + args = parser.parse_args(["evolve", "--seed-dna", "my_dna.json"]) + assert args.seed_dna == "my_dna.json" + + def test_evolve_seed_dna_default_none(self): + from stratevo.cli.main import build_parser + parser = build_parser() + args = parser.parse_args(["evolve"]) + assert args.seed_dna is None + + +# ══════════════════════════════════════════════════════════════════ +# 6. Seed DNA loading tests (auto_evolve integration) +# ══════════════════════════════════════════════════════════════════ + +class TestSeedDNALoading: + def test_seed_dna_loaded_from_file(self, tmp_path): + """Test that AutoEvolver loads seed DNA from a JSON file.""" + dna_dict = {"hold_days": 15, "stop_loss_pct": 8.0, "w_momentum": 0.5} + dna_file = tmp_path / "seed.json" + dna_file.write_text(json.dumps(dna_dict)) + + from stratevo.evolution.auto_evolve import AutoEvolver + # Create evolver with seed_dna_path — we only test __init__ stores it + evolver = AutoEvolver( + data_dir=str(tmp_path), + seed_dna_path=str(dna_file), + walk_forward=False, + ) + assert evolver.seed_dna_path == str(dna_file) + + def test_seed_dna_file_parsed_correctly(self, tmp_path): + """Test that a seed DNA JSON file creates a valid StrategyDNA.""" + dna_dict = {"hold_days": 20, "w_trend": 0.6, "max_positions": 5} + dna_file = tmp_path / "seed.json" + dna_file.write_text(json.dumps(dna_dict)) + + dna = StrategyDNA.from_dict(json.loads(dna_file.read_text())) + assert dna.hold_days == 20 + assert dna.w_trend == 0.6 + assert dna.max_positions == 5 + + def test_seed_dna_path_none_uses_default(self): + """Test that seed_dna_path=None falls through to default DNA.""" + from stratevo.evolution.auto_evolve import AutoEvolver + evolver = AutoEvolver( + data_dir=".", + walk_forward=False, + ) + assert evolver.seed_dna_path is None + + def test_seed_dna_invalid_file_falls_back(self, tmp_path): + """Test that an invalid seed DNA file doesn't crash — logs warning and falls back.""" + bad_file = tmp_path / "bad_seed.json" + bad_file.write_text("not valid json {{{") + + from stratevo.evolution.auto_evolve import AutoEvolver + evolver = AutoEvolver( + data_dir=str(tmp_path), + seed_dna_path=str(bad_file), + walk_forward=False, + ) + # The evolver should store the path; the fallback happens in evolve() + assert evolver.seed_dna_path == str(bad_file) + + +# ══════════════════════════════════════════════════════════════════ +# 7. Edge cases and integration +# ══════════════════════════════════════════════════════════════════ + +class TestEdgeCases: + def test_all_param_ranges_covered_in_validation(self): + """Every key in _PARAM_RANGES should be validatable.""" + d = {} + for key, (lo, hi, is_int) in _PARAM_RANGES.items(): + d[key] = lo + cleaned, warnings = validate_dna_dict(d) + assert len(cleaned) == len(_PARAM_RANGES) + assert warnings == [] + + def test_dna_roundtrip_via_dict(self): + """StrategyDNA → dict → StrategyDNA roundtrip preserves values.""" + dna = StrategyDNA(hold_days=15, w_momentum=0.5, stop_loss_pct=3.0) + d = dna.to_dict() + dna2 = StrategyDNA.from_dict(d) + assert dna2.hold_days == 15 + assert dna2.w_momentum == 0.5 + assert dna2.stop_loss_pct == 3.0 + + def test_generate_with_chinese_description(self): + """Test that Chinese descriptions work.""" + mock_response = json.dumps({"w_kdj": 0.3, "hold_days": 5}) + gen = DNAGenerator(provider=_mock_provider(mock_response)) + result = gen.generate("使用KDJ金叉策略,短线持有") + assert result["valid"] is True + assert result["dna"]["w_kdj"] == 0.3 + + def test_generate_output_saved_to_file(self, tmp_path): + """Integration test: generated DNA can be saved and reloaded.""" + mock_response = json.dumps({"hold_days": 7, "w_macd": 0.4}) + gen = DNAGenerator(provider=_mock_provider(mock_response)) + result = gen.generate("MACD crossover strategy") + assert result["valid"] is True + + out_file = tmp_path / "dna.json" + out_file.write_text(json.dumps(result["dna"], indent=2)) + + loaded = json.loads(out_file.read_text()) + dna = StrategyDNA.from_dict(loaded) + assert dna.hold_days == 7 + assert dna.w_macd == 0.4