diff --git a/service/server/market_intel.py b/service/server/market_intel.py index 68a27921..37107bfa 100644 --- a/service/server/market_intel.py +++ b/service/server/market_intel.py @@ -290,6 +290,44 @@ def _fetch_stock_quote_payload(symbol: str) -> Optional[dict[str, Any]]: return _extract_intraday_quote(payload) +def _fetch_yfinance_quote_payload(symbol: str) -> Optional[dict[str, Any]]: + try: + import yfinance as yf + except ImportError: + return None + + cleaned = (symbol or "").strip().upper() + if not cleaned: + return None + + try: + ticker = yf.Ticker(cleaned) + price: Optional[float] = None + fast_info = getattr(ticker, "fast_info", None) or {} + for key in ("last_price", "lastPrice", "regular_market_price", "regularMarketPrice"): + value = fast_info.get(key) if isinstance(fast_info, dict) else getattr(fast_info, key, None) + if value is not None: + candidate = float(value) + if candidate > 0: + price = candidate + break + if price is None: + history = ticker.history(period="1d", interval="1m") + if history is not None and not history.empty: + price = float(history["Close"].iloc[-1]) + if price is None or price <= 0: + return None + except Exception: + return None + + return { + "available": True, + "current_price": round(price, 2), + "price_as_of": _utc_now_iso_z(), + "price_source": "yfinance_fast_info", + } + + def _get_stock_quote_payload(symbol: str) -> Optional[dict[str, Any]]: cached = _stock_quote_cache_get(symbol) if isinstance(cached, dict): @@ -306,6 +344,11 @@ def _get_stock_quote_payload(symbol: str) -> Optional[dict[str, Any]]: _stock_quote_cache_set(symbol, quote, ttl_seconds=STOCK_QUOTE_CACHE_TTL_SECONDS) return quote + quote = _fetch_yfinance_quote_payload(symbol) + if quote: + _stock_quote_cache_set(symbol, quote, ttl_seconds=STOCK_QUOTE_CACHE_TTL_SECONDS) + return quote + unavailable = {"available": False} _stock_quote_cache_set(symbol, unavailable, ttl_seconds=STOCK_QUOTE_FAILURE_CACHE_TTL_SECONDS) return None @@ -325,7 +368,7 @@ def _build_stock_price_metadata(price_as_of: Optional[str], price_source: Option stale = True status = "stale" - if price_source == "alpha_vantage_time_series_intraday": + if price_source in ("alpha_vantage_time_series_intraday", "yfinance_fast_info"): market_open = _is_us_market_open(now_utc) quote_et = parsed_as_of.astimezone(US_EASTERN_TZ) now_et = now_utc.astimezone(US_EASTERN_TZ) diff --git a/service/server/routes_agent.py b/service/server/routes_agent.py index 9512aabd..e7539836 100644 --- a/service/server/routes_agent.py +++ b/service/server/routes_agent.py @@ -580,12 +580,13 @@ async def agent_self_register(data: AgentRegister): now = utc_now_iso_z() if data.positions: for pos in data.positions: - market = validate_market(pos.get('market', 'us-stock')) - symbol = str(pos.get('symbol') or '').strip() - if not symbol: - raise HTTPException(status_code=400, detail='Position symbol is required') + market = validate_market(pos.market) + symbol = pos.symbol.strip() if market != 'polymarket': symbol = symbol.upper() + quantity = abs(float(pos.quantity)) + if pos.side == 'short': + quantity = -quantity cursor.execute( """ INSERT INTO positions (agent_id, symbol, market, side, quantity, entry_price, opened_at) @@ -595,9 +596,9 @@ async def agent_self_register(data: AgentRegister): agent_id, symbol, market, - pos.get('side', 'long'), - pos.get('quantity', 0), - pos.get('entry_price', 0), + pos.side, + quantity, + float(pos.entry_price), now, ), ) diff --git a/service/server/routes_models.py b/service/server/routes_models.py index 9a4ce43a..215a8702 100644 --- a/service/server/routes_models.py +++ b/service/server/routes_models.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional -from pydantic import BaseModel, EmailStr +from pydantic import BaseModel, EmailStr, field_validator class AgentLogin(BaseModel): @@ -8,12 +8,53 @@ class AgentLogin(BaseModel): password: str +class AgentPositionInput(BaseModel): + symbol: str + market: str = "us-stock" + side: str = "long" + quantity: float + entry_price: float + + @field_validator("symbol") + @classmethod + def symbol_required(cls, value: str) -> str: + cleaned = (value or "").strip() + if not cleaned: + raise ValueError("symbol is required") + return cleaned + + @field_validator("quantity", "entry_price") + @classmethod + def must_be_positive_numbers(cls, value: float) -> float: + if not isinstance(value, (int, float)): + raise ValueError("must be numeric") + numeric = float(value) + if numeric <= 0: + raise ValueError("must be a positive number") + return numeric + + @field_validator("side") + @classmethod + def side_must_be_long_or_short(cls, value: str) -> str: + normalized = (value or "long").strip().lower() + if normalized not in {"long", "short"}: + raise ValueError("side must be long or short") + return normalized + + class AgentRegister(BaseModel): name: str password: str wallet_address: Optional[str] = None initial_balance: float = 100000.0 - positions: Optional[List[dict]] = None + positions: Optional[List[AgentPositionInput]] = None + + @field_validator("initial_balance") + @classmethod + def initial_balance_must_be_positive(cls, value: float) -> float: + if value <= 0: + raise ValueError("initial_balance must be positive") + return value class AgentTokenRecoveryRequest(BaseModel): diff --git a/service/server/tests/test_agent_register_positions.py b/service/server/tests/test_agent_register_positions.py new file mode 100644 index 00000000..fe9e138d --- /dev/null +++ b/service/server/tests/test_agent_register_positions.py @@ -0,0 +1,84 @@ +import os +import sys +import tempfile +import unittest +from pathlib import Path + +from fastapi.testclient import TestClient + +SERVER_DIR = Path(__file__).resolve().parents[1] +if str(SERVER_DIR) not in sys.path: + sys.path.insert(0, str(SERVER_DIR)) + +import database +from routes import create_app + + +class AgentRegisterPositionTests(unittest.TestCase): + def setUp(self) -> None: + self.tmp = tempfile.TemporaryDirectory() + os.environ["ALLOW_SQLITE"] = "true" + database.DATABASE_URL = "" + database._SQLITE_DB_PATH = os.path.join(self.tmp.name, "test.db") + database.init_database() + self.client = TestClient(create_app()) + + def tearDown(self) -> None: + self.tmp.cleanup() + + def test_rejects_invalid_position_quantity(self) -> None: + response = self.client.post( + "/api/claw/agents/selfRegister", + json={ + "name": "pos_validation_agent", + "password": "secret123", + "positions": [ + { + "symbol": "BTC", + "market": "crypto", + "side": "short", + "quantity": 0, + "entry_price": 100.0, + } + ], + }, + ) + self.assertEqual(response.status_code, 422) + + def test_short_position_stored_with_negative_quantity(self) -> None: + response = self.client.post( + "/api/claw/agents/selfRegister", + json={ + "name": "short_pos_agent", + "password": "secret123", + "positions": [ + { + "symbol": "BTC", + "market": "crypto", + "side": "short", + "quantity": 0.2, + "entry_price": 100.0, + } + ], + }, + ) + self.assertEqual(response.status_code, 200) + agent_id = response.json()["agent_id"] + + conn = database.get_db_connection() + cursor = conn.cursor() + cursor.execute( + "SELECT quantity, entry_price, side FROM positions WHERE agent_id = ?", + (agent_id,), + ) + row = cursor.fetchone() + conn.close() + + self.assertIsNotNone(row) + self.assertAlmostEqual(row["quantity"], -0.2) + self.assertAlmostEqual(row["entry_price"], 100.0) + self.assertEqual(row["side"], "short") + + +if __name__ == "__main__": + unittest.main() diff --git a/service/server/tests/test_market_intel.py b/service/server/tests/test_market_intel.py index 0ed308a1..ea70f8e7 100644 --- a/service/server/tests/test_market_intel.py +++ b/service/server/tests/test_market_intel.py @@ -298,6 +298,28 @@ def test_metadata_unparseable_timestamp_is_stale(self) -> None: self.assertEqual(meta["price_status"], "stale") self.assertIsNone(meta["price_age_seconds"]) + @patch("market_intel.set_json") + @patch("market_intel.get_json", return_value=None) + @patch("market_intel._fetch_stock_quote_payload", return_value=None) + @patch("market_intel._fetch_yfinance_quote_payload") + def test_stock_quote_falls_back_to_yfinance( + self, + mock_yfinance_quote, + _mock_alpha_quote, + _mock_get_json, + _mock_set_json, + ) -> None: + mock_yfinance_quote.return_value = { + "available": True, + "current_price": 271.19, + "price_as_of": "2026-04-17T20:15:00Z", + "price_source": "yfinance_fast_info", + } + quote = market_intel._get_stock_quote_payload("AAPL") + self.assertIsNotNone(quote) + self.assertEqual(quote["price_source"], "yfinance_fast_info") + self.assertEqual(quote["current_price"], 271.19) + if __name__ == "__main__": unittest.main()