diff --git a/examples/shortable_backtest_crypto_loop.py b/examples/shortable_backtest_crypto_loop.py new file mode 100644 index 0000000000..296f061cca --- /dev/null +++ b/examples/shortable_backtest_crypto_loop.py @@ -0,0 +1,94 @@ +import os +import pandas as pd +import qlib +from qlib.data import D +from qlib.constant import REG_CRYPTO + +from qlib.backtest.shortable_backtest import ShortableExecutor, ShortableAccount +from qlib.backtest.shortable_exchange import ShortableExchange +from qlib.backtest.decision import OrderDir +from qlib.contrib.strategy.signal_strategy import LongShortTopKStrategy +from qlib.backtest.utils import CommonInfrastructure + + +def main(): + provider = os.path.expanduser("~/.qlib/qlib_data/crypto_data_perp") + qlib.init(provider_uri=provider, region=REG_CRYPTO, kernels=1) + + start = pd.Timestamp("2021-07-11") + end = pd.Timestamp("2021-08-10") + + # Universe + inst_conf = D.instruments("all") + codes = D.list_instruments(inst_conf, start_time=start, end_time=end, freq="day", as_list=True)[:20] + if not codes: + print("No instruments.") + return + + # Exchange + ex = ShortableExchange( + freq="day", + start_time=start, + end_time=end, + codes=codes, + deal_price="$close", + open_cost=0.0005, + close_cost=0.0015, + min_cost=0.0, + impact_cost=0.0, + limit_threshold=None, + ) + + # Account and executor + account = ShortableAccount(benchmark_config={"benchmark": None}) + exe = ShortableExecutor( + time_per_step="day", + generate_portfolio_metrics=True, + trade_exchange=ex, + region="crypto", + verbose=False, + account=account, + ) + # Build and inject common infrastructure to executor (and later strategy) + common_infra = CommonInfrastructure(trade_account=account, trade_exchange=ex) + exe.reset(common_infra=common_infra, start_time=start, end_time=end) + + # Precompute momentum signal for the whole period (shift=1 used by strategy) + feat = D.features(codes, ["$close"], start, end, freq="day", disk_cache=True) + if feat is None or feat.empty: + print("No features to build signal.") + return + feat = feat.sort_index() + grp = feat.groupby("instrument")["$close"] + prev_close = grp.shift(1) + mom = (feat["$close"] / prev_close - 1.0).rename("score") + # Use MultiIndex Series (instrument, datetime) + signal_series = mom.dropna() + + # Strategy (TopK-aligned, long-short) + strat = LongShortTopKStrategy( + topk_long=3, + topk_short=3, + n_drop_long=1, + n_drop_short=1, + only_tradable=False, + forbid_all_trade_at_limit=True, + signal=signal_series, + trade_exchange=ex, + ) + # Bind strategy infra explicitly with the same common_infra + strat.reset(level_infra=exe.get_level_infra(), common_infra=common_infra) + + # Drive by executor calendar + while not exe.finished(): + td = strat.generate_trade_decision() + exe.execute(td) + + # Output metrics + df, meta = exe.trade_account.get_portfolio_metrics() + print("Portfolio metrics meta:", meta) + print("Portfolio df tail:\n", df.tail() if hasattr(df, "tail") else df) + + +if __name__ == "__main__": + main() diff --git a/examples/shortable_debug_day.py b/examples/shortable_debug_day.py new file mode 100644 index 0000000000..7e7bd86a3b --- /dev/null +++ b/examples/shortable_debug_day.py @@ -0,0 +1,80 @@ +import os +import pandas as pd +import qlib +from qlib.data import D +from qlib.constant import REG_CRYPTO +from qlib.backtest.decision import OrderDir +from qlib.backtest.shortable_exchange import ShortableExchange + + +def main(): + provider = os.path.expanduser("~/.qlib/qlib_data/crypto_data_perp") + qlib.init(provider_uri=provider, region=REG_CRYPTO, kernels=1) + + start = pd.Timestamp("2021-07-11") + end = pd.Timestamp("2021-08-10") + day = pd.Timestamp("2021-08-10") + + inst_conf = D.instruments("all") + codes = D.list_instruments(inst_conf, start_time=start, end_time=end, freq="day", as_list=True)[:10] + + ex = ShortableExchange( + freq="day", + start_time=start, + end_time=end, + codes=codes, + deal_price="$close", + open_cost=0.0005, + close_cost=0.0015, + min_cost=0.0, + impact_cost=0.0, + limit_threshold=None, + ) + + feat = D.features(codes, ["$close"], day - pd.Timedelta(days=10), day, freq="day", disk_cache=True) + g = feat.groupby("instrument")["$close"] + last = g.last() + # Use the second-to-last value per group and drop the datetime level, ensuring index is instrument + prev = g.apply(lambda s: s.iloc[-2]) + sig = (last / prev - 1.0).dropna().sort_values(ascending=False) + + longs = sig.head(3).index.tolist() + shorts = sig.tail(3).index.tolist() + + equity = 1_000_000.0 + long_weight = 0.5 / max(len(longs), 1) + short_weight = -0.5 / max(len(shorts), 1) + + print("day:", day.date()) + for leg, lst, w, dir_ in [ + ("LONG", longs, long_weight, OrderDir.BUY), + ("SHORT", shorts, short_weight, OrderDir.SELL), + ]: + print(f"\n{leg} candidates:") + for code in lst: + try: + px = ex.get_deal_price(code, day, day, dir_) + fac = ex.get_factor(code, day, day) + unit = ex.get_amount_of_trade_unit(fac, code, day, day) + tradable = ex.is_stock_tradable(code, day, day, dir_) + raw = (w * equity) / px if px else 0.0 + rounded = ex.round_amount_by_trade_unit(abs(raw), fac) if px else 0.0 + if dir_ == OrderDir.SELL: + rounded = -rounded + print( + code, + { + "price": px, + "factor": fac, + "unit": unit, + "tradable": tradable, + "raw_shares": raw, + "rounded": rounded, + }, + ) + except Exception as e: + print(code, "error:", e) + + +if __name__ == "__main__": + main() diff --git a/examples/workflow_by_code_longshort_crypto.py b/examples/workflow_by_code_longshort_crypto.py new file mode 100644 index 0000000000..9f5dc5b0a0 --- /dev/null +++ b/examples/workflow_by_code_longshort_crypto.py @@ -0,0 +1,217 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +Long-Short workflow by code (Crypto Perp). + +This script mirrors `workflow_by_code_longshort.py` but switches to a crypto futures +dataset/provider and sets the benchmark to BTCUSDT. Other parts are kept the same. +""" +# pylint: disable=C0301 + +import sys +import multiprocessing as mp +import os +import qlib +from qlib.utils import init_instance_by_config, flatten_dict +from qlib.constant import REG_CRYPTO + + +if __name__ == "__main__": + # Windows compatibility: spawn mode needs freeze_support and avoid heavy top-level imports + if sys.platform.startswith("win"): + mp.freeze_support() + # Emulate Windows spawn on POSIX if needed + if os.environ.get("WINDOWS_SPAWN_TEST") == "1" and not sys.platform.startswith("win"): + try: + mp.set_start_method("spawn", force=True) + except RuntimeError: + pass + # Lazy imports to avoid circular import issues on Windows spawn mode + from qlib.workflow import R + from qlib.workflow.record_temp import SignalRecord, SigAnaRecord + from qlib.data import D + + # Initialize with crypto perp data provider (ensure this path exists in your env) + PROVIDER_URI = "~/.qlib/qlib_data/crypto_data_perp" + # Use crypto-specific region to align trading rules/calendars with provider data + qlib.init(provider_uri=PROVIDER_URI, region=REG_CRYPTO, kernels=1) + + # Auto-select benchmark by data source: cn_data -> SH000300; crypto -> BTCUSDT + # Fallback: if path not resolvable, default to SH000300 for safety + try: + from qlib.config import C + + data_roots = {k: str(C.dpm.get_data_uri(k)) for k in C.dpm.provider_uri.keys()} + DATA_ROOTS_STR = " ".join(data_roots.values()).lower() + IS_CN = ("cn_data" in DATA_ROOTS_STR) or ("cn\x5fdata" in DATA_ROOTS_STR) + BENCHMARK_AUTO = "SH000300" if IS_CN else "BTCUSDT" + except Exception: # pylint: disable=W0718 + BENCHMARK_AUTO = "SH000300" + + # Dataset & model + data_handler_config = { + "start_time": "2019-01-02", + "end_time": "2025-08-07", + "fit_start_time": "2019-01-02", + "fit_end_time": "2022-12-19", + "instruments": "all", + "label": ["Ref($close, -2) / Ref($close, -1) - 1"], + } + + DEBUG_FAST = os.environ.get("FAST_DEBUG") == "1" + if DEBUG_FAST: + # Use the latest available calendar to auto-derive a tiny, non-empty window + cal = D.calendar(freq="day", future=False) + if len(cal) >= 45: + end_dt = cal[-1] + # last 45 days: 20d fit, 10d valid, 15d test + fit_start_dt = cal[-45] + fit_end_dt = cal[-25] + valid_start_dt = cal[-24] + valid_end_dt = cal[-15] + test_start_dt = cal[-14] + test_end_dt = end_dt + data_handler_config.update( + { + "fit_start_time": fit_start_dt, + "fit_end_time": fit_end_dt, + "start_time": fit_start_dt, + "end_time": end_dt, + } + ) + + dataset_config = { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": data_handler_config, + }, + "segments": { + # train uses fit window; split the rest to valid/test roughly + "train": (data_handler_config["fit_start_time"], data_handler_config["fit_end_time"]), + "valid": ("2022-12-20", "2023-12-31"), + "test": ("2024-01-01", data_handler_config["end_time"]), + }, + }, + } + + # Predefine debug dates to avoid linter used-before-assignment warning + VALID_START_DT = VALID_END_DT = TEST_START_DT = TEST_END_DT = None + + if DEBUG_FAST and len(D.calendar(freq="day", future=False)) >= 45: + dataset_config["kwargs"]["segments"] = { + "train": (data_handler_config["fit_start_time"], data_handler_config["fit_end_time"]), + "valid": (VALID_START_DT, VALID_END_DT), + "test": (TEST_START_DT, TEST_END_DT), + } + + model_config = { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + "kwargs": { + "loss": "mse", + "colsample_bytree": 0.8879, + "learning_rate": 0.0421, + "subsample": 0.8789, + "lambda_l1": 205.6999, + "lambda_l2": 580.9768, + "max_depth": 8, + "num_leaves": 210, + "num_threads": 20, + }, + } + + if DEBUG_FAST: + model_config["kwargs"].update({"num_threads": 2, "num_boost_round": 10}) + + model = init_instance_by_config(model_config) + dataset = init_instance_by_config(dataset_config) + + # Prefer contrib's crypto version; fallback to default PortAnaRecord (no external local dependency) + try: + from qlib.contrib.workflow.crypto_record_temp import CryptoPortAnaRecord as PortAnaRecord # type: ignore + + print("Using contrib's crypto version of CryptoPortAnaRecord as PortAnaRecord") + except Exception: # pylint: disable=W0718 + from qlib.workflow.record_temp import PortAnaRecord + + print("Using default version of PortAnaRecord") + + # Align backtest time to test segment + test_start, test_end = dataset_config["kwargs"]["segments"]["test"] + + # Strategy params (shrink for fast validation) + TOPK_L, TOPK_S, DROP_L, DROP_S = 20, 20, 10, 10 + if DEBUG_FAST: + TOPK_L = TOPK_S = 5 + DROP_L = DROP_S = 1 + + port_analysis_config = { + "executor": { + "class": "ShortableExecutor", + "module_path": "qlib.backtest.shortable_backtest", + "kwargs": { + "time_per_step": "day", + "generate_portfolio_metrics": True, + }, + }, + "strategy": { + "class": "LongShortTopKStrategy", + "module_path": "qlib.contrib.strategy.signal_strategy", + "kwargs": { + "signal": (model, dataset), + "topk_long": TOPK_L, + "topk_short": TOPK_S, + "n_drop_long": DROP_L, + "n_drop_short": DROP_S, + "hold_thresh": 3, + "only_tradable": True, + "forbid_all_trade_at_limit": False, + }, + }, + "backtest": { + "start_time": test_start, + "end_time": test_end, + "account": 100000000, + "benchmark": BENCHMARK_AUTO, + "exchange_kwargs": { + "exchange": { + "class": "ShortableExchange", + "module_path": "qlib.backtest.shortable_exchange", + }, + "freq": "day", + # Crypto has no daily price limit; set to 0.0 to avoid false limit locks + "limit_threshold": 0.0, + "deal_price": "close", + "open_cost": 0.0002, + "close_cost": 0.0005, + "min_cost": 0, + }, + }, + } + + # Preview prepared data + example_df = dataset.prepare("train") + print(example_df.head()) + + # Start experiment + with R.start(experiment_name="workflow_longshort_crypto"): + R.log_params(**flatten_dict({"model": model_config, "dataset": dataset_config})) + model.fit(dataset) + R.save_objects(**{"params.pkl": model}) + + # Prediction + recorder = R.get_recorder() + sr = SignalRecord(model, dataset, recorder) + sr.generate() + + # Signal Analysis + sar = SigAnaRecord(recorder) + sar.generate() + + # Backtest with long-short strategy (Crypto metrics) + par = PortAnaRecord(recorder, port_analysis_config, "day") + par.generate() diff --git a/qlib/.DS_Store b/qlib/.DS_Store new file mode 100644 index 0000000000..897f54a77b Binary files /dev/null and b/qlib/.DS_Store differ diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index 9daba91153..a70ac669a7 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -107,6 +107,24 @@ def get_exchange( ) return exchange else: + # If user passes an exchange config, inject missing basic kwargs such as freq/start/end. + if isinstance(exchange, dict): + ex_cfg = copy.deepcopy(exchange) + ex_kwargs = ex_cfg.setdefault("kwargs", {}) + ex_kwargs.setdefault("freq", freq) + ex_kwargs.setdefault("start_time", start_time) + ex_kwargs.setdefault("end_time", end_time) + ex_kwargs.setdefault("codes", codes) + if deal_price is not None: + ex_kwargs.setdefault("deal_price", deal_price) + if subscribe_fields: + ex_kwargs.setdefault("subscribe_fields", subscribe_fields) + if limit_threshold is not None: + ex_kwargs.setdefault("limit_threshold", limit_threshold) + ex_kwargs.setdefault("open_cost", open_cost) + ex_kwargs.setdefault("close_cost", close_cost) + ex_kwargs.setdefault("min_cost", min_cost) + exchange = ex_cfg return init_instance_by_config(exchange, accept_types=Exchange) @@ -199,12 +217,29 @@ def get_strategy_executor( ) exchange_kwargs = copy.copy(exchange_kwargs) + # derive freq from executor config if not explicitly provided + if "freq" not in exchange_kwargs: + try: + if isinstance(executor, dict): + tps = executor.get("kwargs", {}).get("time_per_step") + if isinstance(tps, str) and tps: + exchange_kwargs["freq"] = tps + except Exception: + pass + if "start_time" not in exchange_kwargs: exchange_kwargs["start_time"] = start_time if "end_time" not in exchange_kwargs: exchange_kwargs["end_time"] = end_time trade_exchange = get_exchange(**exchange_kwargs) + # align account/report frequency with exchange frequency to avoid inconsistent benchmark frequency + try: + trade_account.reset(freq=trade_exchange.freq, benchmark_config=trade_account.benchmark_config) + except Exception: + # best effort; keep original when unexpected + pass + common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange) trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy) trade_strategy.reset_common_infra(common_infra) diff --git a/qlib/backtest/borrow_fee_model.py b/qlib/backtest/borrow_fee_model.py new file mode 100644 index 0000000000..6ece4c3897 --- /dev/null +++ b/qlib/backtest/borrow_fee_model.py @@ -0,0 +1,358 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Borrow fee models for short selling in Qlib backtests.""" + +# pylint: disable=R1716,R0913,W0613,W0201,W0718 + +from abc import ABC, abstractmethod +from typing import Dict, Optional +import pandas as pd + + +class BaseBorrowFeeModel(ABC): + """ + Abstract base class for modeling borrowing fees in short selling. + """ + + @abstractmethod + def get_borrow_rate(self, stock_id: str, date: pd.Timestamp) -> float: + """ + Get the borrowing rate for a specific stock on a specific date. + + Parameters + ---------- + stock_id : str + The stock identifier + date : pd.Timestamp + The date for which to get the rate + + Returns + ------- + float + Annual borrowing rate as decimal (e.g., 0.03 for 3%) + """ + raise NotImplementedError + + @abstractmethod + def calculate_daily_cost(self, positions: Dict, date: pd.Timestamp) -> float: + """ + Calculate total daily borrowing cost for all short positions. + + Parameters + ---------- + positions : Dict + Dictionary of positions with amounts and prices + date : pd.Timestamp + The date for calculation + + Returns + ------- + float + Total daily borrowing cost + """ + raise NotImplementedError + + +class FixedRateBorrowFeeModel(BaseBorrowFeeModel): + """ + Simple borrowing fee model with fixed rates. + """ + + def __init__( + self, + default_rate: float = 0.03, + stock_rates: Optional[Dict[str, float]] = None, + hard_to_borrow_rate: float = 0.10, + days_per_year: int = 365, + ): + """ + Initialize fixed rate borrow fee model. + + Parameters + ---------- + default_rate : float + Default annual borrowing rate for most stocks (default 3%) + stock_rates : Dict[str, float], optional + Specific rates for individual stocks + hard_to_borrow_rate : float + Rate for hard-to-borrow stocks (default 10%) + """ + self.default_rate = default_rate + self.stock_rates = stock_rates or {} + self.hard_to_borrow_rate = hard_to_borrow_rate + # Configurable: set days-per-year by region (252 for stocks, 365 for crypto) + self.daily_divisor = int(days_per_year) if days_per_year and days_per_year > 0 else 365 + + def set_days_per_year(self, n: int) -> None: + """Set days-per-year divisor used to convert annual rate to daily.""" + try: # pylint: disable=W0718 # robustness preferred; benign conversion + n = int(n) + if n > 0: + self.daily_divisor = n + except Exception: # pylint: disable=W0718 + pass + + def get_borrow_rate(self, stock_id: str, date: pd.Timestamp) -> float: + """Get annual borrowing rate for a stock.""" + if stock_id in self.stock_rates: + return self.stock_rates[stock_id] + return self.default_rate + + def calculate_daily_cost(self, positions: Dict, date: pd.Timestamp) -> float: + """Calculate total daily borrowing cost.""" + total_cost = 0.0 + + for stock_id, position_info in positions.items(): + # Fix #4: strictly filter non-stock keys + if not self._is_valid_stock_id(stock_id): + continue + + if isinstance(position_info, dict): + amount = position_info.get("amount", 0) + price = position_info.get("price", 0) + + if (amount < 0) and (price > 0): # charge only valid short positions + annual_rate = self.get_borrow_rate(stock_id, date) + daily_rate = annual_rate / self.daily_divisor + short_value = abs(amount * price) + total_cost += short_value * daily_rate + + return total_cost + + def _is_valid_stock_id(self, stock_id: str) -> bool: + """Check whether it's a valid stock identifier.""" + # Filter out known non-stock keys + non_stock_keys = {"cash", "cash_delay", "now_account_value", "borrow_cost_accumulated", "short_proceeds"} + if stock_id in non_stock_keys: + return False + + # Additional check: valid stock ids typically have a certain format/length + if (not isinstance(stock_id, str)) or (len(stock_id) < 4): + return False + + return True + + +class DynamicBorrowFeeModel(BaseBorrowFeeModel): + """ + Dynamic borrowing fee model based on market conditions and availability. + """ + + def __init__( + self, + rate_data: Optional[pd.DataFrame] = None, + default_rate: float = 0.03, + volatility_adjustment: bool = True, + liquidity_adjustment: bool = True, + days_per_year: int = 365, + ): + """ + Initialize dynamic borrow fee model. + + Parameters + ---------- + rate_data : pd.DataFrame, optional + Historical borrowing rate data with MultiIndex (date, stock_id) + default_rate : float + Default rate when no data available + volatility_adjustment : bool + Adjust rates based on stock volatility + liquidity_adjustment : bool + Adjust rates based on stock liquidity + """ + self.rate_data = rate_data + self.default_rate = default_rate + self.volatility_adjustment = volatility_adjustment + self.liquidity_adjustment = liquidity_adjustment + # Configurable: set days-per-year by region (252 for stocks, 365 for crypto) + self.daily_divisor = int(days_per_year) if days_per_year and days_per_year > 0 else 365 + + def set_days_per_year(self, n: int) -> None: + """Set days-per-year divisor used to convert annual rate to daily.""" + try: # pylint: disable=W0718 + n = int(n) + if n > 0: + self.daily_divisor = n + except Exception: # pylint: disable=W0718 + pass + + # Cache for calculated rates + self._rate_cache = {} + + def get_borrow_rate(self, stock_id: str, date: pd.Timestamp) -> float: + """ + Get borrowing rate with dynamic adjustments. + """ + cache_key = (stock_id, date) + if cache_key in self._rate_cache: + return self._rate_cache[cache_key] + + base_rate = self._get_base_rate(stock_id, date) + + # Apply adjustments + if self.volatility_adjustment: + base_rate *= self._get_volatility_multiplier(stock_id, date) + + if self.liquidity_adjustment: + base_rate *= self._get_liquidity_multiplier(stock_id, date) + + # Cap the rate at reasonable levels + final_rate = min(base_rate, 0.50) # Cap at 50% annual + self._rate_cache[cache_key] = final_rate + + return final_rate + + def _get_base_rate(self, stock_id: str, date: pd.Timestamp) -> float: + """Get base borrowing rate from data if available, otherwise default.""" + if self.rate_data is not None: + try: + return self.rate_data.loc[(date, stock_id), "borrow_rate"] + except (KeyError, IndexError): + pass + return self.default_rate + + def _get_volatility_multiplier(self, stock_id: str, date: pd.Timestamp) -> float: + """Return volatility multiplier (placeholder=1.0).""" + return 1.0 + + def _get_liquidity_multiplier(self, stock_id: str, date: pd.Timestamp) -> float: + """Return liquidity multiplier (placeholder=1.0).""" + return 1.0 + + def calculate_daily_cost(self, positions: Dict, date: pd.Timestamp) -> float: + """Calculate total daily borrowing cost with dynamic rates.""" + total_cost = 0.0 + + for stock_id, position_info in positions.items(): + # Fix #4: use unified stock id validation + if not self._is_valid_stock_id(stock_id): + continue + + if isinstance(position_info, dict): + amount = position_info.get("amount", 0) + price = position_info.get("price", 0) + + if (amount < 0) and (price > 0): # Short position + annual_rate = self.get_borrow_rate(stock_id, date) + daily_rate = annual_rate / self.daily_divisor + short_value = abs(amount * price) + total_cost += short_value * daily_rate + + return total_cost + + def _is_valid_stock_id(self, stock_id: str) -> bool: + """Check whether it's a valid stock identifier.""" + # Filter out known non-stock keys + non_stock_keys = {"cash", "cash_delay", "now_account_value", "borrow_cost_accumulated", "short_proceeds"} + if stock_id in non_stock_keys: + return False + + # Additional check: valid stock ids typically have a certain format/length + if (not isinstance(stock_id, str)) or (len(stock_id) < 4): + return False + + return True + + +class TieredBorrowFeeModel(BaseBorrowFeeModel): + """ + Tiered borrowing fee model based on position size and stock category. + """ + + def __init__( + self, + easy_to_borrow: set = None, + hard_to_borrow: set = None, + size_tiers: Optional[Dict[float, float]] = None, + days_per_year: int = 365, + ): + """ + Initialize tiered borrow fee model. + + Parameters + ---------- + easy_to_borrow : set + Set of stock IDs that are easy to borrow + hard_to_borrow : set + Set of stock IDs that are hard to borrow + size_tiers : Dict[float, float] + Position size tiers and corresponding rate adjustments + E.g., {100000: 1.0, 1000000: 1.2, 10000000: 1.5} + """ + self.easy_to_borrow = easy_to_borrow or set() + self.hard_to_borrow = hard_to_borrow or set() + + # Default tier structure + self.size_tiers = size_tiers or { + 100000: 1.0, # <$100k: base rate + 1000000: 1.2, # $100k-$1M: 1.2x rate + 10000000: 1.5, # $1M-$10M: 1.5x rate + float("inf"): 2.0, # >$10M: 2x rate + } + + # Base rates by category + self.easy_rate = 0.01 # 1% for easy-to-borrow + self.normal_rate = 0.03 # 3% for normal + self.hard_rate = 0.10 # 10% for hard-to-borrow + + # Configurable: set days-per-year by region (252 for stocks, 365 for crypto) + self.daily_divisor = int(days_per_year) if days_per_year and days_per_year > 0 else 365 + + def set_days_per_year(self, n: int) -> None: + """Set days-per-year divisor used to convert annual rate to daily.""" + try: + n = int(n) + if n > 0: + self.daily_divisor = n + except Exception: + pass + + def get_borrow_rate(self, stock_id: str, date: pd.Timestamp) -> float: + """Get base borrowing rate by stock category.""" + if stock_id in self.easy_to_borrow: + return self.easy_rate + if stock_id in self.hard_to_borrow: + return self.hard_rate + return self.normal_rate + + def _get_size_multiplier(self, position_value: float) -> float: + """Get rate multiplier based on position size.""" + for threshold, multiplier in sorted(self.size_tiers.items()): + if position_value <= threshold: + return multiplier + return 2.0 # Default max multiplier + + def calculate_daily_cost(self, positions: Dict, date: pd.Timestamp) -> float: + """Calculate daily cost with tiered rates.""" + total_cost = 0.0 + + for stock_id, position_info in positions.items(): + # Fix #4: use unified stock id validation + if not self._is_valid_stock_id(stock_id): + continue + + if isinstance(position_info, dict): + amount = position_info.get("amount", 0) + price = position_info.get("price", 0) + + if (amount < 0) and (price > 0): # Short position + annual_rate = self.get_borrow_rate(stock_id, date) + daily_rate = annual_rate / self.daily_divisor + short_value = abs(amount * price) + total_cost += short_value * daily_rate + + return total_cost + + def _is_valid_stock_id(self, stock_id: str) -> bool: + """Check whether it's a valid stock identifier.""" + # Filter out known non-stock keys + non_stock_keys = {"cash", "cash_delay", "now_account_value", "borrow_cost_accumulated", "short_proceeds"} + if stock_id in non_stock_keys: + return False + + # Additional check: valid stock ids typically have a certain format/length + if (not isinstance(stock_id, str)) or (len(stock_id) < 4): + return False + + return True diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 1ab0d07a75..dec1f3cd75 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -175,7 +175,11 @@ def __init__( #  get volume limit from kwargs self.buy_vol_limit, self.sell_vol_limit, vol_lt_fields = self._get_vol_limit(volume_threshold) - necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"} + # base fields always needed + necessary_fields = {self.buy_price, self.sell_price, "$close", "$factor", "$volume"} + # only require $change when using float threshold + if self.limit_type == self.LT_FLT: + necessary_fields.add("$change") if self.limit_type == self.LT_TP_EXP: assert isinstance(limit_threshold, tuple) for exp in limit_threshold: @@ -202,14 +206,24 @@ def get_quote_from_qlib(self) -> None: # get stock data from qlib if len(self.codes) == 0: self.codes = D.instruments() - self.quote_df = D.features( - self.codes, - self.all_fields, - self.start_time, - self.end_time, - freq=self.freq, - disk_cache=True, - ) + try: + self.quote_df = D.features( + self.codes, + self.all_fields, + self.start_time, + self.end_time, + freq=self.freq, + disk_cache=True, + ) + except (ValueError, KeyError): + # fallback to available higher/equal frequency (e.g., 60min) when requested freq (e.g., day) is unavailable + from ..utils.resam import get_higher_eq_freq_feature # pylint: disable=C0415 + + _df, _freq = get_higher_eq_freq_feature( + self.codes, self.all_fields, self.start_time, self.end_time, freq=self.freq, disk_cache=1 + ) + self.quote_df = _df + self.freq = _freq self.quote_df.columns = self.all_fields # check buy_price data and sell_price data diff --git a/qlib/backtest/shortable_backtest.py b/qlib/backtest/shortable_backtest.py new file mode 100644 index 0000000000..9895cc312e --- /dev/null +++ b/qlib/backtest/shortable_backtest.py @@ -0,0 +1,682 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Integration module for short-selling support in Qlib backtest. +This module provides the main executor and strategy components. + +Pylint notes: +- C0301 (line-too-long): Disabled at module level due to verbose logging and URLs. +- W0718 (broad-exception-caught): Used intentionally around optional hooks; safe and logged. +- W0212 (protected-access): Access needed for adapting Qlib internals; guarded carefully. +- W0201 (attribute-defined-outside-init): Account/position aliases injected post reset; intentional. +- R0902/R0913/R0914/R0903: Complexity from executor/strategy wiring; contained locally. +- W0237: Signature differs intentionally to match Qlib hooks; behavior preserved. +""" + +# pylint: disable=C0301,W0718,W0212,W0201,R0902,R0913,R0914,W0237,R0903 + +from __future__ import annotations + +import math +from typing import Dict, List, Optional +import pandas as pd +import numpy as np +from qlib.backtest.executor import SimulatorExecutor +from qlib.backtest.utils import CommonInfrastructure +from qlib.backtest.decision import Order, OrderDir, TradeDecisionWO +from qlib.backtest.account import Account +from qlib.backtest.position import Position +from qlib.utils import init_instance_by_config + +from .shortable_exchange import ShortableExchange +from .shortable_position import ShortablePosition +from .borrow_fee_model import FixedRateBorrowFeeModel, BaseBorrowFeeModel + + +class ShortableAccount(Account): + """ + Account that supports short selling by handling cases where + stocks don't exist in current position. + """ + + def _update_state_from_order(self, order, trade_val, cost, trade_price): + """ + Override to handle short selling cases where stock may not exist in position. + """ + # CRITICAL: Validate price + if trade_price is None or not np.isfinite(trade_price) or trade_price <= 0: + return + + if self.is_port_metr_enabled(): + self.accum_info.add_turnover(abs(trade_val)) # Use absolute value for turnover + self.accum_info.add_cost(cost) + + trade_amount = trade_val / trade_price + + if order.direction == OrderDir.SELL: + # For short selling, stock may not exist in position + try: + p0 = self.current_position.get_stock_price(order.stock_id) + profit = (trade_val - p0 * trade_amount) if (p0 is not None and np.isfinite(p0) and p0 > 0) else 0.0 + except (KeyError, AttributeError): + profit = 0.0 + + if self.is_port_metr_enabled(): + self.accum_info.add_return_value(profit) # note here do not consider cost + + elif order.direction == OrderDir.BUY: + try: + p0 = self.current_position.get_stock_price(order.stock_id) + profit = (p0 * trade_amount - trade_val) if (p0 is not None and np.isfinite(p0) and p0 > 0) else 0.0 + except (KeyError, AttributeError): + profit = 0.0 + + if self.is_port_metr_enabled(): + self.accum_info.add_return_value(profit) # note here do not consider cost + + def get_portfolio_metrics(self): + """Extend parent metrics with long/short-specific fields while keeping return shape unchanged.""" + try: + df, meta = super().get_portfolio_metrics() + except Exception: + pm = super().get_portfolio_metrics() + if isinstance(pm, tuple) and len(pm) == 2: + df, meta = pm + else: + df, meta = None, pm if isinstance(pm, dict) else {} + + try: + pos = self.current_position + if isinstance(pos, ShortablePosition): + extra = { + "leverage": pos.get_leverage(), + "net_exposure": pos.get_net_exposure(), + "total_borrow_cost": pos.borrow_cost_accumulated, + } + meta = {**(meta or {}), **extra} + except Exception: + pass + + return df, meta + + +class ShortableExecutor(SimulatorExecutor): + """ + Executor that supports short selling with proper position and fee management. + """ + + # pylint: disable=W0613 # some optional parameters are kept for API compatibility + def __init__( + self, + time_per_step: str = "day", + generate_portfolio_metrics: bool = False, + verbose: bool = False, + track_data: bool = False, + trade_exchange: Optional[ShortableExchange] = None, + borrow_fee_model: Optional[BaseBorrowFeeModel] = None, + settle_type: str = Position.ST_NO, + region: str = "cn", # Tweak #3: parameterize region to follow Qlib standard + account: Optional[ShortableAccount] = None, + common_infra: Optional[CommonInfrastructure] = None, + **kwargs, + ): + """ + Initialize ShortableExecutor. + + Parameters + ---------- + time_per_step : str + Trading frequency + generate_portfolio_metrics : bool + Whether to generate portfolio metrics + verbose : bool + Print detailed execution info + track_data : bool + Track detailed trading data + trade_exchange : ShortableExchange + Exchange instance supporting short selling + borrow_fee_model : BaseBorrowFeeModel + Model for calculating borrowing fees + settle_type : str + Settlement type for positions + region : str + Region for trading calendar ('cn', 'us', etc.) - follows qlib.init() default + """ + # Set attributes before calling parent __init__ because parent will invoke reset() + self.settle_type = settle_type + self.borrow_fee_model = borrow_fee_model or FixedRateBorrowFeeModel() + self.region = region + + # Initialize trade_exchange if it's a config dict + if isinstance(trade_exchange, dict): + trade_exchange = init_instance_by_config(trade_exchange) + + super().__init__( + time_per_step=time_per_step, + generate_portfolio_metrics=generate_portfolio_metrics, + verbose=verbose, + track_data=track_data, + trade_exchange=trade_exchange, + settle_type=settle_type, + common_infra=common_infra, + **kwargs, + ) + + # Configure days-per-year for borrow fee (252 for stocks, 365 for crypto) + try: + if hasattr(self.borrow_fee_model, "set_days_per_year"): + self.borrow_fee_model.set_days_per_year(365 if self.region == "crypto" else 252) + except Exception: + pass + + def reset(self, start_time=None, end_time=None, init_cash=1e6, **kwargs): + """ + Reset executor time window. Position adaptation is handled in reset_common_infra when account is ready. + """ + super().reset(start_time=start_time, end_time=end_time, init_cash=init_cash, **kwargs) + + def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None: + """Ensure account exists first, then adapt position to ShortablePosition and monkey-patch account hooks.""" + super().reset_common_infra(common_infra, copy_trade_account=copy_trade_account) + if not hasattr(self, "trade_account") or self.trade_account is None: + return + # Replace current position with ShortablePosition (preserve holdings and cash) + old_pos = self.trade_account.current_position + position_dict = {} + try: + if hasattr(old_pos, "get_stock_list"): + for sid in old_pos.get_stock_list(): + position_dict[sid] = { + "amount": old_pos.get_stock_amount(sid), + "price": old_pos.get_stock_price(sid), + } + except Exception: + position_dict = {} + + # Determine a safe initial cash if old_pos has no get_cash + try: + fallback_cash = old_pos.get_cash(include_settle=True) if hasattr(old_pos, "get_cash") else None + except Exception: + fallback_cash = None + if fallback_cash is None: + try: + fallback_cash = ( + self.trade_account.current_position.get_cash() # type: ignore[attr-defined] + if hasattr(self.trade_account.current_position, "get_cash") + else 1e6 + ) + except Exception: + fallback_cash = 1e6 + + pos = ShortablePosition(cash=fallback_cash, position_dict=position_dict) + pos._settle_type = getattr(self, "settle_type", Position.ST_NO) + self.trade_account.current_position = pos + + # Monkey-patch: use our fixed _update_state_from_order on existing account + import types # pylint: disable=C0415 + + self.trade_account._update_state_from_order = types.MethodType( + ShortableAccount._update_state_from_order, self.trade_account + ) + # NOTE: Do not monkey-patch get_portfolio_metrics to avoid super() binding issues. + + # Sync aliases + self.account = self.trade_account + self.position = self.trade_account.current_position + + def _execute_orders(self, trade_decision: TradeDecisionWO, date: pd.Timestamp): + """ + Execute orders with short-selling support and fee settlement. + """ + # CRITICAL FIX: Mark-to-market all positions before trading + # This ensures PnL is recognized daily, not just on trade days + self._mark_to_market(date) + + # Execute orders normally + trade_info = super()._execute_orders(trade_decision, date) # pylint: disable=E1101 + + # Post-check: ensure cash is non-negative + if hasattr(self.account.current_position, "get_cash"): + if self.account.current_position.get_cash() < -1e-6: + if self.verbose: + print(f"[{date}] Warning: negative cash; check margin logic or scale weights") + + # Charge borrow fee once per trading day + if self._is_trading_day(date) and isinstance(self.account.current_position, ShortablePosition): + # CRITICAL FIX: use current market value instead of entry price for borrow fee + position = self.account.current_position + stock_positions = {} + + for stock_id in position.get_stock_list(): + info = position.position.get(stock_id, {}) + amt = info.get("amount", 0.0) + + # Skip non-short and zero positions + if amt >= 0: + continue + + # Use current price (aligned with matching) instead of entry + # For borrow fee, direction is not important; use BUY as a placeholder + px = self.trade_exchange.get_deal_price( + stock_id=stock_id, + start_time=date, + end_time=date, + direction=OrderDir.BUY, # Use OrderDir for consistency + ) + + # Robust fallback for borrow fee price + if px is None or not np.isfinite(px) or px <= 0: + # Try position's last MTM price + px = position.get_stock_price(stock_id) + + if px is None or not np.isfinite(px) or px <= 0: + # Still no valid price; skip this stock + if self.verbose: + print(f"[{date}] Warning: Cannot get price for {stock_id}, skipping borrow fee") + continue + + # Use current market price or fallback + stock_positions[stock_id] = { + "amount": amt, + "price": float(px), # CRITICAL: Use daily market price or fallback + } + + borrow_cost = self.borrow_fee_model.calculate_daily_cost( + stock_positions, date # Now with current daily prices + ) + # Scale by step length (minute freq uses minutes-per-day proportion) + try: + borrow_cost *= self._borrow_fee_step_multiplier() + except Exception: + pass + + if borrow_cost > 0: + self.account.current_position.add_borrow_cost(borrow_cost) + if self.verbose: + print(f"[{date}] Daily borrowing cost: ${borrow_cost:.2f}") + + return trade_info + + def _mark_to_market(self, date: pd.Timestamp): + """ + Mark all positions to market using current prices. + This ensures daily PnL recognition. + + CRITICAL: Use same price calibration as trading (close or open) + """ + if not isinstance(self.account.current_position, ShortablePosition): + return + + position = self.account.current_position + + # Update price for all positions + for stock_id in position.get_stock_list(): + if stock_id in position.position and isinstance(position.position[stock_id], dict): + # Get current market price (use same calibration as trading) + # For consistency, use close price if that's what we're trading at + px = self.trade_exchange.get_deal_price( + stock_id=stock_id, + start_time=date, + end_time=date, + direction=OrderDir.BUY, # Use OrderDir for consistency + ) + + if px is None or not np.isfinite(px) or px <= 0: + # Fallback to last valid price + px = position.get_stock_price(stock_id) + + if px is not None and np.isfinite(px) and px > 0: + # Update the position price to current market price + position.position[stock_id]["price"] = float(px) + + # This ensures PnL is calculated with current prices + if self.verbose: + equity = position.calculate_value() + leverage = position.get_leverage() + net_exp = position.get_net_exposure() + print(f"[{date}] Mark-to-market: Equity=${equity:,.0f}, Leverage={leverage:.2f}, NetExp={net_exp:.2%}") + + def _is_trading_day(self, date): + """Check whether it is a trading day. + + CRITICAL FIX: Only crypto markets trade 24/7, not US markets! + """ + if self.region == "crypto": + return True # Crypto trades every day + + # For all other markets (including US), use trading calendar + try: + from qlib.data import D # pylint: disable=C0415 + + cal = D.calendar(freq=self.time_per_step, future=False) + return date in cal + except Exception: + # Fallback: weekdays only for traditional markets + return date.weekday() < 5 + + def _borrow_fee_step_multiplier(self) -> float: + """Convert per-day borrow fee to current step multiplier.""" + t = (self.time_per_step or "").lower() + if t in ("day", "1d"): + return 1.0 + try: + import re # pylint: disable=C0415 + + m = re.match(r"(\d+)\s*min", t) + if not m: + return 1.0 + step_min = int(m.group(1)) + minutes_per_day = 1440 if self.region == "crypto" else 390 + if step_min <= 0: + return 1.0 + return float(step_min) / float(minutes_per_day) + except Exception: + return 1.0 + + def get_portfolio_metrics(self) -> Dict: + """ + Get enhanced portfolio metrics including short-specific metrics. + """ + metrics = super().get_portfolio_metrics() # pylint: disable=E1101 + + if isinstance(self.account.current_position, ShortablePosition): + position = self.account.current_position + + # Add short-specific metrics + metrics.update( + { + "leverage": position.get_leverage(), + "net_exposure": position.get_net_exposure(), + "total_borrow_cost": position.borrow_cost_accumulated, # read from attribute, not dict + } + ) + + # Calculate long/short breakdown + position_info = position.get_position_info() + if not position_info.empty: + long_positions = position_info[position_info["position_type"] == "long"] + short_positions = position_info[position_info["position_type"] == "short"] + + metrics.update( + { + "long_value": long_positions["value"].sum() if not long_positions.empty else 0, + "short_value": short_positions["value"].abs().sum() if not short_positions.empty else 0, + "num_long_positions": len(long_positions), + "num_short_positions": len(short_positions), + } + ) + + return metrics + + +def round_to_lot(shares, lot=100): + """Round towards zero by lot size to avoid exceeding limits.""" + if lot <= 1: + return int(shares) # toward zero + lots = int(abs(shares) // lot) # toward zero in lot units + return int(math.copysign(lots * lot, shares)) + + +class LongShortStrategy: + """ + Long-short strategy that generates balanced long and short positions. + """ + + def __init__( + self, + gross_leverage: float = 1.6, + net_exposure: float = 0.0, + top_k: int = 30, + exchange: Optional = None, + risk_limit: Optional[Dict] = None, + lot_size: Optional[int] = 100, + min_trade_threshold: Optional[int] = 100, + ): + """ + Initialize long-short strategy. + + Parameters + ---------- + gross_leverage : float + Total leverage (long + short), e.g., 1.6 means 160% gross exposure + net_exposure : float + Net market exposure (long - short), e.g., 0.0 for market neutral + top_k : int + Number of stocks in each leg (long and short) + exchange : Exchange + Exchange instance for price queries + risk_limit : Dict + Risk limits (max_leverage, max_position_size, etc.) + lot_size : int + Trading lot size (default 100 for A-shares) + min_trade_threshold : int + Minimum trade threshold in shares (default 100) + """ + self.gross_leverage = gross_leverage + self.net_exposure = net_exposure + self.top_k = top_k + self.exchange = exchange + # Allow None and treat intuitively: None -> no lot limit / no min threshold + self.lot_size = 1 if lot_size is None else lot_size + self.min_trade_threshold = 0 if min_trade_threshold is None else min_trade_threshold + self.risk_limit = risk_limit or { + "max_leverage": 2.0, + "max_position_size": 0.1, + "max_net_exposure": 0.3, + } + + # Compute long/short ratios: gross = long + short, net = long - short + # So: long = (gross + net) / 2, short = (gross - net) / 2 + self.long_ratio = (gross_leverage + net_exposure) / 2 + self.short_ratio = (gross_leverage - net_exposure) / 2 + + def generate_trade_decision( + self, signal: pd.Series, current_position: ShortablePosition, date: pd.Timestamp + ) -> TradeDecisionWO: + """ + Generate trade decisions based on signal using correct weight-to-shares conversion. + """ + # Get current equity + equity = current_position.calculate_value() + + # Select stocks + signal_sorted = signal.sort_values(ascending=False) + long_stocks = signal_sorted.head(self.top_k).index.tolist() + short_stocks = signal_sorted.tail(self.top_k).index.tolist() + + # Fix #3: get prices by direction (consistent with matching) + long_prices = self._get_current_prices(long_stocks, date, self.exchange, OrderDir.BUY) if long_stocks else {} + short_prices = ( + self._get_current_prices(short_stocks, date, self.exchange, OrderDir.SELL) if short_stocks else {} + ) + prices = {**long_prices, **short_prices} + + # Compute per-stock weights + long_weight_per_stock = self.long_ratio / len(long_stocks) if long_stocks else 0 + short_weight_per_stock = -self.short_ratio / len(short_stocks) if short_stocks else 0 # negative + + # Tweak #2: hard cap per-position weight at equity × cap + max_position_weight = self.risk_limit.get("max_position_size", 0.1) # default 10% + long_weight_per_stock = min(long_weight_per_stock, max_position_weight) + short_weight_per_stock = max(short_weight_per_stock, -max_position_weight) # negative, so use max + + orders = [] + + # Long orders + for stock in long_stocks: + if stock in prices: + target_shares = round_to_lot((long_weight_per_stock * equity) / prices[stock], lot=self.lot_size) + current_shares = current_position.get_stock_amount(stock) + delta = target_shares - current_shares + + if abs(delta) >= self.min_trade_threshold: # respect configured trade threshold + direction = OrderDir.BUY if delta > 0 else OrderDir.SELL + orders.append( + Order( + stock_id=stock, amount=abs(int(delta)), direction=direction, start_time=date, end_time=date + ) + ) + + # Short orders + for stock in short_stocks: + if stock in prices: + target_shares = round_to_lot( + (short_weight_per_stock * equity) / prices[stock], lot=self.lot_size # negative + ) + current_shares = current_position.get_stock_amount(stock) + delta = target_shares - current_shares + + if abs(delta) >= self.min_trade_threshold: + direction = OrderDir.BUY if delta > 0 else OrderDir.SELL + orders.append( + Order( + stock_id=stock, amount=abs(int(delta)), direction=direction, start_time=date, end_time=date + ) + ) + + # Close positions not in target set + current_stocks = set(current_position.get_stock_list()) + target_stocks = set(long_stocks + short_stocks) + + for stock in current_stocks - target_stocks: + amount = current_position.get_stock_amount(stock) + if abs(amount) >= self.min_trade_threshold: # respect configured trade threshold + direction = OrderDir.SELL if amount > 0 else OrderDir.BUY + orders.append( + Order(stock_id=stock, amount=abs(int(amount)), direction=direction, start_time=date, end_time=date) + ) + + # Fix #2: enable risk limit checks + if orders and not self._check_risk_limits(orders, current_position): + # If exceeding risk limits, scale orders + orders = self._scale_orders_for_risk(orders, current_position) + + # Note: The 2nd arg of TradeDecisionWO should be the strategy per Qlib design + return TradeDecisionWO(orders, self) + + def _get_current_prices(self, stock_list, date, exchange=None, direction=None): + """Fetch prices consistent with matching, supporting order direction.""" + prices = {} + + if exchange is not None: + # Use exchange API to ensure consistency with matching + for stock in stock_list: + try: + # Fix #3: use direction-aware price fetching + price = exchange.get_deal_price( + stock_id=stock, + start_time=date, + end_time=date, + direction=direction, # BUY/SELL direction, aligned with execution + ) + if price is not None and not math.isnan(price): + prices[stock] = float(price) + else: + # Skip this stock if price unavailable + continue + except Exception: + # Price fetch failed; skip + continue + else: + # Fallback: use a fixed price (testing only) + for stock in stock_list: + prices[stock] = 100.0 # placeholder + + return prices + + def _check_risk_limits(self, orders: List[Order], position: ShortablePosition) -> bool: + """Check if orders comply with risk limits.""" + # Simulate position after orders + simulated_position = self._simulate_position_change(orders, position) + + leverage = simulated_position.get_leverage() + net_exposure = simulated_position.get_net_exposure() + + return leverage <= self.risk_limit["max_leverage"] and abs(net_exposure) <= self.risk_limit["max_net_exposure"] + + def _simulate_position_change(self, orders: List[Order], position: ShortablePosition) -> ShortablePosition: + """Simulate position after executing orders with improved price sourcing.""" + stock_positions = { + sid: {"amount": position.get_stock_amount(sid), "price": position.get_stock_price(sid)} + for sid in position.get_stock_list() + } + + sim = ShortablePosition(cash=position.get_cash(), position_dict=stock_positions) + + def _valid(p): + return (p is not None) and np.isfinite(p) and (p > 0) + + for od in orders: + cur = sim.get_stock_amount(od.stock_id) + new_amt = cur + od.amount if od.direction == OrderDir.BUY else cur - od.amount + + # Try to get price: position price > exchange price; skip if can't get valid price + price = sim.get_stock_price(od.stock_id) if od.stock_id in sim.position else None + if not _valid(price) and getattr(self, "trade_exchange", None) is not None and hasattr(od, "start_time"): + try: + px = self.trade_exchange.get_deal_price( # pylint: disable=E1101 + od.stock_id, od.start_time, od.end_time or od.start_time, od.direction + ) + if _valid(px): + price = float(px) + except Exception: + pass + + if not _valid(price): + price = None # Don't use placeholder 100, avoid misjudging leverage + + if od.stock_id not in sim.position: + sim._init_stock(od.stock_id, new_amt, price if price is not None else 0.0) + else: + sim.position[od.stock_id]["amount"] = new_amt + if price is not None: + sim.position[od.stock_id]["price"] = price + + # Only adjust cash with valid price (prevent placeholder from polluting risk control) + if price is not None: + if od.direction == OrderDir.BUY: + sim.position["cash"] -= price * od.amount + else: + sim.position["cash"] += price * od.amount + + return sim + + def _scale_orders_for_risk(self, orders: List[Order], position: ShortablePosition) -> List[Order]: + """Adaptive risk scaling - scale precisely by the degree of limit breach.""" + # Fix #2: simulate execution first to get leverage and net_exposure + simulated_position = self._simulate_position_change(orders, position) + leverage = simulated_position.get_leverage() + net_exposure = abs(simulated_position.get_net_exposure()) + + # Compute scale factor based on degree of breach + max_leverage = self.risk_limit.get("max_leverage", 2.0) + max_net_exposure = self.risk_limit.get("max_net_exposure", 0.3) + + scale_leverage = max_leverage / leverage if leverage > max_leverage else 1.0 + scale_net = max_net_exposure / net_exposure if net_exposure > max_net_exposure else 1.0 + + # Take stricter constraint with a small safety margin + scale_factor = min(scale_leverage, scale_net) * 0.98 + scale_factor = min(scale_factor, 1.0) # never amplify, only shrink + + if scale_factor >= 0.99: # scaling nearly unnecessary + return orders + + scaled_orders = [] + for order in orders: + # Round by lot size; keep original time fields + scaled_amount = round_to_lot(order.amount * scale_factor, lot=self.lot_size) + if scaled_amount <= 0: # skip zero-after-rounding + continue + + scaled_order = Order( + stock_id=order.stock_id, + amount=int(scaled_amount), + direction=order.direction, + start_time=order.start_time, + end_time=order.end_time, + ) + scaled_orders.append(scaled_order) + + return scaled_orders diff --git a/qlib/backtest/shortable_exchange.py b/qlib/backtest/shortable_exchange.py new file mode 100644 index 0000000000..d8d4d8b1f5 --- /dev/null +++ b/qlib/backtest/shortable_exchange.py @@ -0,0 +1,738 @@ +"""ShortableExchange: extend Exchange to support short selling with zero-crossing logic. + +Pylint notes: +- C0301 (line-too-long): allow long explanatory comments and formulas. +- R1702/R0912/R0915 (nested blocks/branches/statements): complex matching kept for fidelity. +- R0914/R0913 (many locals/args): accepted due to detailed cost/cash handling. +- R1716 (chained-comparison): allow for concise numerical checks. +- W0237 (arguments-renamed), W0613 (unused-argument): compatibility with base signatures. +""" + +# pylint: disable=C0301,R1702,R0912,R0915,R0914,R0913,R1716,W0237,W0613 + +from typing import Optional, Tuple, cast, TYPE_CHECKING +import numpy as np +import pandas as pd +from qlib.backtest.exchange import Exchange +from qlib.backtest.decision import Order +from qlib.backtest.position import BasePosition + +if TYPE_CHECKING: + from qlib.backtest.account import Account + + +class ShortableExchange(Exchange): + """ + Exchange that supports short selling by removing the constraint + that prevents selling more than current holdings. + + Key modifications: + - Allows selling stocks not in current position (short selling) + - Properly determines open/close costs based on position direction + - Splits orders that cross zero position for accurate cost calculation + - Maintains all other constraints (cash, volume limits, etc.) + """ + + def _calc_trade_info_by_order( + self, + order: Order, + position: Optional[BasePosition], + dealt_order_amount: dict, + ) -> Tuple[float, float, float]: + """ + Calculation of trade info with short selling support. + + **IMPORTANT**: Returns (trade_price, trade_val, trade_cost) to match parent class + + For BUY orders: + - If current position < 0: covering short position -> use close_cost + - If current position >= 0: opening/adding long position -> use open_cost + - If crossing zero: split into cover short (close_cost) + open long (open_cost) + + For SELL orders: + - If current position > 0: closing long position -> use close_cost + - If current position <= 0: opening/adding short position -> use open_cost + - If crossing zero: split into close long (close_cost) + open short (open_cost) + + :param order: Order to be processed + :param position: Current position (Optional) + :param dealt_order_amount: Dict tracking dealt amounts {stock_id: float} + :return: Tuple of (trade_price, trade_val, trade_cost) + """ + + # Get deal price first - with NaN/None guard + trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction) + if trade_price is None or np.isnan(trade_price) or trade_price <= 0: + self.logger.debug(f"Invalid price for {order.stock_id}, skipping order") + order.deal_amount = 0 + return 0.0, 0.0, 0.0 + trade_price = cast(float, trade_price) + + # Calculate total market volume for impact cost - with NaN/None guard + volume = self.get_volume(order.stock_id, order.start_time, order.end_time) + if volume is None or np.isnan(volume): + total_trade_val = 0.0 + else: + total_trade_val = cast(float, volume) * trade_price + + # Set order factor for rounding + order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time) + order.deal_amount = order.amount # Start with full amount + + # Apply volume limits (common for both BUY and SELL) + self._clip_amount_by_volume(order, dealt_order_amount) + + # Get current position amount + current_amount = 0.0 + if position is not None and position.check_stock(order.stock_id): + current_amount = position.get_stock_amount(order.stock_id) + + # Handle BUY orders + if order.direction == Order.BUY: + # Check if we're crossing zero (covering short then opening long) + if current_amount < 0 and order.deal_amount > abs(current_amount): + # Split into two legs: cover short + open long + cover_amount = abs(current_amount) + open_amount = order.deal_amount - cover_amount + + # Apply cash constraints for both legs (before rounding) + if position is not None: + cash = position.get_cash(include_settle=True) if hasattr(position, "get_cash") else 0.0 + + # Calculate costs for both legs (pre-rounding) + cover_val = cover_amount * trade_price + open_val = open_amount * trade_price + + # Initial impact cost calculation + if not total_trade_val or np.isnan(total_trade_val): + cover_impact = self.impact_cost + open_impact = self.impact_cost + else: + cover_impact = self.impact_cost * (cover_val / total_trade_val) ** 2 + open_impact = self.impact_cost * (open_val / total_trade_val) ** 2 + + # Calculate costs WITHOUT min_cost for each leg + cover_cost_no_min = cover_val * (self.close_cost + cover_impact) + open_cost_no_min = open_val * (self.open_cost + open_impact) + + # Apply min_cost ONCE for the total + total_cost = max(cover_cost_no_min + open_cost_no_min, self.min_cost) + total_val = cover_val + open_val + + # Check cash constraints + if cash < total_cost: + # Can't afford even the costs + order.deal_amount = 0 + self.logger.debug(f"Order clipped due to cost higher than cash: {order}") + elif cash < total_val + total_cost: + # Need to reduce the open leg + available_for_open = cash - cover_val - cover_cost_no_min + if available_for_open > 0: + # Calculate max open amount considering the cost + max_open = self._get_buy_amount_by_cash_limit( + trade_price, available_for_open, self.open_cost + open_impact + ) + open_amount = min(max_open, open_amount) + order.deal_amount = cover_amount + open_amount + else: + # Can only cover, not open new + order.deal_amount = cover_amount + + # Round the final amount + order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor) + + # Re-check cash constraints after rounding + final_val = order.deal_amount * trade_price + if order.deal_amount <= abs(current_amount): + # Only covering + if not total_trade_val or np.isnan(total_trade_val): + final_impact = self.impact_cost + else: + final_impact = self.impact_cost * (final_val / total_trade_val) ** 2 + final_cost = max(final_val * (self.close_cost + final_impact), self.min_cost) + else: + # Still crossing zero after rounding + cover_amount = abs(current_amount) + open_amount = order.deal_amount - cover_amount + cover_val = cover_amount * trade_price + open_val = open_amount * trade_price + + if not total_trade_val or np.isnan(total_trade_val): + cover_impact = self.impact_cost + open_impact = self.impact_cost + else: + cover_impact = self.impact_cost * (cover_val / total_trade_val) ** 2 + open_impact = self.impact_cost * (open_val / total_trade_val) ** 2 + + # Calculate costs WITHOUT min_cost, then apply min_cost ONCE + cover_cost_no_min = cover_val * (self.close_cost + cover_impact) + open_cost_no_min = open_val * (self.open_cost + open_impact) + final_cost = max(cover_cost_no_min + open_cost_no_min, self.min_cost) + + # Final cash check after rounding with trade unit protection + if cash < final_val + final_cost: + trade_unit_amount = self.get_amount_of_trade_unit( + order.factor, order.stock_id, order.start_time, order.end_time + ) + if getattr(self, "impact_cost", 0.0) == 0.0: + feasible = self._compute_feasible_buy_amount_cross_zero( + price=trade_price, + cash=cash, + cover_amount=abs(current_amount), + open_cost_ratio=self.open_cost, + close_cost_ratio=self.close_cost, + min_cost=self.min_cost, + trade_unit_amount=trade_unit_amount or 0.0, + ) + order.deal_amount = min(order.deal_amount, feasible) + else: + # Reduce by trade unit until it fits (fallback) + if trade_unit_amount and trade_unit_amount > 0: + steps = 0 + max_steps = 10000 # Prevent infinite loop + while ( + order.deal_amount > 0 + and cash < order.deal_amount * trade_price + final_cost + and steps < max_steps + ): + order.deal_amount -= trade_unit_amount + steps += 1 + final_val = order.deal_amount * trade_price + # Recalculate cost with new amount + if order.deal_amount <= abs(current_amount): + if not total_trade_val or np.isnan(total_trade_val): + final_impact = self.impact_cost + else: + final_impact = self.impact_cost * (final_val / total_trade_val) ** 2 + final_cost = max(final_val * (self.close_cost + final_impact), self.min_cost) + else: + cover_val = abs(current_amount) * trade_price + open_val = (order.deal_amount - abs(current_amount)) * trade_price + if not total_trade_val or np.isnan(total_trade_val): + cover_impact = self.impact_cost + open_impact = self.impact_cost + else: + cover_impact = self.impact_cost * (cover_val / total_trade_val) ** 2 + open_impact = self.impact_cost * (open_val / total_trade_val) ** 2 + cover_cost_no_min = cover_val * (self.close_cost + cover_impact) + open_cost_no_min = open_val * (self.open_cost + open_impact) + final_cost = max(cover_cost_no_min + open_cost_no_min, self.min_cost) + if steps >= max_steps: + self.logger.warning(f"Max iterations reached for order {order}, setting to 0") + order.deal_amount = 0 + else: + order.deal_amount = 0 + else: + # No position info, just round + order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor) + + # Calculate final trade cost based on split legs + trade_val = order.deal_amount * trade_price + if order.deal_amount <= abs(current_amount): + # Only covering short + if not total_trade_val or np.isnan(total_trade_val): + adj_cost_ratio = self.impact_cost + else: + adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2 + trade_cost = ( + max(trade_val * (self.close_cost + adj_cost_ratio), self.min_cost) if trade_val > 1e-5 else 0 + ) + else: + # Crossing zero: cover short + open long + cover_amount = abs(current_amount) + open_amount = order.deal_amount - cover_amount + cover_val = cover_amount * trade_price + open_val = open_amount * trade_price + + if not total_trade_val or np.isnan(total_trade_val): + cover_impact = self.impact_cost + open_impact = self.impact_cost + else: + cover_impact = self.impact_cost * (cover_val / total_trade_val) ** 2 + open_impact = self.impact_cost * (open_val / total_trade_val) ** 2 + + # Calculate costs WITHOUT min_cost, then apply min_cost ONCE + cover_cost_no_min = cover_val * (self.close_cost + cover_impact) + open_cost_no_min = open_val * (self.open_cost + open_impact) + trade_cost = max(cover_cost_no_min + open_cost_no_min, self.min_cost) if trade_val > 1e-5 else 0 + + else: + # Simple case: either pure covering short or pure opening long + if current_amount < 0: + # Covering short position - use close_cost + cost_ratio = self.close_cost + else: + # Opening or adding to long position - use open_cost + cost_ratio = self.open_cost + + # Apply cash constraints + if position is not None: + cash = position.get_cash(include_settle=True) if hasattr(position, "get_cash") else 0.0 + trade_val = order.deal_amount * trade_price + + # Pre-calculate impact cost + if not total_trade_val or np.isnan(total_trade_val): + adj_cost_ratio = self.impact_cost + else: + adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2 + + total_cost_ratio = cost_ratio + adj_cost_ratio + + if cash < max(trade_val * total_cost_ratio, self.min_cost): + # Cash cannot cover cost + order.deal_amount = 0 + self.logger.debug(f"Order clipped due to cost higher than cash: {order}") + elif cash < trade_val + max(trade_val * total_cost_ratio, self.min_cost): + # Money is not enough for full order + max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, total_cost_ratio) + order.deal_amount = min(max_buy_amount, order.deal_amount) + self.logger.debug(f"Order clipped due to cash limitation: {order}") + + # Round the amount + order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor) + + # Re-check cash constraint after rounding + final_val = order.deal_amount * trade_price + if not total_trade_val or np.isnan(total_trade_val): + final_impact = self.impact_cost + else: + final_impact = self.impact_cost * (final_val / total_trade_val) ** 2 + final_cost = max(final_val * (cost_ratio + final_impact), self.min_cost) + + if cash < final_val + final_cost: + trade_unit_amount = self.get_amount_of_trade_unit( + order.factor, order.stock_id, order.start_time, order.end_time + ) + if getattr(self, "impact_cost", 0.0) == 0.0: + feasible = self._compute_feasible_buy_amount( + price=trade_price, + cash=cash, + cost_ratio=cost_ratio, + min_cost=self.min_cost, + trade_unit_amount=trade_unit_amount or 0.0, + ) + order.deal_amount = min(order.deal_amount, feasible) + else: + # Reduce by trade units until it fits + if trade_unit_amount and trade_unit_amount > 0: + steps = 0 + max_steps = 10000 + while ( + order.deal_amount > 0 + and cash < order.deal_amount * trade_price + final_cost + and steps < max_steps + ): + order.deal_amount -= trade_unit_amount + steps += 1 + final_val = order.deal_amount * trade_price + if not total_trade_val or np.isnan(total_trade_val): + final_impact = self.impact_cost + else: + final_impact = self.impact_cost * (final_val / total_trade_val) ** 2 + final_cost = max(final_val * (cost_ratio + final_impact), self.min_cost) + if steps >= max_steps: + self.logger.warning(f"Max iterations reached for order {order}, setting to 0") + order.deal_amount = 0 + else: + order.deal_amount = 0 + else: + # Unknown amount of money - just round the amount + order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor) + + # Calculate final cost with final amount + trade_val = order.deal_amount * trade_price + if not total_trade_val or np.isnan(total_trade_val): + adj_cost_ratio = self.impact_cost + else: + adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2 + trade_cost = max(trade_val * (cost_ratio + adj_cost_ratio), self.min_cost) if trade_val > 1e-5 else 0 + + # Handle SELL orders + elif order.direction == Order.SELL: + # Check if we're crossing zero (closing long then opening short) + if current_amount > 0 and order.deal_amount > current_amount: + # Split into two legs: close long + open short + close_amount = current_amount + open_amount = order.deal_amount - current_amount + + # Apply cash constraint for transaction costs BEFORE rounding + if position is not None: + cash = position.get_cash(include_settle=True) if hasattr(position, "get_cash") else 0.0 + close_val = close_amount * trade_price + open_val = open_amount * trade_price + total_val = close_val + open_val + + # Calculate impact costs for both legs (pre-rounding) + if not total_trade_val or np.isnan(total_trade_val): + close_impact = self.impact_cost + open_impact = self.impact_cost + else: + close_impact = self.impact_cost * (close_val / total_trade_val) ** 2 + open_impact = self.impact_cost * (open_val / total_trade_val) ** 2 + + # Calculate costs WITHOUT min_cost for each leg + close_cost_no_min = close_val * (self.close_cost + close_impact) + open_cost_no_min = open_val * (self.open_cost + open_impact) + + # Apply min_cost ONCE for the total + total_cost = max(close_cost_no_min + open_cost_no_min, self.min_cost) + + # Check if we have enough cash to pay transaction costs + # We receive cash from the sale but still need to pay costs + if cash + total_val < total_cost: + # Try to reduce the short leg + if cash + close_val >= max(close_cost_no_min, self.min_cost): + # Can at least close the long position + order.deal_amount = close_amount + else: + # Can't even close the position + order.deal_amount = 0 + self.logger.debug(f"Order clipped due to insufficient cash for transaction costs: {order}") + else: + # Cash is sufficient, keep full amount + order.deal_amount = close_amount + open_amount + + # Now round both legs + if order.deal_amount > 0: + if order.deal_amount <= close_amount: + # Only closing, round the close amount + order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor) + else: + # Crossing zero, round both legs + close_amount = self.round_amount_by_trade_unit(close_amount, order.factor) + open_amount = self.round_amount_by_trade_unit( + order.deal_amount - current_amount, order.factor + ) + order.deal_amount = close_amount + open_amount + + # Re-check cash constraint after rounding + final_val = order.deal_amount * trade_price + if order.deal_amount <= current_amount: + # Only closing + if not total_trade_val or np.isnan(total_trade_val): + final_impact = self.impact_cost + else: + final_impact = self.impact_cost * (final_val / total_trade_val) ** 2 + final_cost = max(final_val * (self.close_cost + final_impact), self.min_cost) + else: + # Still crossing zero + close_val = current_amount * trade_price + open_val = (order.deal_amount - current_amount) * trade_price + if not total_trade_val or np.isnan(total_trade_val): + close_impact = self.impact_cost + open_impact = self.impact_cost + else: + close_impact = self.impact_cost * (close_val / total_trade_val) ** 2 + open_impact = self.impact_cost * (open_val / total_trade_val) ** 2 + close_cost_no_min = close_val * (self.close_cost + close_impact) + open_cost_no_min = open_val * (self.open_cost + open_impact) + final_cost = max(close_cost_no_min + open_cost_no_min, self.min_cost) + + # Final check and potential reduction + if cash + final_val < final_cost: + trade_unit_amount = self.get_amount_of_trade_unit( + order.factor, order.stock_id, order.start_time, order.end_time + ) + if trade_unit_amount and trade_unit_amount > 0: + steps = 0 + max_steps = 10000 + while ( + order.deal_amount > 0 + and cash + order.deal_amount * trade_price < final_cost + and steps < max_steps + ): + order.deal_amount -= trade_unit_amount + steps += 1 + final_val = order.deal_amount * trade_price + # Recalculate cost + if order.deal_amount <= current_amount: + if not total_trade_val or np.isnan(total_trade_val): + final_impact = self.impact_cost + else: + final_impact = self.impact_cost * (final_val / total_trade_val) ** 2 + final_cost = max(final_val * (self.close_cost + final_impact), self.min_cost) + else: + close_val = current_amount * trade_price + open_val = (order.deal_amount - current_amount) * trade_price + if not total_trade_val or np.isnan(total_trade_val): + close_impact = self.impact_cost + open_impact = self.impact_cost + else: + close_impact = self.impact_cost * (close_val / total_trade_val) ** 2 + open_impact = self.impact_cost * (open_val / total_trade_val) ** 2 + close_cost_no_min = close_val * (self.close_cost + close_impact) + open_cost_no_min = open_val * (self.open_cost + open_impact) + final_cost = max(close_cost_no_min + open_cost_no_min, self.min_cost) + if steps >= max_steps: + self.logger.warning(f"Max iterations reached for order {order}, setting to 0") + order.deal_amount = 0 + else: + order.deal_amount = 0 + else: + # No position info, just round + close_amount = self.round_amount_by_trade_unit(close_amount, order.factor) + open_amount = self.round_amount_by_trade_unit(open_amount, order.factor) + order.deal_amount = close_amount + open_amount + + # Calculate final trade cost based on split legs + trade_val = order.deal_amount * trade_price + if order.deal_amount <= current_amount: + # Only closing long + if not total_trade_val or np.isnan(total_trade_val): + adj_cost_ratio = self.impact_cost + else: + adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2 + trade_cost = ( + max(trade_val * (self.close_cost + adj_cost_ratio), self.min_cost) if trade_val > 1e-5 else 0 + ) + else: + # Crossing zero: close long + open short + close_val = current_amount * trade_price + open_val = (order.deal_amount - current_amount) * trade_price + + if not total_trade_val or np.isnan(total_trade_val): + close_impact = self.impact_cost + open_impact = self.impact_cost + else: + close_impact = self.impact_cost * (close_val / total_trade_val) ** 2 + open_impact = self.impact_cost * (open_val / total_trade_val) ** 2 + + # Calculate costs WITHOUT min_cost, then apply min_cost ONCE + close_cost_no_min = close_val * (self.close_cost + close_impact) + open_cost_no_min = open_val * (self.open_cost + open_impact) + trade_cost = max(close_cost_no_min + open_cost_no_min, self.min_cost) if trade_val > 1e-5 else 0 + + else: + # Simple case: either pure closing long or pure opening short + if current_amount > 0: + # Closing long position - use close_cost + cost_ratio = self.close_cost + # Don't sell more than we have when closing long + order.deal_amount = min(current_amount, order.deal_amount) + else: + # Opening or adding to short position - use open_cost + cost_ratio = self.open_cost + # No constraint on amount for short selling + + # Round the amount + order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor) + + # Apply cash constraint for transaction costs + if position is not None: + cash = position.get_cash(include_settle=True) if hasattr(position, "get_cash") else 0.0 + trade_val = order.deal_amount * trade_price + + # Calculate impact cost with final amount + if not total_trade_val or np.isnan(total_trade_val): + adj_cost_ratio = self.impact_cost + else: + adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2 + + expected_cost = max(trade_val * (cost_ratio + adj_cost_ratio), self.min_cost) + + # Check if we have enough cash to pay transaction costs + # For SELL orders, we receive cash from the sale but still need to pay costs + if cash + trade_val < expected_cost: + # Not enough cash to cover transaction costs even after receiving sale proceeds + order.deal_amount = 0 + self.logger.debug(f"Order clipped due to insufficient cash for transaction costs: {order}") + + # Calculate final cost + trade_val = order.deal_amount * trade_price + if not total_trade_val or np.isnan(total_trade_val): + adj_cost_ratio = self.impact_cost + else: + adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2 + trade_cost = max(trade_val * (cost_ratio + adj_cost_ratio), self.min_cost) if trade_val > 1e-5 else 0 + + else: + raise NotImplementedError(f"Order direction {order.direction} not supported") + + # Final trade value calculation + trade_val = order.deal_amount * trade_price + + # CRITICAL: Return in correct order (trade_price, trade_val, trade_cost) + return trade_price, trade_val, trade_cost + + # ------------------------ + # Helpers to compute feasible amounts without slow loops + # ------------------------ + def _compute_feasible_buy_value_linear_min_cost(self, cash: float, cost_ratio: float, min_cost: float) -> float: + """ + Compute max trade value for BUY given cash, cost ratio and min_cost (impact_cost assumed 0 here). + Returns value in currency (not amount). + """ + if cash <= 0: + return 0.0 + # No transaction cost + if cost_ratio <= 0 and min_cost <= 0: + return cash + # If only min_cost + if cost_ratio <= 0 and min_cost > 0: + return max(0.0, cash - min_cost) + # cost_ratio > 0 + threshold_val = min_cost / cost_ratio if min_cost > 0 else 0.0 + # Candidate in linear region + linear_val = cash / (1.0 + cost_ratio) + # Candidate in min_cost region + min_region_val = max(0.0, min(cash - min_cost, threshold_val)) if min_cost > 0 else 0.0 + # Check linear region validity: must be >= threshold + linear_valid = linear_val >= threshold_val + if linear_valid: + return max(0.0, linear_val) + return max(0.0, min_region_val) + + def _compute_feasible_buy_amount( + self, price: float, cash: float, cost_ratio: float, min_cost: float, trade_unit_amount: float + ) -> float: + """Return feasible BUY amount honoring trade unit and min_cost (impact_cost assumed 0).""" + if price <= 0 or cash <= 0: + return 0.0 + val = self._compute_feasible_buy_value_linear_min_cost(cash, cost_ratio, min_cost) + amount = val / price + if trade_unit_amount and trade_unit_amount > 0: + amount = (amount // trade_unit_amount) * trade_unit_amount + return max(0.0, amount) + + def _compute_feasible_buy_amount_cross_zero( + self, + price: float, + cash: float, + cover_amount: float, + open_cost_ratio: float, + close_cost_ratio: float, + min_cost: float, + trade_unit_amount: float, + ) -> float: + """ + For BUY crossing zero: cover a fixed short (cover_amount) then optionally open long. + Compute the max total amount (cover + open) that fits the cash constraint with min_cost applied once. + Assumes impact_cost == 0 for closed-form computation. + """ + if price <= 0 or cash <= 0: + return 0.0 + cover_val = cover_amount * price + cover_cost_lin = cover_val * close_cost_ratio + # Case when even covering cost cannot be paid -> 0 + if cash <= min(cover_val + min_cost, cover_val + cover_cost_lin): + # If can't afford to cover full, try partial cover constrained by min_cost/linear + # Under min_cost regime, any positive trade needs paying min_cost, which may be impossible; set 0 + # Under linear regime, solve for max cover value only: + if min_cost <= cover_cost_lin and close_cost_ratio > 0: + # linear regime for cover only + max_cover_val = max(0.0, cash - cover_cost_lin) / ( + 1.0 + ) # since inequality cash >= cover_val + cover_cost_lin + max_cover_amount = max_cover_val / price + if trade_unit_amount and trade_unit_amount > 0: + max_cover_amount = (max_cover_amount // trade_unit_amount) * trade_unit_amount + return max(0.0, min(cover_amount, max_cover_amount)) + return 0.0 + + # We can cover; now compute max open value + # Two regimes depending on min_cost vs linear total cost + # Total cost = max(cover_cost_lin + open_val*open_cost_ratio, min_cost) + # Regime boundary at open_val_threshold where cover_cost_lin + open_val*open_cost_ratio == min_cost + if open_cost_ratio <= 0: + # No open cost; cost is either min_cost or cover_cost_lin + if min_cost > cover_cost_lin: + open_val_max = max(0.0, cash - cover_val - min_cost) + else: + open_val_max = max(0.0, cash - cover_val - cover_cost_lin) + else: + threshold_open_val = ( + max(0.0, (min_cost - cover_cost_lin) / open_cost_ratio) if min_cost > cover_cost_lin else 0.0 + ) + # Candidate in min_cost regime + min_region_val = max(0.0, cash - cover_val - min_cost) + # Candidate in linear regime + linear_val = max(0.0, (cash - cover_val - cover_cost_lin) / (1.0 + open_cost_ratio)) + # Choose regime consistently + if min_cost <= cover_cost_lin: + # Always linear + open_val_max = linear_val + elif min_region_val <= threshold_open_val: + open_val_max = min_region_val + else: + open_val_max = linear_val + + # Round by trade unit + open_amount = open_val_max / price + if trade_unit_amount and trade_unit_amount > 0: + open_amount = (open_amount // trade_unit_amount) * trade_unit_amount + total_amount = cover_amount + max(0.0, open_amount) + return max(0.0, total_amount) + + def generate_amount_position_from_weight_position( + self, + weight_position: dict, + cash: float, + start_time: pd.Timestamp, + end_time: pd.Timestamp, + round_amount: bool = True, + verbose: bool = False, + account: "Account" = None, + gross_leverage: float = 1.0, + ) -> dict: + """ + Generate amount position from weight position with support for negative weights (short positions). + + Uses absolute weight normalization to avoid "double spending" cash on long and short positions. + + :param weight_position: Dict of {stock_id: weight}, weights can be negative for short positions + :param cash: Available cash + :param start_time: Start time for the trading period + :param end_time: End time for the trading period + :param round_amount: Whether to round amounts to trading units + :param verbose: Whether to print debug information + :param account: Account object (optional) + :param gross_leverage: Gross leverage factor (default 1.0). + Total position value = cash * gross_leverage + :return: Dict of {stock_id: amount}, negative amounts indicate short positions + """ + + # Calculate total absolute weight for normalization + total_abs_weight = sum(abs(w) for w in weight_position.values()) + + if total_abs_weight == 0: + return {} + + amount_position = {} + + # Process all positions using absolute weight normalization + for stock_id, weight in weight_position.items(): + if self.is_stock_tradable(stock_id, start_time, end_time): + # Determine order direction based on weight sign + if weight > 0: + price = self.get_deal_price(stock_id, start_time, end_time, Order.BUY) + else: + price = self.get_deal_price(stock_id, start_time, end_time, Order.SELL) + + # Price protection: skip if price is invalid + if not price or np.isnan(price) or price <= 0: + self.logger.debug(f"Invalid price for {stock_id}, skipping position generation") + continue + + # Calculate target value using absolute weight normalization + target_value = cash * (abs(weight) / total_abs_weight) * gross_leverage + + # Calculate target amount (positive for long, negative for short) + if weight > 0: + target_amount = target_value / price + else: + target_amount = -target_value / price + + if round_amount: + factor = self.get_factor(stock_id, start_time, end_time) + if target_amount > 0: + target_amount = self.round_amount_by_trade_unit(target_amount, factor) + else: + # Round the absolute value then make it negative again + target_amount = -self.round_amount_by_trade_unit(abs(target_amount), factor) + + amount_position[stock_id] = target_amount + + if verbose: + self.logger.info(f"Generated amount position with gross leverage {gross_leverage}: {amount_position}") + + return amount_position diff --git a/qlib/backtest/shortable_position.py b/qlib/backtest/shortable_position.py new file mode 100644 index 0000000000..ea2e1cbe3b --- /dev/null +++ b/qlib/backtest/shortable_position.py @@ -0,0 +1,521 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Shortable position implementation for Qlib backtests.""" + +from typing import Dict, Union +import numpy as np +import pandas as pd +from qlib.backtest.position import Position + + +class ShortablePosition(Position): + """ + Position that supports negative holdings (short positions). + + Key differences from standard Position: + 1. Allows negative amounts for stocks (short positions) + 2. Properly calculates value for both long and short positions + 3. Tracks borrowing costs and other short-related metrics + 4. Maintains cash settlement consistency with qlib + """ + + # Class constant for position close tolerance + # Use a slightly larger epsilon to suppress floating residuals in full-window runs + POSITION_EPSILON = 1e-06 # Can be adjusted based on trade unit requirements + + def __init__( + self, + cash: float = 0, + position_dict: Dict[str, Union[Dict[str, float], float]] = None, + borrow_rate: float = 0.03, + ): # Annual borrowing rate, default 3% + """ + Initialize ShortablePosition. + + Parameters + ---------- + cash : float + Initial cash + position_dict : dict + Initial positions (can include negative amounts for shorts) + borrow_rate : float + Annual rate for borrowing stocks (as decimal, e.g., 0.03 for 3%) + """ + # Initialize our attributes BEFORE calling super().__init__ + # because super().__init__ will call calculate_value() which needs these + self.borrow_rate = borrow_rate + self._daily_borrow_rate = borrow_rate / 252 # Convert to daily rate + self.borrow_cost_accumulated = 0.0 + self.short_proceeds = {} # Track proceeds from short sales {stock_id: proceeds} + + # Initialize logger if available + try: + from qlib.log import get_module_logger # pylint: disable=C0415 + + self.logger = get_module_logger("ShortablePosition") + except ImportError: + self.logger = None + + # Handle default parameter + if position_dict is None: + position_dict = {} + + # Now call parent init which will use our calculate_value() method + super().__init__(cash=cash, position_dict=position_dict) + + # Ensure cash_delay exists for robustness + self.position.setdefault("cash_delay", 0.0) + + def _sell_stock(self, stock_id: str, trade_val: float, cost: float, trade_price: float) -> None: + """ + Sell stock, allowing short positions. + + This overrides the parent method to allow negative positions. + """ + trade_amount = trade_val / trade_price + + if stock_id not in self.position: + # Opening a new short position + self._init_stock(stock_id=stock_id, amount=-trade_amount, price=trade_price) + # Track short sale proceeds + self.short_proceeds[stock_id] = trade_val + else: + current_amount = self.position[stock_id]["amount"] + new_amount = current_amount - trade_amount + + # Use absolute tolerance for position close check + if abs(new_amount) < self.POSITION_EPSILON: + # Position closed + self._del_stock(stock_id) + if stock_id in self.short_proceeds: + del self.short_proceeds[stock_id] + else: + # Update position (can go negative) + self.position[stock_id]["amount"] = new_amount + self.position[stock_id]["price"] = trade_price # Update price on trade + + # Track short proceeds for new or increased short positions + if new_amount < 0: + if current_amount >= 0: + # Going from long to short: record short portion proceeds + short_amount = abs(new_amount) + self.short_proceeds[stock_id] = short_amount * trade_price + else: + # Increasing short position: accumulate new short proceeds + if stock_id not in self.short_proceeds: + self.short_proceeds[stock_id] = 0 + # Only accumulate the additional short portion + # More explicit calculation for robustness + additional_short_amount = max(0.0, -(new_amount - current_amount)) + self.short_proceeds[stock_id] += additional_short_amount * trade_price + + # Update cash + new_cash = trade_val - cost + if self._settle_type == self.ST_CASH: + self.position["cash_delay"] += new_cash + elif self._settle_type == self.ST_NO: + self.position["cash"] += new_cash + else: + raise NotImplementedError("This type of input is not supported") + + def _buy_stock(self, stock_id: str, trade_val: float, cost: float, trade_price: float) -> None: + """ + Buy stock, which can also mean covering a short position. + + CRITICAL FIX: Buy orders immediately reduce cash (not delayed), consistent with qlib. + """ + trade_amount = trade_val / trade_price + + if stock_id not in self.position: + # Opening new long position + self._init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price) + else: + current_amount = self.position[stock_id]["amount"] + + if current_amount < 0: + # Covering a short position + new_amount = current_amount + trade_amount + + # CRITICAL FIX: Reduce short_proceeds when partially covering + covered_amount = min(trade_amount, abs(current_amount)) + if stock_id in self.short_proceeds and covered_amount > 0: + if abs(current_amount) > 0: + reduction_ratio = covered_amount / abs(current_amount) + self.short_proceeds[stock_id] *= 1 - reduction_ratio + if self.short_proceeds[stock_id] < self.POSITION_EPSILON: + del self.short_proceeds[stock_id] + + if new_amount >= 0: + # Fully covered and possibly going long + if stock_id in self.short_proceeds: + del self.short_proceeds[stock_id] + + # Use absolute tolerance for position close check + if abs(new_amount) < self.POSITION_EPSILON: + # Position fully closed + self._del_stock(stock_id) + else: + self.position[stock_id]["amount"] = new_amount + self.position[stock_id]["price"] = trade_price # Update price on trade + else: + # Adding to long position + self.position[stock_id]["amount"] += trade_amount + self.position[stock_id]["price"] = trade_price # Update price on trade + + # CRITICAL FIX: Buy orders immediately reduce cash (not delayed) + # This is consistent with qlib's implementation and prevents over-buying + self.position["cash"] -= trade_val + cost + + def calculate_stock_value(self) -> float: + """ + Calculate total value of stock positions. + + For long positions: value = amount * price + For short positions: value = amount * price (negative) + """ + stock_list = self.get_stock_list() + value = 0 + + for stock_id in stock_list: + amount = self.position[stock_id]["amount"] + price = self.position[stock_id].get("price", 0) + # Price robustness check + if price is not None and np.isfinite(price) and price > 0: + value += amount * price # Negative for shorts + elif price is None or not np.isfinite(price) or price <= 0: + # Log for debugging if logger is available + if getattr(self, "logger", None) is not None: + self.logger.debug(f"Invalid price for {stock_id}: {price}") + + return value + + def get_cash(self, include_settle: bool = False) -> float: + """ + Get available cash. + + CRITICAL FIX: Added include_settle parameter to match parent class interface. + + Parameters + ---------- + include_settle : bool + If True, include cash_delay (pending settlements) in the returned value + + Returns + ------- + float + Available cash (optionally including pending settlements) + """ + cash = self.position.get("cash", 0.0) + if include_settle: + cash += self.position.get("cash_delay", 0.0) + return cash + + def get_stock_amount(self, code: str) -> float: + """ + Return amount with near-zero values clamped to zero to avoid false residual shorts. + """ + amt = super().get_stock_amount(code) + if abs(amt) < self.POSITION_EPSILON: + return 0.0 + return amt + + def set_cash(self, value: float) -> None: + """ + Set cash value directly. + + Parameters + ---------- + value : float + New cash value + """ + self.position["cash"] = float(value) + + def add_borrow_cost(self, cost: float) -> None: + """ + Deduct borrowing cost from cash and track accumulated costs. + + Parameters + ---------- + cost : float + Borrowing cost to deduct + """ + self.position["cash"] -= float(cost) + self.borrow_cost_accumulated += float(cost) + + def calculate_value(self) -> float: + """ + Calculate total portfolio value. + + Total value = cash + cash_delay + stock_value + Borrowing costs are already deducted from cash, so not subtracted again. + """ + stock_value = self.calculate_stock_value() + cash = self.position.get("cash", 0.0) + cash_delay = self.position.get("cash_delay", 0.0) + + return cash + cash_delay + stock_value + + def get_leverage(self) -> float: + """ + Calculate portfolio leverage. + + Leverage = (Long Value + |Short Value|) / Total Equity + + Returns + ------- + float + Portfolio leverage ratio + """ + stock_list = self.get_stock_list() + long_value = 0 + short_value = 0 + + for stock_id in stock_list: + if isinstance(self.position[stock_id], dict): + amount = self.position[stock_id].get("amount", 0) + price = self.position[stock_id].get("price", 0) + # Price robustness check + if price is not None and np.isfinite(price) and price > 0: + position_value = amount * price + + if amount > 0: + long_value += position_value + else: + short_value += abs(position_value) + + total_equity = self.calculate_value() + if total_equity <= 0: + return np.inf + + gross_exposure = long_value + short_value + return gross_exposure / total_equity + + def get_net_exposure(self) -> float: + """ + Calculate net market exposure. + + Net Exposure = (Long Value - Short Value) / Total Equity + + Returns + ------- + float + Net exposure ratio + """ + stock_list = self.get_stock_list() + long_value = 0 + short_value = 0 + + for stock_id in stock_list: + if isinstance(self.position[stock_id], dict): + amount = self.position[stock_id].get("amount", 0) + price = self.position[stock_id].get("price", 0) + # Price robustness check + if price is not None and np.isfinite(price) and price > 0: + position_value = amount * price + + if amount > 0: + long_value += position_value + else: + short_value += abs(position_value) + + total_equity = self.calculate_value() + if total_equity <= 0: + return 0 + + net_exposure = (long_value - short_value) / total_equity + return net_exposure + + def calculate_daily_borrow_cost(self) -> float: + """ + Calculate daily borrowing cost for short positions. + + Returns + ------- + float + Daily borrowing cost + """ + stock_list = self.get_stock_list() + daily_cost = 0 + + for stock_id in stock_list: + if isinstance(self.position[stock_id], dict): + amount = self.position[stock_id].get("amount", 0) + if amount < 0: # Short position + price = self.position[stock_id].get("price", 0) + # Price robustness check + if price is not None and np.isfinite(price) and price > 0: + short_value = abs(amount * price) + daily_cost += short_value * self._daily_borrow_rate + elif price is None or not np.isfinite(price) or price <= 0: + if getattr(self, "logger", None) is not None: + self.logger.debug(f"Invalid price for short position {stock_id}: {price}") + + return daily_cost + + def settle_daily_costs(self) -> None: + """ + Settle daily costs including borrowing fees. + Should be called at the end of each trading day. + + Note: Consider using add_borrow_cost() for more control. + """ + borrow_cost = self.calculate_daily_borrow_cost() + if borrow_cost > 0: + self.add_borrow_cost(borrow_cost) + + def get_position_info(self) -> pd.DataFrame: + """ + Get detailed position information as DataFrame. + + Returns + ------- + pd.DataFrame + DataFrame with position details including: + - amount: position size (negative for shorts) + - price: current price + - value: position value + - weight: position weight in portfolio + - position_type: "long" or "short" + """ + data = [] + stock_list = self.get_stock_list() + + for stock_id in stock_list: + amount = self.position[stock_id]["amount"] + price = self.position[stock_id].get("price", 0) + weight = self.position[stock_id].get("weight", 0) + + # Price robustness check + if price is not None and np.isfinite(price) and price > 0: + value = amount * price + else: + value = 0 # Cannot calculate value without valid price + + data.append( + { + "stock_id": stock_id, + "amount": amount, + "price": price if price is not None else 0, + "value": value, + "weight": weight, + "position_type": "long" if amount > 0 else "short", + } + ) + + if not data: + return pd.DataFrame() + + df = pd.DataFrame(data) + df = df.set_index("stock_id") + return df + + def get_short_positions(self) -> Dict[str, float]: + """ + Get all short positions. + + Returns + ------- + dict + Dictionary of {stock_id: amount} for all short positions + """ + shorts = {} + stock_list = self.get_stock_list() + + for stock_id in stock_list: + amount = self.position[stock_id]["amount"] + if amount < -self.POSITION_EPSILON: + shorts[stock_id] = amount + + return shorts + + def get_long_positions(self) -> Dict[str, float]: + """ + Get all long positions. + + Returns + ------- + dict + Dictionary of {stock_id: amount} for all long positions + """ + longs = {} + stock_list = self.get_stock_list() + + for stock_id in stock_list: + amount = self.position[stock_id]["amount"] + if amount > self.POSITION_EPSILON: + longs[stock_id] = amount + + return longs + + def get_gross_value(self) -> float: + """ + Get gross portfolio value (sum of absolute values of all positions). + + Returns + ------- + float + Gross portfolio value + """ + gross = 0.0 + for sid in self.get_stock_list(): + pos = self.position[sid] + amt = pos.get("amount", 0.0) + price = pos.get("price", None) + if price is not None and np.isfinite(price) and price > 0: + gross += abs(amt * price) + elif price is None or not np.isfinite(price) or price <= 0: + if getattr(self, "logger", None) is not None: + self.logger.debug(f"Invalid price for {sid} in gross value calculation: {price}") + return gross + + def get_net_value(self) -> float: + """ + Get net portfolio value (long value - short value). + + Returns + ------- + float + Net portfolio value + """ + return self.calculate_stock_value() + + def update_all_stock_prices(self, price_dict: Dict[str, float]) -> None: + """ + Update prices for all positions (mark-to-market). + + This should be called at the end of each trading day with closing prices + to ensure accurate portfolio valuation. + + Parameters + ---------- + price_dict : dict + Dictionary of {stock_id: price} with current market prices + """ + for stock_id in self.get_stock_list(): + if stock_id in price_dict: + price = price_dict[stock_id] + if price is not None and np.isfinite(price) and price > 0: + self.position[stock_id]["price"] = price + + def __str__(self) -> str: + """String representation showing position details.""" + # Handle potential inf values safely + leverage = self.get_leverage() + leverage_str = round(leverage, 2) if np.isfinite(leverage) else "inf" + + net_exp = self.get_net_exposure() + net_exp_str = round(net_exp, 2) if np.isfinite(net_exp) else "inf" + + info = { + "cash": self.get_cash(), + "cash_delay": self.position.get("cash_delay", 0), + "stock_value": self.calculate_stock_value(), + "total_value": self.calculate_value(), + "leverage": leverage_str, + "net_exposure": net_exp_str, + "long_positions": len(self.get_long_positions()), + "short_positions": len(self.get_short_positions()), + "borrow_cost_accumulated": round(self.borrow_cost_accumulated, 2), + } + return f"ShortablePosition({info})" diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index 4210c9548a..586600ef64 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -8,7 +8,7 @@ import numpy as np -from qlib.utils.time import epsilon_change +from qlib.utils.time import epsilon_change, Freq if TYPE_CHECKING: from qlib.backtest.decision import BaseTradeDecision @@ -128,7 +128,16 @@ def get_step_time(self, trade_step: int | None = None, shift: int = 0) -> Tuple[ if trade_step is None: trade_step = self.get_trade_step() calendar_index = self.start_index + trade_step - shift - return self._calendar[calendar_index], epsilon_change(self._calendar[calendar_index + 1]) + left = self._calendar[calendar_index] + # Robust right endpoint even when future calendar is unavailable + next_idx = calendar_index + 1 + if next_idx < len(self._calendar): + right = epsilon_change(self._calendar[next_idx]) + else: + # estimate next boundary by freq delta + n, base = Freq.parse(self.freq) + right = epsilon_change(left + Freq.get_timedelta(n, base)) + return left, right def get_data_cal_range(self, rtype: str = "full") -> Tuple[int, int]: """ diff --git a/qlib/config.py b/qlib/config.py index a0b4aad28b..cde650784e 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -22,7 +22,7 @@ from typing import Callable, Optional, Union from typing import TYPE_CHECKING -from qlib.constant import REG_CN, REG_US, REG_TW +from qlib.constant import REG_CN, REG_US, REG_TW, REG_CRYPTO if TYPE_CHECKING: from qlib.utils.time import Freq @@ -307,6 +307,12 @@ def register_from_C(config, skip_register=True): "limit_threshold": 0.1, "deal_price": "close", }, + # Crypto region: 24/7, no limit_threshold, unit=1, default deal_price=close + REG_CRYPTO: { + "trade_unit": 1, + "limit_threshold": None, + "deal_price": "close", + }, } diff --git a/qlib/constant.py b/qlib/constant.py index ac6c76ae22..3a8f738522 100644 --- a/qlib/constant.py +++ b/qlib/constant.py @@ -10,6 +10,7 @@ REG_CN = "cn" REG_US = "us" REG_TW = "tw" +REG_CRYPTO = "crypto" # Epsilon for avoiding division by zero. EPS = 1e-12 diff --git a/qlib/contrib/strategy/signal_strategy.py b/qlib/contrib/strategy/signal_strategy.py index bad19ddfdc..0284810582 100644 --- a/qlib/contrib/strategy/signal_strategy.py +++ b/qlib/contrib/strategy/signal_strategy.py @@ -1,3 +1,6 @@ +"""Signal-driven strategies including LongShortTopKStrategy (crypto-ready).""" + +# pylint: disable=C0301,R0912,R0915,R0902,R0913,R0914,C0411,W0511,W0718,W0612,W0613,C0209,W1309,C1802,C0115,C0116 # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import os @@ -6,7 +9,7 @@ import numpy as np import pandas as pd -from typing import Dict, List, Text, Tuple, Union +from typing import Dict, List, Text, Tuple, Union, Optional from abc import ABC from qlib.data import D @@ -409,8 +412,8 @@ def __init__( riskmodel_root, market="csi500", turn_limit=None, - name_mapping={}, - optimizer_kwargs={}, + name_mapping=None, + optimizer_kwargs=None, verbose=False, **kwargs, ): @@ -422,11 +425,13 @@ def __init__( self.market = market self.turn_limit = turn_limit + name_mapping = {} if name_mapping is None else name_mapping self.factor_exp_path = name_mapping.get("factor_exp", self.FACTOR_EXP_NAME) self.factor_cov_path = name_mapping.get("factor_cov", self.FACTOR_COV_NAME) self.specific_risk_path = name_mapping.get("specific_risk", self.SPECIFIC_RISK_NAME) self.blacklist_path = name_mapping.get("blacklist", self.BLACKLIST_NAME) + optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs self.optimizer = EnhancedIndexingOptimizer(**optimizer_kwargs) self.verbose = verbose @@ -520,3 +525,460 @@ def generate_target_weight_position(self, score, current, trade_start_time, trad self.logger.info("total holding weight: {:.6f}".format(weight.sum())) return target_weight_position + + +class LongShortTopKStrategy(BaseSignalStrategy): + """ + Strict TopK-aligned Long-Short strategy. + + - Uses shift=1 signals (previous bar's signal for current trading) like TopkDropoutStrategy + - Maintains separate TopK pools for long and short legs with independent rotation (n_drop) + - Respects tradability checks and limit rules consistent with TopkDropoutStrategy + - Requires a shortable exchange to open short positions; otherwise SELL will be clipped by Exchange + """ + + def __init__( + self, + *, + topk_long: int, + topk_short: int, + n_drop_long: int, + n_drop_short: int, + method_sell: str = "bottom", + method_buy: str = "top", + hold_thresh: int = 1, + only_tradable: bool = False, + forbid_all_trade_at_limit: bool = True, + rebalance_to_weights: bool = True, + long_share: Optional[float] = None, + debug: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.topk_long = topk_long + self.topk_short = topk_short + self.n_drop_long = n_drop_long + self.n_drop_short = n_drop_short + self.method_sell = method_sell + self.method_buy = method_buy + self.hold_thresh = hold_thresh + self.only_tradable = only_tradable + self.forbid_all_trade_at_limit = forbid_all_trade_at_limit + self.rebalance_to_weights = rebalance_to_weights + # When both legs enabled, split risk_degree by long_share (0~1). None -> 0.5 default. + self.long_share = long_share + self._debug = debug + + def generate_trade_decision(self, execute_result=None): + # Align time windows (shift=1) + trade_step = self.trade_calendar.get_trade_step() + trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) + pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) + pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time) + if isinstance(pred_score, pd.DataFrame): + pred_score = pred_score.iloc[:, 0] + if pred_score is None: + return TradeDecisionWO([], self) + + # Helper functions copied from TopkDropoutStrategy semantics + if self.only_tradable: + + def get_first_n(li, n, reverse=False): + cur_n = 0 + res = [] + for si in reversed(li) if reverse else li: + if self.trade_exchange.is_stock_tradable( + stock_id=si, start_time=trade_start_time, end_time=trade_end_time + ): + res.append(si) + cur_n += 1 + if cur_n >= n: + break + return res[::-1] if reverse else res + + def get_last_n(li, n): + return get_first_n(li, n, reverse=True) + + def filter_stock(li): + return [ + si + for si in li + if self.trade_exchange.is_stock_tradable( + stock_id=si, start_time=trade_start_time, end_time=trade_end_time + ) + ] + + else: + + def get_first_n(li, n): + return list(li)[:n] + + def get_last_n(li, n): + return list(li)[-n:] + + def filter_stock(li): + return li + + import copy as _copy # local alias for deepcopy + + # Use instance configuration; keep behavior unchanged (no external kwargs expected here) + risk_aversion = _copy.deepcopy(getattr(self, "risk_aversion", None)) + + current_temp: Position = _copy.deepcopy(self.trade_position) + + # Build current long/short lists by sign of amount + current_stock_list = current_temp.get_stock_list() + long_now = [] # amounts > 0 + short_now = [] # amounts < 0 + for code in current_stock_list: + amt = current_temp.get_stock_amount(code) + if amt > 0: + long_now.append(code) + elif amt < 0: + short_now.append(code) + if self._debug: + print( + f"[LongShortTopKStrategy][{trade_start_time}] init_pos: longs={len(long_now)}, shorts={len(short_now)}" + ) + if short_now: + try: + details = [(c, float(current_temp.get_stock_amount(c))) for c in short_now] + print(f"[LongShortTopKStrategy][{trade_start_time}] short_details: {details}") + except Exception: + pass + + # ---- Long leg selection (descending score) ---- + last_long = pred_score.reindex(long_now).sort_values(ascending=False).index + n_to_add_long = max(0, self.n_drop_long + self.topk_long - len(last_long)) + if self.method_buy == "top": + today_long_candi = get_first_n( + pred_score[~pred_score.index.isin(last_long)].sort_values(ascending=False).index, + n_to_add_long, + ) + elif self.method_buy == "random": + topk_candi = get_first_n(pred_score.sort_values(ascending=False).index, self.topk_long) + candi = list(filter(lambda x: x not in last_long, topk_candi)) + try: + today_long_candi = ( + list(np.random.choice(candi, n_to_add_long, replace=False)) if n_to_add_long > 0 else [] + ) + except ValueError: + today_long_candi = candi + else: + raise NotImplementedError + comb_long = pred_score.reindex(last_long.union(pd.Index(today_long_candi))).sort_values(ascending=False).index + if self.method_sell == "bottom": + sell_long = last_long[last_long.isin(get_last_n(comb_long, self.n_drop_long))] + elif self.method_sell == "random": + candi = filter_stock(last_long) + try: + sell_long = pd.Index(np.random.choice(candi, self.n_drop_long, replace=False) if len(candi) else []) + except ValueError: + sell_long = pd.Index(candi) + else: + raise NotImplementedError + buy_long = today_long_candi[: len(sell_long) + self.topk_long - len(last_long)] + + # ---- Short leg selection (ascending score) ---- + last_short = pred_score.reindex(short_now).sort_values(ascending=True).index + n_to_add_short = max(0, self.n_drop_short + self.topk_short - len(last_short)) + if self.method_buy == "top": # for short, "top" means most negative i.e., ascending + today_short_candi = get_first_n( + pred_score[~pred_score.index.isin(last_short)].sort_values(ascending=True).index, + n_to_add_short, + ) + elif self.method_buy == "random": + topk_candi = get_first_n(pred_score.sort_values(ascending=True).index, self.topk_short) + candi = list(filter(lambda x: x not in last_short, topk_candi)) + try: + today_short_candi = ( + list(np.random.choice(candi, n_to_add_short, replace=False)) if n_to_add_short > 0 else [] + ) + except ValueError: + today_short_candi = candi + else: + raise NotImplementedError + comb_short = pred_score.reindex(last_short.union(pd.Index(today_short_candi))).sort_values(ascending=True).index + if self.method_sell == "bottom": # for short, bottom means highest scores among shorts (least negative) + cover_short = last_short[last_short.isin(get_last_n(comb_short, self.n_drop_short))] + elif self.method_sell == "random": + candi = filter_stock(last_short) + try: + cover_short = pd.Index(np.random.choice(candi, self.n_drop_short, replace=False) if len(candi) else []) + except ValueError: + cover_short = pd.Index(candi) + else: + raise NotImplementedError + open_short = today_short_candi[: len(cover_short) + self.topk_short - len(last_short)] + + # ---- Rebalance to target weights to bound gross leverage and net exposure ---- + # Determine final long/short sets considering hold_thresh and tradability + def can_trade(code: str, direction: int) -> bool: + return self.trade_exchange.is_stock_tradable( + stock_id=code, + start_time=trade_start_time, + end_time=trade_end_time, + direction=None if self.forbid_all_trade_at_limit else direction, + ) + + time_per_step = self.trade_calendar.get_freq() + + # apply hold_thresh when removing + actual_sold_longs = set() + for code in last_long: + if ( + code in sell_long + and current_temp.get_stock_count(code, bar=time_per_step) >= self.hold_thresh + and can_trade(code, OrderDir.SELL) + ): + actual_sold_longs.add(code) + + actual_covered_shorts = set() + # Align with TopK: in long-only mode, fully cover any existing shorts (not limited by n_drop_short or hold_thresh) + long_only_mode = (self.topk_short is None) or (self.topk_short <= 0) + if long_only_mode: + # Only cover when there is a real negative position to avoid false positives + for code in last_short: + if current_temp.get_stock_amount(code) < 0 and can_trade(code, OrderDir.BUY): + actual_covered_shorts.add(code) + else: + for code in last_short: + if ( + code in cover_short + and current_temp.get_stock_count(code, bar=time_per_step) >= self.hold_thresh + and can_trade(code, OrderDir.BUY) + ): + actual_covered_shorts.add(code) + if self._debug: + print( + f"[LongShortTopKStrategy][{trade_start_time}] cover_shorts={len(actual_covered_shorts)} buy_longs_plan={len(buy_long)} open_shorts_plan={len(open_short)}" + ) + + # Preserve raw planned lists before tradability filtering to align with TopK semantics + raw_buy_long = list(buy_long) + raw_open_short = list(open_short) + + buy_long = [c for c in buy_long if can_trade(c, OrderDir.BUY)] + open_short = [c for c in open_short if can_trade(c, OrderDir.SELL)] + open_short = [c for c in open_short if c not in buy_long] # avoid overlap + + final_long_set = (set(long_now) - actual_sold_longs) | set(buy_long) + final_short_set = (set(short_now) - actual_covered_shorts) | set(open_short) + + # Optional: TopK-style no-rebalance branch (symmetric long/short) + if not self.rebalance_to_weights: + order_list: List[Order] = [] + cash = current_temp.get_cash() + + # 1) Sell dropped longs entirely + for code in long_now: + if code in actual_sold_longs and can_trade(code, OrderDir.SELL): + sell_amount = current_temp.get_stock_amount(code=code) + if sell_amount <= 0: + continue + sell_order = Order( + stock_id=code, + amount=sell_amount, + start_time=trade_start_time, + end_time=trade_end_time, + direction=OrderDir.SELL, + ) + if self.trade_exchange.check_order(sell_order): + order_list.append(sell_order) + trade_val, trade_cost, trade_price = self.trade_exchange.deal_order( + sell_order, position=current_temp + ) + cash += trade_val - trade_cost + + # Snapshot cash AFTER long sells but BEFORE short covers + # TopK-style long leg should allocate based on this snapshot to avoid + # short-cover cash consumption leaking into long-buy budget. + cash_after_long_sells = cash + + # 2) Cover dropped shorts entirely (BUY to cover) + for code in short_now: + if code in actual_covered_shorts and can_trade(code, OrderDir.BUY): + cover_amount = abs(current_temp.get_stock_amount(code=code)) + if cover_amount <= 0: + continue + cover_order = Order( + stock_id=code, + amount=cover_amount, + start_time=trade_start_time, + end_time=trade_end_time, + direction=OrderDir.BUY, + ) + if self.trade_exchange.check_order(cover_order): + order_list.append(cover_order) + trade_val, trade_cost, trade_price = self.trade_exchange.deal_order( + cover_order, position=current_temp + ) + cash -= trade_val + trade_cost # covering consumes cash + + # 3) Buy new longs with equal cash split, honoring risk_degree + rd = float(self.get_risk_degree(trade_step)) + # Allocate long/short share: support long_share; degenerate for single-leg mode + short_only_mode = (self.topk_long is None) or (self.topk_long <= 0) + share = self.long_share if (self.long_share is not None) else 0.5 + if long_only_mode: + rd_long, rd_short = rd, 0.0 + elif short_only_mode: + rd_long, rd_short = 0.0, rd + else: + rd_long, rd_short = rd * share, rd * (1.0 - share) + if self._debug: + print( + f"[LongShortTopKStrategy][{trade_start_time}] rd={rd:.4f} rd_long={rd_long:.4f} rd_short={rd_short:.4f} cash_after_long_sells={cash_after_long_sells:.2f}" + ) + # Align with TopK: use cash snapshot after long sells; split by planned count (raw) + value_per_buy = cash_after_long_sells * rd_long / len(raw_buy_long) if len(raw_buy_long) > 0 else 0.0 + for code in raw_buy_long: + if not can_trade(code, OrderDir.BUY): + continue + price = self.trade_exchange.get_deal_price( + stock_id=code, start_time=trade_start_time, end_time=trade_end_time, direction=OrderDir.BUY + ) + if price is None or not np.isfinite(price) or price <= 0: + continue + buy_amount = value_per_buy / float(price) + factor = self.trade_exchange.get_factor( + stock_id=code, start_time=trade_start_time, end_time=trade_end_time + ) + buy_amount = self.trade_exchange.round_amount_by_trade_unit(buy_amount, factor) + if buy_amount <= 0: + continue + buy_order = Order( + stock_id=code, + amount=buy_amount, + start_time=trade_start_time, + end_time=trade_end_time, + direction=OrderDir.BUY, + ) + order_list.append(buy_order) + + # 4) Open new shorts equally by target short notional derived from rd + # Compute current short notional after covering + def _get_price(sid: str, direction: int): + px = self.trade_exchange.get_deal_price( + stock_id=sid, start_time=trade_start_time, end_time=trade_end_time, direction=direction + ) + return float(px) if (px is not None and np.isfinite(px) and px > 0) else None + + # Recompute equity after previous simulated deals + # For TopK parity, compute equity BEFORE executing new long buys and BEFORE opening new shorts + # i.e., after simulated sells/covers above. + equity = max(1e-12, float(current_temp.calculate_value())) + + # Sum current short notional + current_short_value = 0.0 + for sid in current_temp.get_stock_list(): + amt = current_temp.get_stock_amount(sid) + if amt < 0: + px = _get_price(sid, OrderDir.BUY) # price to cover + if px is not None: + current_short_value += abs(float(amt)) * px + + # Use the same rd_short allocation as above + # Note: if short_only_mode, rd_long = 0 and rd_short = rd + # Reuse the rd_short computed earlier + desired_short_value = equity * rd_short + remaining_short_value = max(0.0, desired_short_value - current_short_value) + # Align with TopK: split by planned short-open count (raw), then check tradability + value_per_short_open = remaining_short_value / len(raw_open_short) if len(raw_open_short) > 0 else 0.0 + if self._debug: + print( + f"[LongShortTopKStrategy][{trade_start_time}] equity={equity:.2f} cur_short_val={current_short_value:.2f} desired_short_val={desired_short_value:.2f} rem_short_val={remaining_short_value:.2f} v_per_short={value_per_short_open:.2f}" + ) + + for code in raw_open_short: + if not can_trade(code, OrderDir.SELL): + continue + price = _get_price(code, OrderDir.SELL) + if price is None: + continue + sell_amount = value_per_short_open / float(price) + factor = self.trade_exchange.get_factor( + stock_id=code, start_time=trade_start_time, end_time=trade_end_time + ) + sell_amount = self.trade_exchange.round_amount_by_trade_unit(sell_amount, factor) + if sell_amount <= 0: + continue + sell_order = Order( + stock_id=code, + amount=sell_amount, + start_time=trade_start_time, + end_time=trade_end_time, + direction=OrderDir.SELL, + ) + order_list.append(sell_order) + + return TradeDecisionWO(order_list, self) + + # Target weights + rd = float(self.get_risk_degree(trade_step)) + share = self.long_share if (self.long_share is not None) else 0.5 + long_total = 0.0 + short_total = 0.0 + if len(final_long_set) > 0 and len(final_short_set) > 0: + long_total = rd * share + short_total = rd * (1.0 - share) + elif len(final_long_set) > 0: + long_total = rd + elif len(final_short_set) > 0: + short_total = rd + + target_weight: Dict[str, float] = {} + if len(final_long_set) > 0: + lw = long_total / len(final_long_set) + for c in final_long_set: + target_weight[c] = lw + if len(final_short_set) > 0: + sw = -short_total / len(final_short_set) + for c in final_short_set: + target_weight[c] = sw + + # Stocks to liquidate + for c in current_temp.get_stock_list(): + if c not in target_weight: + target_weight[c] = 0.0 + + # Generate orders by comparing current vs target + order_list: List[Order] = [] + equity = max(1e-12, float(current_temp.calculate_value())) + for code, tw in target_weight.items(): + # get price + # We select direction by desired delta later, here just fetch a price using BUY as placeholder if needed + price_buy = self.trade_exchange.get_deal_price( + stock_id=code, start_time=trade_start_time, end_time=trade_end_time, direction=OrderDir.BUY + ) + price_sell = self.trade_exchange.get_deal_price( + stock_id=code, start_time=trade_start_time, end_time=trade_end_time, direction=OrderDir.SELL + ) + price = price_buy if price_buy else price_sell + if not price or price <= 0: + continue + cur_amount = float(current_temp.get_stock_amount(code)) if code in current_temp.get_stock_list() else 0.0 + cur_value = cur_amount * price + tgt_value = tw * equity + delta_value = tgt_value - cur_value + if abs(delta_value) <= 0: + continue + direction = OrderDir.BUY if delta_value > 0 else OrderDir.SELL + if not can_trade(code, direction): + continue + delta_amount = abs(delta_value) / price + factor = self.trade_exchange.get_factor(stock_id=code, start_time=trade_start_time, end_time=trade_end_time) + delta_amount = self.trade_exchange.round_amount_by_trade_unit(delta_amount, factor) + if delta_amount <= 0: + continue + order_list.append( + Order( + stock_id=code, + amount=delta_amount, + start_time=trade_start_time, + end_time=trade_end_time, + direction=direction, + ) + ) + + return TradeDecisionWO(order_list, self) diff --git a/qlib/contrib/workflow/__init__.py b/qlib/contrib/workflow/__init__.py index 0faf4e5f9d..17a6d488e8 100644 --- a/qlib/contrib/workflow/__init__.py +++ b/qlib/contrib/workflow/__init__.py @@ -1,7 +1,25 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .record_temp import MultiSegRecord -from .record_temp import SignalMseRecord +"""Lightweight contrib.workflow package init. +Avoid importing heavy submodules at import time to prevent unintended +side-effects and circular imports when users import a specific submodule +like `qlib.contrib.workflow.crypto_record_temp`. +""" + +from __future__ import annotations + +import importlib +from typing import Any, TYPE_CHECKING __all__ = ["MultiSegRecord", "SignalMseRecord"] + +if TYPE_CHECKING: # only for type checkers; no runtime import + from .record_temp import MultiSegRecord, SignalMseRecord # noqa: F401 + + +def __getattr__(name: str) -> Any: + if name in ("MultiSegRecord", "SignalMseRecord"): + mod = importlib.import_module(__name__ + ".record_temp") + return getattr(mod, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/qlib/contrib/workflow/crypto_record_temp.py b/qlib/contrib/workflow/crypto_record_temp.py new file mode 100644 index 0000000000..3d321cf6d5 --- /dev/null +++ b/qlib/contrib/workflow/crypto_record_temp.py @@ -0,0 +1,168 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +Crypto-specific portfolio analysis record. + +This module provides `CryptoPortAnaRecord`, a non-intrusive extension of +`qlib.workflow.record_temp.PortAnaRecord` that adapts portfolio analysis for +crypto markets (e.g., 365-day annualization, product compounding) while keeping +the default Qlib behavior unchanged for other users. +""" + +# pylint: disable=C0301,R0913,R0914,R0912,R0915,W0718,C0103 + +from __future__ import annotations + +from typing import List, Union + +import pandas as pd + +from ..evaluate import risk_analysis as original_risk_analysis +from ...utils import fill_placeholder, get_date_by_shift +from ...workflow.record_temp import PortAnaRecord + + +def _crypto_risk_analysis(r: pd.Series, N: int = 365) -> pd.Series: + """Risk analysis with product compounding and 365 annual days. + + This wraps Qlib's contrib risk_analysis with crypto-friendly defaults by + passing N and forcing product mode through freq=None. + """ + return original_risk_analysis(r, freq=None, N=N, mode="product") + + +class CryptoPortAnaRecord(PortAnaRecord): + """A crypto-friendly PortAnaRecord. + + Differences vs PortAnaRecord (only when used): + - Annualization uses 365 trading days. + - Product compounding for cumulative/excess returns. + - Optionally align exchange freq based on risk_analysis_freq if provided. + + Defaults and behavior of the core PortAnaRecord remain unchanged elsewhere. + """ + + def __init__( + self, + recorder, + config=None, + risk_analysis_freq: Union[List, str] = None, + indicator_analysis_freq: Union[List, str] = None, + indicator_analysis_method=None, + crypto_annual_days: int = 365, + skip_existing: bool = False, + **kwargs, + ): + super().__init__( + recorder=recorder, + config=config, + risk_analysis_freq=risk_analysis_freq, + indicator_analysis_freq=indicator_analysis_freq, + indicator_analysis_method=indicator_analysis_method, + skip_existing=skip_existing, + **kwargs, + ) + self.crypto_annual_days = crypto_annual_days + + def _generate(self, **kwargs): # override only the generation logic + from ...backtest import backtest as normal_backtest # pylint: disable=C0415 + + pred = self.load("pred.pkl") + + # Replace placeholder values + placeholder_value = {"": pred} + for k in "executor_config", "strategy_config": + setattr(self, k, fill_placeholder(getattr(self, k), placeholder_value)) + + # Auto-extract time range if not set + dt_values = pred.index.get_level_values("datetime") + if self.backtest_config["start_time"] is None: + self.backtest_config["start_time"] = dt_values.min() + if self.backtest_config["end_time"] is None: + self.backtest_config["end_time"] = get_date_by_shift(dt_values.max(), 1) + + # Optionally align exchange frequency with requested risk analysis frequency + try: + target_freq = None + raf = getattr(self, "risk_analysis_freq", None) + if isinstance(raf, (list, tuple)) and len(raf) > 0: + target_freq = raf[0] + elif isinstance(raf, str): + target_freq = raf + if isinstance(target_freq, str) and target_freq: + ex_kwargs = dict(self.backtest_config.get("exchange_kwargs", {}) or {}) + ex_kwargs.setdefault("freq", target_freq) + self.backtest_config["exchange_kwargs"] = ex_kwargs + except Exception: + pass + + # Run backtest + portfolio_metric_dict, indicator_dict = normal_backtest( + executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config + ) + + artifact_objects = {} + + # Save portfolio metrics; also attach crypto metrics as attrs for consumers + for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items(): + if "return" in report_normal.columns: + r = report_normal["return"].astype(float).fillna(0) + b = report_normal["bench"].astype(float).fillna(0) + c = report_normal.get("cost", 0.0) + c = c.astype(float).fillna(0) if isinstance(c, pd.Series) else float(c) + + # Attach crypto metrics for downstream use (non-breaking) + try: + report_normal.attrs["crypto_metrics"] = { + "strategy": _crypto_risk_analysis(r, N=self.crypto_annual_days), + "benchmark": _crypto_risk_analysis(b, N=self.crypto_annual_days), + "excess_wo_cost": _crypto_risk_analysis((1 + r) / (1 + b) - 1, N=self.crypto_annual_days), + "excess_w_cost": _crypto_risk_analysis((1 + (r - c)) / (1 + b) - 1, N=self.crypto_annual_days), + "annual_days": self.crypto_annual_days, + } + except Exception: + pass + + artifact_objects.update({f"report_normal_{_freq}.pkl": report_normal}) + artifact_objects.update({f"positions_normal_{_freq}.pkl": positions_normal}) + + for _freq, indicators_normal in indicator_dict.items(): + artifact_objects.update({f"indicators_normal_{_freq}.pkl": indicators_normal[0]}) + artifact_objects.update({f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]}) + + # Risk analysis (365 days, product mode) printing and artifacts, mirroring PortAnaRecord + for _analysis_freq in self.risk_analysis_freq: + if _analysis_freq not in portfolio_metric_dict: + import warnings # pylint: disable=C0415 + + warnings.warn( + f"the freq {_analysis_freq} report is not found, please set the corresponding env with `generate_portfolio_metrics=True`" + ) + else: + report_normal, _ = portfolio_metric_dict.get(_analysis_freq) + analysis = {} + + r = report_normal["return"].astype(float).fillna(0) + b = report_normal["bench"].astype(float).fillna(0) + c = report_normal.get("cost", 0.0) + c = c.astype(float).fillna(0) if isinstance(c, pd.Series) else float(c) + + # geometric excess + analysis["excess_return_without_cost"] = _crypto_risk_analysis( + (1 + r) / (1 + b) - 1, N=self.crypto_annual_days + ) + analysis["excess_return_with_cost"] = _crypto_risk_analysis( + (1 + (r - c)) / (1 + b) - 1, N=self.crypto_annual_days + ) + + analysis_df = pd.concat(analysis) + from ...utils import flatten_dict # pylint: disable=C0415 + + analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict()) + self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()}) + artifact_objects.update({f"port_analysis_{_analysis_freq}.pkl": analysis_df}) + + return artifact_objects + + +__all__ = ["CryptoPortAnaRecord"] diff --git a/qlib/data/data.py b/qlib/data/data.py index aba75c0b1a..d7e2dbb928 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -1184,10 +1184,23 @@ def features( fields = list(fields) # In case of tuple. try: return DatasetD.dataset( - instruments, fields, start_time, end_time, freq, disk_cache, inst_processors=inst_processors + instruments=instruments, + fields=fields, + start_time=start_time, + end_time=end_time, + freq=freq, + disk_cache=disk_cache, + inst_processors=inst_processors, ) except TypeError: - return DatasetD.dataset(instruments, fields, start_time, end_time, freq, inst_processors=inst_processors) + return DatasetD.dataset( + instruments=instruments, + fields=fields, + start_time=start_time, + end_time=end_time, + freq=freq, + inst_processors=inst_processors, + ) class LocalProvider(BaseProvider): diff --git a/qlib/tests/test_shortable_crypto_real.py b/qlib/tests/test_shortable_crypto_real.py new file mode 100644 index 0000000000..ea4f8a7bc4 --- /dev/null +++ b/qlib/tests/test_shortable_crypto_real.py @@ -0,0 +1,139 @@ +"""Tests for shortable crypto backtest components (executor/exchange/position).""" + +# pylint: disable=C0301,W0718,C0116,R1710,R0914,C0411 +import os +from pathlib import Path +import pytest +import pandas as pd + +import qlib +import warnings +from qlib.data import D +from qlib.constant import REG_CRYPTO + +from qlib.backtest.shortable_exchange import ShortableExchange +from qlib.backtest.shortable_backtest import ShortableExecutor, LongShortStrategy, ShortableAccount + + +def _try_init_qlib(): + """Initialize qlib with real crypto data if available; otherwise skip tests.""" + candidates = [ + os.path.expanduser("~/.qlib/qlib_data/crypto_data_perp"), # Prefer user's provided perp path + os.path.expanduser("~/.qlib/qlib_data/crypto_data"), + str(Path(__file__).resolve().parents[3] / "crypto-qlib" / "binance_crypto_data_perp"), + str(Path(__file__).resolve().parents[3] / "crypto-qlib" / "binance_crypto_data"), + ] + for p in candidates: + try: + if p and (p.startswith("~") or os.path.isabs(p)): + # Expand ~ and check existence loosely (provider may be a directory with sub-structure) + _p = os.path.expanduser(p) + else: + _p = p + qlib.init(provider_uri=_p, region=REG_CRYPTO, skip_if_reg=True, kernels=1) + # Silence known harmless warning from numpy on empty slice in qlib internal mean + warnings.filterwarnings( + "ignore", + message="Mean of empty slice", + category=RuntimeWarning, + module=r".*qlib\\.utils\\.index_data", + ) + # Probe one simple call + _ = D.instruments() + return _p + except Exception: + continue + pytest.skip("No valid crypto provider_uri found; skipping real-data tests") + + +def test_shortable_with_real_data_end_to_end(): + _ = _try_init_qlib() + + # Use a fixed window you confirmed has data + start_time = pd.Timestamp("2021-07-11") + end_time = pd.Timestamp("2021-08-10") + + # Pick a small universe via proper API: instruments config -> list + inst_conf = D.instruments(market="all") + instruments = D.list_instruments(inst_conf, start_time=start_time, end_time=end_time, freq="day", as_list=True)[:10] + if not instruments: + pytest.skip("No instruments available from provider; skipping") + + # Build exchange on real data, restrict to small universe + ex = ShortableExchange( + freq="day", + start_time=start_time, + end_time=end_time, + codes=instruments, + deal_price="$close", + open_cost=0.0015, + close_cost=0.0025, + impact_cost=0.0, + limit_threshold=None, + ) + + # Avoid default CSI300 benchmark by constructing account with benchmark=None + account = ShortableAccount(benchmark_config={"benchmark": None}) + + exe = ShortableExecutor( + time_per_step="day", + generate_portfolio_metrics=True, + trade_exchange=ex, + region="crypto", + verbose=False, + account=account, + ) + + # Build a simple momentum signal on end_time (fallback to last-close ranking if necessary) + feat = D.features( + instruments, + ["$close"], + start_time, + end_time, + freq="day", + disk_cache=True, + ) + if feat is None or feat.empty: + pytest.skip("No valid features in selected window; skipping") + + g = feat.groupby("instrument")["$close"] + last = g.last() + # momentum needs at least 2 rows per instrument + try: + prev = g.nth(-2) + sig = (last / prev - 1.0).dropna() + except Exception: + sig = pd.Series(dtype=float) + + if sig.empty: + # fallback: rank by last close (descending) + last = last.dropna() + if last.empty: + pytest.skip("No closes to build fallback signal; skipping") + sig = last - last.mean() # demeaned last close as pseudo-signal + + # Generate orders for the end_time + # For crypto, use unit step to ensure orders are generated and avoid empty indicators + strat = LongShortStrategy( + gross_leverage=1.0, + net_exposure=0.0, + top_k=3, + exchange=ex, + lot_size=1, + min_trade_threshold=1, + ) + td = strat.generate_trade_decision(sig, exe.position, end_time) + + # Execute one step via standard API + exe.reset(start_time=start_time, end_time=end_time) + _ = exe.execute(td) + + # Validate metrics shape and key fields + df, meta = exe.trade_account.get_portfolio_metrics() + assert hasattr(df, "shape") + assert isinstance(meta, dict) + # net_exposure should be finite; leverage should be >= 0 + assert meta.get("leverage", 0) >= 0 + assert isinstance(meta.get("net_exposure", 0), float) + # If we have short positions, borrow cost may be > 0 + assert meta.get("total_borrow_cost", 0) >= 0 diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py index 9fe38ad662..6e0692ea72 100644 --- a/qlib/utils/resam.py +++ b/qlib/utils/resam.py @@ -2,7 +2,7 @@ import pandas as pd from functools import partial -from typing import Union, Callable +from typing import Union, Callable, List from . import lazy_sort_index from .time import Freq, cal_sam_minute @@ -79,21 +79,85 @@ def get_higher_eq_freq_feature(instruments, fields, start_time=None, end_time=No from ..data.data import D # pylint: disable=C0415 + def _list_supported_minute_freqs() -> List[str]: + """Return supported minute freqs sorted ascending (e.g., ["1min", "5min", "60min"]).""" + try: + calendars_dir = C.dpm.get_data_uri(C.DEFAULT_FREQ).joinpath("calendars") + freq_names = [] + for p in calendars_dir.glob("*.txt"): + stem = p.stem + # skip future calendars + if stem.endswith("_future"): + continue + try: + from .time import Freq as _Freq # local import to avoid cycle # pylint: disable=C0415 + except Exception: + continue + _f = _Freq(stem) + if _f.base == _Freq.NORM_FREQ_MINUTE: + freq_names.append(str(_f)) + + # sort by minute count (1min < 5min < 60min) + def _minute_order(x: str) -> int: + from .time import Freq as _Freq # local import # pylint: disable=C0415 + + _f = _Freq(x) + return _f.count + + return sorted(set(freq_names), key=_minute_order) + except Exception: + # best effort + return [] + try: _result = D.features(instruments, fields, start_time, end_time, freq=freq, disk_cache=disk_cache) _freq = freq except (ValueError, KeyError) as value_key_e: _, norm_freq = Freq.parse(freq) if norm_freq in [Freq.NORM_FREQ_MONTH, Freq.NORM_FREQ_WEEK, Freq.NORM_FREQ_DAY]: + # try day first try: _result = D.features(instruments, fields, start_time, end_time, freq="day", disk_cache=disk_cache) _freq = "day" except (ValueError, KeyError): - _result = D.features(instruments, fields, start_time, end_time, freq="1min", disk_cache=disk_cache) - _freq = "1min" + # fall back to best available minute frequency (e.g., 1min/5min/60min) + min_freqs = _list_supported_minute_freqs() + if not min_freqs: + # last resort: 1min (original behavior) + min_freqs = ["1min"] + last_exc = None + for mf in min_freqs: + try: + _result = D.features(instruments, fields, start_time, end_time, freq=mf, disk_cache=disk_cache) + _freq = mf + break + except (ValueError, KeyError) as _e: + last_exc = _e + continue + else: + raise ValueError(f"No supported minute frequency found for features; tried: {min_freqs}") from ( + last_exc or value_key_e + ) elif norm_freq == Freq.NORM_FREQ_MINUTE: - _result = D.features(instruments, fields, start_time, end_time, freq="1min", disk_cache=disk_cache) - _freq = "1min" + # try requested minute first; if fails, try other supported minute freqs + min_freqs = [freq] + sup_mf = _list_supported_minute_freqs() + for mf in sup_mf: + if mf not in min_freqs: + min_freqs.append(mf) + last_exc = None + for mf in min_freqs: + try: + _result = D.features(instruments, fields, start_time, end_time, freq=mf, disk_cache=disk_cache) + _freq = mf + break + except (ValueError, KeyError) as _e: + last_exc = _e + continue + else: + raise ValueError(f"No supported minute frequency found for features; tried: {min_freqs}") from ( + last_exc or value_key_e + ) else: raise ValueError(f"freq {freq} is not supported") from value_key_e return _result, _freq