diff --git a/pyproject.toml b/pyproject.toml index c27f940..14b1f68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "polars>=1.2.0", "loguru>=0.7.2", "ranzen>=2.5.1", + "hydra-zen>=0.13.0", ] classifiers = [ "Programming Language :: Python :: 3.10", diff --git a/requirements-dev.lock b/requirements-dev.lock index 6d47cb3..63685ce 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -10,6 +10,9 @@ # universal: false -e file:. +antlr4-python3-runtime==4.9.3 + # via hydra-core + # via omegaconf cloudpickle==3.0.0 # via gymnasium contourpy==1.2.1 @@ -22,6 +25,10 @@ fonttools==4.53.1 # via matplotlib gymnasium==0.29.1 # via fescher +hydra-core==1.3.2 + # via hydra-zen +hydra-zen==0.13.0 + # via fescher iniconfig==2.0.0 # via pytest joblib==1.4.2 @@ -41,7 +48,11 @@ numpy==1.26.4 # via scikit-learn # via scipy # via seaborn +omegaconf==2.3.0 + # via hydra-core + # via hydra-zen packaging==24.1 + # via hydra-core # via matplotlib # via pytest pandas==2.2.2 @@ -58,9 +69,11 @@ pytest==8.2.2 python-dateutil==2.9.0.post0 # via matplotlib # via pandas -python-type-stubs @ git+https://github.com/wearepal/python-type-stubs@95a26e597ca5e3cc45ff76b435b537202aedd9c7 +python-type-stubs @ git+https://github.com/wearepal/python-type-stubs@95a26e5 pytz==2024.1 # via pandas +pyyaml==6.0.1 + # via omegaconf ranzen==2.5.1 # via fescher scikit-learn==1.5.1 @@ -77,6 +90,7 @@ tqdm==4.66.4 # via fescher typing-extensions==4.12.2 # via gymnasium + # via hydra-zen # via ranzen tzdata==2024.1 # via pandas diff --git a/requirements.lock b/requirements.lock index 16ef884..32eb1a5 100644 --- a/requirements.lock +++ b/requirements.lock @@ -10,6 +10,9 @@ # universal: false -e file:. +antlr4-python3-runtime==4.9.3 + # via hydra-core + # via omegaconf cloudpickle==3.0.0 # via gymnasium contourpy==1.2.1 @@ -22,6 +25,10 @@ fonttools==4.53.1 # via matplotlib gymnasium==0.29.1 # via fescher +hydra-core==1.3.2 + # via hydra-zen +hydra-zen==0.13.0 + # via fescher joblib==1.4.2 # via scikit-learn kiwisolver==1.4.5 @@ -39,7 +46,11 @@ numpy==1.26.4 # via scikit-learn # via scipy # via seaborn +omegaconf==2.3.0 + # via hydra-core + # via hydra-zen packaging==24.1 + # via hydra-core # via matplotlib pandas==2.2.2 # via seaborn @@ -54,6 +65,8 @@ python-dateutil==2.9.0.post0 # via pandas pytz==2024.1 # via pandas +pyyaml==6.0.1 + # via omegaconf ranzen==2.5.1 # via fescher scikit-learn==1.5.1 @@ -70,6 +83,7 @@ tqdm==4.66.4 # via fescher typing-extensions==4.12.2 # via gymnasium + # via hydra-zen # via ranzen tzdata==2024.1 # via pandas diff --git a/run/rrm_credit_lr.py b/run/rrm_credit_lr.py index 2a6f753..0552962 100644 --- a/run/rrm_credit_lr.py +++ b/run/rrm_credit_lr.py @@ -11,10 +11,12 @@ sys.path.append(str(Path(__file__).parent / "..")) +from hydra_zen import ZenStore, zen + from src.dynamics.env import DynamicEnv from src.dynamics.response import LinearResponse from src.dynamics.state import State -from src.loader.credit import CreditData +from src.loader.credit import CreditData, Data from src.models.lr import Lr, logistic_loss @@ -29,10 +31,7 @@ class EpisodeRecord: def repeated_risk_minimization( - *, - env: DynamicEnv, - num_steps: int, - lr: Lr, + *, env: DynamicEnv, num_steps: int, lr: Lr, l2_penalty: float ) -> EpisodeRecord: """Run repeated risk minimization for num_iters steps""" # Track loss and accuracy before/after updating model on new distribution @@ -84,10 +83,16 @@ def make_env(*, initial_state: State, epsilon: float) -> DynamicEnv: return env -if __name__ == "__main__": +store = ZenStore() +data_store = store(group="dataset") +data_store(CreditData, seed=0, name="credit") + + +@store(name="fescher", hydra_defaults=["_self_", {"dataset": "credit"}]) +def main(dataset: Data): # We use the credit simulator, which is a strategic classification # simulator based on the 'Kaggle Give Me Some Credit' (GMSC) dataset. - initial_state = CreditData.as_state(seed=0) + initial_state = dataset.as_state() # The state of the environment is a dataset consisting of (1) financial # features of individuals, e.g. DebtRatio, and (2) a binary label indicating # whether an individual experienced financial distress in the subsequent two @@ -128,9 +133,7 @@ def make_env(*, initial_state: State, epsilon: float) -> DynamicEnv: env = make_env(initial_state=initial_state, epsilon=epsilon) logger.info(f"Running retraining for epsilon {epsilon:.2f}") record = repeated_risk_minimization( - env=env, - lr=lr, - num_steps=num_steps, + env=env, lr=lr, num_steps=num_steps, l2_penalty=l2_penalty ) loss_starts.append(record.loss_start) loss_ends.append(record.loss_end) @@ -245,3 +248,12 @@ def make_env(*, initial_state: State, epsilon: float) -> DynamicEnv: # ax.set_ylim([0.5, 0.75]) plt.subplots_adjust(hspace=0.25) plt.show() + + +if __name__ == "__main__": + store.add_to_hydra_store() + zen(main).hydra_main( + config_name="fescher", + version_base="1.1", + config_path=None, + ) diff --git a/src/dynamics/registration.py b/src/dynamics/registration.py index dc7ef38..ad67a98 100644 --- a/src/dynamics/registration.py +++ b/src/dynamics/registration.py @@ -91,7 +91,7 @@ def __call__( **kwargs: Any, ) -> DynamicEnv: del kwargs - initial_state = unwrap_or(initial_state, default=CreditData.as_state()) + initial_state = unwrap_or(initial_state, default=CreditData().as_state()) reward_fn = unwrap_or(reward_fn, default=LogisticReward(l2_penalty=0.0)) response_fn = unwrap_or(response_fn, default=LinearResponse(epsilon=1.0)) simulator = Simulator( diff --git a/src/loader/credit.py b/src/loader/credit.py index cc032d8..ac2d028 100644 --- a/src/loader/credit.py +++ b/src/loader/credit.py @@ -1,6 +1,7 @@ """Load and preprocess Kaggle credit dataset.""" from pathlib import Path +from typing import Protocol import numpy as np import polars as pl @@ -12,6 +13,23 @@ __all__ = ["CreditData"] +class Data(Protocol): + def __init__(self, seed: int | None = None) -> None: ... + + def as_state(self) -> State: ... + + @property + def features(self) -> FloatArray: ... + + @property + def labels(self) -> IntArray: ... + + @property + def num_agents(self) -> int: ... + + def load(self) -> tuple[FloatArray, FloatArray]: ... + + class CreditData: """Class to lazily load the credit dataset.""" @@ -21,10 +39,8 @@ def __init__(self, seed: int | None = None) -> None: self._labels = None self.seed = seed - @classmethod - def as_state(cls, seed: int | None = None) -> State: - data = cls(seed=seed) - return State(features=data.features, labels=data.labels) + def as_state(self) -> State: + return State(features=self.features, labels=self.labels) @property def features(self) -> FloatArray: