diff --git a/stormworkflow/post/analyze_ensemble.py b/stormworkflow/post/analyze_ensemble.py index f12ea8b..4690fe6 100644 --- a/stormworkflow/post/analyze_ensemble.py +++ b/stormworkflow/post/analyze_ensemble.py @@ -112,6 +112,8 @@ def _analyze(tracks_dir, analyze_dir, mann_coef): storm_name = None + validate_surrogate_model = False #if True, split 70/30 + if log_space: output_directory = analyze_dir / f'log_k{k_neighbors}_p{idw_order}_n{mann_coef}' else: @@ -163,14 +165,21 @@ def _analyze(tracks_dir, analyze_dir, mann_coef): ) ) - if len(numpy.unique(perturbations['type'][:])) == 1: +# if len(numpy.unique(perturbations['type'][:])) == 1: + + if validate_surrogate_model: perturbations['type'][:] = numpy.random.choice( ['training', 'validation'], size=len(perturbations.run), p=[0.7, 0.3] ) LOGGER.info('dividing 70/30% for training/testing the model') - training_perturbations = perturbations.sel(run=perturbations['type'] == 'training') - validation_perturbations = perturbations.sel(run=perturbations['type'] == 'validation') + training_perturbations = perturbations.sel(run=perturbations['type'] == 'training') + validation_perturbations = perturbations.sel(run=perturbations['type'] == 'validation') + else: + training_perturbations = perturbations.sel(run=perturbations['type'] == 'training') + validation_perturbations = perturbations.sel(run=perturbations['type'] == 'training') + LOGGER.info('using all members for training the model') + make_validation_plot = False if make_perturbations_plot: plot_perturbations( @@ -222,7 +231,8 @@ def _analyze(tracks_dir, analyze_dir, mann_coef): validation_set = subset.sel(run=validation_perturbations['run']) LOGGER.info(f'total {training_set.shape} training samples') - LOGGER.info(f'total {validation_set.shape} validation samples') + if validate_surrogate_model: + LOGGER.info(f'total {validation_set.shape} validation samples') if node_status_mask == 'always_wet': training_set_adjusted = training_set.copy(deep=True) diff --git a/stormworkflow/slurm/post.sbatch b/stormworkflow/slurm/post.sbatch index 81d039d..0df1afe 100644 --- a/stormworkflow/slurm/post.sbatch +++ b/stormworkflow/slurm/post.sbatch @@ -1,6 +1,6 @@ #!/bin/bash #SBATCH --parsable -#SBATCH --time=05:00:00 +#SBATCH --time=08:00:00 #SBATCH --nodes=1 #SBATCH --exclusive diff --git a/stormworkflow/slurm/schism.sbatch b/stormworkflow/slurm/schism.sbatch index fef7e43..a2d0590 100644 --- a/stormworkflow/slurm/schism.sbatch +++ b/stormworkflow/slurm/schism.sbatch @@ -1,7 +1,7 @@ #!/bin/bash #SBATCH --parsable #SBATCH --exclusive -#SBATCH --time=03:00:00 +#SBATCH --time=05:00:00 set -ex