Skip to content

Commit

Permalink
Merge pull request #16 from ivelin/feature-cli
Browse files Browse the repository at this point in the history
Feature cli - improve dashboard scan and advanced tabs
  • Loading branch information
ivelin authored Feb 26, 2024
2 parents 097c783 + b46e723 commit c6bebd1
Show file tree
Hide file tree
Showing 12 changed files with 306 additions and 179 deletions.
5 changes: 4 additions & 1 deletion example.env
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ train_date_start='1991-01-01'
n_outer_train_loop=1000

# file name with stock tickers to work with in data/data-3rd-party dir
stocks_train_list="IBD50.csv"
stock_tickers_list="IBD50.csv"
# "all_stocks.csv"

logging_dir="tmp"
Expand All @@ -33,3 +33,6 @@ data_dir="data"
data_3rd_party="data-3rd-party"
# data subdir for forecast data storage
forecast_subdir="forecast/"

# Keep changes local or sync with remote HF Hub repo
local_mode=False
11 changes: 6 additions & 5 deletions forecast.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ set exv


./canswim.sh forecast
# ./canswim.sh forecast --forecast_start_date "2023-11-18"
# ./canswim.sh forecast --forecast_start_date "2023-12-02"
# ./canswim.sh forecast --forecast_start_date "2023-12-16"
# ./canswim.sh forecast --forecast_start_date "2024-01-13"
# ./canswim.sh forecast --forecast_start_date "2024-01-27"
#...# ./canswim.sh forecast --forecast_start_date "2023-11-18"
./canswim.sh forecast --forecast_start_date "2023-12-02"
./canswim.sh forecast --forecast_start_date "2023-12-16"
./canswim.sh forecast --forecast_start_date "2024-01-13"
./canswim.sh forecast --forecast_start_date "2024-01-27"
./canswim.sh forecast --forecast_start_date "2024-02-17"
20 changes: 13 additions & 7 deletions src/canswim/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
help="""Which %(prog)s task to run:
`dashboard` for stock charting and scans of recorded forecasts.
'gatherdata` to gather 3rd party stock market data and save to HF Hub.
'uploaddata` upload to HF Hub any interim changes to local data storage.
'downloaddata` download model training and forecast data from HF Hub to local data storage.
'uploaddata` upload to HF Hub any interim changes to local train and forecast data.
`modelsearch` to find and save optimal hyperparameters for model training.
`train` for continuous model training.
`finetune` to fine tune pretrained model on new stock market data.
Expand All @@ -48,6 +49,7 @@
choices=[
"dashboard",
"gatherdata",
"downloaddata",
"uploaddata",
"modelsearch",
"train",
Expand All @@ -57,18 +59,19 @@
)

parser.add_argument(
'--forecast_start_date',
"--forecast_start_date",
type=str,
required=False,
help="""Optional argument for the `forecast` task. Indicate forecast start date in YYYY-MM-DD format. If not specified, forecast will start from the end of the target series.""")
help="""Optional argument for the `forecast` task. Indicate forecast start date in YYYY-MM-DD format. If not specified, forecast will start from the end of the target series.""",
)

parser.add_argument(
'--new_model',
"--new_model",
type=bool,
required=False,
default=False,
help="""Optional argument for the `train` task. Whether to train a newly created model or continue training an existing pre-trained model.""")

help="""Optional argument for the `train` task. Whether to train a newly created model or continue training an existing pre-trained model.""",
)

args = parser.parse_args()

Expand All @@ -94,9 +97,10 @@


def signal_handler(sig, frame):
print('Ctrl+C - Exit')
print("Ctrl+C - Exit")
sys.exit(0)


signal.signal(signal.SIGINT, signal_handler)

match args.task:
Expand All @@ -106,6 +110,8 @@ def signal_handler(sig, frame):
model_search.main()
case "gatherdata":
gather_data.main()
case "downloaddata":
hfhub.download_data()
case "uploaddata":
hfhub.upload_data()
case "train":
Expand Down
17 changes: 10 additions & 7 deletions src/canswim/covariates.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def prepare_earn_series(self, tickers=None):
earn_df.pop("fiscalDateEnding")

# convert earnings reporting time - Before Market Open / After Market Close - categories to numerical representation
earn_df["time"] = pd.Categorical(earn_df["time"], categories=["bmo", "amc"]).codes
earn_df["time"] = pd.Categorical(
earn_df["time"], categories=["bmo", "amc"]
).codes
# earn_df["time"] = (
# earn_df["time"]
# .replace(["bmo", "amc", "--", "dmh"], [0, 1, -1, -1], inplace=False)
Expand Down Expand Up @@ -137,7 +139,7 @@ def prepare_earn_series(self, tickers=None):
assert len(tes.gaps()) == 0
t_earn_series[t] = tes
except KeyError as e:
logger.exception(f"Skipping {t} due to error: ", e)
logger.warning(f"Skipping {t} due to error: {type(e)}: {e}")

return t_earn_series

Expand Down Expand Up @@ -217,8 +219,7 @@ def prepare_institutional_symbol_ownership_series(self, stock_price_series=None)
# .apply(lambda x: tuple(x.index))
# .reset_index(name="date")
# )
logger.warning(f"Skipping {t} due to error: \n{e}")
logger.warning(f"Skipping {t} due to error: {e}")
logger.warning(f"Skipping {t} due to error: {type(e)}: {e}")
# logger.info(
# f"Duplicated index rows: \n {t_iown.loc[t_iown.index == pd.Timestamp('1987-03-31')]}"
#
Expand All @@ -245,7 +246,7 @@ def stack_covariates(self, old_covs=None, new_covs=None, min_samples=1):
if len(stacked) >= min_samples:
stacked_covs[t] = stacked
except KeyError as e:
logger.warning(f"Skipping {t} covariates stack due to error: {e}")
logger.warning(f"Skipping {t} due to error: {type(e)}: {e}")
return stacked_covs

def df_index_to_biz_days(self, df=None):
Expand Down Expand Up @@ -286,7 +287,9 @@ def prepare_key_metrics(self, stock_price_series=None):
assert not kms_unique.index.has_duplicates
kms_loaded_df = kms_unique.copy()
# convert earnings reporting time - Before Market Open / After Market Close - categories to numerical representation
kms_loaded_df["period"] = pd.Categorical(kms_loaded_df["period"], categories=["_", "Q1", "Q2", "Q3", "Q4"]).codes
kms_loaded_df["period"] = pd.Categorical(
kms_loaded_df["period"], categories=["_", "Q1", "Q2", "Q3", "Q4"]
).codes
# kms_loaded_df["period"] = (
# kms_loaded_df["period"]
# .replace(["Q1", "Q2", "Q3", "Q4"], [1, 2, 3, 4], inplace=False)
Expand Down Expand Up @@ -324,7 +327,7 @@ def prepare_key_metrics(self, stock_price_series=None):
), f"found gaps in tmks series: \n{kms_ser_padded.gaps()}"
t_kms_series[t] = kms_ser_padded
except (KeyError, AssertionError) as e:
logger.exception(f"Skipping {t} due to error: ", e)
logger.warning(f"Skipping {t} due to error: {type(e)}: {e}")
# logger.info("t_kms_series:", t_kms_series)
return t_kms_series

Expand Down
62 changes: 41 additions & 21 deletions src/canswim/dashboard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@

repo_id = "ivelin/canswim"


class CanswimPlayground:

def __init__(self):
self.canswim_model = CanswimModel()
self.hfhub = HFHub()
data_dir = os.getenv("data_dir", "data")
forecast_subdir = os.getenv(
"forecast_subdir", "forecast/"
)
forecast_subdir = os.getenv("forecast_subdir", "forecast/")
self.forecast_path = f"{data_dir}/{forecast_subdir}"
self.data_3rd_party = os.getenv("data-3rd-party", "data-3rd-party")
data_3rd_party = os.getenv("data-3rd-party", "data-3rd-party")
price_data = os.getenv("price_data", "all_stocks_price_hist_1d.parquet")
self.stocks_price_path = f"{data_dir}/{self.data_3rd_party}/{price_data}"

self.stocks_price_path = f"{data_dir}/{data_3rd_party}/{price_data}"
stock_tickers_list = os.getenv("stock_tickers_list", "all_stocks.csv")
self.stock_tickers_path = f"{data_dir}/{data_3rd_party}/{stock_tickers_list}"

def download_model(self):
"""Load model from HF Hub"""
Expand All @@ -43,22 +43,40 @@ def download_model(self):

def download_data(self):
"""Prepare time series for model forecast"""
# download raw data from hf hub
self.hfhub.download_data(repo_id=repo_id)
# load raw data from hf hub
start_date = pd.Timestamp.now() - BDay(
n=self.canswim_model.min_samples + self.canswim_model.train_history
)
self.canswim_model.load_data(start_date=start_date)
# prepare timeseries for forecast
self.canswim_model.prepare_forecast_data(start_date=start_date)


def initdb(self):
logger.info(f"Forecast path: {self.forecast_path}")
tickers_str = "'"+"','".join(self.canswim_model.targets_ticker_list)+"'"
duckdb.sql(f"CREATE TABLE forecast AS SELECT date, symbol, forecast_start_year, forecast_start_month, forecast_start_day, COLUMNS(\"close_quantile_\d+\.\d+\") FROM read_parquet('{self.forecast_path}/**/*.parquet', hive_partitioning = 1) WHERE symbol in ({tickers_str})")
duckdb.sql(f"CREATE TABLE close_price AS SELECT Date, Symbol, Close FROM read_parquet('{self.stocks_price_path}') WHERE symbol in ({tickers_str})")
duckdb.sql('SET enable_external_access = false; ')
duckdb.sql(
f"""
CREATE VIEW stock_tickers
AS SELECT * FROM read_csv('{self.stock_tickers_path}', header=True)
"""
)
# df_tickers = duckdb.sql("SELECT * from stock_tickers").df()
# logger.info(f"stock ticker list:\n {df_tickers}")
duckdb.sql(
f"""
CREATE VIEW forecast
AS SELECT date, symbol, forecast_start_year, forecast_start_month, forecast_start_day, COLUMNS(\"close_quantile_\d+\.\d+\")
FROM read_parquet('{self.forecast_path}/**/*.parquet', hive_partitioning = 1) as f
SEMI JOIN stock_tickers
ON f.symbol = stock_tickers.symbol;
"""
)
duckdb.sql(
f"""
CREATE VIEW close_price
AS SELECT Date, Symbol, Close
FROM read_parquet('{self.stocks_price_path}') as cp
SEMI JOIN stock_tickers
ON cp.symbol = stock_tickers.symbol;
"""
)
# access protected via read only remote access tokebs
# restricting access prevents sql views from working
# duckdb.sql("SET enable_external_access = false; ")


def main():
Expand All @@ -78,7 +96,10 @@ def main():
canswim_playground.initdb()

with gr.Tab("Charts"):
charts_tab = ChartTab(canswim_playground.canswim_model, forecast_path=canswim_playground.forecast_path)
charts_tab = ChartTab(
canswim_playground.canswim_model,
forecast_path=canswim_playground.forecast_path,
)
with gr.Tab("Scans"):
ScanTab(canswim_playground.canswim_model)
with gr.Tab("Advanced Queries"):
Expand All @@ -88,10 +109,9 @@ def main():
fn=charts_tab.plot_forecast,
inputs=[charts_tab.tickerDropdown, charts_tab.lowq],
outputs=[charts_tab.plotComponent],
queue=False,
)

demo.launch()
demo.queue().launch()


if __name__ == "__main__":
Expand Down
35 changes: 24 additions & 11 deletions src/canswim/dashboard/advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,30 @@
import gradio as gr
import duckdb


class AdvancedTab:

def __init__(self, canswim_model: CanswimModel = None):
self.canswim_model = canswim_model
with gr.Row():
self.queryBox = gr.TextArea(value="""
SELECT f.symbol, min(f.date) as forecast_start_date, max(c.date) as prior_close_date, arg_max(c.close, c.date) as prior_close_price, min("close_quantile_0.2") as forecast_low_quantile, max("close_quantile_0.5") as forecast_mean_quantile
self.queryBox = gr.TextArea(
value="""
SELECT
f.symbol,
min(f.date) as forecast_start_date,
max(c.date) as prior_close_date,
arg_max(c.close, c.date) as prior_close_price,
min("close_quantile_0.2") as forecast_low_quantile,
max("close_quantile_0.5") as forecast_mean_quantile,
ROUND(100*(forecast_mean_quantile - prior_close_price) / prior_close_price) as reward_percent,
ROUND((forecast_mean_quantile - prior_close_price)/GREATEST(prior_close_price-forecast_low_quantile, 0.01),2) as reward_risk
FROM forecast f, close_price c
WHERE f.symbol = c.symbol
GROUP BY f.symbol, f.forecast_start_year, f.forecast_start_month, f.forecast_start_day, c.symbol
HAVING prior_close_date < forecast_start_date AND forecast_mean_quantile > prior_close_price AND (forecast_low_quantile > prior_close_price OR (forecast_mean_quantile - prior_close_price)/(prior_close_price-forecast_low_quantile) > 3)
AND (forecast_mean_quantile - prior_close_price) / prior_close_price > 0.2
""")
HAVING prior_close_date < forecast_start_date AND forecast_mean_quantile > prior_close_price
AND reward_risk> 3 AND reward_percent >= 20
"""
)

with gr.Row():
self.runBtn = gr.Button(value="Run Query", variant="primary")
Expand All @@ -25,18 +36,20 @@ def __init__(self, canswim_model: CanswimModel = None):

self.runBtn.click(
fn=self.scan_forecasts,
inputs=[self.queryBox, ],
inputs=[
self.queryBox,
],
outputs=[self.queryResult],
queue=False,
)
self.queryBox.submit(
fn=self.scan_forecasts,
inputs=[self.queryBox, ],
inputs=[
self.queryBox,
],
outputs=[self.queryResult],
queue=False,
)

def scan_forecasts(self, query):
# only run select queries
if query.strip().upper().startswith('SELECT'):
return duckdb.sql(query).df()
if query.strip().upper().startswith("SELECT"):
return duckdb.sql(query).df()
Loading

0 comments on commit c6bebd1

Please sign in to comment.