Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion service/server/market_intel.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,44 @@ def _fetch_stock_quote_payload(symbol: str) -> Optional[dict[str, Any]]:
return _extract_intraday_quote(payload)


def _fetch_yfinance_quote_payload(symbol: str) -> Optional[dict[str, Any]]:
try:
import yfinance as yf
except ImportError:
return None

cleaned = (symbol or "").strip().upper()
if not cleaned:
return None

try:
ticker = yf.Ticker(cleaned)
price: Optional[float] = None
fast_info = getattr(ticker, "fast_info", None) or {}
for key in ("last_price", "lastPrice", "regular_market_price", "regularMarketPrice"):
value = fast_info.get(key) if isinstance(fast_info, dict) else getattr(fast_info, key, None)
if value is not None:
candidate = float(value)
if candidate > 0:
price = candidate
break
if price is None:
history = ticker.history(period="1d", interval="1m")
if history is not None and not history.empty:
price = float(history["Close"].iloc[-1])
if price is None or price <= 0:
return None
except Exception:
return None

return {
"available": True,
"current_price": round(price, 2),
"price_as_of": _utc_now_iso_z(),
"price_source": "yfinance_fast_info",
}


def _get_stock_quote_payload(symbol: str) -> Optional[dict[str, Any]]:
cached = _stock_quote_cache_get(symbol)
if isinstance(cached, dict):
Expand All @@ -306,6 +344,11 @@ def _get_stock_quote_payload(symbol: str) -> Optional[dict[str, Any]]:
_stock_quote_cache_set(symbol, quote, ttl_seconds=STOCK_QUOTE_CACHE_TTL_SECONDS)
return quote

quote = _fetch_yfinance_quote_payload(symbol)
if quote:
_stock_quote_cache_set(symbol, quote, ttl_seconds=STOCK_QUOTE_CACHE_TTL_SECONDS)
return quote

