Skip to content

Commit f13652e

Browse files
authored
Merge pull request #8 from wearepal/hz
add hydra
2 parents 4601868 + e331685 commit f13652e

File tree

6 files changed

+73
-16
lines changed

6 files changed

+73
-16
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dependencies = [
1515
"polars>=1.2.0",
1616
"loguru>=0.7.2",
1717
"ranzen>=2.5.1",
18+
"hydra-zen>=0.13.0",
1819
]
1920
classifiers = [
2021
"Programming Language :: Python :: 3.10",

requirements-dev.lock

+15-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
# universal: false
1111

1212
-e file:.
13+
antlr4-python3-runtime==4.9.3
14+
# via hydra-core
15+
# via omegaconf
1316
cloudpickle==3.0.0
1417
# via gymnasium
1518
contourpy==1.2.1
@@ -22,6 +25,10 @@ fonttools==4.53.1
2225
# via matplotlib
2326
gymnasium==0.29.1
2427
# via fescher
28+
hydra-core==1.3.2
29+
# via hydra-zen
30+
hydra-zen==0.13.0
31+
# via fescher
2532
iniconfig==2.0.0
2633
# via pytest
2734
joblib==1.4.2
@@ -41,7 +48,11 @@ numpy==1.26.4
4148
# via scikit-learn
4249
# via scipy
4350
# via seaborn
51+
omegaconf==2.3.0
52+
# via hydra-core
53+
# via hydra-zen
4454
packaging==24.1
55+
# via hydra-core
4556
# via matplotlib
4657
# via pytest
4758
pandas==2.2.2
@@ -58,9 +69,11 @@ pytest==8.2.2
5869
python-dateutil==2.9.0.post0
5970
# via matplotlib
6071
# via pandas
61-
python-type-stubs @ git+https://github.com/wearepal/python-type-stubs@95a26e597ca5e3cc45ff76b435b537202aedd9c7
72+
python-type-stubs @ git+https://github.com/wearepal/python-type-stubs@95a26e5
6273
pytz==2024.1
6374
# via pandas
75+
pyyaml==6.0.1
76+
# via omegaconf
6477
ranzen==2.5.1
6578
# via fescher
6679
scikit-learn==1.5.1
@@ -77,6 +90,7 @@ tqdm==4.66.4
7790
# via fescher
7891
typing-extensions==4.12.2
7992
# via gymnasium
93+
# via hydra-zen
8094
# via ranzen
8195
tzdata==2024.1
8296
# via pandas

requirements.lock

+14
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
# universal: false
1111

1212
-e file:.
13+
antlr4-python3-runtime==4.9.3
14+
# via hydra-core
15+
# via omegaconf
1316
cloudpickle==3.0.0
1417
# via gymnasium
1518
contourpy==1.2.1
@@ -22,6 +25,10 @@ fonttools==4.53.1
2225
# via matplotlib
2326
gymnasium==0.29.1
2427
# via fescher
28+
hydra-core==1.3.2
29+
# via hydra-zen
30+
hydra-zen==0.13.0
31+
# via fescher
2532
joblib==1.4.2
2633
# via scikit-learn
2734
kiwisolver==1.4.5
@@ -39,7 +46,11 @@ numpy==1.26.4
3946
# via scikit-learn
4047
# via scipy
4148
# via seaborn
49+
omegaconf==2.3.0
50+
# via hydra-core
51+
# via hydra-zen
4252
packaging==24.1
53+
# via hydra-core
4354
# via matplotlib
4455
pandas==2.2.2
4556
# via seaborn
@@ -54,6 +65,8 @@ python-dateutil==2.9.0.post0
5465
# via pandas
5566
pytz==2024.1
5667
# via pandas
68+
pyyaml==6.0.1
69+
# via omegaconf
5770
ranzen==2.5.1
5871
# via fescher
5972
scikit-learn==1.5.1
@@ -70,6 +83,7 @@ tqdm==4.66.4
7083
# via fescher
7184
typing-extensions==4.12.2
7285
# via gymnasium
86+
# via hydra-zen
7387
# via ranzen
7488
tzdata==2024.1
7589
# via pandas

run/rrm_credit_lr.py

+22-10
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
sys.path.append(str(Path(__file__).parent / ".."))
1212

1313

14+
from hydra_zen import ZenStore, zen
15+
1416
from src.dynamics.env import DynamicEnv
1517
from src.dynamics.response import LinearResponse
1618
from src.dynamics.state import State
17-
from src.loader.credit import CreditData
19+
from src.loader.credit import CreditData, Data
1820
from src.models.lr import Lr, logistic_loss
1921

2022

@@ -29,10 +31,7 @@ class EpisodeRecord:
2931

3032

3133
def repeated_risk_minimization(
32-
*,
33-
env: DynamicEnv,
34-
num_steps: int,
35-
lr: Lr,
34+
*, env: DynamicEnv, num_steps: int, lr: Lr, l2_penalty: float
3635
) -> EpisodeRecord:
3736
"""Run repeated risk minimization for num_iters steps"""
3837
# Track loss and accuracy before/after updating model on new distribution
@@ -84,10 +83,16 @@ def make_env(*, initial_state: State, epsilon: float) -> DynamicEnv:
8483
return env
8584

8685

87-
if __name__ == "__main__":
86+
store = ZenStore()
87+
data_store = store(group="dataset")
88+
data_store(CreditData, seed=0, name="credit")
89+
90+
91+
@store(name="fescher", hydra_defaults=["_self_", {"dataset": "credit"}])
92+
def main(dataset: Data):
8893
# We use the credit simulator, which is a strategic classification
8994
# simulator based on the 'Kaggle Give Me Some Credit' (GMSC) dataset.
90-
initial_state = CreditData.as_state(seed=0)
95+
initial_state = dataset.as_state()
9196
# The state of the environment is a dataset consisting of (1) financial
9297
# features of individuals, e.g. DebtRatio, and (2) a binary label indicating
9398
# whether an individual experienced financial distress in the subsequent two
@@ -128,9 +133,7 @@ def make_env(*, initial_state: State, epsilon: float) -> DynamicEnv:
128133
env = make_env(initial_state=initial_state, epsilon=epsilon)
129134
logger.info(f"Running retraining for epsilon {epsilon:.2f}")
130135
record = repeated_risk_minimization(
131-
env=env,
132-
lr=lr,
133-
num_steps=num_steps,
136+
env=env, lr=lr, num_steps=num_steps, l2_penalty=l2_penalty
134137
)
135138
loss_starts.append(record.loss_start)
136139
loss_ends.append(record.loss_end)
@@ -245,3 +248,12 @@ def make_env(*, initial_state: State, epsilon: float) -> DynamicEnv:
245248
# ax.set_ylim([0.5, 0.75])
246249
plt.subplots_adjust(hspace=0.25)
247250
plt.show()
251+
252+
253+
if __name__ == "__main__":
254+
store.add_to_hydra_store()
255+
zen(main).hydra_main(
256+
config_name="fescher",
257+
version_base="1.1",
258+
config_path=None,
259+
)

src/dynamics/registration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __call__(
9191
**kwargs: Any,
9292
) -> DynamicEnv:
9393
del kwargs
94-
initial_state = unwrap_or(initial_state, default=CreditData.as_state())
94+
initial_state = unwrap_or(initial_state, default=CreditData().as_state())
9595
reward_fn = unwrap_or(reward_fn, default=LogisticReward(l2_penalty=0.0))
9696
response_fn = unwrap_or(response_fn, default=LinearResponse(epsilon=1.0))
9797
simulator = Simulator(

src/loader/credit.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Load and preprocess Kaggle credit dataset."""
22

33
from pathlib import Path
4+
from typing import Protocol
45

56
import numpy as np
67
import polars as pl
@@ -12,6 +13,23 @@
1213
__all__ = ["CreditData"]
1314

1415

16+
class Data(Protocol):
17+
def __init__(self, seed: int | None = None) -> None: ...
18+
19+
def as_state(self) -> State: ...
20+
21+
@property
22+
def features(self) -> FloatArray: ...
23+
24+
@property
25+
def labels(self) -> IntArray: ...
26+
27+
@property
28+
def num_agents(self) -> int: ...
29+
30+
def load(self) -> tuple[FloatArray, FloatArray]: ...
31+
32+
1533
class CreditData:
1634
"""Class to lazily load the credit dataset."""
1735

@@ -21,10 +39,8 @@ def __init__(self, seed: int | None = None) -> None:
2139
self._labels = None
2240
self.seed = seed
2341

24-
@classmethod
25-
def as_state(cls, seed: int | None = None) -> State:
26-
data = cls(seed=seed)
27-
return State(features=data.features, labels=data.labels)
42+
def as_state(self) -> State:
43+
return State(features=self.features, labels=self.labels)
2844

2945
@property
3046
def features(self) -> FloatArray:

0 commit comments

Comments
 (0)