Skip to content

Commit 302deac

Browse files
author
zeisler
committed
Two-stage working example.
1 parent 93cde40 commit 302deac

11 files changed

+87
-290
lines changed

dvc.lock

+32-9
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,12 @@ stages:
44
cmd: python src/get_data.py --output_folder outs/data
55
deps:
66
- path: src/get_data.py
7-
md5: 0b9160f7475a6aaa50acb793b7527f68
8-
size: 991
7+
md5: c7452de42878b5e9e672d786d48e7c53
8+
size: 998
99
params:
1010
params.yaml:
1111
data:
12-
repo: iterative/dvc
13-
labels:
14-
- data-sync
15-
- experiments
16-
- plots
17-
since: 2021/01/01
18-
until: 2022/05/01
12+
data_date: 2022/08/08
1913
output_folder: outs/data
2014
metrics_file: outs/data_metrics.json
2115
outs:
@@ -118,3 +112,32 @@ stages:
118112
- path: outs/eval/plots/confusion_matrix.json
119113
md5: 7ad7650d9b00d4ae671f04ba20889b79
120114
size: 1274
115+
fit:
116+
cmd: python src/fit.py outs/data outs/fit outs/fit_metrics
117+
deps:
118+
- path: outs/data
119+
md5: 0223fa49fdf31fd4bb2f99cf8322f401.dir
120+
size: 136
121+
nfiles: 1
122+
- path: src/fit.py
123+
md5: ef5f1cb19e0a3d286c43877e6b459620
124+
size: 947
125+
params:
126+
params.yaml:
127+
fit:
128+
output_folder: outs/fit
129+
metrics_folder: outs/fit_metrics
130+
metrics_file: outs/fit_metrics.json
131+
epochs: 8
132+
outs:
133+
- path: outs/fit
134+
md5: d751713988987e9331980363e24189ce.dir
135+
size: 0
136+
nfiles: 0
137+
- path: outs/fit_metrics
138+
md5: d751713988987e9331980363e24189ce.dir
139+
size: 0
140+
nfiles: 0
141+
- path: outs/fit_metrics.json
142+
md5: eba383f3de32d0a53c5962f4119e6dcf
143+
size: 23

dvc.yaml

+10-46
Original file line numberDiff line numberDiff line change
@@ -18,58 +18,22 @@ stages:
1818
- ${data.metrics_file}:
1919
cache: false
2020

21-
split-data:
21+
fit:
2222
cmd:
23-
python src/split_data.py
24-
${data.output_folder}
25-
${split.output_folder}
26-
${split.test_size}
27-
${split.metrics_file}
23+
python src/fit.py
24+
${data.output_folder}
25+
${fit.output_folder}
26+
${fit.metrics_folder}
2827
deps:
2928
- ${data.output_folder}
30-
- src/split_data.py
31-
outs:
32-
- ${split.output_folder}
33-
metrics:
34-
- ${split.metrics_file}:
35-
cache: false
36-
37-
train:
38-
cmd:
39-
python src/train.py
40-
${split.output_folder}
41-
${train.output_folder}
42-
deps:
43-
- ${split.output_folder}
44-
- src/train.py
29+
- src/fit.py
4530
params:
46-
- data.labels
47-
- train
31+
- fit
4832
outs:
49-
- ${train.output_folder}
33+
- ${fit.output_folder}
5034
metrics:
51-
- ${train.metrics_folder}.json:
35+
- ${fit.metrics_folder}.json:
5236
cache: false
5337
plots:
54-
- ${train.metrics_folder}:
38+
- ${fit.metrics_folder}:
5539
cache: false
56-
57-
eval:
58-
cmd:
59-
python src/eval.py
60-
${split.output_folder}/val.json
61-
${train.output_folder}
62-
${eval.output_folder}
63-
params:
64-
- data.labels
65-
deps:
66-
- ${split.output_folder}/val.json
67-
- ${train.output_folder}
68-
- src/eval.py
69-
- src/inference.py
70-
plots:
71-
- ${eval.output_folder}/plots/confusion_matrix.json:
72-
cache: false
73-
template: confusion
74-
x: actual
75-
y: predicted

outs/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
/data
22
/split
33
/train
4+
/fit

outs/fit_metrics.json

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"num_trades": 2
3+
}

params.yaml

+5-18
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,13 @@
11
data:
2-
repo: iterative/dvc
3-
4-
labels:
5-
- data-sync
6-
- experiments
7-
- plots
8-
since: 2021/01/01
9-
until: 2022/05/01
2+
data_date: 2022/08/08
103

114
output_folder: outs/data
125
metrics_file: outs/data_metrics.json
136

14-
split:
15-
output_folder: outs/split
16-
metrics_file: outs/split_metrics.json
17-
test_size: 0.1
7+
fit:
8+
output_folder: outs/fit
9+
metrics_folder: outs/fit_metrics
10+
metrics_file: outs/fit_metrics.json
1811

19-
train:
20-
output_folder: outs/train
21-
metrics_folder: outs/train_metrics
22-
pretrained_model: MoritzLaurer/xtremedistil-l6-h256-mnli-fever-anli-ling-binary
2312
epochs: 8
2413

25-
eval:
26-
output_folder: outs/eval

src/eval.py

-39
This file was deleted.

src/fit.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import json
2+
import os
3+
import pandas as pd
4+
from collections import Counter
5+
from pathlib import Path
6+
7+
import fire
8+
import yaml
9+
from loguru import logger
10+
11+
12+
@logger.catch(reraise=True)
13+
def fit(data_output_folder, fit_output_folder, fit_metrics_folder):
14+
base_dir = os.path.dirname(__file__) + "/.."
15+
16+
with open(base_dir+"/params.yaml") as f:
17+
params = yaml.safe_load(f)["fit"]
18+
19+
input_folder = base_dir + "/" + data_output_folder
20+
output_folder = base_dir + "/" + fit_output_folder
21+
22+
Path(output_folder).mkdir(parents=True, exist_ok=True)
23+
24+
trades = pd.read_csv(input_folder + "/trades.csv")
25+
26+
27+
metrics = {"num_trades": len(trades)}
28+
29+
output_metrics_folder = base_dir + "/" + fit_metrics_folder
30+
Path(output_metrics_folder).mkdir(parents=True, exist_ok=True)
31+
32+
Path(base_dir+"/"+params["metrics_file"]).write_text(json.dumps(metrics, indent=4))
33+
34+
if __name__ == "__main__":
35+
fire.Fire(fit)

src/get_data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def get_data(output_folder):
2525

2626
logger.info(f"\nSome log.")
2727

28-
trades = pd.DataFrame(dict(ops_code=["WN COMDTY 1", "CTUSD30Y"], start_time=pd.Timestamp("2022-08-08"), dayfrac=1, quantity=100000, quantity_units="dv01_usd"))
28+
trades = pd.DataFrame(dict(ops_code=["WN COMDTY 1", "CTUSD30Y"], start_time=pd.Timestamp(params["data_date"]), dayfrac=1, quantity=100000, quantity_units="dv01_usd"))
2929

3030
trades_file = output_folder + "/trades.csv"
3131
trades.to_csv(trades_file)

src/inference.py

-22
This file was deleted.

src/split_data.py

-72
This file was deleted.

0 commit comments

Comments
 (0)