Skip to content

Commit b6aef84

Browse files
committed
ci: add tests and github action
Signed-off-by: ivelin <[email protected]>
1 parent bf676e2 commit b6aef84

File tree

4 files changed

+224
-0
lines changed

4 files changed

+224
-0
lines changed

.github/workflows/tests.yml

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
name: Tests
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
matrix:
14+
python-version: ["3.11"]
15+
16+
steps:
17+
- uses: actions/checkout@v3
18+
19+
- name: Set up Python ${{ matrix.python-version }}
20+
uses: actions/setup-python@v4
21+
with:
22+
python-version: ${{ matrix.python-version }}
23+
24+
- name: Install dependencies
25+
run: |
26+
python -m pip install --upgrade pip
27+
pip install pytest
28+
pip install -e .
29+
30+
- name: Create data directories
31+
run: |
32+
mkdir -p data/data-3rd-party
33+
mkdir -p data/forecast
34+
35+
- name: Run tests
36+
run: |
37+
python -m pytest tests/canswim/test_*.py -v

tests/canswim/test_forecast.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import pytest
2+
import pandas as pd
3+
import numpy as np
4+
from canswim.forecast import CanswimForecaster, get_next_open_market_day
5+
from unittest.mock import patch, MagicMock
6+
7+
@pytest.fixture
8+
def forecaster():
9+
with patch('canswim.forecast.CanswimForecaster.download_model'):
10+
return CanswimForecaster()
11+
12+
def test_get_next_open_market_day():
13+
"""Test get_next_open_market_day returns a business day"""
14+
next_day = get_next_open_market_day()
15+
assert isinstance(next_day, pd.Timestamp)
16+
assert next_day.dayofweek < 5 # Not weekend
17+
18+
def test_forecaster_initialization(forecaster):
19+
"""Test that CanswimForecaster initializes correctly"""
20+
assert isinstance(forecaster, CanswimForecaster)

tests/canswim/test_gather_data.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
from canswim.gather_data import MarketDataGatherer, get_latest_date
3+
from unittest.mock import patch, MagicMock
4+
import pandas as pd
5+
6+
@pytest.fixture
7+
def gatherer():
8+
return MarketDataGatherer()
9+
10+
def test_get_latest_date():
11+
"""Test get_latest_date function"""
12+
dates = pd.Series([
13+
pd.Timestamp('2023-01-01'),
14+
pd.Timestamp('2023-01-02'),
15+
pd.Timestamp('2023-01-03')
16+
])
17+
latest = get_latest_date(dates)
18+
assert latest == pd.Timestamp('2023-01-03')
19+
20+
def test_gatherer_initialization(gatherer):
21+
"""Test that MarketDataGatherer initializes correctly"""
22+
assert isinstance(gatherer, MarketDataGatherer)

tests/canswim/test_targets.py