unavailable = {"available": False}
_stock_quote_cache_set(symbol, unavailable, ttl_seconds=STOCK_QUOTE_FAILURE_CACHE_TTL_SECONDS)
return None
Expand All @@ -325,7 +368,7 @@ def _build_stock_price_metadata(price_as_of: Optional[str], price_source: Option
stale = True
status = "stale"

if price_source == "alpha_vantage_time_series_intraday":
if price_source in ("alpha_vantage_time_series_intraday", "yfinance_fast_info"):
market_open = _is_us_market_open(now_utc)
quote_et = parsed_as_of.astimezone(US_EASTERN_TZ)
now_et = now_utc.astimezone(US_EASTERN_TZ)
Expand Down
15 changes: 8 additions & 7 deletions service/server/routes_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,12 +580,13 @@ async def agent_self_register(data: AgentRegister):
now = utc_now_iso_z()
if data.positions:
for pos in data.positions:
market = validate_market(pos.get('market', 'us-stock'))
symbol = str(pos.get('symbol') or '').strip()
if not symbol:
raise HTTPException(status_code=400, detail='Position symbol is required')
market = validate_market(pos.market)
symbol = pos.symbol.strip()
if market != 'polymarket':
symbol = symbol.upper()
quantity = abs(float(pos.quantity))
if pos.side == 'short':
quantity = -quantity
cursor.execute(
"""
INSERT INTO positions (agent_id, symbol, market, side, quantity, entry_price, opened_at)
Expand All @@ -595,9 +596,9 @@ async def agent_self_register(data: AgentRegister):
agent_id,
symbol,
market,
pos.get('side', 'long'),
pos.get('quantity', 0),
pos.get('entry_price', 0),
pos.side,
quantity,
float(pos.entry_price),
now,
),
)
Expand Down
45 changes: 43 additions & 2 deletions service/server/routes_models.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,60 @@
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, EmailStr
from pydantic import BaseModel, EmailStr, field_validator


class AgentLogin(BaseModel):
name: str
password: str


class AgentPositionInput(BaseModel):
symbol: str
market: str = "us-stock"
side: str = "long"
quantity: float
entry_price: float

@field_validator("symbol")
@classmethod
def symbol_required(cls, value: str) -> str:
cleaned = (value or "").strip()
if not cleaned:
raise ValueError("symbol is required")
return cleaned

@field_validator("quantity", "entry_price")
@classmethod
def must_be_positive_numbers(cls, value: float) -> float:
if not isinstance(value, (int, float)):
raise ValueError("must be numeric")
numeric = float(value)
if numeric <= 0:
raise ValueError("must be a positive number")
return numeric

@field_validator("side")
@classmethod
def side_must_be_long_or_short(cls, value: str) -> str:
normalized = (value or "long").strip().lower()
if normalized not in {"long", "short"}:
raise ValueError("side must be long or short")
return normalized


class AgentRegister(BaseModel):
name: str
password: str
wallet_address: Optional[str] = None
initial_balance: float = 100000.0
positions: Optional[List[dict]] = None
positions: Optional[List[AgentPositionInput]] = None

@field_validator("initial_balance")
@classmethod
def initial_balance_must_be_positive(cls, value: float) -> float:
if value <= 0:
raise ValueError("initial_balance must be positive")
return value


class AgentTokenRecoveryRequest(BaseModel):
Expand Down
84 changes: 84 additions & 0 deletions service/server/tests/test_agent_register_positions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
import sys
import tempfile
import unittest
from pathlib import Path

from fastapi.testclient import TestClient

SERVER_DIR = Path(__file__).resolve().parents[1]
if str(SERVER_DIR) not in sys.path:
sys.path.insert(0, str(SERVER_DIR))

import database
from routes import create_app


class AgentRegisterPositionTests(unittest.TestCase):
def setUp(self) -> None:
self.tmp = tempfile.TemporaryDirectory()
os.environ["ALLOW_SQLITE"] = "true"
database.DATABASE_URL = ""
database._SQLITE_DB_PATH = os.path.join(self.tmp.name, "test.db")
database.init_database()
self.client = TestClient(create_app())

def tearDown(self) -> None:
self.tmp.cleanup()

def test_rejects_invalid_position_quantity(self) -> None:
response = self.client.post(
"/api/claw/agents/selfRegister",
json={
"name": "pos_validation_agent",
"password": "secret123",
"positions": [
{
"symbol": "BTC",
"market": "crypto",
"side": "short",
"quantity": 0,
"entry_price": 100.0,
}
],
},
)
self.assertEqual(response.status_code, 422)

def test_short_position_stored_with_negative_quantity(self) -> None:
response = self.client.post(
"/api/claw/agents/selfRegister",
json={
"name": "short_pos_agent",
"password": "secret123",
"positions": [
{
"symbol": "BTC",
"market": "crypto",
"side": "short",
"quantity": 0.2,
"entry_price": 100.0,
}
],
},
)
self.assertEqual(response.status_code, 200)
agent_id = response.json()["agent_id"]

conn = database.get_db_connection()
cursor = conn.cursor()
cursor.execute(
"SELECT quantity, entry_price, side FROM positions WHERE agent_id = ?",
(agent_id,),
)
row = cursor.fetchone()
conn.close()

self.assertIsNotNone(row)
self.assertAlmostEqual(row["quantity"], -0.2)
self.assertAlmostEqual(row["entry_price"], 100.0)
self.assertEqual(row["side"], "short")


if __name__ == "__main__":
unittest.main()
22 changes: 22 additions & 0 deletions service/server/tests/test_market_intel.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,28 @@ def test_metadata_unparseable_timestamp_is_stale(self) -> None:
self.assertEqual(meta["price_status"], "stale")
self.assertIsNone(meta["price_age_seconds"])

@patch("market_intel.set_json")
@patch("market_intel.get_json", return_value=None)
@patch("market_intel._fetch_stock_quote_payload", return_value=None)
@patch("market_intel._fetch_yfinance_quote_payload")
def test_stock_quote_falls_back_to_yfinance(
self,
mock_yfinance_quote,
_mock_alpha_quote,
_mock_get_json,
_mock_set_json,
) -> None:
mock_yfinance_quote.return_value = {
"available": True,
"current_price": 271.19,
"price_as_of": "2026-04-17T20:15:00Z",
"price_source": "yfinance_fast_info",
}
quote = market_intel._get_stock_quote_payload("AAPL")
self.assertIsNotNone(quote)
self.assertEqual(quote["price_source"], "yfinance_fast_info")
self.assertEqual(quote["current_price"], 271.19)


if __name__ == "__main__":
unittest.main()