Skip to content

Commit

Permalink
Merge pull request #8 from wearepal/hz
Browse files Browse the repository at this point in the history
add hydra
  • Loading branch information
olliethomas authored Jul 23, 2024
2 parents 4601868 + e331685 commit f13652e
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 16 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies = [
"polars>=1.2.0",
"loguru>=0.7.2",
"ranzen>=2.5.1",
"hydra-zen>=0.13.0",
]
classifiers = [
"Programming Language :: Python :: 3.10",
Expand Down
16 changes: 15 additions & 1 deletion requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
# universal: false

-e file:.
antlr4-python3-runtime==4.9.3
# via hydra-core
# via omegaconf
cloudpickle==3.0.0
# via gymnasium
contourpy==1.2.1
Expand All @@ -22,6 +25,10 @@ fonttools==4.53.1
# via matplotlib
gymnasium==0.29.1
# via fescher
hydra-core==1.3.2
# via hydra-zen
hydra-zen==0.13.0
# via fescher
iniconfig==2.0.0
# via pytest
joblib==1.4.2
Expand All @@ -41,7 +48,11 @@ numpy==1.26.4
# via scikit-learn
# via scipy
# via seaborn
omegaconf==2.3.0
# via hydra-core
# via hydra-zen
packaging==24.1
# via hydra-core
# via matplotlib
# via pytest
pandas==2.2.2
Expand All @@ -58,9 +69,11 @@ pytest==8.2.2
python-dateutil==2.9.0.post0
# via matplotlib
# via pandas
python-type-stubs @ git+https://github.com/wearepal/python-type-stubs@95a26e597ca5e3cc45ff76b435b537202aedd9c7
python-type-stubs @ git+https://github.com/wearepal/python-type-stubs@95a26e5
pytz==2024.1
# via pandas
pyyaml==6.0.1
# via omegaconf
ranzen==2.5.1
# via fescher
scikit-learn==1.5.1
Expand All @@ -77,6 +90,7 @@ tqdm==4.66.4
# via fescher
typing-extensions==4.12.2
# via gymnasium
# via hydra-zen
# via ranzen
tzdata==2024.1
# via pandas
14 changes: 14 additions & 0 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
# universal: false

-e file:.
antlr4-python3-runtime==4.9.3
# via hydra-core
# via omegaconf
cloudpickle==3.0.0
# via gymnasium
contourpy==1.2.1
Expand All @@ -22,6 +25,10 @@ fonttools==4.53.1
# via matplotlib
gymnasium==0.29.1
# via fescher
hydra-core==1.3.2
# via hydra-zen
hydra-zen==0.13.0
# via fescher
joblib==1.4.2
# via scikit-learn
kiwisolver==1.4.5
Expand All @@ -39,7 +46,11 @@ numpy==1.26.4
# via scikit-learn
# via scipy
# via seaborn
omegaconf==2.3.0
# via hydra-core
# via hydra-zen
packaging==24.1
# via hydra-core
# via matplotlib
pandas==2.2.2
# via seaborn
Expand All @@ -54,6 +65,8 @@ python-dateutil==2.9.0.post0
# via pandas
pytz==2024.1
# via pandas
pyyaml==6.0.1
# via omegaconf
ranzen==2.5.1
# via fescher
scikit-learn==1.5.1
Expand All @@ -70,6 +83,7 @@ tqdm==4.66.4
# via fescher
typing-extensions==4.12.2
# via gymnasium
# via hydra-zen
# via ranzen
tzdata==2024.1
# via pandas
32 changes: 22 additions & 10 deletions run/rrm_credit_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
sys.path.append(str(Path(__file__).parent / ".."))


from hydra_zen import ZenStore, zen

from src.dynamics.env import DynamicEnv
from src.dynamics.response import LinearResponse
from src.dynamics.state import State
from src.loader.credit import CreditData
from src.loader.credit import CreditData, Data
from src.models.lr import Lr, logistic_loss


Expand All @@ -29,10 +31,7 @@ class EpisodeRecord:


def repeated_risk_minimization(
*,
env: DynamicEnv,
num_steps: int,
lr: Lr,
*, env: DynamicEnv, num_steps: int, lr: Lr, l2_penalty: float
) -> EpisodeRecord:
"""Run repeated risk minimization for num_iters steps"""
# Track loss and accuracy before/after updating model on new distribution
Expand Down Expand Up @@ -84,10 +83,16 @@ def make_env(*, initial_state: State, epsilon: float) -> DynamicEnv:
return env


if __name__ == "__main__":
store = ZenStore()
data_store = store(group="dataset")
data_store(CreditData, seed=0, name="credit")


@store(name="fescher", hydra_defaults=["_self_", {"dataset": "credit"}])
def main(dataset: Data):
# We use the credit simulator, which is a strategic classification
# simulator based on the 'Kaggle Give Me Some Credit' (GMSC) dataset.
initial_state = CreditData.as_state(seed=0)
initial_state = dataset.as_state()
# The state of the environment is a dataset consisting of (1) financial
# features of individuals, e.g. DebtRatio, and (2) a binary label indicating
# whether an individual experienced financial distress in the subsequent two
Expand Down Expand Up @@ -128,9 +133,7 @@ def make_env(*, initial_state: State, epsilon: float) -> DynamicEnv:
env = make_env(initial_state=initial_state, epsilon=epsilon)
logger.info(f"Running retraining for epsilon {epsilon:.2f}")
record = repeated_risk_minimization(
env=env,
lr=lr,
num_steps=num_steps,
env=env, lr=lr, num_steps=num_steps, l2_penalty=l2_penalty
)
loss_starts.append(record.loss_start)
loss_ends.append(record.loss_end)
Expand Down Expand Up @@ -245,3 +248,12 @@ def make_env(*, initial_state: State, epsilon: float) -> DynamicEnv:
# ax.set_ylim([0.5, 0.75])
plt.subplots_adjust(hspace=0.25)
plt.show()


if __name__ == "__main__":
store.add_to_hydra_store()
zen(main).hydra_main(
config_name="fescher",
version_base="1.1",
config_path=None,
)
2 changes: 1 addition & 1 deletion src/dynamics/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __call__(
**kwargs: Any,
) -> DynamicEnv:
del kwargs
initial_state = unwrap_or(initial_state, default=CreditData.as_state())
initial_state = unwrap_or(initial_state, default=CreditData().as_state())
reward_fn = unwrap_or(reward_fn, default=LogisticReward(l2_penalty=0.0))
response_fn = unwrap_or(response_fn, default=LinearResponse(epsilon=1.0))
simulator = Simulator(
Expand Down
24 changes: 20 additions & 4 deletions src/loader/credit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Load and preprocess Kaggle credit dataset."""

from pathlib import Path
from typing import Protocol

import numpy as np
import polars as pl
Expand All @@ -12,6 +13,23 @@
__all__ = ["CreditData"]


class Data(Protocol):
def __init__(self, seed: int | None = None) -> None: ...

def as_state(self) -> State: ...

@property
def features(self) -> FloatArray: ...

@property
def labels(self) -> IntArray: ...

@property
def num_agents(self) -> int: ...

def load(self) -> tuple[FloatArray, FloatArray]: ...


class CreditData:
"""Class to lazily load the credit dataset."""

Expand All @@ -21,10 +39,8 @@ def __init__(self, seed: int | None = None) -> None:
self._labels = None
self.seed = seed

@classmethod
def as_state(cls, seed: int | None = None) -> State:
data = cls(seed=seed)
return State(features=data.features, labels=data.labels)
def as_state(self) -> State:
return State(features=self.features, labels=self.labels)

@property
def features(self) -> FloatArray:
Expand Down

0 comments on commit f13652e

Please sign in to comment.