+145
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import pytest
2+
import pandas as pd
3+
import numpy as np
4+
from canswim.targets import Targets
5+
from unittest.mock import patch, MagicMock
6+
from darts import TimeSeries
7+
8+
@pytest.fixture
9+
def mock_stock_data():
10+
# Create sample stock price data
11+
dates = pd.date_range(start='2023-01-01', end='2023-01-10', freq='B')
12+
data = {
13+
'AAPL': pd.DataFrame({
14+
'Close': np.random.uniform(150, 160, size=len(dates)),
15+
'Open': np.random.uniform(150, 160, size=len(dates)),
16+
'High': np.random.uniform(155, 165, size=len(dates)),
17+
'Low': np.random.uniform(145, 155, size=len(dates)),
18+
'Volume': np.random.randint(1000000, 2000000, size=len(dates))
19+
}, index=dates)
20+
}
21+
return data
22+
23+
@pytest.fixture
24+
def targets():
25+
return Targets()
26+
27+
def test_targets_initialization(targets):
28+
"""Test that Targets class initializes correctly"""
29+
assert isinstance(targets, Targets)
30+
31+
@patch('canswim.targets.Targets.load_stock_prices')
32+
def test_load_data(mock_load_prices, targets, mock_stock_data):
33+
"""Test load_data method with mock stock prices"""
34+
# Setup test data
35+
stock_tickers = {'AAPL'}
36+
start_date = pd.Timestamp('2023-01-01')
37+
38+
# Mock load_stock_prices to set the stock_price_dict
39+
def mock_load():
40+
targets.stock_price_dict = mock_stock_data
41+
mock_load_prices.side_effect = mock_load
42+
43+
# Call load_data with required parameters
44+
targets.load_data(stock_tickers=stock_tickers, start_date=start_date)
45+
46+
# Verify the data was loaded correctly
47+
assert hasattr(targets, 'stock_price_dict')
48+
assert isinstance(targets.stock_price_dict, dict)
49+
assert 'AAPL' in targets.stock_price_dict
50+
51+
# Verify DataFrame structure
52+
df = targets.stock_price_dict['AAPL']
53+
assert isinstance(df, pd.DataFrame)
54+
assert all(col in df.columns for col in ['Close', 'Open', 'High', 'Low', 'Volume'])
55+
56+
def test_pyarrow_filters(targets):
57+
"""Test pyarrow_filters property returns correct filters"""
58+
# Setup test data
59+
stock_tickers = {'AAPL', 'MSFT'}
60+
start_date = pd.Timestamp('2023-01-01')
61+
62+
# Call load_data to set up the required instance variables
63+
with patch('canswim.targets.Targets.load_stock_prices'):
64+
targets.load_data(stock_tickers=stock_tickers, start_date=start_date)
65+
66+
# Get filters
67+
filters = targets.pyarrow_filters
68+
69+
# Verify filter structure
70+
assert isinstance(filters, list)
71+
assert len(filters) == 2
72+
assert filters[0] == ("Symbol", "in", stock_tickers)
73+
assert filters[1] == ("Date", ">=", start_date)
74+
75+
@patch('canswim.targets.TimeSeries.from_dataframe')
76+
@patch('canswim.targets.MissingValuesFiller.transform')
77+
def test_prepare_stock_price_series(mock_transform, mock_from_df, targets, mock_stock_data):
78+
"""Test prepare_stock_price_series method"""
79+
# Setup test data
80+
train_date_start = pd.Timestamp('2023-01-05')
81+
targets.stock_price_dict = mock_stock_data
82+
83+
# Mock TimeSeries creation and transformation
84+
mock_series = MagicMock()
85+
mock_series.gaps.return_value = [] # No gaps after filling
86+
mock_series.end_time.return_value = pd.Timestamp('2023-01-10')
87+
mock_transform.return_value = mock_series
88+
mock_from_df.return_value = mock_series
89+
90+
# Call the method
91+
result = targets.prepare_stock_price_series(train_date_start=train_date_start)
92+
93+
# Verify results
94+
assert isinstance(result, dict)
95+
assert 'AAPL' in result
96+
assert mock_from_df.called
97+
assert mock_transform.called
98+
mock_series.slice.assert_called_with(train_date_start, mock_series.end_time())
99+
100+
def test_prepare_data(targets, mock_stock_data):
101+
"""Test prepare_data method with univariate target"""
102+
# Setup mock TimeSeries
103+
dates = pd.date_range(start='2023-01-01', end='2023-01-10', freq='B')
104+
mock_series = MagicMock()
105+
mock_series.univariate_component.return_value = TimeSeries.from_dataframe(
106+
pd.DataFrame({'Close': np.random.uniform(150, 160, size=len(dates))}, index=dates)
107+
)
108+
109+
stock_price_series = {'AAPL': mock_series}
110+
target_columns = 'Close'
111+
112+
# Call prepare_data
113+
targets.prepare_data(stock_price_series=stock_price_series, target_columns=target_columns)
114+
115+
# Verify results
116+
assert hasattr(targets, 'target_series')
117+
assert isinstance(targets.target_series, dict)
118+
assert 'AAPL' in targets.target_series
119+
mock_series.univariate_component.assert_called_with(target_columns)
120+
121+
def test_prepare_data_multivariate(targets, mock_stock_data):
122+
"""Test prepare_data method with multivariate targets"""
123+
# Setup mock TimeSeries
124+
dates = pd.date_range(start='2023-01-01', end='2023-01-10', freq='B')
125+
mock_series = MagicMock()
126+
mock_series.columns = ['Open', 'Close', 'Volume', 'Extra']
127+
mock_series.drop_columns.return_value = TimeSeries.from_dataframe(
128+
pd.DataFrame({
129+
'Open': np.random.uniform(150, 160, size=len(dates)),
130+
'Close': np.random.uniform(150, 160, size=len(dates)),
131+
'Volume': np.random.randint(1000000, 2000000, size=len(dates))
132+
}, index=dates)
133+
)
134+
135+
stock_price_series = {'AAPL': mock_series}
136+
target_columns = ['Open', 'Close', 'Volume']
137+
138+
# Call prepare_data
139+
targets.prepare_data(stock_price_series=stock_price_series, target_columns=target_columns)
140+
141+
# Verify results
142+
assert hasattr(targets, 'target_series')
143+
assert isinstance(targets.target_series, dict)
144+
assert 'AAPL' in targets.target_series
145+
mock_series.drop_columns.assert_called()

0 commit comments

Comments
 (0)