-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Data augmentation #7
base: main
Are you sure you want to change the base?
Changes from all commits
826b81f
a0b0d4c
e240e2a
4a5ae22
3ed90f2
79e51f5
9fb460c
f9bceb4
6da9b54
6cc66ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -6,15 +6,21 @@ | |||||||||||||||||
# - getting requirements info when all dependencies are not installed. | ||||||||||||||||||
with safe_import_context() as import_ctx: | ||||||||||||||||||
import numpy as np | ||||||||||||||||||
|
||||||||||||||||||
from numpy import concatenate | ||||||||||||||||||
from torch import as_tensor | ||||||||||||||||||
from skorch.helper import to_numpy | ||||||||||||||||||
from braindecode.augmentation import ChannelsDropout, SmoothTimeMask | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def gen_seed(): | ||||||||||||||||||
# Iterator that generates random seeds for reproducibility reasons | ||||||||||||||||||
seed = 0 | ||||||||||||||||||
while True: | ||||||||||||||||||
yield seed | ||||||||||||||||||
seed += 1 | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def channels_dropout( | ||||||||||||||||||
X, y, n_augmentation, seed=0, probability=0.5, p_drop=0.2 | ||||||||||||||||||
X, y, n_augmentation, probability=0.5, p_drop=0.2 | ||||||||||||||||||
): | ||||||||||||||||||
""" | ||||||||||||||||||
Function to apply channels dropout to X raw data | ||||||||||||||||||
|
@@ -27,8 +33,6 @@ def channels_dropout( | |||||||||||||||||
The labels. | ||||||||||||||||||
n_augmentation : int | ||||||||||||||||||
Number of augmentation to apply and increase the size of the dataset. | ||||||||||||||||||
seed : int | ||||||||||||||||||
Random seed. | ||||||||||||||||||
probability : float | ||||||||||||||||||
Probability of applying the tranformation. | ||||||||||||||||||
p_drop : float | ||||||||||||||||||
|
@@ -43,51 +47,58 @@ def channels_dropout( | |||||||||||||||||
The labels. | ||||||||||||||||||
|
||||||||||||||||||
""" | ||||||||||||||||||
transform = ChannelsDropout(probability=probability, random_state=seed) | ||||||||||||||||||
|
||||||||||||||||||
seed = gen_seed() | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For a reproducible sequence of random numbers, the best is to use a random number generator, that will be used to sample different number for each draw, but with a reproducible order. |
||||||||||||||||||
X_augm = to_numpy(X) | ||||||||||||||||||
y_augm = y | ||||||||||||||||||
for i in range(n_augmentation): | ||||||||||||||||||
for _ in range(n_augmentation): | ||||||||||||||||||
transform = ChannelsDropout( | ||||||||||||||||||
probability=probability, | ||||||||||||||||||
random_state=next(seed) | ||||||||||||||||||
) | ||||||||||||||||||
Comment on lines
+55
to
+58
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
X_tr, _ = transform.operation( | ||||||||||||||||||
as_tensor(X).float(), None, p_drop=p_drop | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
X_tr = X_tr.numpy() | ||||||||||||||||||
X_augm = concatenate((X_augm, X_tr)) | ||||||||||||||||||
y_augm = concatenate((y_augm, y)) | ||||||||||||||||||
X_augm = np.concatenate((X_augm, X_tr)) | ||||||||||||||||||
y_augm = np.concatenate((y_augm, y)) | ||||||||||||||||||
|
||||||||||||||||||
return X_augm, y_augm | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def smooth_timemask( | ||||||||||||||||||
X, y, n_augmentation, sfreq, seed=0, probability=0.5, second=0.1 | ||||||||||||||||||
X, y, n_augmentation, sfreq, probability=0.8, second=0.2 | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
): | ||||||||||||||||||
""" | ||||||||||||||||||
Function to apply smooth time mask to X raw data | ||||||||||||||||||
and concatenate it to the original data. | ||||||||||||||||||
""" | ||||||||||||||||||
|
||||||||||||||||||
transform = SmoothTimeMask( | ||||||||||||||||||
probability=probability, | ||||||||||||||||||
mask_len_samples=int(sfreq * second), | ||||||||||||||||||
random_state=seed, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
seed_generator = gen_seed() | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
X_torch = as_tensor(np.array(X)).float() | ||||||||||||||||||
y_torch = as_tensor(y).float() | ||||||||||||||||||
param_augm = transform.get_augmentation_params(X_torch, y_torch) | ||||||||||||||||||
mls = param_augm["mask_len_samples"] | ||||||||||||||||||
msps = param_augm["mask_start_per_sample"] | ||||||||||||||||||
|
||||||||||||||||||
X_augm = to_numpy(X) | ||||||||||||||||||
y_augm = y | ||||||||||||||||||
|
||||||||||||||||||
for i in range(n_augmentation): | ||||||||||||||||||
mls = int(sfreq * second) | ||||||||||||||||||
for _ in range(n_augmentation): | ||||||||||||||||||
seed = next(seed_generator) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
transform = SmoothTimeMask( | ||||||||||||||||||
probability=probability, | ||||||||||||||||||
mask_len_samples=mls, | ||||||||||||||||||
random_state=rng | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
param_augm = transform.get_augmentation_params(X_torch, y_torch) | ||||||||||||||||||
mls = param_augm["mask_len_samples"] | ||||||||||||||||||
msps = param_augm["mask_start_per_sample"] | ||||||||||||||||||
|
||||||||||||||||||
X_tr, _ = transform.operation( | ||||||||||||||||||
X_torch, None, mask_len_samples=mls, mask_start_per_sample=msps | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
X_tr = X_tr.numpy() | ||||||||||||||||||
X_augm = concatenate((X_augm, X_tr)) | ||||||||||||||||||
y_augm = concatenate((y_augm, y)) | ||||||||||||||||||
X_augm = np.concatenate((X_augm, X_tr)) | ||||||||||||||||||
y_augm = np.concatenate((y_augm, y)) | ||||||||||||||||||
|
||||||||||||||||||
return X_augm, y_augm |
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2,14 +2,14 @@ | |||||||||||||
|
||||||||||||||
|
||||||||||||||
with safe_import_context() as import_ctx: | ||||||||||||||
from numpy import array | ||||||||||||||
|
||||||||||||||
import numpy as np | ||||||||||||||
from sklearn.dummy import DummyClassifier | ||||||||||||||
from sklearn.pipeline import make_pipeline | ||||||||||||||
from sklearn.pipeline import FunctionTransformer | ||||||||||||||
|
||||||||||||||
from sklearn.model_selection import train_test_split | ||||||||||||||
from sklearn.metrics import balanced_accuracy_score as BAS | ||||||||||||||
from sklearn.metrics import accuracy_score | ||||||||||||||
|
||||||||||||||
from skorch.helper import SliceDataset, to_numpy | ||||||||||||||
from benchmark_utils.dataset import split_windows_train_test | ||||||||||||||
|
@@ -31,6 +31,7 @@ class Objective(BaseObjective): | |||||||||||||
parameters = { | ||||||||||||||
'evaluation_process, subject, subject_test, session_test': [ | ||||||||||||||
('intra_subject', 1, None, None), | ||||||||||||||
('inter_subject', None, 3, None), | ||||||||||||||
], | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
|
@@ -55,7 +56,9 @@ def set_data(self, dataset, sfreq): | |||||||||||||
|
||||||||||||||
dataset = data_split_subject[str(self.subject)] | ||||||||||||||
X = SliceDataset(dataset, idx=0) | ||||||||||||||
y = array(list(SliceDataset(dataset, idx=1))) | ||||||||||||||
y = np.array(list(SliceDataset(dataset, idx=1)))-1 | ||||||||||||||
# we have to susbtract 1 to the labels for compatibility reasons | ||||||||||||||
# with the deep learning solvers | ||||||||||||||
Comment on lines
+59
to
+61
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By convention, we put comments before the code, not after.
Suggested change
|
||||||||||||||
|
||||||||||||||
# maybe we need to do here different process for each subjects | ||||||||||||||
|
||||||||||||||
|
@@ -64,6 +67,8 @@ def set_data(self, dataset, sfreq): | |||||||||||||
self.X_test, self.y_test = X_test, y_test | ||||||||||||||
|
||||||||||||||
elif self.evaluation_process == 'inter_subject': | ||||||||||||||
# the evaluation proccess here is to leave one subject out | ||||||||||||||
# to test on it and train on the rest of the subjects | ||||||||||||||
|
||||||||||||||
sujet_test = self.subject_test | ||||||||||||||
data_subject_test = data_split_subject[str(sujet_test)] | ||||||||||||||
|
@@ -82,6 +87,8 @@ def set_data(self, dataset, sfreq): | |||||||||||||
self.y_test = splitted_data['y_test'] | ||||||||||||||
|
||||||||||||||
elif self.evaluation_process == 'inter_session': | ||||||||||||||
# the evaluation proccess here is to leave one session out | ||||||||||||||
# to test on it and train on the rest of the sessions | ||||||||||||||
|
||||||||||||||
data_subject = data_split_subject[str(self.subject)] | ||||||||||||||
data_split_session = data_subject.split('session') | ||||||||||||||
|
@@ -103,9 +110,9 @@ def set_data(self, dataset, sfreq): | |||||||||||||
self.sfreq = sfreq | ||||||||||||||
|
||||||||||||||
return dict( | ||||||||||||||
X_train=X_train, y_train=y_train, | ||||||||||||||
X_test=X_test, y_test=y_test, | ||||||||||||||
sfreq=sfreq, | ||||||||||||||
X_train=self.X_train, y_train=self.y_train, | ||||||||||||||
X_test=self.X_test, y_test=self.y_test, | ||||||||||||||
sfreq=self.sfreq, | ||||||||||||||
Comment on lines
112
to
+115
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||
) | ||||||||||||||
|
||||||||||||||
def evaluate_result(self, model): | ||||||||||||||
|
@@ -125,9 +132,14 @@ def evaluate_result(self, model): | |||||||||||||
value: error on the testing set. | ||||||||||||||
""" | ||||||||||||||
|
||||||||||||||
score_train = model.score(self.X_train, self.y_train) | ||||||||||||||
score_test = model.score(self.X_test, self.y_test) | ||||||||||||||
bl_acc = BAS(self.y_test, model.predict(self.X_test)) | ||||||||||||||
# we compute here the predictions so | ||||||||||||||
# that we don't compute it for each score | ||||||||||||||
y_pred_train = model.predict(self.X_train) | ||||||||||||||
y_pred_test = model.predict(self.X_test) | ||||||||||||||
tomMoral marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
|
||||||||||||||
score_train = accuracy_score(self.y_train, y_pred_train) | ||||||||||||||
score_test = accuracy_score(self.y_test, y_pred_test) | ||||||||||||||
bl_acc = BAS(self.y_test, y_pred_test) | ||||||||||||||
|
||||||||||||||
return dict( | ||||||||||||||
score_test=score_test, | ||||||||||||||
|
@@ -161,6 +173,8 @@ def get_objective(self): | |||||||||||||
sfreq: sampling frequency to allow filtering the data. | ||||||||||||||
""" | ||||||||||||||
|
||||||||||||||
X_train, X_test, y_train, y_test = self.get_split(self.X, self.y) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work? I don't see how you can get the split here but maybe I am missing something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I think I mixed here the branch named CV and this one, where I was beginning to add our work on cross validation. |
||||||||||||||
|
||||||||||||||
return dict( | ||||||||||||||
X=self.X_train, | ||||||||||||||
y=self.y_train, | ||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -15,8 +15,23 @@ | |||||||||
|
||||||||||
|
||||||||||
class Solver(AugmentedBCISolver): | ||||||||||
''' | ||||||||||
You can choose an augmentation parameter from the following list: | ||||||||||
- IdentityTransform | ||||||||||
- ChannelsDropout | ||||||||||
- SmoothTimeMask | ||||||||||
|
||||||||||
Running the benchmark with -n = n_augmentation | ||||||||||
you will get a cuvre of solver's score | ||||||||||
with respect to the number of augmentation which corresponds | ||||||||||
to how much the dataset has been multiplied. | ||||||||||
''' | ||||||||||
name = "CSPLDA" | ||||||||||
parameters = { | ||||||||||
"augmentation": [ | ||||||||||
"SmoothTimeMask", | ||||||||||
"ChannelsDropout", | ||||||||||
], | ||||||||||
Comment on lines
+31
to
+34
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All the augmentations are passed through
Suggested change
|
||||||||||
"n_components": [8], | ||||||||||
**AugmentedBCISolver.parameters | ||||||||||
} | ||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -13,7 +13,7 @@ | |||||
SmoothTimeMask, | ||||||
) | ||||||
from braindecode.models import ShallowFBCSPNet | ||||||
from numpy import linspace, pi | ||||||
from numpy import linspace | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
from skorch.callbacks import LRScheduler | ||||||
|
||||||
|
||||||
|
@@ -33,7 +33,7 @@ class Solver(BaseSolver): | |||||
"lr": [0.0625 * 0.01], | ||||||
"weight_decay": [0], | ||||||
"batch_size": [64], | ||||||
"n_epochs": [4], | ||||||
"n_epochs": [1], | ||||||
"proba": [0.5], | ||||||
|
||||||
} | ||||||
|
@@ -96,25 +96,25 @@ def set_objective(self, X, y, sfreq): | |||||
mask_len_samples=int(sfreq * second), | ||||||
random_state=seed, | ||||||
) | ||||||
for second in linspace(0.1, 2, 10) | ||||||
for second in linspace(0.1, 2, 3) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you decrease the number of samples? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was only to have faster results |
||||||
] | ||||||
|
||||||
elif self.augmentation == "ChannelDropout": | ||||||
transforms = [ | ||||||
ChannelsDropout( | ||||||
probability=self.proba, p_drop=prob, random_state=seed | ||||||
) | ||||||
for prob in linspace(0, 1, 10) | ||||||
for prob in linspace(0, 1, 3) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
] | ||||||
|
||||||
elif self.augmentation == "FTSurrogate": | ||||||
transforms = [ | ||||||
FTSurrogate( | ||||||
probability=self.proba, | ||||||
phase_noise_magnitude=prob, | ||||||
phase_noise_magnitude=phase_freq, | ||||||
random_state=seed, | ||||||
) | ||||||
for prob in linspace(0, 2 * pi, 10) | ||||||
for phase_freq in linspace(0, 1, 3) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
] | ||||||
else: | ||||||
transforms = [IdentityTransform()] | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -19,6 +19,9 @@ class Solver(AugmentedBCISolver): | |||||||||||
|
||||||||||||
name = "TGSPSVM" | ||||||||||||
parameters = { | ||||||||||||
"augmentation": [ | ||||||||||||
"SmoothTimeMask", | ||||||||||||
], | ||||||||||||
"covariances_estimator": ["oas"], | ||||||||||||
Comment on lines
+22
to
25
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
"tangentspace_metric": ["riemann"], | ||||||||||||
"svm_kernel": ["linear"], | ||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.