diff --git a/tests/state/test_state_manager.py b/tests/state/test_state_manager.py index c565c2d..5c61afe 100644 --- a/tests/state/test_state_manager.py +++ b/tests/state/test_state_manager.py @@ -96,7 +96,8 @@ def test_side_and_order_id_preserved(self): ticker = mgr.get_ticker('BTC') assert ticker.orders[0].side == OrderSide.SELL assert ticker.orders[0].order_id == 'ORD-001' - assert ticker.orders[0].created_at == '2025-06-01 10:00:00' + from datetime import datetime + assert ticker.orders[0].created_at == datetime(2025, 6, 1, 10, 0) assert ticker.orders[1].side == OrderSide.BUY assert ticker.orders[1].order_id == 'ORD-002' diff --git a/tests/test_momentum_dca.py b/tests/test_momentum_dca.py index f899b89..077115c 100644 --- a/tests/test_momentum_dca.py +++ b/tests/test_momentum_dca.py @@ -537,7 +537,7 @@ def test_created_at_passed_through(self): 'quantity': 20, 'limit_price': 443.0, 'stop_price': 445.0, 'created_at': '2026-02-07 10:30:00', 'order_id': 'abc-123'}, ] - mgr.load_broker_sell_orders('SPY', broker_orders) + mgr.load_broker_orders('SPY', broker_orders) order = mgr.get_ticker('SPY').orders[0] assert order.created_at is not None assert order.created_at.year == 2026 @@ -549,7 +549,7 @@ def test_missing_created_at_is_none(self): {'symbol': 'SPY', 'side': 'SELL', 'order_type': 'Limit', 'quantity': 10, 'limit_price': 460.0, 'stop_price': None}, ] - mgr.load_broker_sell_orders('SPY', broker_orders) + mgr.load_broker_orders('SPY', broker_orders) order = mgr.get_ticker('SPY').orders[0] assert order.created_at is None assert order.order_id is None diff --git a/tests/test_order_replacement.py b/tests/test_order_replacement.py index 6f252ae..da82934 100644 --- a/tests/test_order_replacement.py +++ b/tests/test_order_replacement.py @@ -10,14 +10,14 @@ def _make_system(): """Build a TradingSystem with mocked dependencies so no real I/O occurs.""" with patch('trading_system.main.TwelveDataProvider'), \ - patch('trading_system.main.SafeCashBot') as MockBot: - bot_instance = MockBot.return_value + patch('trading_system.main.RobinhoodClient') as MockBot: + bot_instance = MockBot.create.return_value bot_instance.get_pdt_status.return_value = { 'day_trade_count': 0, 'flagged': False, 'trades': [], } - bot_instance.cancel_order_by_id.return_value = True + bot_instance.cancel_order.return_value = True system = TradingSystem( twelve_data_api_key='fake', @@ -94,7 +94,7 @@ def test_both_sells_cancelled_and_replaced(self): system._handle_order_replacement('SPY', signal, symbol_orders) # Both sells should be cancelled - calls = system.trading_bot.cancel_order_by_id.call_args_list + calls = system.trading_bot.cancel_order.call_args_list cancelled_ids = {c[0][0] for c in calls} assert cancelled_ids == {'S1', 'S2'} @@ -109,7 +109,7 @@ def test_replaces_all_lot_orders(self): system._handle_order_replacement('SPY', signal, symbol_orders) - calls = system.trading_bot.cancel_order_by_id.call_args_list + calls = system.trading_bot.cancel_order.call_args_list cancelled_ids = {c[0][0] for c in calls} assert cancelled_ids == {'SELL-OLD', 'BUY-OLD'} @@ -124,7 +124,7 @@ def test_qty_mismatch_still_cancels_all(self): system._handle_order_replacement('SPY', signal, symbol_orders) # Both should be cancelled regardless of quantity - calls = system.trading_bot.cancel_order_by_id.call_args_list + calls = system.trading_bot.cancel_order.call_args_list cancelled_ids = {c[0][0] for c in calls} assert cancelled_ids == {'SELL-001', 'BUY-001'} @@ -146,7 +146,7 @@ def test_pdt_count_2_alerts_and_skips(self, mock_slack): mock_slack.assert_called_once() assert 'PDT day trade count at 2/3' in mock_slack.call_args[0][0] - system.trading_bot.cancel_order_by_id.assert_not_called() + system.trading_bot.cancel_order.assert_not_called() class TestPdtFlaggedAlertsAndSkips: @@ -166,7 +166,7 @@ def test_pdt_flagged_alerts_and_skips(self, mock_slack): mock_slack.assert_called_once() assert 'PDT FLAGGED' in mock_slack.call_args[0][0] - system.trading_bot.cancel_order_by_id.assert_not_called() + system.trading_bot.cancel_order.assert_not_called() class TestPdtSafeProceeds: @@ -186,7 +186,7 @@ def test_pdt_safe_proceeds(self, mock_slack): mock_slack.assert_not_called() # Both sell and buy should be cancelled - assert system.trading_bot.cancel_order_by_id.call_count == 2 + assert system.trading_bot.cancel_order.call_count == 2 class TestPdtNoneProceeds: @@ -200,15 +200,15 @@ def test_pdt_none_proceeds(self): system._handle_order_replacement('SPY', signal, symbol_orders) # Both sell and buy should be cancelled - assert system.trading_bot.cancel_order_by_id.call_count == 2 + assert system.trading_bot.cancel_order.call_count == 2 class TestCancelFailsStillPlaces: def test_cancel_fails_still_places(self): - """cancel_order_by_id returns False → placement still proceeds + """cancel_order returns False → placement still proceeds (order likely already filled/cancelled).""" system = _make_system() - system.trading_bot.cancel_order_by_id.return_value = False + system.trading_bot.cancel_order.return_value = False signal = _make_signal() symbol_orders = [_sell_order(), _buy_order()] @@ -235,9 +235,9 @@ def test_momentum_pricing_used(self): system._handle_order_replacement('SPY', signal, symbol_orders) - # strategy defaults: stop_offset_pct=0.0125, buy_offset=0.50 - expected_stop = round(current_price * (1 - 0.0125), 2) # 493.75 - expected_buy = round(expected_stop - 0.50, 2) # 493.25 + # strategy defaults: stop_offset_pct=0.015, buy_offset=0.20 + expected_stop = round(current_price * (1 - 0.015), 2) # 492.50 + expected_buy = round(expected_stop - 0.20, 2) # 492.30 sell_call = system._execute_stop_limit_sell_order.call_args assert sell_call[0][1]['stop_price'] == expected_stop @@ -264,7 +264,7 @@ def test_stop_limit_sell_cancels_existing_sell(self): system.process_signal('SPY', signal, open_orders) # Existing sell should be cancelled before new pair is placed - system.trading_bot.cancel_order_by_id.assert_called_once_with('EXISTING-SELL') + system.trading_bot.cancel_order.assert_called_once_with('EXISTING-SELL') system._execute_stop_limit_sell_order.assert_called_once() system._execute_paired_limit_buy.assert_called_once() @@ -284,7 +284,7 @@ def test_uses_target_qty_for_new_order(self): system.process_signal('SPY', signal, open_orders) # Old sell cancelled - system.trading_bot.cancel_order_by_id.assert_called_once_with('OLD-SELL') + system.trading_bot.cancel_order.assert_called_once_with('OLD-SELL') # New sell uses target_qty (250), not gap_qty (100) sell_order = system._execute_stop_limit_sell_order.call_args[0][1] @@ -323,6 +323,6 @@ def test_non_lot_sized_sell_is_cancelled(self): system.process_signal('SPY', signal, open_orders) # Non-lot-sized sell should still be cancelled - system.trading_bot.cancel_order_by_id.assert_called_once_with('PARTIAL-SELL') + system.trading_bot.cancel_order.assert_called_once_with('PARTIAL-SELL') system._execute_stop_limit_sell_order.assert_called_once() system._execute_paired_limit_buy.assert_called_once() diff --git a/trading_system/execution/trade_executor.py b/trading_system/execution/trade_executor.py new file mode 100644 index 0000000..7d133b3 --- /dev/null +++ b/trading_system/execution/trade_executor.py @@ -0,0 +1,153 @@ +""" +Executor — execution quality layer + deferred order queue. + +Sits between the strategy (which creates TradeTasks) and the Brokerage +(which talks to the broker API). Owns: + - PDT gate, spread check, price optimizer, fill logger + - In-memory deferred queue for PDT-blocked orders + +The Brokerage is injected (Dependency Inversion), so the Executor never +imports robin_stocks or any platform-specific code. +""" + +from datetime import date, timedelta +from typing import Optional + +from trading_system.execution.trade_task import TradeTask, DeferredTask + + +def _next_trading_day(from_date: date) -> date: + d = from_date + timedelta(days=1) + while d.weekday() >= 5: # skip Saturday=5, Sunday=6 + d += timedelta(days=1) + return d + + +class Executor: + """ + Processes TradeTasks. Applies pre-flight checks, dispatches to the + Brokerage, and defers PDT-blocked orders to the next trading day. + """ + + def __init__(self, brokerage, pdt_gate=None, spread_checker=None, + price_optimizer=None, fill_logger=None, fill_auditor=None): + self._brokerage = brokerage + self._pdt_gate = pdt_gate + self._spread_checker = spread_checker + self._price_optimizer = price_optimizer + self._fill_logger = fill_logger + self._fill_auditor = fill_auditor + self._deferred_queue: list[DeferredTask] = [] + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def submit(self, task: TradeTask) -> Optional[dict]: + """ + Try to execute a task now. + Returns broker response on success, None if blocked or deferred. + """ + can_execute, reason = task.should_execute() + if not can_execute: + self._enqueue_deferred(task, reason) + return None + + if self._pdt_gate: + allowed, pdt_reason = self._pdt_gate.can_place_order(task.symbol, task.side) + print(f" PDT Gate: {pdt_reason}") + if not allowed: + self._enqueue_deferred(task, pdt_reason) + return None + + if self._spread_checker: + spread_info = self._spread_checker.check_spread(task.symbol) + if spread_info and not spread_info.get('is_acceptable', True): + print(f" Spread Check: BLOCKED — {spread_info.get('reason', '')}") + return None + if spread_info and spread_info.get('should_wait'): + print(f" Spread Check: WARNING — {spread_info.get('reason', '')}") + + return self._dispatch(task) + + def drain_deferred(self): + """ + Re-attempt deferred tasks whose execute_date has been reached. + Call at the top of each engine cycle. + """ + ready = [t for t in self._deferred_queue if t.should_execute()[0]] + submitted = [] + for task in ready: + print(f" [executor] Retrying deferred {task.symbol} {task.side} " + f"(was: {task.deferred_reason})") + result = self.submit(task) + if result: + submitted.append(task) + for task in submitted: + self._deferred_queue.remove(task) + + def get_deferred(self) -> list[DeferredTask]: + """Read-only view of the current deferred queue.""" + return list(self._deferred_queue) + + # ------------------------------------------------------------------ + # Internal + # ------------------------------------------------------------------ + + def _dispatch(self, task: TradeTask) -> Optional[dict]: + """Send the task to the brokerage after pre-flight passes.""" + submission_id = None + bid, ask = None, None + + if self._fill_auditor: + nbbo = self._fill_auditor.get_nbbo_now(task.symbol) + if nbbo: + bid = nbbo.get('bid') + ask = nbbo.get('ask') + mid = nbbo.get('mid') + if mid and mid > 0: + print(f" NBBO: bid=${bid:.2f} ask=${ask:.2f} mid=${mid:.2f}") + + if self._fill_logger: + submission_id = self._fill_logger.log_submission( + task.symbol, task.side, task.price, bid, ask) + + try: + if task.side == 'buy': + result = self._brokerage.place_limit_buy( + task.symbol, task.quantity, task.price) + elif task.order_type == 'stop_limit': + result = self._brokerage.place_stop_limit_sell( + task.symbol, task.quantity, task.stop_price, task.price) + else: + result = self._brokerage.place_limit_sell( + task.symbol, task.quantity, task.price) + except Exception as e: + if self._fill_logger and submission_id: + self._fill_logger.log_cancel(submission_id, reason=str(e)) + raise + + if result and self._fill_logger and submission_id: + order_id = result.get('id') if isinstance(result, dict) else None + if order_id: + self._fill_logger.log_fill(submission_id, task.price) + + return result + + def _enqueue_deferred(self, task: TradeTask, reason: str): + next_day = _next_trading_day(date.today()) + retry_count = getattr(task, 'retry_count', 0) + 1 + deferred = DeferredTask( + symbol=task.symbol, + side=task.side, + quantity=task.quantity, + price=task.price, + order_type=task.order_type, + stop_price=getattr(task, 'stop_price', None), + execute_date=next_day, + checks=["pdt_gate"], + deferred_reason=reason, + retry_count=retry_count, + ) + print(f" [executor] Deferred {task.symbol} {task.side} → {next_day} | {reason}") + self._deferred_queue.append(deferred) diff --git a/trading_system/execution/trade_task.py b/trading_system/execution/trade_task.py new file mode 100644 index 0000000..c5141b0 --- /dev/null +++ b/trading_system/execution/trade_task.py @@ -0,0 +1,59 @@ +""" +TradeTask — Command Pattern base for deferred and scheduled orders. + +Each subclass implements should_execute() (Strategy Pattern) so the +Executor can ask "is this task ready?" without knowing which type +it holds. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import date + + +class TradeTask(ABC): + """Base command. Holds trade data; does not execute itself.""" + + @abstractmethod + def should_execute(self) -> tuple[bool, str]: + """Returns (can_execute, reason).""" + ... + + +@dataclass +class ScheduledTask(TradeTask): + """Executes on or after execute_date.""" + + symbol: str + side: str # "buy" | "sell" + quantity: float + price: float + order_type: str # "limit" | "stop_limit" | "market" + execute_date: date + stop_price: float | None = None + + def should_execute(self) -> tuple[bool, str]: + if date.today() >= self.execute_date: + return True, "scheduled date reached" + return False, f"scheduled for {self.execute_date}" + + +@dataclass +class DeferredTask(TradeTask): + """Blocked by PDT or another gate. Retried each cycle once execute_date is reached.""" + + symbol: str + side: str + quantity: float + price: float + order_type: str + execute_date: date + checks: list[str] # gates to re-run: ["pdt_gate", "spread_check"] + deferred_reason: str = "" + stop_price: float | None = None + retry_count: int = 0 + + def should_execute(self) -> tuple[bool, str]: + if date.today() >= self.execute_date: + return True, "deferred date reached" + return False, f"deferred until {self.execute_date}" diff --git a/trading_system/main.py b/trading_system/main.py index 4865dca..2034652 100644 --- a/trading_system/main.py +++ b/trading_system/main.py @@ -5,7 +5,7 @@ import os import sys -from datetime import datetime +from datetime import datetime, date from pathlib import Path from typing import List, Dict import time @@ -24,7 +24,9 @@ from trading_system.market_indicators import fetch_and_write_indicators # noqa: E402 from trading_system.utils.slack import send_slack_alert # noqa: E402 from trading_system.entities.OrderType import OrderSide # noqa: E402 -from utils.safe_cash_bot import SafeCashBot # noqa: E402 +from trading_system.execution.trade_executor import Executor # noqa: E402 +from trading_system.execution.trade_task import ScheduledTask # noqa: E402 +from utils.robinhood_client import RobinhoodClient # noqa: E402 class TradingSystem: @@ -60,9 +62,10 @@ def __init__(self, twelve_data_api_key: str, symbols: List[str], self.data_provider = TwelveDataProvider(twelve_data_api_key) self.metrics_calculator = MetricsCalculator() self.state_manager = StateManager() - self.trading_bot = SafeCashBot() + self.trading_bot = RobinhoodClient.create() - # Initialize execution quality layer + # Initialize TradeExecutor with execution quality layer + self.executor = None self.fill_logger = None try: from trading_system.execution.fill_auditor import FillAuditor @@ -81,12 +84,13 @@ def __init__(self, twelve_data_api_key: str, symbols: List[str], pdt_gate = PDTGate(trading_bot=self.trading_bot) fill_logger = FillLogger() - self.trading_bot.init_execution_layer( - fill_auditor=fill_auditor, + self.executor = Executor( + brokerage=self.trading_bot, + pdt_gate=pdt_gate, spread_checker=spread_checker, price_optimizer=price_optimizer, - pdt_gate=pdt_gate, fill_logger=fill_logger, + fill_auditor=fill_auditor, ) self.fill_logger = fill_logger @@ -270,7 +274,7 @@ def _cancel_orders_by_side(self, symbol: str, side: str, open_orders: list) -> t if (order.get('symbol') == symbol and order.get('side') == side): order_id = order.get('order_id') - if order_id and self.trading_bot.cancel_order_by_id(order_id): + if order_id and self.trading_bot.cancel_order(order_id): cancelled += 1 qty_cancelled += int(float(order.get('quantity', 0))) print(f" Cancelled {side} {order_id} qty={int(float(order.get('quantity', 0)))}") @@ -289,7 +293,17 @@ def _handle_order_replacement(self, symbol: str, signal: Dict, symbol_orders: li """ lot_size = getattr(self.strategy, 'lot_size', None) - # PDT check now handled centrally by PDTGate in SafeCashBot.init_execution_layer() + # PDT pre-flight: order replacement cancels existing orders, which can + # create day trades. Check before touching anything. + pdt_info = self.trading_bot.get_pdt_status() + if pdt_info: + if pdt_info.get('flagged'): + send_slack_alert(f"PDT FLAGGED — skipping order replacement for {symbol}") + return + if pdt_info.get('day_trade_count', 0) >= 2: + send_slack_alert( + f"PDT day trade count at 2/3 — skipping order replacement for {symbol}") + return # Cancel ALL existing orders for the symbol sells_cancelled, _ = self._cancel_orders_by_side(symbol, 'SELL', symbol_orders) @@ -330,11 +344,12 @@ def _handle_order_replacement(self, symbol: str, signal: Dict, symbol_orders: li def _execute_buy_order(self, symbol: str, order: Dict): """Execute buy order""" - # Get available cash cash_info = self.trading_bot.get_cash_balance() + if not cash_info: + print(f" Cannot retrieve cash balance for {symbol}") + return available_cash = cash_info['tradeable_cash'] - # Calculate position size quantity = self.strategy.calculate_position_size( symbol, order['current_price'], available_cash ) @@ -343,6 +358,13 @@ def _execute_buy_order(self, symbol: str, order: Dict): print(f" Insufficient cash to buy {symbol}") return + is_valid, reason = self.trading_bot.validate_buy_order( + symbol, quantity, order['current_price'], buying_power=cash_info['buying_power'] + ) + if not is_valid: + print(f" Buy order invalid: {reason}") + return + # Queue order in state order_details = { 'quantity': quantity, @@ -352,22 +374,25 @@ def _execute_buy_order(self, symbol: str, order: Dict): } self.state_manager.queue_buy_order(symbol, order_details) - if self.verbose: + if self.verbose and self.dry_run: print(f"\n{'='*70}") - print(f"EXECUTING BUY ORDER: {symbol}") + print(f"DRY RUN — BUY ORDER: {symbol}") print(f"{'='*70}") print(f"Quantity: {quantity}") print(f"Price: ${order['current_price']:,.2f}") print(f"Total: ${quantity * order['current_price']:,.2f}") - print(f"Mode: {'DRY RUN' if self.dry_run else 'LIVE'}") print(f"{'='*70}\n") - if not self.dry_run: - # Execute real order - result = self.trading_bot.place_cash_buy_order( - symbol, quantity, order['current_price'], dry_run=False + if not self.dry_run and self.executor: + task = ScheduledTask( + symbol=symbol, + side='buy', + quantity=quantity, + price=order['current_price'], + order_type='limit', + execute_date=date.today(), ) - + result = self.executor.submit(task) if result: order_id = result.get('id') if isinstance(result, dict) else None if order_id: @@ -389,22 +414,25 @@ def _execute_sell_order(self, symbol: str, order: Dict): } self.state_manager.queue_sell_order(symbol, order_details) - if self.verbose: + if self.verbose and self.dry_run: print(f"\n{'='*70}") - print(f"EXECUTING SELL ORDER: {symbol}") + print(f"DRY RUN — SELL ORDER: {symbol}") print(f"{'='*70}") print(f"Quantity: {quantity}") print(f"Price: ${order['current_price']:,.2f}") print(f"Total: ${quantity * order['current_price']:,.2f}") - print(f"Mode: {'DRY RUN' if self.dry_run else 'LIVE'}") print(f"{'='*70}\n") - if not self.dry_run: - # Execute real order - result = self.trading_bot.place_sell_order( - symbol, quantity, order['current_price'], dry_run=False + if not self.dry_run and self.executor: + task = ScheduledTask( + symbol=symbol, + side='sell', + quantity=quantity, + price=order['current_price'], + order_type='limit', + execute_date=date.today(), ) - + result = self.executor.submit(task) if result: order_id = result.get('id', 'unknown') self.state_manager.update_order_status( @@ -428,20 +456,26 @@ def _execute_stop_limit_sell_order(self, symbol: str, order: Dict): } self.state_manager.queue_sell_order(symbol, order_details) - if self.verbose: + if self.verbose and self.dry_run: print(f"\n{'='*70}") - print(f"EXECUTING STOP-LIMIT SELL: {symbol}") + print(f"DRY RUN — STOP-LIMIT SELL: {symbol}") print(f"{'='*70}") print(f"Quantity: {quantity}") print(f"Stop Price: ${stop_price:,.2f}") print(f"Limit Price: ${limit_price:,.2f}") - print(f"Mode: {'DRY RUN' if self.dry_run else 'LIVE'}") print(f"{'='*70}\n") - if not self.dry_run: - result = self.trading_bot.place_stop_limit_sell_order( - symbol, quantity, stop_price, limit_price, dry_run=False + if not self.dry_run and self.executor: + task = ScheduledTask( + symbol=symbol, + side='sell', + quantity=quantity, + price=limit_price, + order_type='stop_limit', + execute_date=date.today(), + stop_price=stop_price, ) + result = self.executor.submit(task) if result: order_id = result.get('id') if isinstance(result, dict) else None if order_id: @@ -463,19 +497,24 @@ def _execute_limit_sell_resubmit(self, symbol: str, order: Dict): } self.state_manager.queue_sell_order(symbol, order_details) - if self.verbose: + if self.verbose and self.dry_run: print(f"\n{'='*70}") - print(f"RESUBMITTING LIMIT SELL: {symbol}") + print(f"DRY RUN — RESUBMIT LIMIT SELL: {symbol}") print(f"{'='*70}") print(f"Quantity: {quantity}") print(f"Limit Price: ${price:,.2f} (original order price)") - print(f"Mode: {'DRY RUN' if self.dry_run else 'LIVE'}") print(f"{'='*70}\n") - if not self.dry_run: - result = self.trading_bot.place_sell_order( - symbol, quantity, price, dry_run=False + if not self.dry_run and self.executor: + task = ScheduledTask( + symbol=symbol, + side='sell', + quantity=quantity, + price=price, + order_type='limit', + execute_date=date.today(), ) + result = self.executor.submit(task) if result: order_id = result.get('id', 'unknown') self.state_manager.update_order_status( @@ -489,12 +528,10 @@ def _execute_stale_refresh(self, symbol: str, signal: Dict): order = signal['order'] if not self.dry_run: - cancelled = self.trading_bot.cancel_order(cancel_id, dry_run=False) + cancelled = self.trading_bot.cancel_order(cancel_id) if not cancelled: print(f" Failed to cancel stale order {cancel_id}, skipping replacement") return - else: - self.trading_bot.cancel_order(cancel_id, dry_run=True) # Place replacement stop-limit sell self._execute_stop_limit_sell_order(symbol, order) @@ -517,19 +554,24 @@ def _execute_paired_limit_buy(self, symbol: str, buy_order: Dict): } self.state_manager.queue_buy_order(order_symbol, order_details) - if self.verbose: + if self.verbose and self.dry_run: print(f"\n{'='*70}") - print(f"PAIRED LIMIT BUY: {order_symbol}") + print(f"DRY RUN — PAIRED LIMIT BUY: {order_symbol}") print(f"{'='*70}") print(f"Quantity: {quantity}") print(f"Limit Price: ${price:,.2f}") - print(f"Mode: {'DRY RUN' if self.dry_run else 'LIVE'}") print(f"{'='*70}\n") - if not self.dry_run: - result = self.trading_bot.place_cash_buy_order( - order_symbol, quantity, price, dry_run=False + if not self.dry_run and self.executor: + task = ScheduledTask( + symbol=order_symbol, + side='buy', + quantity=quantity, + price=price, + order_type='limit', + execute_date=date.today(), ) + result = self.executor.submit(task) if result: order_id = result.get('id') if isinstance(result, dict) else None if order_id: @@ -609,6 +651,10 @@ def run_once(self): # Print initial portfolio allocation self.print_portfolio_allocation() + # Re-attempt any orders deferred by PDT gate from a previous cycle + if self.executor: + self.executor.drain_deferred() + # Fetch open orders once (used by momentum_dca) open_orders = [] if self.strategy_name == 'momentum_dca_long': diff --git a/trading_system/state/blob_logger.py b/trading_system/state/blob_logger.py index 64baae9..90c0947 100644 --- a/trading_system/state/blob_logger.py +++ b/trading_system/state/blob_logger.py @@ -8,7 +8,7 @@ import json import os -from datetime import datetime +import datetime from pathlib import Path import requests @@ -33,7 +33,7 @@ def _serialize_state(state_manager, order_book=None, portfolio=None, recent_option_orders=None) -> dict: """Serialize StateManager state to a JSON-safe dictionary.""" snapshot = { - "timestamp": datetime.now().isoformat(), + "timestamp": datetime.datetime.now().isoformat(), "state": state_manager.state, "tickers": {}, } @@ -73,6 +73,8 @@ def _serialize_state(state_manager, order_book=None, portfolio=None, def _serialize_value(obj): """JSON serializer for objects not serializable by default.""" + if isinstance(obj, (datetime.datetime, datetime.date)): + return obj.isoformat() if hasattr(obj, "value"): return obj.value raise TypeError(f"Object of type {type(obj)} is not JSON serializable") @@ -87,7 +89,7 @@ def _log_local(state_manager, order_book=None, portfolio=None, portfolio=portfolio, drift_metrics=drift_metrics, recent_orders=recent_orders, recent_option_orders=recent_option_orders) - blob_key = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + blob_key = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") payload = json.dumps(snapshot, default=_serialize_value, indent=2) out_path = LOCAL_LOG_DIR / f"{blob_key}.json" @@ -110,7 +112,7 @@ def _log_remote(state_manager, order_book=None, portfolio=None, portfolio=portfolio, drift_metrics=drift_metrics, recent_orders=recent_orders, recent_option_orders=recent_option_orders) - blob_key = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + blob_key = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") payload = json.dumps(snapshot, default=_serialize_value) url = f"{NETLIFY_BLOBS_URL}/{config['site_id']}/{STORE_NAME}/{blob_key}" diff --git a/trading_system/strategies/momentum_dca_strategy.py b/trading_system/strategies/momentum_dca_strategy.py index fbad678..518a6a3 100644 --- a/trading_system/strategies/momentum_dca_strategy.py +++ b/trading_system/strategies/momentum_dca_strategy.py @@ -37,7 +37,7 @@ def __init__(self, symbols: List[str], coverage_threshold: float = 0.20, stop_offset_pct: float = 0.015, proximity_pct: float = 0.0075, coverage_range_pct: float = 0.08, buy_offset: float = 0.20, - lot_size: int = 150, + lot_size: int = DEFAULT_LOT_SIZE, stale_order_age_hours: int = 24, hedge_symbol_map: Dict = None): self.symbols = symbols diff --git a/utils/brokerage.py b/utils/brokerage.py new file mode 100644 index 0000000..e2da562 --- /dev/null +++ b/utils/brokerage.py @@ -0,0 +1,53 @@ +""" +Brokerage — abstract interface for trading platforms. + +The TradeExecutor depends on this interface, not on any concrete broker. +Swap RobinhoodClient for AlpacaClient or any other implementation without +touching the executor. +""" + +from abc import ABC, abstractmethod + + +class Brokerage(ABC): + + @abstractmethod + def place_limit_buy(self, symbol: str, quantity: float, price: float) -> dict | None: + """Place a limit buy order. Returns broker response dict or None on failure.""" + ... + + @abstractmethod + def place_limit_sell(self, symbol: str, quantity: float, price: float) -> dict | None: + """Place a limit sell order.""" + ... + + @abstractmethod + def place_stop_limit_sell(self, symbol: str, quantity: float, + stop_price: float, limit_price: float) -> dict | None: + """Place a stop-limit sell order.""" + ... + + @abstractmethod + def cancel_order(self, order_id: str) -> bool: + """Cancel an open order by ID. Returns True if cancelled.""" + ... + + @abstractmethod + def get_open_orders(self) -> list: + """Return list of open orders.""" + ... + + @abstractmethod + def get_positions(self) -> list: + """Return list of current positions.""" + ... + + @abstractmethod + def get_cash_balance(self) -> dict | None: + """Return cash balance dict with at least 'tradeable_cash' and 'buying_power'.""" + ... + + @abstractmethod + def get_pdt_status(self) -> dict | None: + """Return PDT status dict with 'day_trade_count' and 'flagged'.""" + ... diff --git a/utils/robinhood_client.py b/utils/robinhood_client.py new file mode 100644 index 0000000..8d2d48f --- /dev/null +++ b/utils/robinhood_client.py @@ -0,0 +1,774 @@ +""" +RobinhoodClient — concrete Brokerage implementation for Robinhood. + +Thin wrapper around robin_stocks. Only responsible for: + - Account locking and verification + - Reading account state (positions, orders, balances, PDT) + - Placing and cancelling orders via robin_stocks + +Execution quality logic (PDT gate, spread check, deferred queue) lives in +TradeExecutor, which injects this class via the Brokerage interface. +""" + +import os +import sys +from datetime import datetime, date +from typing import ClassVar + +import robin_stocks.robinhood as r +from dotenv import load_dotenv + +from .rh_auth import RobinhoodAuth +from .brokerage import Brokerage + + +class RobinhoodClient(Brokerage): + """ + Robinhood broker implementation. Cash-only, locked to one account. + + Patterns: + Singleton — one authenticated session per process (via create()) + Factory — create() owns env validation, login, and account lock + Decorator — @_retry wraps robin_stocks calls for transient failures + Facade — hides robin_stocks complexity behind a clean Brokerage interface + """ + + EXPECTED_ACCOUNT = "490706777" + _instance: ClassVar["RobinhoodClient | None"] = None + + def __init__(self, account_number: str, auth: RobinhoodAuth): + self.account_number = account_number + self.auth = auth + + @classmethod + def create(cls) -> "RobinhoodClient": + """Return the singleton client, creating and authenticating it on first call.""" + if cls._instance is not None: + return cls._instance + + load_dotenv() + + account_number = os.getenv('RH_AUTOMATED_ACCOUNT_NUMBER') + if not account_number: + print("[ERR] ERROR: RH_AUTOMATED_ACCOUNT_NUMBER not set in .env") + sys.exit(1) + + if account_number != cls.EXPECTED_ACCOUNT: + print(f"[WARN] WARNING: Expected account {cls.EXPECTED_ACCOUNT}, got {account_number}") + if sys.stdin.isatty(): + if input("Continue anyway? (yes/no): ").lower() != 'yes': + sys.exit(1) + else: + print(" Non-interactive mode: proceeding with configured account") + + auth = RobinhoodAuth() + auth.login() + client = cls(account_number=account_number, auth=auth) + client._verify_account() + cls._instance = client + return cls._instance + + @classmethod + def reset(cls) -> None: + """Clear the singleton — use in tests or after explicit logout.""" + cls._instance = None + + # ------------------------------------------------------------------ + # Brokerage interface — order placement + # ------------------------------------------------------------------ + + def place_limit_buy(self, symbol: str, quantity: float, price: float) -> dict | None: + """Place a limit buy order.""" + print(f"\n{'='*70}") + print(f"BUY ORDER - LIVE") + print(f"{'='*70}") + print(f" Account: {self.account_number}") + print(f" Symbol: {symbol} Qty: {quantity} Limit: ${price:.2f}") + print(f" Total Cost: ${quantity * price:.2f}") + + try: + print("\n Executing order...") + order = r.orders.order_buy_limit( + symbol=symbol, + quantity=quantity, + limitPrice=price, + account_number=self.account_number, + ) + order_id = order.get('id') if isinstance(order, dict) else None + if order_id: + print(f" [OK] Order placed: {order_id} | state: {order.get('state', 'N/A')}") + print(f"{'='*70}\n") + return order + + # Broker rejected + detail = None + if isinstance(order, dict): + detail = order.get('detail') or order.get('non_field_errors') or order.get('message') + print(f" [ERR] Buy order failed: {detail or order}") + + # PDT retry: cancel conflicting buy and resubmit once + if isinstance(detail, str) and 'pdt' in detail.lower(): + print(f" PDT hit — cancelling existing buy(s) for {symbol} and retrying...") + cancelled_qty = self._cancel_existing_orders(symbol, 'buy') + if cancelled_qty: + retry = r.orders.order_buy_limit( + symbol=symbol, + quantity=quantity, + limitPrice=price, + account_number=self.account_number, + ) + retry_id = retry.get('id') if isinstance(retry, dict) else None + if retry_id: + print(f" [OK] Retry placed: {retry_id}") + print(f"{'='*70}\n") + return retry + retry_detail = None + if isinstance(retry, dict): + retry_detail = retry.get('detail') or retry.get('non_field_errors') + print(f" Retry failed: {retry_detail or retry}") + + print(f"{'='*70}\n") + return order + + except Exception as e: + print(f" [ERR] Order failed: {e}") + print(f"{'='*70}\n") + return None + + def place_limit_sell(self, symbol: str, quantity: float, price: float) -> dict | None: + """Place a limit sell order. Validates position first.""" + print(f"\n{'='*70}") + print(f"SELL ORDER - LIVE") + print(f"{'='*70}") + print(f" Account: {self.account_number}") + print(f" Symbol: {symbol} Qty: {quantity} Limit: ${price:.2f}") + + positions = self.get_positions() + position = next((p for p in positions if p['symbol'] == symbol), None) + if not position: + print(f" [ERR] No position in {symbol}") + print(f"{'='*70}\n") + return None + if quantity > position['quantity']: + print(f" [ERR] Insufficient shares (have {position['quantity']})") + print(f"{'='*70}\n") + return None + + try: + print("\n Executing order...") + order = r.orders.order_sell_limit( + symbol=symbol, + quantity=quantity, + limitPrice=price, + account_number=self.account_number, + ) + order_id = order.get('id', 'N/A') if isinstance(order, dict) else 'N/A' + print(f" [OK] Order placed: {order_id} | state: {order.get('state', 'N/A') if isinstance(order, dict) else 'N/A'}") + print(f"{'='*70}\n") + return order + + except Exception as e: + print(f" [ERR] Order failed: {e}") + print(f"{'='*70}\n") + return None + + def place_stop_limit_sell(self, symbol: str, quantity: float, + stop_price: float, limit_price: float) -> dict | None: + """Place a stop-limit sell order. Validates position first.""" + print(f"\n{'='*70}") + print(f"STOP-LIMIT SELL ORDER - LIVE") + print(f"{'='*70}") + print(f" Account: {self.account_number}") + print(f" Symbol: {symbol} Qty: {quantity} Stop: ${stop_price:.2f} Limit: ${limit_price:.2f}") + + positions = self.get_positions() + position = next((p for p in positions if p['symbol'] == symbol), None) + if not position: + print(f" [ERR] No position in {symbol}") + print(f"{'='*70}\n") + return None + if quantity > position['quantity']: + print(f" [ERR] Insufficient shares (have {position['quantity']})") + print(f"{'='*70}\n") + return None + + # Prevent zero-fills on gap-through: ensure stop != limit + if abs(stop_price - limit_price) < 0.01: + limit_price = round(stop_price * 0.995, 2) + print(f" Stop=Limit buffer applied: limit adjusted to ${limit_price:.2f}") + + try: + print("\n Executing order...") + order = r.orders.order_sell_stop_limit( + symbol=symbol, + quantity=quantity, + limitPrice=limit_price, + stopPrice=stop_price, + account_number=self.account_number, + timeInForce='gtc', + ) + order_id = order.get('id') if isinstance(order, dict) else None + if order_id: + print(f" [OK] Order placed: {order_id} | state: {order.get('state', 'N/A')}") + print(f"{'='*70}\n") + return order + + detail = None + if isinstance(order, dict): + detail = order.get('detail') or order.get('non_field_errors') + print(f" [ERR] Stop-limit failed: {detail or order}") + print(f"{'='*70}\n") + return order + + except Exception as e: + print(f" [ERR] Order failed: {e}") + print(f"{'='*70}\n") + return None + + def cancel_order(self, order_id: str) -> bool: + """Cancel an open order by ID. Only cancels if still in a cancellable state.""" + try: + order_info = r.orders.get_stock_order_info(order_id) + if not order_info or not isinstance(order_info, dict): + print(f" cancel_order: order {order_id} not found") + return False + state = order_info.get('state', '') + if state not in ('queued', 'unconfirmed', 'confirmed'): + print(f" cancel_order: order {order_id} in state '{state}', cannot cancel") + return False + r.orders.cancel_stock_order(order_id) + print(f" cancel_order: cancelled {order_id}") + return True + except Exception as e: + print(f" cancel_order: error cancelling {order_id}: {e}") + return False + + # ------------------------------------------------------------------ + # Brokerage interface — account reads + # ------------------------------------------------------------------ + + def get_open_orders(self) -> list: + """Get all open orders for this account.""" + try: + open_orders = r.orders.get_all_open_stock_orders() + orders = [] + if open_orders: + for order in open_orders: + order_id = order.get('id', 'N/A') + symbol = order.get('symbol', 'N/A') + if symbol == 'N/A': + instrument_id = order.get('instrument_id') + if instrument_id: + try: + instrument = r.stocks.get_instrument_by_url( + f"https://api.robinhood.com/instruments/{instrument_id}/" + ) + if instrument: + symbol = instrument.get('symbol', 'N/A') + except Exception: + pass + side = order.get('side', 'N/A') + order_type = order.get('type', 'N/A') + trigger = order.get('trigger', 'immediate') + state = order.get('state', 'N/A') + quantity = float(order.get('quantity', 0)) + limit_price = order.get('price') + stop_price = order.get('stop_price') + created_at = order.get('created_at', 'N/A') + updated_at = order.get('updated_at', 'N/A') + try: + if created_at != 'N/A': + created_dt = datetime.fromisoformat(created_at.replace('Z', '+00:00')) + created_at = created_dt.strftime('%Y-%m-%d %H:%M:%S') + except Exception: + pass + if trigger == 'stop' and order_type == 'limit': + order_desc = 'Stop Limit' + elif trigger == 'stop': + order_desc = 'Stop Loss' + elif order_type == 'limit': + order_desc = 'Limit' + else: + order_desc = 'Market' + orders.append({ + 'order_id': order_id, + 'symbol': symbol, + 'side': side.upper() if side != 'N/A' else 'N/A', + 'order_type': order_desc, + 'trigger': trigger, + 'state': state, + 'quantity': quantity, + 'limit_price': float(limit_price) if limit_price else None, + 'stop_price': float(stop_price) if stop_price else None, + 'created_at': created_at, + 'updated_at': updated_at, + }) + return orders + except Exception as e: + print(f"[ERR] Error getting open orders: {e}") + return [] + + def get_positions(self) -> list: + """Get current equity positions.""" + try: + holdings = r.account.build_holdings() + positions = [] + if holdings: + for symbol, data in holdings.items(): + quantity = float(data.get('quantity', 0)) + if quantity > 0: + avg_price = float(data.get('average_buy_price', 0)) + current_price = float(data.get('price', 0)) + equity = float(data.get('equity', 0)) + profit_loss = (current_price - avg_price) * quantity + profit_loss_pct = ((current_price - avg_price) / avg_price * 100 + if avg_price > 0 else 0) + positions.append({ + 'symbol': symbol, + 'name': data.get('name', ''), + 'type': data.get('type', ''), + 'quantity': quantity, + 'avg_buy_price': avg_price, + 'current_price': current_price, + 'equity': equity, + 'profit_loss': profit_loss, + 'profit_loss_pct': profit_loss_pct, + 'percent_change': self._safe_float(data.get('percent_change')), + 'equity_change': self._safe_float(data.get('equity_change')), + 'pe_ratio': self._safe_float(data.get('pe_ratio')), + 'percentage': self._safe_float(data.get('percentage')), + }) + return positions + except Exception as e: + print(f"[ERR] Error getting positions: {e}") + return [] + + def get_cash_balance(self) -> dict | None: + """Get available cash balance.""" + try: + account = r.profiles.load_account_profile(account_number=self.account_number) + cash = float(account.get('cash', 0)) + cash_available_for_withdrawal = float(account.get('cash_available_for_withdrawal', 0)) + buying_power = float(account.get('buying_power', 0)) + return { + 'cash': cash, + 'cash_available_for_withdrawal': cash_available_for_withdrawal, + 'buying_power': buying_power, + 'tradeable_cash': cash, + } + except Exception as e: + print(f"[ERR] Error getting cash balance: {e}") + return None + + def get_pdt_status(self) -> dict | None: + """Get Pattern Day Trading status.""" + try: + import time + time.sleep(0.5) + account = r.profiles.load_account_profile(account_number=self.account_number) + day_trade_count = int(account.get('day_trade_count') or 0) + flagged = account.get('pattern_day_trader', False) + trades = [] + day_trade_info = account.get('day_trades', []) + for dt in (day_trade_info or []): + opened = dt.get('opened_at', 'N/A') + closed = dt.get('closed_at', 'N/A') + try: + if opened != 'N/A': + opened = datetime.fromisoformat( + opened.replace('Z', '+00:00')).strftime('%Y-%m-%d %H:%M') + if closed != 'N/A': + closed = datetime.fromisoformat( + closed.replace('Z', '+00:00')).strftime('%Y-%m-%d %H:%M') + except Exception: + pass + trades.append(f"opened {opened} → closed {closed}") + return {'day_trade_count': day_trade_count, 'flagged': flagged, 'trades': trades} + except Exception as e: + print(f" Could not fetch PDT status: {e}") + return None + + # ------------------------------------------------------------------ + # Extended reads (not part of Brokerage ABC, used by main.py) + # ------------------------------------------------------------------ + + def get_portfolio_summary(self, symbols=None): + """Full portfolio summary — positions, cash, orders, PDT, options.""" + try: + r.profiles.load_account_profile(account_number=self.account_number) + portfolio = r.profiles.load_portfolio_profile(account_number=self.account_number) + equity = float(portfolio.get('equity', 0)) + market_value = float(portfolio.get('market_value', 0)) + cash_info = self.get_cash_balance() + positions = self.get_positions() + open_orders = self.get_open_orders() + option_positions = self.get_option_positions() + + if symbols: + positions = [p for p in positions if p['symbol'] in symbols] + open_orders = [o for o in open_orders if o['symbol'] in symbols] + + total_position_value = sum(pos['equity'] for pos in positions) + + return { + 'equity': equity, + 'market_value': market_value, + 'cash': cash_info, + 'positions': positions, + 'open_orders': open_orders, + 'options': option_positions, + 'total_position_value': total_position_value, + } + except Exception as e: + print(f"[ERR] Error getting portfolio summary: {e}") + return None + + def get_recent_orders(self, days=7) -> list: + """Get recently filled/cancelled orders.""" + try: + from datetime import timedelta + cutoff = datetime.utcnow() - timedelta(days=days) + cutoff_str = cutoff.strftime('%Y-%m-%dT00:00:00Z') + all_orders = r.orders.get_all_stock_orders(info=None) + orders = [] + if not all_orders: + return orders + for order in all_orders: + state = order.get('state', '') + if state not in ('filled', 'cancelled', 'failed', 'rejected'): + continue + if order.get('updated_at', '') < cutoff_str: + continue + symbol = order.get('symbol', 'N/A') + if symbol == 'N/A': + instrument_id = order.get('instrument_id') + if instrument_id: + try: + inst = r.stocks.get_instrument_by_url( + f"https://api.robinhood.com/instruments/{instrument_id}/") + if inst: + symbol = inst.get('symbol', 'N/A') + except Exception: + pass + trigger = order.get('trigger', 'immediate') + order_type = order.get('type', 'N/A') + if trigger == 'stop' and order_type == 'limit': + order_desc = 'Stop Limit' + elif trigger == 'stop': + order_desc = 'Stop Loss' + elif order_type == 'limit': + order_desc = 'Limit' + else: + order_desc = 'Market' + created_at = order.get('created_at', 'N/A') + try: + if created_at != 'N/A': + created_at = datetime.fromisoformat( + created_at.replace('Z', '+00:00')).strftime('%Y-%m-%d %H:%M:%S') + except Exception: + pass + limit_price = order.get('price') + stop_price = order.get('stop_price') + average_price = order.get('average_price') + cumulative_quantity = order.get('cumulative_quantity') + orders.append({ + 'order_id': order.get('id', 'N/A'), + 'symbol': symbol, + 'side': order.get('side', 'N/A').upper(), + 'order_type': order_desc, + 'trigger': trigger, + 'state': state, + 'quantity': float(order.get('quantity', 0)), + 'limit_price': float(limit_price) if limit_price else None, + 'stop_price': float(stop_price) if stop_price else None, + 'average_price': float(average_price) if average_price else None, + 'filled_quantity': float(cumulative_quantity) if cumulative_quantity else None, + 'created_at': created_at, + 'updated_at': order.get('updated_at', 'N/A'), + }) + return orders + except Exception as e: + print(f"[ERR] Error getting recent orders: {e}") + return [] + + def get_recent_option_orders(self, days=7) -> list: + """Get recently filled/cancelled option orders.""" + try: + from datetime import timedelta + cutoff = datetime.utcnow() - timedelta(days=days) + cutoff_str = cutoff.strftime('%Y-%m-%dT00:00:00Z') + raw_orders = r.orders.get_all_option_orders() + if not raw_orders: + return [] + orders = [] + for order in raw_orders: + state = order.get('state', '') + if state not in ('filled', 'cancelled', 'failed', 'rejected'): + continue + if order.get('updated_at', '') < cutoff_str: + continue + orders.append({ + 'order_id': order.get('id', 'N/A'), + 'state': state, + 'quantity': float(order.get('quantity', 0)), + 'price': float(order.get('price', 0) or 0), + 'direction': order.get('direction', 'N/A'), + 'order_type': order.get('type', 'N/A'), + 'created_at': order.get('created_at', 'N/A'), + 'updated_at': order.get('updated_at', 'N/A'), + }) + return orders + except Exception as e: + print(f"[ERR] Error getting recent option orders: {e}") + return [] + + def get_open_option_orders(self) -> list: + """Get all open option orders.""" + try: + raw_orders = r.orders.get_all_open_option_orders( + account_number=self.account_number) + if not raw_orders: + return [] + return raw_orders + except Exception as e: + print(f"[ERR] Error getting open option orders: {e}") + return [] + + def get_option_positions(self) -> list: + """Get open option positions with greeks and analytics.""" + try: + raw_positions = r.options.get_open_option_positions( + account_number=self.account_number) + if not raw_positions: + return [] + positions = [] + underlying_symbols = set() + for pos in raw_positions: + quantity = float(pos.get('quantity', 0)) + if quantity == 0: + continue + chain_symbol = pos.get('chain_symbol', 'N/A') + underlying_symbols.add(chain_symbol) + avg_price = float(pos.get('average_price', 0)) / 100 + pos_type = pos.get('type', 'long') + multiplier = float(pos.get('trade_value_multiplier', '100')) + option_url = pos.get('option', '') + option_id = option_url.rstrip('/').split('/')[-1] if option_url else None + instrument = {} + if option_id: + try: + instrument = r.options.get_option_instrument_data_by_id(option_id) or {} + except Exception: + pass + strike = float(instrument.get('strike_price', 0)) + expiration = instrument.get('expiration_date', 'N/A') + option_type = instrument.get('type', 'N/A') + market_data = {} + if option_id: + try: + md = r.options.get_option_market_data_by_id(option_id) + if md and isinstance(md, list) and len(md) > 0: + market_data = md[0] + elif md and isinstance(md, dict): + market_data = md + except Exception: + pass + delta = self._safe_float(market_data.get('delta')) + gamma = self._safe_float(market_data.get('gamma')) + theta = self._safe_float(market_data.get('theta')) + vega = self._safe_float(market_data.get('vega')) + rho = self._safe_float(market_data.get('rho')) + iv = self._safe_float(market_data.get('implied_volatility')) + mark_price = self._safe_float(market_data.get('adjusted_mark_price')) + chance_profit_long = self._safe_float(market_data.get('chance_of_profit_long')) + chance_profit_short = self._safe_float(market_data.get('chance_of_profit_short')) + break_even = self._safe_float(market_data.get('break_even_price')) + underlying_price = None + try: + price_data = r.stocks.get_latest_price(chain_symbol) + if price_data and price_data[0]: + underlying_price = float(price_data[0]) + except Exception: + pass + dte = None + if expiration and expiration != 'N/A': + try: + exp_date = datetime.strptime(expiration, '%Y-%m-%d').date() + dte = (exp_date - date.today()).days + except Exception: + pass + current_value = (mark_price or 0) * quantity * multiplier + recommendation = self._recommend_option_action( + option_type, pos_type, delta, theta, iv, dte, mark_price, + avg_price, underlying_price, strike, chance_profit_long, + chance_profit_short) + expected_pnl = self._calculate_expected_pnl( + delta, gamma, underlying_price, quantity, multiplier, pos_type) + greeks = {'delta': delta, 'gamma': gamma, 'theta': theta, + 'vega': vega, 'rho': rho, 'iv': iv} + positions.append({ + 'chain_symbol': chain_symbol, + 'option_type': option_type, + 'strike': strike, + 'expiration': expiration, + 'quantity': quantity, + 'avg_price': avg_price, + 'mark_price': mark_price, + 'current_value': current_value, + 'position_type': pos_type, + 'dte': dte, + 'underlying_price': underlying_price, + 'greeks': greeks, + 'break_even': break_even, + 'chance_profit_long': chance_profit_long, + 'chance_profit_short': chance_profit_short, + 'expected_pnl': expected_pnl, + 'recommendation': recommendation, + }) + return positions + except Exception as e: + print(f"[ERR] Error getting option positions: {e}") + return [] + + # ------------------------------------------------------------------ + # Validation helpers + # ------------------------------------------------------------------ + + def validate_buy_order(self, symbol, quantity, price, buying_power: float | None = None): + """Returns (is_valid, reason). Pass buying_power to skip a redundant API call.""" + if buying_power is None: + cash_info = self.get_cash_balance() + if not cash_info: + return False, "Cannot retrieve cash balance" + buying_power = cash_info['buying_power'] + total_cost_with_buffer = quantity * price * 1.01 + if total_cost_with_buffer > buying_power: + return False, (f"Insufficient buying power: need " + f"${total_cost_with_buffer:,.2f}, " + f"have ${buying_power:,.2f}") + try: + quote = r.stocks.get_quotes(symbol) + if not quote or len(quote) == 0: + return False, f"Invalid symbol: {symbol}" + except Exception: + return False, f"Cannot get quote for {symbol}" + return True, "Order validated" + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _verify_account(self): + try: + account = r.profiles.load_account_profile(account_number=self.account_number) + account_type = account.get('type', 'unknown') + print(f"\n{'='*70}") + print(f"[LOCKED] LOCKED TO ACCOUNT: {self.account_number}") + print(f"{'='*70}") + print(f" Account Type: {account_type}") + if account_type != 'cash': + print(f" [WARN] Account type is '{account_type}', not 'cash'") + print(f"{'='*70}\n") + except Exception as e: + print(f"[ERR] Cannot access account {self.account_number}: {e}") + sys.exit(1) + + def _cancel_existing_orders(self, symbol, side) -> int: + """Cancel all open orders for symbol on the given side. Returns cancelled qty.""" + target_instrument_url = None + try: + instruments = r.stocks.get_instruments_by_symbols(symbol, info='url') + if instruments: + target_instrument_url = instruments[0] + except Exception as e: + print(f" Failed to resolve instrument for {symbol}: {e}") + return 0 + if not target_instrument_url: + return 0 + open_orders = r.orders.get_all_open_stock_orders(account_number=self.account_number) + cancelled_qty = 0 + if open_orders: + for order in open_orders: + if (order.get('instrument', '') == target_instrument_url + and order.get('side', '') == side): + existing_id = order.get('id') + r.orders.cancel_stock_order(existing_id) + try: + cancelled_qty += int(float(order.get('quantity', '0'))) + except (ValueError, TypeError): + pass + return cancelled_qty + + @staticmethod + def _safe_float(value): + try: + return float(value) if value is not None else None + except (ValueError, TypeError): + return None + + def _calculate_expected_pnl(self, delta, gamma, underlying_price, + quantity, multiplier, pos_type): + if not underlying_price or delta is None: + return None + sign = 1.0 if pos_type == 'long' else -1.0 + scenarios = {} + for pct_label, pct in [('-5%', -0.05), ('-1%', -0.01), ('+1%', 0.01), ('+5%', 0.05)]: + dollar_move = underlying_price * pct + option_delta_price = (delta or 0) * dollar_move + if gamma: + option_delta_price += 0.5 * gamma * dollar_move ** 2 + scenarios[pct_label] = round(sign * option_delta_price * quantity * multiplier, 2) + if delta is not None: + scenarios['theta_daily'] = round(sign * (delta or 0) * quantity * multiplier, 2) + return scenarios + + def _recommend_option_action(self, option_type, pos_type, delta, theta, iv, + dte, mark_price, avg_price, underlying_price, + strike, chance_profit_long, chance_profit_short): + reasons = [] + action = 'HOLD' + if dte is not None and dte <= 0: + return {'action': 'CLOSE', 'reasons': ['Expired or expiring today']} + chance_of_profit = chance_profit_long if pos_type == 'long' else chance_profit_short + if pos_type == 'long': + if mark_price and avg_price and avg_price > 0: + gain_pct = (mark_price - avg_price) / avg_price * 100 + if gain_pct >= 100: + action = 'CLOSE' + reasons.append(f'Up {gain_pct:.0f}% — take profit') + elif gain_pct >= 50: + reasons.append(f'Up {gain_pct:.0f}% — consider partial close') + if dte is not None and dte <= 7 and theta is not None and theta < -0.03: + action = 'CLOSE' + reasons.append(f'DTE={dte}, heavy theta decay (${theta:.3f}/day)') + elif dte is not None and dte <= 14: + reasons.append(f'DTE={dte} — monitor theta decay') + if chance_of_profit is not None and chance_of_profit < 0.20: + action = 'CLOSE' + reasons.append(f'Low probability of profit ({chance_of_profit:.0%})') + if underlying_price and strike and option_type in ('call', 'put'): + if option_type == 'call' and underlying_price < strike * 0.90: + reasons.append('Deep OTM call') + elif option_type == 'put' and underlying_price > strike * 1.10: + reasons.append('Deep OTM put') + else: + if mark_price and avg_price and avg_price > 0: + decay_pct = (avg_price - mark_price) / avg_price * 100 + if decay_pct >= 80: + action = 'CLOSE' + reasons.append(f'Captured {decay_pct:.0f}% of premium — close to lock in') + elif decay_pct >= 50: + reasons.append(f'Captured {decay_pct:.0f}% of premium — consider closing') + if iv is not None and iv > 0.80: + reasons.append(f'High IV ({iv:.0%}) — increased risk') + if dte is not None and dte <= 3 and underlying_price and strike: + if option_type == 'call' and underlying_price >= strike: + action = 'CLOSE' + reasons.append('ITM near expiration — assignment risk') + elif option_type == 'put' and underlying_price <= strike: + action = 'CLOSE' + reasons.append('ITM near expiration — assignment risk') + if not reasons: + reasons.append('No immediate signals') + return {'action': action, 'reasons': reasons} +