From 48e4f7719a44f7527b2b2aececd7c985526023b3 Mon Sep 17 00:00:00 2001 From: Mikko Ohtamaa Date: Tue, 29 Oct 2024 19:15:30 +0100 Subject: [PATCH] Alternative ways to visualise grid search (#1073) - Add new charts and animations to grid search analysis, see `tradeexecutor.visualisation.grid_search_advanced` --- CHANGELOG.md | 2 +- poetry.lock | 160 +++- pyproject.toml | 2 +- tests/backtest/test_grid_search.py | 184 ++++- tradeexecutor/analysis/grid_search.py | 16 +- tradeexecutor/analysis/optimiser.py | 4 +- tradeexecutor/backtest/grid_search.py | 54 +- tradeexecutor/visual/equity_curve.py | 14 +- tradeexecutor/visual/grid_search.py | 279 +------ tradeexecutor/visual/grid_search_advanced.py | 682 ++++++++++++++++++ tradeexecutor/visual/grid_search_basic.py | 277 +++++++ .../visual/grid_search_visualisation.py | 328 +++++++++ 12 files changed, 1671 insertions(+), 331 deletions(-) create mode 100644 tradeexecutor/visual/grid_search_advanced.py create mode 100644 tradeexecutor/visual/grid_search_basic.py create mode 100644 tradeexecutor/visual/grid_search_visualisation.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b29d98f1..26b0d1d00 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -**Note**: A full changelog is not available as long as `trade-executor` package is in active beta developmnt. + **Note**: A full changelog is not available as long as `trade-executor` package is in active beta developmnt. ## 0.2 diff --git a/poetry.lock b/poetry.lock index eec530190..0c288e8e5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1145,6 +1145,23 @@ files = [ marshmallow = ">=3.18.0,<4.0.0" typing-inspect = ">=0.4.0,<1" +[[package]] +name = "db-dtypes" +version = "1.3.0" +description = "Pandas Data Types for SQL systems (BigQuery, Spanner)" +optional = true +python-versions = ">=3.7" +files = [ + {file = "db_dtypes-1.3.0-py2.py3-none-any.whl", hash = "sha256:7e65c59f849ccbe6f7bc4d0253edcc212a7907662906921caba3e4aadd0bc277"}, + {file = "db_dtypes-1.3.0.tar.gz", hash = "sha256:7bcbc8858b07474dc85b77bb2f3ae488978d1336f5ea73b58c39d9118bc3e91b"}, +] + +[package.dependencies] +numpy = ">=1.16.6" +packaging = ">=17.0" +pandas = ">=0.24.2" +pyarrow = ">=3.0.0" + [[package]] name = "debugpy" version = "1.8.7" @@ -2301,13 +2318,13 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio [[package]] name = "ipython" -version = "8.28.0" +version = "8.29.0" description = "IPython: Productive Interactive Computing" optional = false python-versions = ">=3.10" files = [ - {file = "ipython-8.28.0-py3-none-any.whl", hash = "sha256:530ef1e7bb693724d3cdc37287c80b07ad9b25986c007a53aa1857272dac3f35"}, - {file = "ipython-8.28.0.tar.gz", hash = "sha256:0d0d15ca1e01faeb868ef56bc7ee5a0de5bd66885735682e8a322ae289a13d1a"}, + {file = "ipython-8.29.0-py3-none-any.whl", hash = "sha256:0188a1bd83267192123ccea7f4a8ed0a78910535dbaa3f37671dca76ebd429c8"}, + {file = "ipython-8.29.0.tar.gz", hash = "sha256:40b60e15b22591450eef73e40a027cf77bd652e757523eebc5bd7c7c498290eb"}, ] [package.dependencies] @@ -3851,6 +3868,73 @@ jsonschema-spec = ">=0.1.1,<0.2.0" lazy-object-proxy = ">=1.7.1,<2.0.0" openapi-schema-validator = ">=0.4.2,<0.5.0" +[[package]] +name = "orjson" +version = "3.10.10" +description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" +optional = true +python-versions = ">=3.8" +files = [ + {file = "orjson-3.10.10-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:b788a579b113acf1c57e0a68e558be71d5d09aa67f62ca1f68e01117e550a998"}, + {file = "orjson-3.10.10-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:804b18e2b88022c8905bb79bd2cbe59c0cd014b9328f43da8d3b28441995cda4"}, + {file = "orjson-3.10.10-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9972572a1d042ec9ee421b6da69f7cc823da5962237563fa548ab17f152f0b9b"}, + {file = "orjson-3.10.10-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc6993ab1c2ae7dd0711161e303f1db69062955ac2668181bfdf2dd410e65258"}, + {file = "orjson-3.10.10-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d78e4cacced5781b01d9bc0f0cd8b70b906a0e109825cb41c1b03f9c41e4ce86"}, + {file = "orjson-3.10.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e6eb2598df518281ba0cbc30d24c5b06124ccf7e19169e883c14e0831217a0bc"}, + {file = "orjson-3.10.10-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:23776265c5215ec532de6238a52707048401a568f0fa0d938008e92a147fe2c7"}, + {file = "orjson-3.10.10-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8cc2a654c08755cef90b468ff17c102e2def0edd62898b2486767204a7f5cc9c"}, + {file = "orjson-3.10.10-cp310-none-win32.whl", hash = "sha256:081b3fc6a86d72efeb67c13d0ea7c030017bd95f9868b1e329a376edc456153b"}, + {file = "orjson-3.10.10-cp310-none-win_amd64.whl", hash = "sha256:ff38c5fb749347768a603be1fb8a31856458af839f31f064c5aa74aca5be9efe"}, + {file = "orjson-3.10.10-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:879e99486c0fbb256266c7c6a67ff84f46035e4f8749ac6317cc83dacd7f993a"}, + {file = "orjson-3.10.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:019481fa9ea5ff13b5d5d95e6fd5ab25ded0810c80b150c2c7b1cc8660b662a7"}, + {file = "orjson-3.10.10-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0dd57eff09894938b4c86d4b871a479260f9e156fa7f12f8cad4b39ea8028bb5"}, + {file = "orjson-3.10.10-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dbde6d70cd95ab4d11ea8ac5e738e30764e510fc54d777336eec09bb93b8576c"}, + {file = "orjson-3.10.10-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b2625cb37b8fb42e2147404e5ff7ef08712099197a9cd38895006d7053e69d6"}, + {file = "orjson-3.10.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbf3c20c6a7db69df58672a0d5815647ecf78c8e62a4d9bd284e8621c1fe5ccb"}, + {file = "orjson-3.10.10-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:75c38f5647e02d423807d252ce4528bf6a95bd776af999cb1fb48867ed01d1f6"}, + {file = "orjson-3.10.10-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:23458d31fa50ec18e0ec4b0b4343730928296b11111df5f547c75913714116b2"}, + {file = "orjson-3.10.10-cp311-none-win32.whl", hash = "sha256:2787cd9dedc591c989f3facd7e3e86508eafdc9536a26ec277699c0aa63c685b"}, + {file = "orjson-3.10.10-cp311-none-win_amd64.whl", hash = "sha256:6514449d2c202a75183f807bc755167713297c69f1db57a89a1ef4a0170ee269"}, + {file = "orjson-3.10.10-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:8564f48f3620861f5ef1e080ce7cd122ee89d7d6dacf25fcae675ff63b4d6e05"}, + {file = "orjson-3.10.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5bf161a32b479034098c5b81f2608f09167ad2fa1c06abd4e527ea6bf4837a9"}, + {file = "orjson-3.10.10-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:68b65c93617bcafa7f04b74ae8bc2cc214bd5cb45168a953256ff83015c6747d"}, + {file = "orjson-3.10.10-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e8e28406f97fc2ea0c6150f4c1b6e8261453318930b334abc419214c82314f85"}, + {file = "orjson-3.10.10-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4d0d9fe174cc7a5bdce2e6c378bcdb4c49b2bf522a8f996aa586020e1b96cee"}, + {file = "orjson-3.10.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3be81c42f1242cbed03cbb3973501fcaa2675a0af638f8be494eaf37143d999"}, + {file = "orjson-3.10.10-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:65f9886d3bae65be026219c0a5f32dbbe91a9e6272f56d092ab22561ad0ea33b"}, + {file = "orjson-3.10.10-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:730ed5350147db7beb23ddaf072f490329e90a1d059711d364b49fe352ec987b"}, + {file = "orjson-3.10.10-cp312-none-win32.whl", hash = "sha256:a8f4bf5f1c85bea2170800020d53a8877812892697f9c2de73d576c9307a8a5f"}, + {file = "orjson-3.10.10-cp312-none-win_amd64.whl", hash = "sha256:384cd13579a1b4cd689d218e329f459eb9ddc504fa48c5a83ef4889db7fd7a4f"}, + {file = "orjson-3.10.10-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:44bffae68c291f94ff5a9b4149fe9d1bdd4cd0ff0fb575bcea8351d48db629a1"}, + {file = "orjson-3.10.10-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e27b4c6437315df3024f0835887127dac2a0a3ff643500ec27088d2588fa5ae1"}, + {file = "orjson-3.10.10-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bca84df16d6b49325a4084fd8b2fe2229cb415e15c46c529f868c3387bb1339d"}, + {file = "orjson-3.10.10-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c14ce70e8f39bd71f9f80423801b5d10bf93d1dceffdecd04df0f64d2c69bc01"}, + {file = "orjson-3.10.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:24ac62336da9bda1bd93c0491eff0613003b48d3cb5d01470842e7b52a40d5b4"}, + {file = "orjson-3.10.10-cp313-none-win32.whl", hash = "sha256:eb0a42831372ec2b05acc9ee45af77bcaccbd91257345f93780a8e654efc75db"}, + {file = "orjson-3.10.10-cp313-none-win_amd64.whl", hash = "sha256:f0c4f37f8bf3f1075c6cc8dd8a9f843689a4b618628f8812d0a71e6968b95ffd"}, + {file = "orjson-3.10.10-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:829700cc18503efc0cf502d630f612884258020d98a317679cd2054af0259568"}, + {file = "orjson-3.10.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0ceb5e0e8c4f010ac787d29ae6299846935044686509e2f0f06ed441c1ca949"}, + {file = "orjson-3.10.10-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0c25908eb86968613216f3db4d3003f1c45d78eb9046b71056ca327ff92bdbd4"}, + {file = "orjson-3.10.10-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:218cb0bc03340144b6328a9ff78f0932e642199ac184dd74b01ad691f42f93ff"}, + {file = "orjson-3.10.10-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e2277ec2cea3775640dc81ab5195bb5b2ada2fe0ea6eee4677474edc75ea6785"}, + {file = "orjson-3.10.10-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:848ea3b55ab5ccc9d7bbd420d69432628b691fba3ca8ae3148c35156cbd282aa"}, + {file = "orjson-3.10.10-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:e3e67b537ac0c835b25b5f7d40d83816abd2d3f4c0b0866ee981a045287a54f3"}, + {file = "orjson-3.10.10-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:7948cfb909353fce2135dcdbe4521a5e7e1159484e0bb024c1722f272488f2b8"}, + {file = "orjson-3.10.10-cp38-none-win32.whl", hash = "sha256:78bee66a988f1a333dc0b6257503d63553b1957889c17b2c4ed72385cd1b96ae"}, + {file = "orjson-3.10.10-cp38-none-win_amd64.whl", hash = "sha256:f1d647ca8d62afeb774340a343c7fc023efacfd3a39f70c798991063f0c681dd"}, + {file = "orjson-3.10.10-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:5a059afddbaa6dd733b5a2d76a90dbc8af790b993b1b5cb97a1176ca713b5df8"}, + {file = "orjson-3.10.10-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f9b5c59f7e2a1a410f971c5ebc68f1995822837cd10905ee255f96074537ee6"}, + {file = "orjson-3.10.10-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d5ef198bafdef4aa9d49a4165ba53ffdc0a9e1c7b6f76178572ab33118afea25"}, + {file = "orjson-3.10.10-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aaf29ce0bb5d3320824ec3d1508652421000ba466abd63bdd52c64bcce9eb1fa"}, + {file = "orjson-3.10.10-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dddd5516bcc93e723d029c1633ae79c4417477b4f57dad9bfeeb6bc0315e654a"}, + {file = "orjson-3.10.10-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a12f2003695b10817f0fa8b8fca982ed7f5761dcb0d93cff4f2f9f6709903fd7"}, + {file = "orjson-3.10.10-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:672f9874a8a8fb9bb1b771331d31ba27f57702c8106cdbadad8bda5d10bc1019"}, + {file = "orjson-3.10.10-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1dcbb0ca5fafb2b378b2c74419480ab2486326974826bbf6588f4dc62137570a"}, + {file = "orjson-3.10.10-cp39-none-win32.whl", hash = "sha256:d9bbd3a4b92256875cb058c3381b782649b9a3c68a4aa9a2fff020c2f9cfc1be"}, + {file = "orjson-3.10.10-cp39-none-win_amd64.whl", hash = "sha256:766f21487a53aee8524b97ca9582d5c6541b03ab6210fbaf10142ae2f3ced2aa"}, + {file = "orjson-3.10.10.tar.gz", hash = "sha256:37949383c4df7b4337ce82ee35b6d7471e55195efa7dcb45ab8226ceadb0fe3b"}, +] + [[package]] name = "overrides" version = "7.7.0" @@ -5008,13 +5092,13 @@ dev = ["pre-commit", "pytest-asyncio", "tox"] [[package]] name = "pytest-reverse" -version = "1.7.0" +version = "1.8.0" description = "Pytest plugin to reverse test order." optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "pytest_reverse-1.7.0-py3-none-any.whl", hash = "sha256:37e83daac57eea3fb1cb718aa9ccdf9ca2ea8ac3645cb5bccf1c7ae25a8ad1d2"}, - {file = "pytest_reverse-1.7.0.tar.gz", hash = "sha256:f943e5b9d253267569fd7ad237afc56b3e98ce9f6d2f6f3bb487b8c759e214fe"}, + {file = "pytest_reverse-1.8.0-py3-none-any.whl", hash = "sha256:e31a2d0b51f2f8b6162aed268f853851f55c62ac445041a032740e985b0bc8c8"}, + {file = "pytest_reverse-1.8.0.tar.gz", hash = "sha256:eb72ffd57cc91061e837b1d2c4522bfda58eaa83fc97147fd90f7929160e97ab"}, ] [package.dependencies] @@ -5831,17 +5915,17 @@ files = [ [[package]] name = "sigfig" -version = "1.3.3" +version = "1.3.17" description = "Python library for rounding numbers (with expected results)" optional = false -python-versions = "*" +python-versions = "<4.0,>=3.6" files = [ - {file = "sigfig-1.3.3-py3-none-any.whl", hash = "sha256:7df6dfc45d09ee7e43a9418e944fcf06f8654477af263e3f0c2bee4234d6a84e"}, - {file = "sigfig-1.3.3.tar.gz", hash = "sha256:d6a720029c2fdb0f1413b14ba72f92db9ab95fe816d198fe7f6311d2bbdc5b61"}, + {file = "sigfig-1.3.17-py3-none-any.whl", hash = "sha256:2b532bf7b12ba81603b273a406cfda35eda9b2329dc08dcd45ce45985ba8c631"}, + {file = "sigfig-1.3.17.tar.gz", hash = "sha256:490fccae9ffcaa15377b215b1d8ddba0f33453053dd8d3a299bf0c5c954429a0"}, ] [package.dependencies] -SortedContainers = "*" +sortedcontainers = ">=2.4.0,<3.0.0" [[package]] name = "six" @@ -5998,7 +6082,7 @@ httpx = "*" type = "git" url = "https://github.com/tradingstrategy-ai/telegram_bot_logger.git" reference = "patch-bleeding-edges" -resolved_reference = "92650efb2607f16a6f614cc83cebb630a5fa7d99" +resolved_reference = "f2319de7f5a18dd69d9e92b09dc95f40ec241a83" [[package]] name = "tenacity" @@ -6049,13 +6133,13 @@ files = [ [[package]] name = "tinycss2" -version = "1.3.0" +version = "1.4.0" description = "A tiny CSS parser" optional = false python-versions = ">=3.8" files = [ - {file = "tinycss2-1.3.0-py3-none-any.whl", hash = "sha256:54a8dbdffb334d536851be0226030e9505965bb2f30f21a4a82c55fb2a80fae7"}, - {file = "tinycss2-1.3.0.tar.gz", hash = "sha256:152f9acabd296a8375fbca5b84c961ff95971fcfc32e79550c8df8e29118c54d"}, + {file = "tinycss2-1.4.0-py3-none-any.whl", hash = "sha256:3a49cf47b7675da0b15d0c6e1df8df4ebd96e9394bb905a5775adb0d884c5289"}, + {file = "tinycss2-1.4.0.tar.gz", hash = "sha256:10c0972f6fc0fbee87c3edb76549357415e94548c1ae10ebccdea16fb404a9b7"}, ] [package.dependencies] @@ -6132,7 +6216,7 @@ tqdm = ">4.64" [[package]] name = "trading-strategy" -version = "0.24.3" +version = "0.24.4" description = "Algorithmic trading data for cryptocurrencies and DEXes like Uniswap, Aave and PancakeSwap" optional = false python-versions = ">=3.10,<3.13" @@ -6601,13 +6685,13 @@ files = [ [[package]] name = "webob" -version = "1.8.8" +version = "1.8.9" description = "WSGI request and response object" optional = true python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ - {file = "WebOb-1.8.8-py2.py3-none-any.whl", hash = "sha256:b60ba63f05c0cf61e086a10c3781a41fcfe30027753a8ae6d819c77592ce83ea"}, - {file = "webob-1.8.8.tar.gz", hash = "sha256:2abc1555e118fc251e705fc6dc66c7f5353bb9fbfab6d20e22f1c02b4b71bcee"}, + {file = "WebOb-1.8.9-py2.py3-none-any.whl", hash = "sha256:45e34c58ed0c7e2ecd238ffd34432487ff13d9ad459ddfd77895e67abba7c1f9"}, + {file = "webob-1.8.9.tar.gz", hash = "sha256:ad6078e2edb6766d1334ec3dee072ac6a7f95b1e32ce10def8ff7f0f02d56589"}, ] [package.extras] @@ -6747,13 +6831,13 @@ tests = ["PasteDeploy", "WSGIProxy2", "coverage", "pyquery", "pytest", "pytest-c [[package]] name = "werkzeug" -version = "3.0.4" +version = "3.0.6" description = "The comprehensive WSGI web application library." optional = true python-versions = ">=3.8" files = [ - {file = "werkzeug-3.0.4-py3-none-any.whl", hash = "sha256:02c9eb92b7d6c06f31a782811505d2157837cea66aaede3e217c7c27c039476c"}, - {file = "werkzeug-3.0.4.tar.gz", hash = "sha256:34f2371506b250df4d4f84bfe7b0921e4762525762bbd936614909fe25cd7306"}, + {file = "werkzeug-3.0.6-py3-none-any.whl", hash = "sha256:1bc0c2310d2fbb07b1dd1105eba2f7af72f322e1e455f2f93c993bee8c8a5f17"}, + {file = "werkzeug-3.0.6.tar.gz", hash = "sha256:a8dd59d4de28ca70471a34cba79bed5f7ef2e036a76b3ab0835474246eb41f8d"}, ] [package.dependencies] @@ -6871,13 +6955,13 @@ propcache = ">=0.2.0" [[package]] name = "yfinance" -version = "0.2.46" +version = "0.2.48" description = "Download market data from Yahoo! Finance API" optional = false python-versions = "*" files = [ - {file = "yfinance-0.2.46-py2.py3-none-any.whl", hash = "sha256:371860d532cae76605195678a540e29382bfd0607f8aa61695f753e714916ffc"}, - {file = "yfinance-0.2.46.tar.gz", hash = "sha256:a6e2a128915532a54b8f6614cfdb7a8c242d2386e05f95c89b15865b5d9c0352"}, + {file = "yfinance-0.2.48-py2.py3-none-any.whl", hash = "sha256:eda797145faa4536595eb629f869d3616e58ed7e71de36856b19f1abaef71a5b"}, + {file = "yfinance-0.2.48.tar.gz", hash = "sha256:1434cd8bf22f345fa27ef1ed82bfdd291c1bb5b6fe3067118a94e256aa90c4eb"}, ] [package.dependencies] @@ -6897,6 +6981,26 @@ requests = ">=2.31" nospam = ["requests-cache (>=1.0)", "requests-ratelimiter (>=0.3.1)"] repair = ["scipy (>=1.6.3)"] +[[package]] +name = "zelos-demeter" +version = "0.7.4" +description = "better DEFI backtesting tool" +optional = true +python-versions = ">=3.11" +files = [ + {file = "zelos-demeter-0.7.4.tar.gz", hash = "sha256:fe042e3b1912e65b38c16f940165483dc32cd94fa0750d6f9024a7d0ba1effb8"}, +] + +[package.dependencies] +db-dtypes = ">=1.2.0" +numpy = ">=1.26.4" +orjson = ">=3.9.15" +pandas = ">=2.2.0" +python-dateutil = ">=2.9.0.post0" +pytz = ">=2024.1" +six = ">=1.16.0" +tqdm = ">=4.66.2" + [[package]] name = "zope-deprecation" version = "5.0" @@ -7082,7 +7186,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -demeter = [] +demeter = ["zelos-demeter"] execution = ["beautifulsoup4", "colorama", "coloredlogs", "prompt-toolkit", "python-dotenv", "python-logging-discord-handler", "python-logstash-tradingstrategy", "telegram-bot-logger", "typer", "web3-ethereum-defi"] qstrader = ["trading-strategy-qstrader"] trendsspotter = ["fsspec", "gcsfs", "google-auth", "google-cloud-storage"] @@ -7091,4 +7195,4 @@ web-server = ["WebTest", "openapi-core", "pyramid", "pyramid-openapi3", "waitres [metadata] lock-version = "2.0" python-versions = ">=3.11,<=3.12" -content-hash = "919bc63839295199cc30563a25c6fe060037c87a6147789f2f6d0ac5b789846f" +content-hash = "b05820fae45d38087d307db60e37e4c241a7a73479f6b1730c3afe6515016618" diff --git a/pyproject.toml b/pyproject.toml index c964bad78..80efcf778 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,7 +114,7 @@ pytest-timeout = "^2.3.1" # # TODO: Disabled. Install wiht pip until dependency version incompatibilies are solved. -# zelos-demeter = {version="^0.7.2", optional = true} + zelos-demeter = {version="^0.7.4", optional = true} # https://github.com/arynyklas/telegram_bot_logger/pull/1 telegram-bot-logger = {git = "https://github.com/tradingstrategy-ai/telegram_bot_logger.git", branch="patch-bleeding-edges", optional = true} diff --git a/tests/backtest/test_grid_search.py b/tests/backtest/test_grid_search.py index 3624d9a8d..4b848edcb 100644 --- a/tests/backtest/test_grid_search.py +++ b/tests/backtest/test_grid_search.py @@ -17,6 +17,9 @@ from tradeexecutor.strategy.pandas_trader.strategy_input import StrategyInput from tradeexecutor.strategy.parameters import StrategyParameters from tradeexecutor.visual.grid_search import visualise_single_grid_search_result_benchmark, visualise_grid_search_equity_curves +from tradeexecutor.visual.grid_search_advanced import calculate_rolling_metrics, BenchmarkMetric, visualise_grid_single_rolling_metric, visualise_grid_rolling_metric_heatmap + +from tradeexecutor.visual.grid_search_advanced import visualise_grid_rolling_metric_line_chart from tradingstrategy.candle import GroupedCandleUniverse from tradingstrategy.chain import ChainId from tradingstrategy.exchange import Exchange @@ -825,4 +828,183 @@ def create_indicators(parameters: StrategyParameters, indicators: IndicatorSet, assert len(results) == 2 for r in results: - assert isinstance(r.exception, BacktestExecutionFailed) \ No newline at end of file + assert isinstance(r.exception, BacktestExecutionFailed) + + + +def test_grid_search_visualisation_line_chart( + strategy_universe, + indicator_storage, + tmp_path, +): + """Advanced calculations and visualisation for grid search results. + """ + class Parameters: + cycle_duration = CycleDuration.cycle_1d + initial_cash = 10_000 + allocation = [0.50, 0.75, 0.99] + cycle_divider = [2, 3, 4] + foo_param = ["a", "b"] + + def _decide_trades_flip_buy_sell(input: StrategyInput) -> list[TradeExecution]: + # Generate some random trades + position_manager = input.get_position_manager() + parameters = input.parameters + pair = input.strategy_universe.get_single_pair() + cash = position_manager.get_current_cash() + if input.cycle % parameters.cycle_divider == 0: + return position_manager.open_spot(pair, cash * parameters.allocation) + else: + if position_manager.is_any_open(): + return position_manager.close_all() + return [] + + def create_indicators(timestamp: datetime.datetime, parameters: StrategyParameters, strategy_universe: TradingStrategyUniverse, execution_context: ExecutionContext): + # No indicators needed + return IndicatorSet() + + combinations = prepare_grid_combinations( + Parameters, + tmp_path, + strategy_universe=strategy_universe, + create_indicators=create_indicators, + execution_context=ExecutionContext(mode=ExecutionMode.unit_testing, grid_search=True), + ) + + assert len(combinations) == 18 + + grid_search_results = perform_grid_search( + _decide_trades_flip_buy_sell, + strategy_universe, + combinations, + trading_strategy_engine_version="0.5", + indicator_storage=indicator_storage, + verbose=False, + multiprocess=True, + ) + + # Calculate rolling sharpe for each month + # x-axis: time + # y-axis: sharpe + # variables as line charts: allocation=0.50, allocation=0.75, allocation=0.99 + # other variables are set to their fixed values + df = calculate_rolling_metrics( + grid_search_results, + visualised_parameters="allocation", + fixed_parameters={"cycle_divider": 2, "foo_param": "a"}, + benchmarked_metric=BenchmarkMetric.sharpe, + ) + + assert isinstance(df, pd.DataFrame) + assert len(df) > 0 + + # Check range is right + assert df.index[0] == pd.Timestamp("2021-06-1") + assert df.index[-1] == pd.Timestamp("2021-12-1") + + + # pull out some values + # (all negative sharpes, strategy does not make sense) + assert df.loc["2021-07-01"][0.50] < 0 + assert df.loc["2021-07-01"][0.75] < 0 + assert df.loc["2021-07-01"][0.99] < 0 + + # Draw line chart over time + fig = visualise_grid_single_rolling_metric(df) + assert isinstance(fig, Figure) + + # Draw evolving series of charts as a sublot + fig = visualise_grid_rolling_metric_line_chart( + df, + range_start="2021-07-01", + range_end="2021-09-01", + ) + assert isinstance(fig, Figure) + + +def test_grid_search_visualisation_heatmap( + strategy_universe, + indicator_storage, + tmp_path, +): + """Advanced calculations and visualisation for grid search results. + """ + + class Parameters: + cycle_duration = CycleDuration.cycle_1d + initial_cash = 10_000 + allocation = [0.50, 0.75, 0.99] + cycle_divider = [2, 3, 4] + foo_param = ["a", "b"] + + def _decide_trades_flip_buy_sell(input: StrategyInput) -> list[TradeExecution]: + # Generate some random trades + position_manager = input.get_position_manager() + parameters = input.parameters + pair = input.strategy_universe.get_single_pair() + cash = position_manager.get_current_cash() + if input.cycle % parameters.cycle_divider == 0: + return position_manager.open_spot(pair, cash * parameters.allocation) + else: + if position_manager.is_any_open(): + return position_manager.close_all() + return [] + + def create_indicators(timestamp: datetime.datetime, parameters: StrategyParameters, strategy_universe: TradingStrategyUniverse, execution_context: ExecutionContext): + # No indicators needed + return IndicatorSet() + + combinations = prepare_grid_combinations( + Parameters, + tmp_path, + strategy_universe=strategy_universe, + create_indicators=create_indicators, + execution_context=ExecutionContext(mode=ExecutionMode.unit_testing, grid_search=True), + ) + + assert len(combinations) == 18 + + grid_search_results = perform_grid_search( + _decide_trades_flip_buy_sell, + strategy_universe, + combinations, + trading_strategy_engine_version="0.5", + indicator_storage=indicator_storage, + verbose=False, + multiprocess=True, + ) + + # Calculate rolling sharpe for each month + # x-axis: time + # y-axis: sharpe + # variables as line charts: allocation=0.50, allocation=0.75, allocation=0.99 + # other variables are set to their fixed values + df = calculate_rolling_metrics( + grid_search_results, + visualised_parameters=("allocation", "foo_param"), + fixed_parameters={"cycle_divider": 2}, + benchmarked_metric=BenchmarkMetric.sharpe, + ) + + assert isinstance(df, pd.DataFrame) + assert len(df) > 0 + + # Check range is right + assert df.index[0] == pd.Timestamp("2021-06-1") + assert df.index[-1] == pd.Timestamp("2021-12-1") + + assert df.columns[0] == (0.5, "a") + + # pull out some values + # (all negative sharpes, strategy does not make sense) + assert df.loc["2021-07-01"][(0.5, 'a')] < 0 + assert df.loc["2021-07-01"][(0.5, 'b')] < 0 + assert df.loc["2021-07-01"][(0.75, 'b')] < 0 + + # Draw evolving series of charts as a sublot + fig = visualise_grid_rolling_metric_heatmap( + df, + range_start="2021-07-01", + range_end="2021-09-01", + ) + assert isinstance(fig, Figure) diff --git a/tradeexecutor/analysis/grid_search.py b/tradeexecutor/analysis/grid_search.py index 522adc058..54ef59c4a 100644 --- a/tradeexecutor/analysis/grid_search.py +++ b/tradeexecutor/analysis/grid_search.py @@ -81,12 +81,12 @@ def clean(x): # "Return": r.summary.return_percent, # "Return2": r.summary.annualised_return_percent, #"Annualised profit": clean(r.metrics.loc["Expected Yearly"][0]), - "CAGR": clean(r.metrics.loc["Annualised return (raw)"][0]), - "Max DD": clean(r.metrics.loc["Max Drawdown"][0]), - "Sharpe": clean(r.metrics.loc["Sharpe"][0]), - "Sortino": clean(r.metrics.loc["Sortino"][0]), + "CAGR": clean(r.metrics.loc["Annualised return (raw)"].iloc[0]), + "Max DD": clean(r.metrics.loc["Max Drawdown"].iloc[0]), + "Sharpe": clean(r.metrics.loc["Sharpe"].iloc[0]), + "Sortino": clean(r.metrics.loc["Sortino"].iloc[0]), # "Combination": r.combination.get_label(), - "Time in market": clean(r.metrics.loc["Time in Market"][0]), + "Time in market": clean(r.metrics.loc["Time in Market"].iloc[0]), "Win rate": clean(r.get_win_rate()), "Avg pos": r.summary.average_trade, # Average position "Med pos": r.summary.median_trade, # Median position @@ -226,7 +226,11 @@ def render_grid_search_result_table(results: pd.DataFrame | list[GridSearchResul def enum_to_value(x): return x.value if isinstance(x, Enum) else x - df = df.applymap(enum_to_value) + if hasattr(df, "map"): + # Pandas 2+ + df = df.map(enum_to_value) + else: + df = df.applymap(enum_to_value) formatted = df.style.background_gradient( axis = 0, diff --git a/tradeexecutor/analysis/optimiser.py b/tradeexecutor/analysis/optimiser.py index 95bb653ef..4d50375fe 100644 --- a/tradeexecutor/analysis/optimiser.py +++ b/tradeexecutor/analysis/optimiser.py @@ -55,13 +55,13 @@ def profile_optimiser(result: OptimiserResult) -> pd.DataFrame: - Indexed by result id. - Durations """ - sorted_result = sorted(result.results, key=lambda r: r.result.start_at) + sorted_result = sorted(result.results, key=lambda r: r.result.run_start_at) data = [] r: OptimiserSearchResult for r in sorted_result: tc = r.result.get_trade_count() data.append({ - "start_at": r.result.start_at, + "start_at": r.result.run_start_at, "backtest": r.result.get_backtest_duration(), "analysis": r.result.get_analysis_duration(), "delivery": r.result.get_delivery_duration(), diff --git a/tradeexecutor/backtest/grid_search.py b/tradeexecutor/backtest/grid_search.py index f892bc979..10f33f0da 100644 --- a/tradeexecutor/backtest/grid_search.py +++ b/tradeexecutor/backtest/grid_search.py @@ -407,11 +407,25 @@ class GridSearchResult: #: exception: Exception | None = None - #: When this test was started - start_at: datetime.datetime | None = None + #: What was the backtesting period + #: + backtest_start: datetime.datetime | None = None + + #: What was the backtesting period + #: + backtest_end: datetime.datetime | None = None + + #: When this test was started. + #: + #: Wall clock time. + #: + run_start_at: datetime.datetime | None = None - #: When this test ended - backtest_end_at: datetime.datetime | None = None + #: When this test run ended. + #: + #: Wall clock time. + #: + run_end_at: datetime.datetime | None = None #: When we completed the analysis analysis_end_at: datetime.datetime | None = None @@ -582,10 +596,10 @@ def get_trade_count(self) -> int: return self.summary.total_trades def get_backtest_duration(self) -> datetime.timedelta: - return self.backtest_end_at - self.start_at + return self.run_end_at - self.run_start_at def get_analysis_duration(self) -> datetime.timedelta: - return self.analysis_end_at - self.backtest_end_at + return self.analysis_end_at - self.run_end_at def get_delivery_duration(self) -> datetime.timedelta: return self.delivered_to_main_thread_at - self.analysis_end_at @@ -882,6 +896,7 @@ def run_grid_combination_multiprocess( data_retention: GridSearchDataRetention, indicator_storage_path = DEFAULT_INDICATOR_STORAGE_PATH, ignore_wallet_errors: bool = False, + verbose: bool = True, ): """Mutltiproecss runner. @@ -889,6 +904,7 @@ def run_grid_combination_multiprocess( :param indicator_storage_path: Override for unit testing + """ from tradeexecutor.monkeypatch import cloudpickle_patch # Enable pickle patch that allows multiprocessing in notebooks @@ -1035,6 +1051,7 @@ def perform_grid_search( execution_context: ExecutionContext = grid_search_execution_context, indicator_storage: DiskIndicatorStorage | None = None, ignore_wallet_errors=False, + verbose=True, ) -> List[GridSearchResult]: """Search different strategy parameters over a grid. @@ -1067,6 +1084,9 @@ def perform_grid_search( :param trading_strategy_engine_version: Which version of engine we are using. + :param verbose: + Disable progress bas + :return: Grid search results for different combinations. @@ -1154,11 +1174,17 @@ def perform_grid_search( # Too wide for Datalore notebooks # label = ", ".join(p.name for p in combinations[0].searchable_parameters) - with tqdm(total=len(task_args), desc=f"Searching") as progress_bar: + + if verbose: + progress_bar = tqdm(total=len(task_args)) progress_bar.set_postfix({"processes": max_workers}) - # Extract results from the parallel task queue - for task in tm.as_completed(): - results.append(task.result) + else: + progress_bar = None + + # Extract results from the parallel task queue + for task in tm.as_completed(): + results.append(task.result) + if verbose: progress_bar.update() else: # @@ -1382,6 +1408,8 @@ def run_grid_search_backtest( analysis_end = datetime.datetime.utcnow() + period = state.get_trading_time_range() + res = GridSearchResult( combination=combination, state=state, @@ -1391,9 +1419,11 @@ def run_grid_search_backtest( equity_curve=equity, returns=returns, initial_cash=state.portfolio.get_initial_cash(), - start_at=backtest_start, - backtest_end_at=backtest_end, + run_start_at=backtest_start, + run_end_at=backtest_end, analysis_end_at=analysis_end, + backtest_start=period[0], + backtest_end=period[1], ) # Double check we have not broken QuantStats again diff --git a/tradeexecutor/visual/equity_curve.py b/tradeexecutor/visual/equity_curve.py index 1324814b1..4dfb1d4d6 100644 --- a/tradeexecutor/visual/equity_curve.py +++ b/tradeexecutor/visual/equity_curve.py @@ -256,14 +256,20 @@ def calculate_daily_returns( def visualise_equity_curve( - returns: pd.Series, - title="Equity curve", - line_width=1.5, + returns: pd.Series, + title="Equity curve", + line_width=1.5, ) -> Figure: - """Draw equity curve, drawdown and daily returns using quantstats. + """Draw equity curve, drawdown and daily returns using Quantstats. `See Quantstats README for more details `__. + See also + + - :py:func:`tradeexecutor.visual.benchmark.visualise_equity_curve_benchmark` + + - :py:func:`tradeexecutor.visual.grid_search.visualise_single_grid_search_result_benchmark` + Example: .. code-block:: python diff --git a/tradeexecutor/visual/grid_search.py b/tradeexecutor/visual/grid_search.py index a63a5d711..d8f5eb9ed 100644 --- a/tradeexecutor/visual/grid_search.py +++ b/tradeexecutor/visual/grid_search.py @@ -1,277 +1,4 @@ -"""Visualise grid search results. +"""A stub module to legacy compatubility.""" -- Different visualisation tools to compare grid search results -""" -from typing import List - -import pandas as pd -import plotly.graph_objects as go -from plotly.graph_objs import Figure, Scatter - -from tradeexecutor.analysis.curve import CurveType, DEFAULT_BENCHMARK_COLOURS -from tradeexecutor.analysis.grid_search import _get_hover_template, order_grid_search_results_by_metric -from tradeexecutor.analysis.multi_asset_benchmark import get_benchmark_data -from tradeexecutor.backtest.grid_search import GridSearchResult -from tradeexecutor.strategy.trading_strategy_universe import TradingStrategyUniverse -from tradeexecutor.visual.benchmark import visualise_equity_curves -from tradingstrategy.types import USDollarAmount - - -def visualise_single_grid_search_result_benchmark( - result: GridSearchResult, - strategy_universe: TradingStrategyUniverse, - initial_cash: USDollarAmount | None = None, - name="Picked search result", - log_y=False, - asset_count=3, -) -> go.Figure: - """Draw one equity curve from grid search results. - - - Compare the equity curve againt buy and hold assets from the trading universe - - - Used to plot "the best" equity curve - - - Use :func:`find_best_grid_search_results` to find some equity curves. - - See also - - - :py:func:`visualise_grid_search_equity_curves` - - - :py:func:`tradeexecutor.visual.benchmark.visualise_equity_curves` - - - :py:func:`tradeexecutor.analysis.multi_asset_benchmark.get_benchmark_data` - - - :py:func:`tradeexecutor.analysis.grid_search.find_best_grid_search_results` - - Example: - - .. code-block:: python - - from tradeexecutor.analysis.grid_search import find_best_grid_search_results - from tradeexecutor.visual.grid_search import visualise_single_grid_search_result_benchmark - - # Show the equity curve of the best grid search performer - best_results = find_best_grid_search_results(grid_search_results) - fig = visualise_single_grid_search_result_benchmark(best_results.cagr[0], strategy_universe) - fig.show() - - :param result: - Picked grid search result - - :param strategy_universe: - Used to get benechmark indexes - - :param name: - Chart title - - :param initial_cash: - Not needed. Automatically filled in by grid search. - - Legacy param. - - :param asset_count: - Draw this many comparison buy-and-hold curves from well-known assets. - - :return: - Plotly figure - """ - - assert isinstance(result, GridSearchResult) - assert isinstance(strategy_universe, TradingStrategyUniverse) - - # Get daily returns - equity = result.equity_curve - equity.attrs["name"] = result.get_truncated_label() - equity.attrs["curve"] = CurveType.equity - equity.attrs["colour"] = DEFAULT_BENCHMARK_COLOURS["Strategy"] - - if result.state is not None: - start_at = result.state.get_trading_time_range()[0] - else: - start_at = equity.index[0] - - benchmarks = get_benchmark_data( - strategy_universe, - cumulative_with_initial_cash=initial_cash or getattr(result, "initial_cash", None), # Legacy support hack - start_at=start_at, - max_count=asset_count, - ) - - benchmark_series = [v for k, v in benchmarks.items()] - - fig = visualise_equity_curves( - [equity] + benchmark_series, - name=name, - log_y=log_y, - start_at=start_at, - ) - - return fig - - - - -def visualise_grid_search_equity_curves( - results: List[GridSearchResult], - name: str | None = None, - benchmark_indexes: pd.DataFrame | None = None, - height=1200, - colour=None, - log_y=False, - alpha=0.7, -) -> Figure: - """Draw multiple equity curves in the same chart. - - - See how all grid searched strategies work - - - Benchmark against buy and hold of various assets - - - Benchmark against hold all cash - - .. note :: - - Only good up to ~hundreds results. If more than thousand result, rendering takes too long time. - - Example that draws equity curves of a grid search results. - - .. code-block:: python - - from tradeexecutor.visual.grid_search import visualise_grid_search_equity_curves - from tradeexecutor.analysis.multi_asset_benchmark import get_benchmark_data - - # Automatically create BTC and ETH buy and hold benchmark if present - # in the trading universe - benchmark_indexes = get_benchmark_data( - strategy_universe, - cumulative_with_initial_cash=ShiftedStrategyParameters.initial_cash, - ) - - fig = visualise_grid_search_equity_curves( - grid_search_results, - name="8h clock shift, stop loss added and adjusted momentum", - benchmark_indexes=benchmark_indexes, - log_y=False, - ) - fig.show() - - :param results: - Results from the grid search. - - :param benchmark_indexes: - List of other asset price series displayed on the timeline besides equity curve. - - DataFrame containing multiple series. - - - Asset name is the series name. - - Setting `colour` for `pd.Series.attrs` allows you to override the colour of the index - - :param height: - Chart height in pixels - - :param colour: - Colour of the equity curve e.g. "rgba(160, 160, 160, 0.5)". If provided, all equity curves will be drawn with this colour. - - :param start_at: - When the backtest started - - :param end_at: - When the backtest ended - - :param additional_indicators: - Additional technical indicators drawn on this chart. - - List of indicator names. - - The indicators must be plotted earlier using `state.visualisation.plot_indicator()`. - - **Note**: Currently not very useful due to Y axis scale - - :param log_y: - Use logarithmic Y-axis. - - Because we accumulate larger treasury over time, - the swings in the value will be higher later. - We need to use a logarithmic Y axis so that we can compare the performance - early in the strateg and late in the strategy. - - """ - def generate_broad_bluered_colors(num_colors, alpha): - """ - Generate a list of RGBA colors along a broad blue-purple-red color scale. - - Parameters: - num_colors (int): Number of colors to generate. - alpha (float): Alpha value for the colors (0 to 1). - - Returns: - list: List of RGBA color tuples. - """ - colors = [] - for i in range(num_colors): - ratio = i / (num_colors - 1) - if ratio < 0.7: - red = int(255 * (1.5 * ratio)) - blue = 255 - else: - red = 255 - blue = int(255 * (2 * (1 - ratio))) - green = max(0, int(255 * (1 - abs(ratio - 0.7) * 2))) - color = (red, green, blue, alpha) - colors.append(f"rgba{color}") - return colors - - - fig = Figure() - - colors = generate_broad_bluered_colors(len(results), alpha) - - results = order_grid_search_results_by_metric(results) - - for result in results: - curve = result.equity_curve - label = result.get_truncated_label() - template =_get_hover_template(result) - scatter = Scatter( - x=curve.index, - y=curve, - mode="lines", - name="", # Hides hover legend, use hovertext only - line=dict(color=colors.pop(0)), - showlegend=False, - hovertemplate=template, - hovertext=None, - ) - fig.add_trace(scatter) - - if benchmark_indexes is not None: - for benchmark_name, curve in benchmark_indexes.items(): - benchmark_colour = curve.attrs.get("colour", "black") - scatter = Scatter( - x=curve.index, - y=curve, - mode="lines", - name=benchmark_name, - line=dict(color=benchmark_colour), - showlegend=True, - ) - fig.add_trace(scatter) - - fig.update_layout(title=f"{name}", height=height) - if log_y: - fig.update_yaxes(title="Value $ (logarithmic)", showgrid=False, type="log") - else: - fig.update_yaxes(title="Value $", showgrid=False) - fig.update_xaxes(rangeslider={"visible": False}) - - # Move legend to the bottom so we have more space for - # time axis in narrow notebook views - # https://plotly.com/python/legend/ - fig.update_layout(legend=dict( - orientation="h", - yanchor="bottom", - y=1.02, - xanchor="right", - x=1 - )) - - return fig +from .grid_search_basic import * +from .grid_search_advanced import * \ No newline at end of file diff --git a/tradeexecutor/visual/grid_search_advanced.py b/tradeexecutor/visual/grid_search_advanced.py new file mode 100644 index 000000000..80e4c7d60 --- /dev/null +++ b/tradeexecutor/visual/grid_search_advanced.py @@ -0,0 +1,682 @@ +"""Plot evolving sharpe ratio in grid search results. + +- Calculate rolling metrics using :py:func:`calculate_rolling_metrics`, + either for 1 parameter or 2 parameters visualisation + +- Visualise with :py:func:`visualise_grid_rolling_metric_heatmap` or :py:func:`visualise_grid_rolling_metric_line_chart` +""" +import enum +from typing import Any +import logging + +import numpy as np +import pandas as pd +from jedi.inference.gradual.typing import TypeAlias + +from plotly.graph_objs import Figure +import plotly.express as px +import plotly.graph_objects as go +from plotly.subplots import make_subplots + +from tradeexecutor.backtest.grid_search import GridSearchResult, GridCombination + + +logger = logging.getLogger(__name__) + + +#: A period like MS for month start +HumanPeriod: TypeAlias = pd.DateOffset | str + + +class BenchmarkMetric(enum.Enum): + sharpe = "sharpe" + + +class BenchmarkVisualisationType(enum.Enum): + line_chart = "line_chart" + heatmap = "heatmap" + + +def _calc_sharpe( + backtest_start_at: pd.Timestamp, + backtest_end_at: pd.Timestamp, + step: pd.Timedelta, + sharpe_period: pd.Timedelta, +): + pass + + +def check_inputs( + visualised_parameters: str | tuple[str, str], + fixed_parameters: dict, + combination: GridCombination, +): + """Raise if we have a human error.""" + + if type(visualised_parameters) == str: + visualised_parameters = [visualised_parameters] + else: + visualised_parameters = list(visualised_parameters) + + parameter_name_list = visualised_parameters + list(fixed_parameters.keys()) + + for p in combination.searchable_parameters: + if p.name not in parameter_name_list: + raise AssertionError(f"Visualisation logic missing coverage for parameter {p.name} - we have {parameter_name_list}") + + + +def prepare_comparisons( + visualised_parameters: str | tuple[str, str], + fixed_parameters: dict, + grid_search_results: list[GridSearchResult], +) -> tuple[list[GridSearchResult], list[Any]]: + """Construct X axis. + + - Get running values for the visualised paramter + + - Discard grid search results that do not match otherwise fixed parameters + + :return: + (Grid search results we need to visualise, unique values we are going to have) + """ + + # Get all fixed_parameter values that go to the x axis + x_axis = [] + uniq = set() + for r in grid_search_results: + params = r.combination.parameters + param_map = {p.name: p.value for p in params} + + if not all([key for key, value in fixed_parameters.items() if param_map.get(key) == value]): + # This grid search result is not in the scope of visualisation, + # as we are looking different fixed parameters + continue + + x_axis.append(r) + if type(visualised_parameters) == tuple: + assert len(visualised_parameters) == 2 + uniq.add( + (param_map.get(visualised_parameters[0]), param_map.get(visualised_parameters[1])) + ) + else: + uniq.add(param_map.get(visualised_parameters)) + + uniq = sorted(list(uniq)) + + return x_axis, uniq + + +def calculate_sharpe_at_timestamps( + index: pd.DatetimeIndex, + lookback: pd.Timedelta, + returns: pd.Series, + days_in_year=365, + risk_free_rate=0, +) -> pd.Series: + """Calculate rolling sharpe at certain points of time.""" + + assert isinstance(returns, pd.Series) + assert isinstance(returns.index, pd.DatetimeIndex) + + annualisation_factor = pd.Timedelta(days=days_in_year) // (returns.index[1] - returns.index[0]) + period_rf_rate = (1 + risk_free_rate) ** (1/annualisation_factor) - 1 + excess_returns = returns - period_rf_rate + + data = [] + + for timestamp in index: + period_returns = excess_returns.loc[timestamp - lookback:timestamp] + mean = period_returns.mean() + std = period_returns.std() + annualized_mean = mean * annualisation_factor + annualized_std = std * np.sqrt(annualisation_factor) + sharpe_ratio = annualized_mean / annualized_std + data.append(sharpe_ratio) + + return pd.Series(data, index=index) + + +def crunch_1d( + visualised_parameter: str, + unique_visualise_parameters: list[Any], + benchmarked_results: list[GridSearchResult], + index: pd.DatetimeIndex, + lookback: pd.Timedelta, + visualised_metric: BenchmarkMetric, +) -> pd.DataFrame: + """Calculate raw results. + + + TODO: Use rolling functions or not? + """ + + assert type(visualised_parameter) == str + + data = {} + + for uniq in unique_visualise_parameters: + logger.info("Calculating %s = %s", visualised_parameter, uniq) + found = False + for res in benchmarked_results: + if res.combination.get_parameter(visualised_parameter) == uniq: + returns = res.returns + sharpes = calculate_sharpe_at_timestamps( + index=index, + lookback=lookback, + returns=returns, + ) + data[uniq] = sharpes + found = True + assert found, f"Zero result match for {visualised_parameter} = {uniq}, we have {len(benchmarked_results)} results" + + return pd.DataFrame(data) + + +def crunch_2d( + visualised_parameters: tuple[str, str], + unique_visualise_parameters: list[Any], + benchmarked_results: list[GridSearchResult], + index: pd.DatetimeIndex, + lookback: pd.Timedelta, + visualised_metric: BenchmarkMetric, +) -> pd.DataFrame: + """Calculate raw results. + + + TODO: Use rolling functions or not? + """ + + assert type(visualised_parameters) == tuple + assert len(visualised_parameters) == 2 + + data = {} + + param_name_1 = visualised_parameters[0] + param_name_2 = visualised_parameters[1] + + for uniq_1, uniq_2 in unique_visualise_parameters: + logger.info( + "Calculating %s = %s, %s = %s", + param_name_1, + uniq_1, + param_name_2, + uniq_2 + ) + found = False + for res in benchmarked_results: + if res.combination.get_parameter(param_name_1) == uniq_1 and res.combination.get_parameter(param_name_2) == uniq_2: + returns = res.returns + sharpes = calculate_sharpe_at_timestamps( + index=index, + lookback=lookback, + returns=returns, + ) + # Pandas DataFrame allows tuples as column keys + data[(uniq_1, uniq_2)] = sharpes + found = True + assert found, f"Zero result match for {param_name_1} = {uniq_1} and {param_name_2} = {uniq_2}, we have {len(benchmarked_results)} results" + + return pd.DataFrame(data) + + +def calculate_rolling_metrics( + grid_search_result: list[GridSearchResult], + visualised_parameters: str | tuple[str, str], + fixed_parameters: dict, + sample_freq: HumanPeriod="MS", + lookback=pd.Timedelta(days=3*30), + benchmarked_metric=BenchmarkMetric.sharpe, +) -> pd.DataFrame: + """Calculate rolling metrics for grid search. + + We can have two parameters e.g. + - N: size of traded basket + - M: number of different pick sizes + + For each N: + - Calc rolling sharpe using last 3 months of returns (do in pandas) + - This will give you an M-sized array or returns + + For each quarter: + - Look back 3 months and plot + - yaxis: sharpe ratios + - x-axis: array of Ns + + Example output if using a single visualised parameter: + + .. code-block:: text + + 0.50 0.75 0.99 + 2021-06-01 NaN NaN NaN + 2021-07-01 -7.565988 -5.788797 -7.554848 + 2021-08-01 -3.924643 -1.919256 -3.914840 + 2021-09-01 -1.807489 -1.050918 -1.798897 + 2021-10-01 -1.849303 -1.604062 -1.841385 + 2021-11-01 -3.792905 -3.924210 -3.784793 + 2021-12-01 -4.156751 -4.186192 -4.148683 + + Example of 2d heatmap output: + + .. code-block:: text + + 0.50 0.75 0.99 + a b a b a b + 2021-06-01 NaN NaN NaN NaN NaN NaN + 2021-07-01 -7.565988 -7.565988 -5.788797 -5.788797 -7.554848 -7.554848 + 2021-08-01 -3.924643 -3.924643 -1.919256 -1.919256 -3.914840 -3.914840 + 2021-09-01 -1.807489 -1.807489 -1.050918 -1.050918 -1.798897 -1.798897 + 2021-10-01 -1.849303 -1.849303 -1.604062 -1.604062 -1.841385 -1.841385 + 2021-11-01 -3.792905 -3.792905 -3.924210 -3.924210 -3.784793 -3.784793 + 2021-12-01 -4.156751 -4.156751 -4.186192 -4.186192 -4.148683 -4.148683 + + :parma visualised_parameters: + Single parameter name for a line chart, two parameter name tuple for a heatmap. + + :param sample_freq: + What is the frequency of calculating rolling value + + :param lookback: + For trailing sharpe, how far look back + + :return: + + DataFrame where + + - Index is timestamp, by `step` + - Each column is value of visualisation parameter + - Each row value is the visualised metric for that parameter and that timestamp + + The first row contains NaNs as it cannot be calculated due to lack of data. + """ + + logger.info( + "calculate_rolling_metrics(), %d results", + len(grid_search_result), + ) + + assert benchmarked_metric == BenchmarkMetric.sharpe, "Only sharpe supported at the moment" + assert len(grid_search_result) > 0 + first_result = grid_search_result[0] + + # Different parmaeters may start trading at different times, + # so we copy the defined backtesting period from the first grid search + # backtest_start, backtest_end = first_result.universe_options.start_at, first_result.universe_options.end_at + + check_inputs( + visualised_parameters, + fixed_parameters, + first_result.combination, + ) + + benchmarked_results, unique_visualise_parameters = prepare_comparisons( + visualised_parameters, + fixed_parameters, + grid_search_result, + ) + + for res in benchmarked_results[0:3]: + logger.info("Example result: %s", res.combination) + + logger.info( + "We have %d unique combinations to analyse over %d results", + len(unique_visualise_parameters), + len(benchmarked_results), + ) + + range_start = first_result.backtest_start + range_end = first_result.backtest_end + + logger.info( + "Range is %s - %s", + range_start, + range_end, + ) + + assert range_end - range_start > pd.Timedelta("1d"), f"Range looks too short: {range_start} - {range_end}" + + # Prepare X axis + index = pd.date_range( + start=range_start, + end=range_end, + freq=sample_freq, + ) + + assert len(index) > 0, f"Could not generate index: {range_start} - {range_end}, freq {sample_freq}" + + if type(visualised_parameters) == tuple: + df = crunch_2d( + visualised_parameters=visualised_parameters, + unique_visualise_parameters=unique_visualise_parameters, + benchmarked_results=benchmarked_results, + index=index, + lookback=lookback, + visualised_metric=benchmarked_metric + ) + else: + df = crunch_1d( + visualised_parameter=visualised_parameters, + unique_visualise_parameters=unique_visualise_parameters, + benchmarked_results=benchmarked_results, + index=index, + lookback=lookback, + visualised_metric=benchmarked_metric + ) + + df.attrs["metric_name"] = benchmarked_metric.name + df.attrs["param_name"] = visualised_parameters + df.attrs["lookback"] = lookback + df.attrs["type"] = BenchmarkVisualisationType.heatmap if type(visualised_parameters) == tuple else BenchmarkVisualisationType.line_chart + + return df + + +def visualise_grid_single_rolling_metric( + df: pd.DataFrame, + width=None, + height=800, +) -> Figure: + """Create a single figure for a grid search parameter how results evolve over time. + + :param df: + Created by :py:func:`calculate_rolling_metrics` + """ + + assert isinstance(df, pd.DataFrame) + + assert df.attrs["type"] == BenchmarkVisualisationType.line_chart + + metric_name = df.attrs["metric_name"] + param_name = df.attrs["param_name"].replace("_", " ").capitalize() + lookback = df.attrs["lookback"] + + # Rename columns for human readable labels + for col in list(df.columns): + df.rename(columns={col: f"{param_name} = {col}"}, inplace=True) + + # Create figure + fig = go.Figure() + + # Add traces for each column + for column in df.columns: + fig.add_trace( + go.Scatter( + x=df.index, + y=df[column], + name=column, + mode='lines', + ) + ) + + # Update layout + fig.update_layout( + title=f"Rolling {metric_name} for {param_name} parameter, with lookback of {lookback}", + yaxis_title=metric_name, + xaxis_title='Date', + hovermode='x unified', + showlegend=True, + template='plotly_white', # Clean white background + height=height, + width=width, + ) + + # Add range slider + # fig.update_xaxes(rangeslider_visible=True) + + return fig + + +def visualise_grid_rolling_metric_line_chart( + df: pd.DataFrame, + width=1200, + height_per_row=500, + extra_height_margin=100, + charts_per_row=3, + range_start=None, + range_end=None, +) -> Figure: + """Create an "animation" for a single grid search parameter how results evolve over time as a line chart. + + :param df: + Created by :py:func:`calculate_rolling_metrics` + + :param charts_per_row: + How many mini charts display per Plotly row + + :param range_start: + Visualise slice of full backtest period. + + Inclusive. + + :param range_end: + Visualise slice of full backtest period. + + Inclusive. + + :return: + List of figure s, one for each index timestamp. + """ + + assert isinstance(df, pd.DataFrame) + + if range_start is not None: + df = df.loc[range_start:range_end] + + metric_name = df.attrs["metric_name"] + lookback = df.attrs["lookback"] + param_name = df.attrs["param_name"].replace("_", " ").capitalize() + + total_rows = len(df.index) // charts_per_row + 1 + + if total_rows == 1: + charts_per_row = min(len(df.index), charts_per_row) + + logger.info( + "visualise_grid_rolling_metric_multi_chart(): entries %d, rows %d, cols %d", + len(df.index), + total_rows, + charts_per_row, + ) + + titles = [] + for timestamp in df.index: + titles.append( + f"{timestamp}" + ) + + fig = make_subplots( + rows=total_rows, cols=charts_per_row, + subplot_titles=titles, + horizontal_spacing=0.05, + vertical_spacing=0.05 + ) + + for idx, timestamp in enumerate(df.index): + + # Get one row as one chart + row_series = df.loc[timestamp] + + sub_fig = px.line( + row_series, + title=f"{timestamp}", + ) + + col = (idx % charts_per_row) + 1 + row = (idx // charts_per_row) + 1 + fig.add_trace(sub_fig.data[0], row=row, col=col) + + height = height_per_row * total_rows + extra_height_margin + + fig.update_layout( + height=height, + width=width, + title_text=f"{metric_name.capitalize()} for parameter {param_name} with lookback of {lookback}", + title_x=0.5, + showlegend=False + ) + + # You can also adjust the overall margins of the figure + fig.update_layout( + margin=dict( + l=0, # left margin + r=0, # right margin + t=extra_height_margin, # top margin + b=0 # bottom margin + ) + ) + return fig + + +def visualise_grid_rolling_metric_heatmap( + df: pd.DataFrame, + width=1200, + height_per_row=500, + extra_height_margin=100, + charts_per_row=3, + range_start=None, + range_end=None, + discrete_parameters=True, + scale=(-2, 2), + colorscale='RdYlGn', +) -> Figure: + """Create an "animation" for two grid search parameters how results evolve over time as a heatmap. + + TODO: Visual Studio Code ignores any given height. + + :param df: + Created by :py:func:`calculate_rolling_metrics` + + :param charts_per_row: + How many mini charts display per Plotly row + + :param range_start: + Visualise slice of full backtest period. + + Inclusive. + + :param range_end: + Visualise slice of full backtest period. + + Inclusive. + + :param discrete_parameters: + Measured parameters are category like, not linear. + + :param scale: + Heatmap scale. + + :return: + List of figure s, one for each index timestamp. + """ + + assert isinstance(df, pd.DataFrame) + + assert df.attrs["type"] == BenchmarkVisualisationType.heatmap + + if range_start is not None: + df = df.loc[range_start:range_end] + + metric_name = df.attrs["metric_name"] + param_1 = df.attrs["param_name"][0] + param_2 = df.attrs["param_name"][1] + param_name = f"""{param_1.replace("_", " ").capitalize()} (Y) and {param_2.replace("_", " ").capitalize()} (X)""" + + lookback = df.attrs["lookback"] + + total_rows = len(df.index) // charts_per_row + 2 + + if total_rows == 1: + charts_per_row = min(len(df.index), charts_per_row) + + logger.info( + "visualise_grid_rolling_metric_multi_chart(): entries %d, rows %d, cols %d", + len(df.index), + total_rows, + charts_per_row, + ) + + titles = [] + for timestamp in df.index: + titles.append( + f"{timestamp}" + ) + + height = height_per_row * total_rows + extra_height_margin + + fig = make_subplots( + rows=total_rows, + cols=charts_per_row, + subplot_titles=titles, + horizontal_spacing=0.05, + vertical_spacing=0.05, + row_heights=[1 / total_rows for _ in range(total_rows)], + ) + + for idx, timestamp in enumerate(df.index): + + # Get one row as one chart + row_series = df.loc[timestamp] + + index_levels = [sorted(row_series.index.get_level_values(i).unique()) + for i in range(row_series.index.nlevels)] + + # Create a 2D array for the heatmap + z = np.zeros((len(index_levels[0]), len(index_levels[1]))) + + # Fill the array with values from the series + for row_idx, value in row_series.items(): + i = index_levels[0].index(row_idx[0]) + j = index_levels[1].index(row_idx[1]) + + # Clamp values to our scale range + value = max(scale[0], value) + value = min(scale[1], value) + + z[i][j] = value + + if discrete_parameters: + x = [str(i) for i in index_levels[1]] + y = [str(i) for i in index_levels[0]] + else: + x = index_levels[1] + y = index_levels[1] + + trace = go.Heatmap( + z=z, + x=x, + y=y, + colorscale=colorscale, # https://plotly.com/python/colorscales/ + # text=[[f'{val:.1f}%' for val in row] for row in z], + text=[], + #texttemplate='%{text}', + textfont={"size": 12}, + showscale=True if idx == 0 else False, + colorbar=None, + zmin=scale[0], + zmax=scale[1], + ) + + col = (idx % charts_per_row) + 1 + row = (idx // charts_per_row) + 1 + fig.add_trace(trace, row=row, col=col) + + fig.update_layout( + height=height, + width=width, + title_text=f"{metric_name.capitalize()} for parameters {param_name} with lookback of {lookback}, height {height}", + title_x=0.5, + showlegend=False + ) + + # You can also adjust the overall margins of the figure + fig.update_layout( + margin=dict( + l=0, # left margin + r=0, # right margin + t=extra_height_margin, # top margin + b=0 # bottom margin + ), + autosize=False, # Respect height_per_row + ) + return fig + diff --git a/tradeexecutor/visual/grid_search_basic.py b/tradeexecutor/visual/grid_search_basic.py new file mode 100644 index 000000000..a63a5d711 --- /dev/null +++ b/tradeexecutor/visual/grid_search_basic.py @@ -0,0 +1,277 @@ +"""Visualise grid search results. + +- Different visualisation tools to compare grid search results +""" +from typing import List + +import pandas as pd +import plotly.graph_objects as go +from plotly.graph_objs import Figure, Scatter + +from tradeexecutor.analysis.curve import CurveType, DEFAULT_BENCHMARK_COLOURS +from tradeexecutor.analysis.grid_search import _get_hover_template, order_grid_search_results_by_metric +from tradeexecutor.analysis.multi_asset_benchmark import get_benchmark_data +from tradeexecutor.backtest.grid_search import GridSearchResult +from tradeexecutor.strategy.trading_strategy_universe import TradingStrategyUniverse +from tradeexecutor.visual.benchmark import visualise_equity_curves +from tradingstrategy.types import USDollarAmount + + +def visualise_single_grid_search_result_benchmark( + result: GridSearchResult, + strategy_universe: TradingStrategyUniverse, + initial_cash: USDollarAmount | None = None, + name="Picked search result", + log_y=False, + asset_count=3, +) -> go.Figure: + """Draw one equity curve from grid search results. + + - Compare the equity curve againt buy and hold assets from the trading universe + + - Used to plot "the best" equity curve + + - Use :func:`find_best_grid_search_results` to find some equity curves. + + See also + + - :py:func:`visualise_grid_search_equity_curves` + + - :py:func:`tradeexecutor.visual.benchmark.visualise_equity_curves` + + - :py:func:`tradeexecutor.analysis.multi_asset_benchmark.get_benchmark_data` + + - :py:func:`tradeexecutor.analysis.grid_search.find_best_grid_search_results` + + Example: + + .. code-block:: python + + from tradeexecutor.analysis.grid_search import find_best_grid_search_results + from tradeexecutor.visual.grid_search import visualise_single_grid_search_result_benchmark + + # Show the equity curve of the best grid search performer + best_results = find_best_grid_search_results(grid_search_results) + fig = visualise_single_grid_search_result_benchmark(best_results.cagr[0], strategy_universe) + fig.show() + + :param result: + Picked grid search result + + :param strategy_universe: + Used to get benechmark indexes + + :param name: + Chart title + + :param initial_cash: + Not needed. Automatically filled in by grid search. + + Legacy param. + + :param asset_count: + Draw this many comparison buy-and-hold curves from well-known assets. + + :return: + Plotly figure + """ + + assert isinstance(result, GridSearchResult) + assert isinstance(strategy_universe, TradingStrategyUniverse) + + # Get daily returns + equity = result.equity_curve + equity.attrs["name"] = result.get_truncated_label() + equity.attrs["curve"] = CurveType.equity + equity.attrs["colour"] = DEFAULT_BENCHMARK_COLOURS["Strategy"] + + if result.state is not None: + start_at = result.state.get_trading_time_range()[0] + else: + start_at = equity.index[0] + + benchmarks = get_benchmark_data( + strategy_universe, + cumulative_with_initial_cash=initial_cash or getattr(result, "initial_cash", None), # Legacy support hack + start_at=start_at, + max_count=asset_count, + ) + + benchmark_series = [v for k, v in benchmarks.items()] + + fig = visualise_equity_curves( + [equity] + benchmark_series, + name=name, + log_y=log_y, + start_at=start_at, + ) + + return fig + + + + +def visualise_grid_search_equity_curves( + results: List[GridSearchResult], + name: str | None = None, + benchmark_indexes: pd.DataFrame | None = None, + height=1200, + colour=None, + log_y=False, + alpha=0.7, +) -> Figure: + """Draw multiple equity curves in the same chart. + + - See how all grid searched strategies work + + - Benchmark against buy and hold of various assets + + - Benchmark against hold all cash + + .. note :: + + Only good up to ~hundreds results. If more than thousand result, rendering takes too long time. + + Example that draws equity curves of a grid search results. + + .. code-block:: python + + from tradeexecutor.visual.grid_search import visualise_grid_search_equity_curves + from tradeexecutor.analysis.multi_asset_benchmark import get_benchmark_data + + # Automatically create BTC and ETH buy and hold benchmark if present + # in the trading universe + benchmark_indexes = get_benchmark_data( + strategy_universe, + cumulative_with_initial_cash=ShiftedStrategyParameters.initial_cash, + ) + + fig = visualise_grid_search_equity_curves( + grid_search_results, + name="8h clock shift, stop loss added and adjusted momentum", + benchmark_indexes=benchmark_indexes, + log_y=False, + ) + fig.show() + + :param results: + Results from the grid search. + + :param benchmark_indexes: + List of other asset price series displayed on the timeline besides equity curve. + + DataFrame containing multiple series. + + - Asset name is the series name. + - Setting `colour` for `pd.Series.attrs` allows you to override the colour of the index + + :param height: + Chart height in pixels + + :param colour: + Colour of the equity curve e.g. "rgba(160, 160, 160, 0.5)". If provided, all equity curves will be drawn with this colour. + + :param start_at: + When the backtest started + + :param end_at: + When the backtest ended + + :param additional_indicators: + Additional technical indicators drawn on this chart. + + List of indicator names. + + The indicators must be plotted earlier using `state.visualisation.plot_indicator()`. + + **Note**: Currently not very useful due to Y axis scale + + :param log_y: + Use logarithmic Y-axis. + + Because we accumulate larger treasury over time, + the swings in the value will be higher later. + We need to use a logarithmic Y axis so that we can compare the performance + early in the strateg and late in the strategy. + + """ + def generate_broad_bluered_colors(num_colors, alpha): + """ + Generate a list of RGBA colors along a broad blue-purple-red color scale. + + Parameters: + num_colors (int): Number of colors to generate. + alpha (float): Alpha value for the colors (0 to 1). + + Returns: + list: List of RGBA color tuples. + """ + colors = [] + for i in range(num_colors): + ratio = i / (num_colors - 1) + if ratio < 0.7: + red = int(255 * (1.5 * ratio)) + blue = 255 + else: + red = 255 + blue = int(255 * (2 * (1 - ratio))) + green = max(0, int(255 * (1 - abs(ratio - 0.7) * 2))) + color = (red, green, blue, alpha) + colors.append(f"rgba{color}") + return colors + + + fig = Figure() + + colors = generate_broad_bluered_colors(len(results), alpha) + + results = order_grid_search_results_by_metric(results) + + for result in results: + curve = result.equity_curve + label = result.get_truncated_label() + template =_get_hover_template(result) + scatter = Scatter( + x=curve.index, + y=curve, + mode="lines", + name="", # Hides hover legend, use hovertext only + line=dict(color=colors.pop(0)), + showlegend=False, + hovertemplate=template, + hovertext=None, + ) + fig.add_trace(scatter) + + if benchmark_indexes is not None: + for benchmark_name, curve in benchmark_indexes.items(): + benchmark_colour = curve.attrs.get("colour", "black") + scatter = Scatter( + x=curve.index, + y=curve, + mode="lines", + name=benchmark_name, + line=dict(color=benchmark_colour), + showlegend=True, + ) + fig.add_trace(scatter) + + fig.update_layout(title=f"{name}", height=height) + if log_y: + fig.update_yaxes(title="Value $ (logarithmic)", showgrid=False, type="log") + else: + fig.update_yaxes(title="Value $", showgrid=False) + fig.update_xaxes(rangeslider={"visible": False}) + + # Move legend to the bottom so we have more space for + # time axis in narrow notebook views + # https://plotly.com/python/legend/ + fig.update_layout(legend=dict( + orientation="h", + yanchor="bottom", + y=1.02, + xanchor="right", + x=1 + )) + + return fig diff --git a/tradeexecutor/visual/grid_search_visualisation.py b/tradeexecutor/visual/grid_search_visualisation.py new file mode 100644 index 000000000..0f82a900a --- /dev/null +++ b/tradeexecutor/visual/grid_search_visualisation.py @@ -0,0 +1,328 @@ +"""Plot evolving sharpe ratio in grid search results. + +""" +import enum +from typing import Any +import logging + +import numpy as np +import pandas as pd +from jedi.inference.gradual.typing import TypeAlias + +from plotly.graph_objs import Figure +import plotly.express as px +import plotly.graph_objects as go + +from tradeexecutor.backtest.grid_search import GridSearchResult, GridCombination + + +logger = logging.getLogger(__name__) + + +#: A period like MS for month start +HumanPeriod: TypeAlias = pd.DateOffset | str + + +class BenchmarkMetric(enum.Enum): + sharpe = "sharpe" + + + +def _calc_sharpe( + backtest_start_at: pd.Timestamp, + backtest_end_at: pd.Timestamp, + step: pd.Timedelta, + sharpe_period: pd.Timedelta, +): + pass + + +def check_inputs( + visualised_parameter: str, + fixed_parameters: dict, + combination: GridCombination, +): + """Raise if we have a human error.""" + parameter_name_list = [visualised_parameter] + list(fixed_parameters.keys()) + + for p in combination.searchable_parameters: + if p.name not in parameter_name_list: + raise AssertionError(f"Visualisation logic missing coverage for parameter {p.name} - we have {parameter_name_list}") + + + +def prepare_comparisons( + visualised_parameter: str, + fixed_parameters: dict, + grid_search_results: list[GridSearchResult], +) -> tuple[list[GridSearchResult], list[Any]]: + """Construct X axis. + + - Get running values for the visualised paramter + + - Discard grid search results that do not match otherwise fixed parameters + + :return: + (Grid search results we need to visualise, unique values we are going to have) + """ + + # Get all fixed_parameter values that go to the x axis + x_axis = [] + uniq = set() + for r in grid_search_results: + params = r.combination.parameters + param_map = {p.name: p.value for p in params} + + if not all([key for key, value in fixed_parameters.items() if param_map.get(key) == value]): + # This grid search result is not in the scope of visualisation, + # as we are looking different fixed parameters + continue + + x_axis.append(r) + uniq.add(param_map.get(visualised_parameter)) + + uniq = sorted(list(uniq)) + + return x_axis, uniq + + +def calculate_sharpe_at_timestamps( + index: pd.DatetimeIndex, + lookback: pd.Timedelta, + returns: pd.Series, + days_in_year=365, + risk_free_rate=0, +) -> pd.Series: + """Calculate rolling sharpe at certain points of time.""" + + assert isinstance(returns, pd.Series) + assert isinstance(returns.index, pd.DatetimeIndex) + + annualisation_factor = pd.Timedelta(days=days_in_year) // (returns.index[1] - returns.index[0]) + period_rf_rate = (1 + risk_free_rate) ** (1/annualisation_factor) - 1 + excess_returns = returns - period_rf_rate + + data = [] + + for timestamp in index: + period_returns = excess_returns.loc[timestamp - lookback:timestamp] + mean = period_returns.mean() + std = period_returns.std() + annualized_mean = mean * annualisation_factor + annualized_std = std * np.sqrt(annualisation_factor) + sharpe_ratio = annualized_mean / annualized_std + data.append(sharpe_ratio) + + return pd.Series(data, index=index) + + +def crunch( + visualised_parameter: str, + unique_visualise_parameters: list[Any], + benchmarked_results: list[GridSearchResult], + index: pd.DatetimeIndex, + lookback: pd.Timedelta, + visualised_metric: BenchmarkMetric, +) -> pd.DataFrame: + """Calculate raw results. + + + TODO: Use rolling functions or not? + """ + + data = {} + + for uniq in unique_visualise_parameters: + logger.info("Calculating %s = %s", visualised_parameter, uniq) + found = False + for res in benchmarked_results: + if res.combination.get_parameter(visualised_parameter) == uniq: + returns = res.returns + sharpes = calculate_sharpe_at_timestamps( + index=index, + lookback=lookback, + returns=returns, + ) + data[uniq] = sharpes + found = True + assert found, f"Zero result match for {visualised_parameter} = {uniq}, we have {len(benchmarked_results)} results" + + return pd.DataFrame(data) + + + +def calculate_rolling_metrics( + grid_search_result: list[GridSearchResult], + visualised_parameter: str, + fixed_parameters: dict, + sample_freq: HumanPeriod="MS", + lookback=pd.Timedelta(days=3*30), + benchmarked_metric=BenchmarkMetric.sharpe, +) -> pd.DataFrame: + """Calculate rolling metrics for grid search. + + We can have two parameters e.g. + - N: size of traded basket + - M: number of different pick sizes + + For each N: + - Calc rolling sharpe using last 3 months of returns (do in pandas) + - This will give you an M-sized array or returns + + For each quarter: + - Look back 3 months and plot + - yaxis: sharpe ratios + - x-axis: array of Ns + + Example output: + + .. code-block:: text + + 0.50 0.75 0.99 + 2021-06-01 NaN NaN NaN + 2021-07-01 -7.565988 -5.788797 -7.554848 + 2021-08-01 -3.924643 -1.919256 -3.914840 + 2021-09-01 -1.807489 -1.050918 -1.798897 + 2021-10-01 -1.849303 -1.604062 -1.841385 + 2021-11-01 -3.792905 -3.924210 -3.784793 + 2021-12-01 -4.156751 -4.186192 -4.148683 + + :param sample_freq: + What is the frequency of calculating rolling value + + :param lookback: + For trailing sharpe, how far look back + + :return: + + DataFrame where + + - Index is timestamp, by `step` + - Each column is value of visualisation parameter + - Each row value is the visualised metric for that parameter and that timestamp + + The first row contains NaNs as it cannot be calculated due to lack of data. + """ + + logger.info( + "calculate_rolling_metrics(), %d results", + len(grid_search_result), + ) + + assert benchmarked_metric == BenchmarkMetric.sharpe, "Only sharpe supported at the moment" + assert len(grid_search_result) > 0 + first_result = grid_search_result[0] + + # Different parmaeters may start trading at different times, + # so we copy the defined backtesting period from the first grid search + # backtest_start, backtest_end = first_result.universe_options.start_at, first_result.universe_options.end_at + + check_inputs( + visualised_parameter, + fixed_parameters, + first_result.combination, + ) + + benchmarked_results, unique_visualise_parameters = prepare_comparisons( + visualised_parameter, + fixed_parameters, + grid_search_result, + ) + + for res in benchmarked_results[0:3]: + logger.info("Example result: %s", res.combination) + + logger.info( + "We have %d unique combinations to analyse over %d results", + len(unique_visualise_parameters), + len(benchmarked_results), + ) + + range_start = first_result.backtest_start + range_end = first_result.backtest_end + + logger.info( + "Range is %s - %s", + range_start, + range_end, + ) + + assert range_end - range_start > pd.Timedelta("1d"), f"Range looks too short: {range_start} - {range_end}" + + # Prepare X axis + index = pd.date_range( + start=range_start, + end=range_end, + freq=sample_freq, + ) + + assert len(index) > 0, f"Could not generate index: {range_start} - {range_end}, freq {sample_freq}" + + df = crunch( + visualised_parameter=visualised_parameter, + unique_visualise_parameters=unique_visualise_parameters, + benchmarked_results=benchmarked_results, + index=index, + lookback=lookback, + visualised_metric=benchmarked_metric + ) + + df.attrs["metric_name"] = benchmarked_metric.name + df.attrs["param_name"] = visualised_parameter + df.attrs["lookback"] = lookback + + return df + + +def visualise_grid_single_rolling_metric( + df: pd.DataFrame, + width=None, + height=800, +) -> Figure: + """Create an animation for a grid search parameter how results evolve over time. + + :param df: + Created by :py:func:`calculate_rolling_metrics` + """ + + assert isinstance(df, pd.DataFrame) + + metric_name = df.attrs["metric_name"] + param_name = df.attrs["param_name"].replace("_", " ").capitalize() + lookback = df.attrs["lookback"] + + # Rename columns for human readable labels + for col in list(df.columns): + df.rename(columns={col: f"{param_name} = {col}"}, inplace=True) + + # Create figure + fig = go.Figure() + + # Add traces for each column + for column in df.columns: + fig.add_trace( + go.Scatter( + x=df.index, + y=df[column], + name=column, + mode='lines', + ) + ) + + # Update layout + fig.update_layout( + title=f"Rolling {metric_name} for {param_name} parameter, with lookback of {lookback}", + yaxis_title=metric_name, + xaxis_title='Date', + hovermode='x unified', + showlegend=True, + template='plotly_white', # Clean white background + height=height, + width=width, + ) + + # Add range slider + # fig.update_xaxes(rangeslider_visible=True) + + return fig +