Skip to content

Commit

Permalink
add functions for validation
Browse files Browse the repository at this point in the history
  • Loading branch information
Koki committed Nov 26, 2021
1 parent 3542d41 commit 38b3aba
Show file tree
Hide file tree
Showing 5 changed files with 1,093 additions and 1,013 deletions.
31 changes: 31 additions & 0 deletions datasets/cmapss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,34 @@ def leave_one_out(target='run-to-failure',
subsets[subsets.fold == i].reset_index(drop=True)) for i in range(4)]

return train_test_sets


def generate_validation_sets(method='leave-one-out', n_splits=5, seed=123, outdir=None):
validation_sets = []

if method == 'kfold':
raise NotImplementedError

elif method == 'leave-one-out':
validation_sets = leave_one_out(target='run-to-failure',
health_censor_aug=1000,
seed=seed)

if outdir is not None:
for i, (train_data, test_data) in enumerate(validation_sets):
train_data.to_csv(outdir + f'/train_{i}.csv.gz', index=False)
test_data.to_csv(outdir + f'/test_{i}.csv.gz', index=False)

return validation_sets


def load_validation_sets(filepath, method='leave-one-out', n_splits=5):
if method == 'kfold':
return [(pd.read_csv(filepath + f'/train_{i}.csv.gz'),
pd.read_csv(filepath + f'/test_{i}.csv.gz'))
for i in range(n_splits)]

elif method == 'leave-one-out':
return [(pd.read_csv(filepath + f'/train_{i}.csv.gz'),
pd.read_csv(filepath + f'/test_{i}.csv.gz'))
for i in range(4)]
Loading

0 comments on commit 38b3aba

Please sign in to comment.