Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions stratevo/ai_strategy/dna_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""
DNA Generator
=============
Natural language → StrategyDNA dict, with validation against known fields and ranges.
"""

from __future__ import annotations

import asyncio
import json
import logging
import re
from dataclasses import fields as dataclass_fields
from typing import Optional

from stratevo.llm.registry import auto_detect_provider, get_provider
from stratevo.llm.base import LLMProvider
from stratevo.ai_strategy.prompt_templates import (
build_dna_system_prompt,
build_dna_user_prompt,
)
from stratevo.evolution.models import StrategyDNA, _PARAM_RANGES

logger = logging.getLogger("stratevo.ai_strategy")


def _extract_json(raw: str) -> str:
"""Extract JSON from LLM response, stripping markdown fences."""
# Try ```json ... ``` blocks
match = re.search(r"```json\s*\n(.*?)```", raw, re.DOTALL)
if match:
return match.group(1).strip()
# Try generic code blocks
match = re.search(r"```\s*\n(.*?)```", raw, re.DOTALL)
if match:
return match.group(1).strip()
# Try to find a JSON object directly
match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", raw, re.DOTALL)
if match:
return match.group(0).strip()
return raw.strip()


def validate_dna_dict(dna_dict: dict) -> tuple[dict, list[str]]:
"""Validate and clamp a DNA dict to valid fields and ranges.

Returns (cleaned_dict, list_of_warnings).
"""
valid_fields = {f.name for f in dataclass_fields(StrategyDNA)}
valid_fields.discard("custom_weights")
warnings: list[str] = []
cleaned: dict = {}

for key, value in dna_dict.items():
if key == "custom_weights":
continue
if key not in valid_fields:
warnings.append(f"Unknown field '{key}' ignored")
continue

# Type coercion and range clamping
if key in _PARAM_RANGES:
lo, hi, is_int = _PARAM_RANGES[key]
try:
if is_int:
value = int(round(float(value)))
else:
value = float(value)
except (ValueError, TypeError):
warnings.append(f"Invalid type for '{key}': {value!r}, skipped")
continue
if value < lo:
warnings.append(f"'{key}' clamped from {value} to min {lo}")
value = int(lo) if is_int else lo
elif value > hi:
warnings.append(f"'{key}' clamped from {value} to max {hi}")
value = int(hi) if is_int else hi
else:
# Field exists in StrategyDNA but not in _PARAM_RANGES — accept as-is
pass

cleaned[key] = value

return cleaned, warnings


class DNAGenerator:
"""Generate StrategyDNA dict from natural language descriptions."""

def __init__(
self,
provider: Optional[LLMProvider] = None,
provider_name: Optional[str] = None,
):
self._provider = provider
self._provider_name = provider_name

def _get_provider(self) -> LLMProvider:
if self._provider:
return self._provider
if self._provider_name:
return get_provider(self._provider_name)
provider = auto_detect_provider()
if provider is None:
raise RuntimeError(
"No LLM provider available. Set an API key env var "
"(OPENAI_API_KEY, ANTHROPIC_API_KEY, DEEPSEEK_API_KEY, etc.) "
"or start a local Ollama instance."
)
return provider

async def generate_async(
self,
description: str,
max_retries: int = 2,
) -> dict:
"""Generate StrategyDNA dict from a natural language description.

Returns:
{"dna": dict, "valid": bool, "warnings": list[str], "errors": list[str]}
"""
provider = self._get_provider()
system = build_dna_system_prompt()
user = build_dna_user_prompt(description)

messages = [
{"role": "system", "content": system},
{"role": "user", "content": user},
]

errors: list[str] = []
for attempt in range(max_retries + 1):
try:
raw = await provider.chat(messages, temperature=0.3)
json_str = _extract_json(raw)
dna_dict = json.loads(json_str)

if not isinstance(dna_dict, dict):
raise ValueError(f"Expected JSON object, got {type(dna_dict).__name__}")

cleaned, warnings = validate_dna_dict(dna_dict)

# Verify it creates a valid StrategyDNA
StrategyDNA.from_dict(cleaned)

return {
"dna": cleaned,
"valid": True,
"warnings": warnings,
"errors": [],
}
except json.JSONDecodeError as e:
errors.append(f"JSON parse error: {e}")
messages.append({"role": "assistant", "content": raw})
messages.append({
"role": "user",
"content": f"Invalid JSON: {e}. Output ONLY a valid JSON object.",
})
except Exception as e:
errors.append(str(e))
if "raw" in dir():
messages.append({"role": "assistant", "content": raw})
messages.append({
"role": "user",
"content": f"Error: {e}. Fix and output ONLY a valid JSON object.",
})

return {
"dna": {},
"valid": False,
"warnings": [],
"errors": errors,
}

def generate(self, description: str, max_retries: int = 2) -> dict:
"""Sync wrapper for generate_async."""
return asyncio.run(self.generate_async(description, max_retries))
129 changes: 129 additions & 0 deletions stratevo/ai_strategy/prompt_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,135 @@ def _format_results(results: dict) -> str:
return "\n".join(lines)


def build_dna_system_prompt() -> str:
"""Build system prompt for StrategyDNA generation from natural language."""
return """You are StratEvo DNA Architect, an expert at translating natural language trading strategy descriptions into StrategyDNA parameter configurations.

