diff --git a/datasets/downstream/prepare_PhysioNetP300.py b/datasets/downstream/prepare_PhysioNetP300.py index b70e56e..dca285c 100644 --- a/datasets/downstream/prepare_PhysioNetP300.py +++ b/datasets/downstream/prepare_PhysioNetP300.py @@ -14,7 +14,8 @@ tmax=2 for sub in [2,3,4,5,6,7,9,11]: - path = "erp-based-brain-computer-interface-recordings-1.0.0/files/s{:02d}".format(sub) + #path = "erp-based-brain-computer-interface-recordings-1.0.0/files/s{:02d}".format(sub) + path="erp-based-brain-computer-interface-recordings-1.0.0/files/erpbci/1.0.0/s{:02d}".format(sub) # based on the current state of the dataset for file in os.listdir(path): if not file.endswith(".edf"):continue raw = mne.io.read_raw_edf(os.path.join(path, file)) @@ -44,7 +45,7 @@ # -- save x = torch.tensor(d*1e3) y = label - spath = dataset_fold+f'{y}/' - os.makedirs(path,exist_ok=True) - spath = spath + f'{i}.sub{sub}' + spath = os.path.join(dataset_fold, f'{y}') + os.makedirs(spath, exist_ok=True) # create the correct folder + spath = os.path.join(spath, f'{i}.sub{sub}.pt') # add .pt for PyTorch files torch.save(x, spath) diff --git a/datasets/downstream/readme.md b/datasets/downstream/readme.md index 83ca667..32be82a 100644 --- a/datasets/downstream/readme.md +++ b/datasets/downstream/readme.md @@ -60,9 +60,9 @@ datasets/downstream/KaggleERN/test/Data_S25_Sess05.csv PhysioP300 datasets can be downloaded from https://physionet.org/content/erpbci/1.0.0/ and save into the `datasets/downstream/erp-based-brain-computer-interface-recordings-1.0.0` folder, which organized as: ``` -datasets/downstream/erp-based-brain-computer-interface-recordings-1.0.0/files/s01/rc01.edf +datasets/downstream/erp-based-brain-computer-interface-recordings-1.0.0/files/erpbci/1.0.0/s01/rc01.edf ... -datasets/downstream/erp-based-brain-computer-interface-recordings-1.0.0/files/s11/rc01.edf +datasets/downstream/erp-based-brain-computer-interface-recordings-1.0.0/files/erpbci/1.0.0/s11/rc01.edf ... Then run the following command to preprocess the data: ```preprocess @@ -74,4 +74,4 @@ For preparing Sleep-EDF dataset, you can run the following command to preprocess ```preprocess cd datasets/downstream python prepare_sleep.py -``` \ No newline at end of file +``` diff --git a/downstream/finetune_BENDR_PhysioP300.py b/downstream/finetune_BENDR_PhysioP300.py index 64f0320..92e52ed 100644 --- a/downstream/finetune_BENDR_PhysioP300.py +++ b/downstream/finetune_BENDR_PhysioP300.py @@ -177,21 +177,43 @@ def configure_optimizers(self): global steps_per_epoch global max_lr +class P300SubjectDataset(Dataset): + def __init__(self, root, subjects): + """ + root: path to PhysioNetP300 folder + subjects: list of subject numbers to include, e.g. [1,2,3] + """ + self.files = [] + self.labels = [] + + for label in ['0','1']: + folder = os.path.join(root, label) + for f in os.listdir(folder): + if f.endswith('.pt') and any(f".sub{s}" in f for s in subjects): + self.files.append(os.path.join(folder, f)) + self.labels.append(int(label)) + + def __len__(self): + return len(self.files) + + def __getitem__(self, idx): + x = torch.load(self.files[idx]) + y = self.labels[idx] + x = x.float() # Convert to float32 + return x, y + batch_size=64 max_epochs = 100 all_subjects = [1,2,3,4,5,6,7,9,11] -for i,sub in enumerate(all_subjects): - sub_train = [f".sub{x}" for x in all_subjects if x!=sub] - sub_valid = [f".sub{sub}"] - print(sub_train, sub_valid) - train_dataset = torchvision.datasets.DatasetFolder(root="../datasets/downstream/PhysioNetP300", loader=torch.load, extensions=sub_train) - valid_dataset = torchvision.datasets.DatasetFolder(root="../datasets/downstream/PhysioNetP300", loader=torch.load, extensions=sub_valid) +for val_sub in (all_subjects): + train_subs = [s for s in all_subjects if s != val_sub] # all except the current + train_dataset = P300SubjectDataset(root="../datasets/downstream/PhysioNetP300", subjects=train_subs) + val_dataset = P300SubjectDataset(root="../datasets/downstream/PhysioNetP300", subjects=[val_sub]) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=8, shuffle=True) - valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, num_workers=8, shuffle=False) - + valid_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=8, shuffle=False) steps_per_epoch = math.ceil(len(train_loader)) # init model @@ -204,7 +226,7 @@ def configure_optimizers(self): trainer = pl.Trainer(accelerator='cuda', max_epochs=max_epochs, callbacks=callbacks, - logger=[pl_loggers.TensorBoardLogger('./logs/', name="BENDR_PhysioP300_tb", version=f"subject{sub}"), + logger=[pl_loggers.TensorBoardLogger('./logs/', name="BENDR_PhysioP300_tb", version=f"subject{val_sub}"), pl_loggers.CSVLogger('./logs/', name="BENDR_PhysioP300_csv")]) - trainer.fit(model, train_loader, valid_loader, ckpt_path='last') \ No newline at end of file + trainer.fit(model, train_loader, valid_loader, ckpt_path='last')