11
11
sys .path .append (str (Path (__file__ ).parent / ".." ))
12
12
13
13
14
+ from hydra_zen import ZenStore , zen
15
+
14
16
from src .dynamics .env import DynamicEnv
15
17
from src .dynamics .response import LinearResponse
16
18
from src .dynamics .state import State
17
- from src .loader .credit import CreditData
19
+ from src .loader .credit import CreditData , Data
18
20
from src .models .lr import Lr , logistic_loss
19
21
20
22
@@ -29,10 +31,7 @@ class EpisodeRecord:
29
31
30
32
31
33
def repeated_risk_minimization (
32
- * ,
33
- env : DynamicEnv ,
34
- num_steps : int ,
35
- lr : Lr ,
34
+ * , env : DynamicEnv , num_steps : int , lr : Lr , l2_penalty : float
36
35
) -> EpisodeRecord :
37
36
"""Run repeated risk minimization for num_iters steps"""
38
37
# Track loss and accuracy before/after updating model on new distribution
@@ -84,10 +83,16 @@ def make_env(*, initial_state: State, epsilon: float) -> DynamicEnv:
84
83
return env
85
84
86
85
87
- if __name__ == "__main__" :
86
+ store = ZenStore ()
87
+ data_store = store (group = "dataset" )
88
+ data_store (CreditData , seed = 0 , name = "credit" )
89
+
90
+
91
+ @store (name = "fescher" , hydra_defaults = ["_self_" , {"dataset" : "credit" }])
92
+ def main (dataset : Data ):
88
93
# We use the credit simulator, which is a strategic classification
89
94
# simulator based on the 'Kaggle Give Me Some Credit' (GMSC) dataset.
90
- initial_state = CreditData .as_state (seed = 0 )
95
+ initial_state = dataset .as_state ()
91
96
# The state of the environment is a dataset consisting of (1) financial
92
97
# features of individuals, e.g. DebtRatio, and (2) a binary label indicating
93
98
# whether an individual experienced financial distress in the subsequent two
@@ -128,9 +133,7 @@ def make_env(*, initial_state: State, epsilon: float) -> DynamicEnv:
128
133
env = make_env (initial_state = initial_state , epsilon = epsilon )
129
134
logger .info (f"Running retraining for epsilon { epsilon :.2f} " )
130
135
record = repeated_risk_minimization (
131
- env = env ,
132
- lr = lr ,
133
- num_steps = num_steps ,
136
+ env = env , lr = lr , num_steps = num_steps , l2_penalty = l2_penalty
134
137
)
135
138
loss_starts .append (record .loss_start )
136
139
loss_ends .append (record .loss_end )
@@ -245,3 +248,12 @@ def make_env(*, initial_state: State, epsilon: float) -> DynamicEnv:
245
248
# ax.set_ylim([0.5, 0.75])
246
249
plt .subplots_adjust (hspace = 0.25 )
247
250
plt .show ()
251
+
252
+
253
+ if __name__ == "__main__" :
254
+ store .add_to_hydra_store ()
255
+ zen (main ).hydra_main (
256
+ config_name = "fescher" ,
257
+ version_base = "1.1" ,
258
+ config_path = None ,
259
+ )
0 commit comments