## Your Task
Given a natural language description of a trading strategy, output a JSON object containing ONLY the StrategyDNA fields that should differ from their defaults. The JSON must be valid and parseable.

## Output Format
Output ONLY a valid JSON object. No markdown fences, no explanations before or after.
Example: {"hold_days": 10, "stop_loss_pct": 5.0, "w_momentum": 0.3, "w_trend": 0.4}

## StrategyDNA Fields Reference

### Selection Thresholds
- min_score (int, default=6, range 4-8): Minimum composite score to select a stock. Higher = stricter filtering.
- rsi_buy_threshold (float, default=35.0, range 10-50): RSI level below which a stock is considered oversold (buy signal).
- rsi_sell_threshold (float, default=75.0, range 55-95): RSI level above which a stock is considered overbought (sell signal).
- r2_min (float, default=0.5, range 0.1-0.95): Minimum R-squared for trend linearity confirmation.
- slope_min (float, default=0.5, range 0.1-3.0): Minimum daily price slope (%) for trend confirmation.
- volume_ratio_min (float, default=1.2, range 0.5-5.0): Minimum volume ratio vs average to confirm interest.

### Execution Parameters
- hold_days (int, default=3, range 2-60): Number of days to hold a position before exiting.
- stop_loss_pct (float, default=2.0, range 0.5-25.0): Maximum loss (%) before forced exit.
- take_profit_pct (float, default=20.0, range 3.0-200.0): Profit target (%) to take gains.
- max_positions (int, default=2, range 1-10): Maximum number of concurrent positions.

### Golden Dip Parameters
- dip_threshold_pct (float, default=10.0, range 3-30): Pullback (%) from recent high to qualify as a dip buy.
- r2_trend_min (float, default=0.6, range 0.2-0.95): Minimum R-squared to confirm a bull trend before dip buying.

### Position & Risk Management
- position_sizing_power (float, default=0.0, range 0-3.0): Score-based allocation power. 0=equal weight, 2=quadratic (high-score stocks get more capital).
- trailing_stop_enabled (float, default=0.0, range 0-1): >0.5 enables trailing stop. Activated after price rises trailing_activation_pct, then trails at trailing_stop_pct below peak.
- trailing_activation_pct (float, default=5.0, range 2-30): % gain to activate trailing stop.
- trailing_stop_pct (float, default=3.0, range 1-15): % below peak to trigger trailing stop exit.
- scale_in_tranches (int, default=1, range 1-3): Split entry into N tranches over consecutive days. 1=all-in.
- stop_loss_per_day (float, default=0.0, range 0-3.0): Per-day stop loss rate. Effective SL = stop_loss_per_day * hold_days. 0=use fixed stop_loss_pct.
- market_regime_sensitivity (float, default=0.0, range 0-1.0): Reduce exposure when market breadth is weak. 0=disabled.
- min_breadth_to_trade (float, default=0.0, range 0-0.8): Minimum market breadth to trade. 0=always trade, 0.5=only trade in bull.
- sector_max_pct (float, default=1.0, range 0.2-1.0): Max fraction of portfolio in one sector. 1.0=no limit.
- profit_target_scaling (float, default=0.0, range 0-1.0): Scale take_profit by signal strength. 0=fixed target.
- time_decay_exit (float, default=0.0, range 0-1.0): After hold_days, tighten stop loss progressively. 0=off.
- weekly_trend_confirm (float, default=0.0, range 0-1.0): >0.5 requires weekly uptrend (5MA>20MA) confirmation.
- kelly_fraction (float, default=0.0, range 0-1.0): Dynamic position sizing via Kelly criterion. 0=disabled, 1.0=full half-Kelly.
- factor_momentum (float, default=0.0, range 0-1.0): Boost weights of recently-profitable factors. 0=static weights.
- capital_utilization (float, default=1.0, range 0.3-1.0): Fraction of capital to deploy. 1.0=100%, 0.5=50% deployed + 50% cash.
- trend_exit_enabled (float, default=0.0, range 0-1.0): >0.5 enables early exit when 5-day MA crosses below 10-day MA.

### Bear Market Regime Overrides
- bear_min_score_add (float, default=0.0, range 0-3.0): Add to min_score in bear markets (tighter filtering).
- bear_stop_loss_mult (float, default=1.0, range 0.3-1.0): Tighten stop loss in bear (multiplier on stop_loss_pct).
- bear_hold_days_mult (float, default=1.0, range 0.3-1.0): Shorten hold period in bear (multiplier on hold_days).
- bear_take_profit_mult (float, default=1.0, range 0.3-1.0): Lower targets in bear (multiplier on take_profit_pct).

### Factor Weights (all float, default varies, range 0-1.0, auto-normalized to sum=1)

Technical Core:
- w_momentum (default=0.1): RSI + slope-based momentum signal.
- w_mean_reversion (default=0.1): RSI oversold mean-reversion signal.
- w_volume (default=0.1): Volume ratio vs average — confirms institutional interest.
- w_trend (default=0.1): R-squared + MA alignment — trend quality.
- w_pattern (default=0.1): Candlestick pattern recognition.
- w_macd (default=0.1): MACD golden/death cross signal.
- w_bollinger (default=0.1): Bollinger Band position (buy near lower, sell near upper).
- w_kdj (default=0.1): KDJ golden cross — popular in Asian markets.
- w_obv (default=0.1): On-Balance Volume trend — price-volume confirmation.
- w_support (default=0.05): Proximity to support/resistance levels.
- w_volume_profile (default=0.05): Volume profile shape analysis.

Technical Extended:
- w_atr (default=0.0): Average True Range — volatility measure.
- w_adx (default=0.0): Average Directional Index — trend strength.
- w_roc (default=0.0): Rate of Change — price momentum.
- w_williams_r (default=0.0): Williams %R — overbought/oversold oscillator.
- w_cci (default=0.0): Commodity Channel Index — deviation from statistical mean.
- w_mfi (default=0.0): Money Flow Index — volume-weighted RSI.
- w_vwap (default=0.0): Volume Weighted Average Price distance.
- w_donchian (default=0.0): Donchian Channel breakout signal.
- w_ichimoku (default=0.0): Ichimoku cloud position — trend/support system.
- w_elder_ray (default=0.0): Elder Ray bull/bear power indicator.

Rolling Statistics:
- w_beta (default=0.0): Price regression slope (beta).
- w_r_squared (default=0.0): Trend linearity via R-squared.
- w_residual (default=0.0): Regression residual — mean reversion signal.
- w_quantile_upper (default=0.0): Distance to 80th percentile.
- w_quantile_lower (default=0.0): Distance to 20th percentile.
- w_aroon (default=0.0): Days since high/low — trend indicator.
- w_price_volume_corr (default=0.0): Price-volume correlation.

Fundamental:
- w_pe (default=0.0): Price-to-Earnings valuation score.
- w_pb (default=0.0): Price-to-Book value score.
- w_roe (default=0.0): Return on Equity.
- w_revenue_growth (default=0.0): Revenue growth rate.
- w_revenue_yoy (default=0.0): Revenue year-over-year growth.
- w_revenue_qoq (default=0.0): Revenue quarter-over-quarter growth.
- w_profit_yoy (default=0.0): Net profit year-over-year growth.
- w_profit_qoq (default=0.0): Net profit quarter-over-quarter growth.
- w_ps (default=0.0): Price-to-Sales ratio.
- w_peg (default=0.0): PEG ratio (PE / earnings growth).
- w_gross_margin (default=0.0): Gross margin quality.
- w_debt_ratio (default=0.0): Debt-to-asset ratio.
- w_cashflow (default=0.0): Operating cashflow quality.

