Skip to content

Commit

Permalink
Alternative ways to visualise grid search (#1073)
Browse files Browse the repository at this point in the history
- Add new charts and animations to grid search analysis, see `tradeexecutor.visualisation.grid_search_advanced`
  • Loading branch information
miohtama authored Oct 29, 2024
1 parent 1717d84 commit 48e4f77
Show file tree
Hide file tree
Showing 12 changed files with 1,671 additions and 331 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
**Note**: A full changelog is not available as long as `trade-executor` package is in active beta developmnt.
**Note**: A full changelog is not available as long as `trade-executor` package is in active beta developmnt.

## 0.2

Expand Down
160 changes: 132 additions & 28 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ pytest-timeout = "^2.3.1"
#

# TODO: Disabled. Install wiht pip until dependency version incompatibilies are solved.
# zelos-demeter = {version="^0.7.2", optional = true}
zelos-demeter = {version="^0.7.4", optional = true}

# https://github.com/arynyklas/telegram_bot_logger/pull/1
telegram-bot-logger = {git = "https://github.com/tradingstrategy-ai/telegram_bot_logger.git", branch="patch-bleeding-edges", optional = true}
Expand Down
184 changes: 183 additions & 1 deletion tests/backtest/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from tradeexecutor.strategy.pandas_trader.strategy_input import StrategyInput
from tradeexecutor.strategy.parameters import StrategyParameters
from tradeexecutor.visual.grid_search import visualise_single_grid_search_result_benchmark, visualise_grid_search_equity_curves
from tradeexecutor.visual.grid_search_advanced import calculate_rolling_metrics, BenchmarkMetric, visualise_grid_single_rolling_metric, visualise_grid_rolling_metric_heatmap

from tradeexecutor.visual.grid_search_advanced import visualise_grid_rolling_metric_line_chart
from tradingstrategy.candle import GroupedCandleUniverse
from tradingstrategy.chain import ChainId
from tradingstrategy.exchange import Exchange
Expand Down Expand Up @@ -825,4 +828,183 @@ def create_indicators(parameters: StrategyParameters, indicators: IndicatorSet,

assert len(results) == 2
for r in results:
assert isinstance(r.exception, BacktestExecutionFailed)
assert isinstance(r.exception, BacktestExecutionFailed)



def test_grid_search_visualisation_line_chart(
strategy_universe,
indicator_storage,
tmp_path,
):
"""Advanced calculations and visualisation for grid search results.
"""
class Parameters:
cycle_duration = CycleDuration.cycle_1d
initial_cash = 10_000
allocation = [0.50, 0.75, 0.99]
cycle_divider = [2, 3, 4]
foo_param = ["a", "b"]

def _decide_trades_flip_buy_sell(input: StrategyInput) -> list[TradeExecution]:
# Generate some random trades
position_manager = input.get_position_manager()
parameters = input.parameters
pair = input.strategy_universe.get_single_pair()
cash = position_manager.get_current_cash()
if input.cycle % parameters.cycle_divider == 0:
return position_manager.open_spot(pair, cash * parameters.allocation)
else:
if position_manager.is_any_open():
return position_manager.close_all()
return []

def create_indicators(timestamp: datetime.datetime, parameters: StrategyParameters, strategy_universe: TradingStrategyUniverse, execution_context: ExecutionContext):
# No indicators needed
return IndicatorSet()

combinations = prepare_grid_combinations(
Parameters,
tmp_path,
strategy_universe=strategy_universe,
create_indicators=create_indicators,
execution_context=ExecutionContext(mode=ExecutionMode.unit_testing, grid_search=True),
)

assert len(combinations) == 18

grid_search_results = perform_grid_search(
_decide_trades_flip_buy_sell,
strategy_universe,
combinations,
trading_strategy_engine_version="0.5",
indicator_storage=indicator_storage,
verbose=False,
multiprocess=True,
)

# Calculate rolling sharpe for each month
# x-axis: time
# y-axis: sharpe
# variables as line charts: allocation=0.50, allocation=0.75, allocation=0.99
# other variables are set to their fixed values
df = calculate_rolling_metrics(
grid_search_results,
visualised_parameters="allocation",
fixed_parameters={"cycle_divider": 2, "foo_param": "a"},
benchmarked_metric=BenchmarkMetric.sharpe,
)

assert isinstance(df, pd.DataFrame)
assert len(df) > 0

# Check range is right
assert df.index[0] == pd.Timestamp("2021-06-1")
assert df.index[-1] == pd.Timestamp("2021-12-1")


# pull out some values
# (all negative sharpes, strategy does not make sense)
assert df.loc["2021-07-01"][0.50] < 0
assert df.loc["2021-07-01"][0.75] < 0
assert df.loc["2021-07-01"][0.99] < 0

# Draw line chart over time
fig = visualise_grid_single_rolling_metric(df)
assert isinstance(fig, Figure)

# Draw evolving series of charts as a sublot
fig = visualise_grid_rolling_metric_line_chart(
df,
range_start="2021-07-01",
range_end="2021-09-01",
)
assert isinstance(fig, Figure)


def test_grid_search_visualisation_heatmap(
strategy_universe,
indicator_storage,
tmp_path,
):
"""Advanced calculations and visualisation for grid search results.
"""

class Parameters:
cycle_duration = CycleDuration.cycle_1d
initial_cash = 10_000
allocation = [0.50, 0.75, 0.99]
cycle_divider = [2, 3, 4]
foo_param = ["a", "b"]

def _decide_trades_flip_buy_sell(input: StrategyInput) -> list[TradeExecution]:
# Generate some random trades
position_manager = input.get_position_manager()
parameters = input.parameters
pair = input.strategy_universe.get_single_pair()
cash = position_manager.get_current_cash()
if input.cycle % parameters.cycle_divider == 0:
return position_manager.open_spot(pair, cash * parameters.allocation)
else:
if position_manager.is_any_open():
return position_manager.close_all()
return []

def create_indicators(timestamp: datetime.datetime, parameters: StrategyParameters, strategy_universe: TradingStrategyUniverse, execution_context: ExecutionContext):
# No indicators needed
return IndicatorSet()

combinations = prepare_grid_combinations(
Parameters,
tmp_path,
strategy_universe=strategy_universe,
create_indicators=create_indicators,
execution_context=ExecutionContext(mode=ExecutionMode.unit_testing, grid_search=True),
)

assert len(combinations) == 18

grid_search_results = perform_grid_search(
_decide_trades_flip_buy_sell,
strategy_universe,
combinations,
trading_strategy_engine_version="0.5",
indicator_storage=indicator_storage,
verbose=False,
multiprocess=True,
)

# Calculate rolling sharpe for each month
# x-axis: time
# y-axis: sharpe
# variables as line charts: allocation=0.50, allocation=0.75, allocation=0.99
# other variables are set to their fixed values
df = calculate_rolling_metrics(
grid_search_results,
visualised_parameters=("allocation", "foo_param"),
fixed_parameters={"cycle_divider": 2},
benchmarked_metric=BenchmarkMetric.sharpe,
)

assert isinstance(df, pd.DataFrame)
assert len(df) > 0

# Check range is right
assert df.index[0] == pd.Timestamp("2021-06-1")
assert df.index[-1] == pd.Timestamp("2021-12-1")

assert df.columns[0] == (0.5, "a")

# pull out some values
# (all negative sharpes, strategy does not make sense)
assert df.loc["2021-07-01"][(0.5, 'a')] < 0
assert df.loc["2021-07-01"][(0.5, 'b')] < 0
assert df.loc["2021-07-01"][(0.75, 'b')] < 0

# Draw evolving series of charts as a sublot
fig = visualise_grid_rolling_metric_heatmap(
df,
range_start="2021-07-01",
range_end="2021-09-01",
)
assert isinstance(fig, Figure)
16 changes: 10 additions & 6 deletions tradeexecutor/analysis/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ def clean(x):
# "Return": r.summary.return_percent,
# "Return2": r.summary.annualised_return_percent,
#"Annualised profit": clean(r.metrics.loc["Expected Yearly"][0]),
"CAGR": clean(r.metrics.loc["Annualised return (raw)"][0]),
"Max DD": clean(r.metrics.loc["Max Drawdown"][0]),
"Sharpe": clean(r.metrics.loc["Sharpe"][0]),
"Sortino": clean(r.metrics.loc["Sortino"][0]),
"CAGR": clean(r.metrics.loc["Annualised return (raw)"].iloc[0]),
"Max DD": clean(r.metrics.loc["Max Drawdown"].iloc[0]),
"Sharpe": clean(r.metrics.loc["Sharpe"].iloc[0]),
"Sortino": clean(r.metrics.loc["Sortino"].iloc[0]),
# "Combination": r.combination.get_label(),
"Time in market": clean(r.metrics.loc["Time in Market"][0]),
"Time in market": clean(r.metrics.loc["Time in Market"].iloc[0]),
"Win rate": clean(r.get_win_rate()),
"Avg pos": r.summary.average_trade, # Average position
"Med pos": r.summary.median_trade, # Median position
Expand Down Expand Up @@ -226,7 +226,11 @@ def render_grid_search_result_table(results: pd.DataFrame | list[GridSearchResul
def enum_to_value(x):
return x.value if isinstance(x, Enum) else x

df = df.applymap(enum_to_value)
if hasattr(df, "map"):
# Pandas 2+
df = df.map(enum_to_value)
else:
df = df.applymap(enum_to_value)

formatted = df.style.background_gradient(
axis = 0,
Expand Down
4 changes: 2 additions & 2 deletions tradeexecutor/analysis/optimiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ def profile_optimiser(result: OptimiserResult) -> pd.DataFrame:
- Indexed by result id.
- Durations
"""
sorted_result = sorted(result.results, key=lambda r: r.result.start_at)
sorted_result = sorted(result.results, key=lambda r: r.result.run_start_at)
data = []
r: OptimiserSearchResult
for r in sorted_result:
tc = r.result.get_trade_count()
data.append({
"start_at": r.result.start_at,
"start_at": r.result.run_start_at,
"backtest": r.result.get_backtest_duration(),
"analysis": r.result.get_analysis_duration(),
"delivery": r.result.get_delivery_duration(),
Expand Down
54 changes: 42 additions & 12 deletions tradeexecutor/backtest/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,25 @@ class GridSearchResult:
#:
exception: Exception | None = None

#: When this test was started
start_at: datetime.datetime | None = None
#: What was the backtesting period
#:
backtest_start: datetime.datetime | None = None

#: What was the backtesting period
#:
backtest_end: datetime.datetime | None = None

#: When this test was started.
#:
#: Wall clock time.
#:
run_start_at: datetime.datetime | None = None

#: When this test ended
backtest_end_at: datetime.datetime | None = None
#: When this test run ended.
#:
#: Wall clock time.
#:
run_end_at: datetime.datetime | None = None

#: When we completed the analysis
analysis_end_at: datetime.datetime | None = None
Expand Down Expand Up @@ -582,10 +596,10 @@ def get_trade_count(self) -> int:
return self.summary.total_trades

def get_backtest_duration(self) -> datetime.timedelta:
return self.backtest_end_at - self.start_at
return self.run_end_at - self.run_start_at

def get_analysis_duration(self) -> datetime.timedelta:
return self.analysis_end_at - self.backtest_end_at
return self.analysis_end_at - self.run_end_at

def get_delivery_duration(self) -> datetime.timedelta:
return self.delivered_to_main_thread_at - self.analysis_end_at
Expand Down Expand Up @@ -882,13 +896,15 @@ def run_grid_combination_multiprocess(
data_retention: GridSearchDataRetention,
indicator_storage_path = DEFAULT_INDICATOR_STORAGE_PATH,
ignore_wallet_errors: bool = False,
verbose: bool = True,
):
"""Mutltiproecss runner.
Universe is passed as process global.
:param indicator_storage_path:
Override for unit testing
"""

from tradeexecutor.monkeypatch import cloudpickle_patch # Enable pickle patch that allows multiprocessing in notebooks
Expand Down Expand Up @@ -1035,6 +1051,7 @@ def perform_grid_search(
execution_context: ExecutionContext = grid_search_execution_context,
indicator_storage: DiskIndicatorStorage | None = None,
ignore_wallet_errors=False,
verbose=True,
) -> List[GridSearchResult]:
"""Search different strategy parameters over a grid.
Expand Down Expand Up @@ -1067,6 +1084,9 @@ def perform_grid_search(
:param trading_strategy_engine_version:
Which version of engine we are using.
:param verbose:
Disable progress bas
:return:
Grid search results for different combinations.
Expand Down Expand Up @@ -1154,11 +1174,17 @@ def perform_grid_search(

# Too wide for Datalore notebooks
# label = ", ".join(p.name for p in combinations[0].searchable_parameters)
with tqdm(total=len(task_args), desc=f"Searching") as progress_bar:

if verbose:
progress_bar = tqdm(total=len(task_args))
progress_bar.set_postfix({"processes": max_workers})
# Extract results from the parallel task queue
for task in tm.as_completed():
results.append(task.result)
else:
progress_bar = None

# Extract results from the parallel task queue
for task in tm.as_completed():
results.append(task.result)
if verbose:
progress_bar.update()
else:
#
Expand Down Expand Up @@ -1382,6 +1408,8 @@ def run_grid_search_backtest(

analysis_end = datetime.datetime.utcnow()

period = state.get_trading_time_range()

res = GridSearchResult(
combination=combination,
state=state,
Expand All @@ -1391,9 +1419,11 @@ def run_grid_search_backtest(
equity_curve=equity,
returns=returns,
initial_cash=state.portfolio.get_initial_cash(),
start_at=backtest_start,
backtest_end_at=backtest_end,
run_start_at=backtest_start,
run_end_at=backtest_end,
analysis_end_at=analysis_end,
backtest_start=period[0],
backtest_end=period[1],
)

# Double check we have not broken QuantStats again
Expand Down
Loading

0 comments on commit 48e4f77

Please sign in to comment.