## Rules
1. Output ONLY changed fields — fields that differ from their defaults.
2. All values must respect the valid ranges listed above.
3. Integer fields (min_score, hold_days, max_positions, scale_in_tranches) must be integers.
4. Factor weights are auto-normalized; you only need to set relative magnitudes.
5. If the user describes a momentum strategy, increase w_momentum, w_macd, w_roc, w_adx etc.
6. If the user describes a value/fundamental strategy, increase w_pe, w_pb, w_roe etc.
7. If the user describes a mean-reversion strategy, increase w_mean_reversion, w_bollinger, w_williams_r etc.
8. If the user wants conservative risk, set tighter stop_loss_pct, lower max_positions, enable trailing stops.
9. If the user wants aggressive growth, use wider stops, more positions, longer hold periods.
10. Think carefully about which parameters match the user's intent.
"""


def build_dna_user_prompt(description: str) -> str:
"""Build user prompt for DNA generation."""
return (
f"Translate the following trading strategy description into a StrategyDNA JSON object. "
f"Only include fields that should differ from defaults.\n\n"
f"Strategy description: {description}"
)


def build_copilot_system_prompt() -> str:
"""System prompt for the StratEvo Copilot chat mode."""
return """You are StratEvo Copilot, an AI financial analysis assistant.
Expand Down
1 change: 1 addition & 0 deletions stratevo/cli/commands/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def cmd_evolve(args):
pareto=getattr(args, "pareto", False),
pareto_complexity=getattr(args, "pareto_complexity", False),
held_out_ratio=getattr(args, "held_out", 0.2),
seed_dna_path=getattr(args, "seed_dna", None),
)

try:
Expand Down
37 changes: 35 additions & 2 deletions stratevo/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,9 @@ def build_parser() -> argparse.ArgumentParser:
p_gen.add_argument("--market", default="us_stock", choices=["us_stock", "crypto", "cn_stock"], help="Target market")
p_gen.add_argument("--risk", default="medium", choices=["low", "medium", "high"], help="Risk profile")
p_gen.add_argument("--provider", default=None, help="LLM provider (openai, anthropic, deepseek, ollama, ...)")
p_gen.add_argument("--output", "-o", default=None, help="Save generated code to file")
p_gen.add_argument("--output", "-o", default=None, help="Save generated code/DNA to file")
p_gen.add_argument("--output-format", default="code", choices=["code", "dna"],
help="Output format: 'code' for StrategyPlugin Python code, 'dna' for StrategyDNA JSON (default: code)")

p_opt = sub.add_parser("optimize-strategy", help="AI-optimize an existing strategy")
p_opt.add_argument("strategy_file", help="Path to strategy .py file")
Expand Down Expand Up @@ -815,6 +817,8 @@ def build_parser() -> argparse.ArgumentParser:
p_evo.add_argument("--held-out", type=float, default=0.2,
help="Held-out validation ratio (0.0-1.0). Last N%% of data reserved for "
"final validation after evolution. Set to 0 to disable. (default: 0.2)")
p_evo.add_argument("--seed-dna", default=None,
help="Path to a StrategyDNA JSON file to use as initial seed instead of defaults")

# compare-markets
p_cmp = sub.add_parser("compare-markets", help="Compare strategy performance across different markets")
Expand Down Expand Up @@ -1084,9 +1088,38 @@ def _handle_paper(args):

def _cmd_generate_strategy(args):
"""Handle: stratevo generate-strategy"""
from stratevo.ai_strategy.strategy_generator import StrategyGenerator
import asyncio

output_format = getattr(args, "output_format", "code")

if output_format == "dna":
from stratevo.ai_strategy.dna_generator import DNAGenerator
import json as _json

if not args.description:
print(" Usage: stratevo generate-strategy \"description\" --output-format dna [-o file.json]")
return

gen = DNAGenerator(provider_name=args.provider)
print(f" 🤖 Generating StrategyDNA for: {args.description}")
result = gen.generate(args.description)

if result["valid"]:
dna_json = _json.dumps(result["dna"], indent=2, ensure_ascii=False)
print(f" ✅ Generated StrategyDNA ({len(result['dna'])} fields)\n")
print(dna_json)
if result["warnings"]:
print(f"\n ⚠ Warnings: {result['warnings']}")
if args.output:
with open(args.output, "w", encoding="utf-8") as f:
f.write(dna_json)
print(f"\n Saved to {args.output}")
else:
print(f" ❌ DNA generation failed: {result['errors']}")
return

from stratevo.ai_strategy.strategy_generator import StrategyGenerator

gen = StrategyGenerator(
provider_name=args.provider,
market=args.market,
Expand Down
Loading
Loading