Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions datasets/downstream/prepare_PhysioNetP300.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
tmax=2
for sub in [2,3,4,5,6,7,9,11]:

path = "erp-based-brain-computer-interface-recordings-1.0.0/files/s{:02d}".format(sub)
#path = "erp-based-brain-computer-interface-recordings-1.0.0/files/s{:02d}".format(sub)
path="erp-based-brain-computer-interface-recordings-1.0.0/files/erpbci/1.0.0/s{:02d}".format(sub) # based on the current state of the dataset
for file in os.listdir(path):
if not file.endswith(".edf"):continue
raw = mne.io.read_raw_edf(os.path.join(path, file))
Expand Down Expand Up @@ -44,7 +45,7 @@
# -- save
x = torch.tensor(d*1e3)
y = label
spath = dataset_fold+f'{y}/'
os.makedirs(path,exist_ok=True)
spath = spath + f'{i}.sub{sub}'
spath = os.path.join(dataset_fold, f'{y}')
os.makedirs(spath, exist_ok=True) # create the correct folder
spath = os.path.join(spath, f'{i}.sub{sub}.pt') # add .pt for PyTorch files
torch.save(x, spath)
6 changes: 3 additions & 3 deletions datasets/downstream/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ datasets/downstream/KaggleERN/test/Data_S25_Sess05.csv

PhysioP300 datasets can be downloaded from https://physionet.org/content/erpbci/1.0.0/ and save into the `datasets/downstream/erp-based-brain-computer-interface-recordings-1.0.0` folder, which organized as:
```
datasets/downstream/erp-based-brain-computer-interface-recordings-1.0.0/files/s01/rc01.edf
datasets/downstream/erp-based-brain-computer-interface-recordings-1.0.0/files/erpbci/1.0.0/s01/rc01.edf
...
datasets/downstream/erp-based-brain-computer-interface-recordings-1.0.0/files/s11/rc01.edf
datasets/downstream/erp-based-brain-computer-interface-recordings-1.0.0/files/erpbci/1.0.0/s11/rc01.edf
...
Then run the following command to preprocess the data:
```preprocess
Expand All @@ -74,4 +74,4 @@ For preparing Sleep-EDF dataset, you can run the following command to preprocess
```preprocess
cd datasets/downstream
python prepare_sleep.py
```
```
42 changes: 32 additions & 10 deletions downstream/finetune_BENDR_PhysioP300.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,21 +177,43 @@ def configure_optimizers(self):
global steps_per_epoch
global max_lr

class P300SubjectDataset(Dataset):
def __init__(self, root, subjects):
"""
root: path to PhysioNetP300 folder
subjects: list of subject numbers to include, e.g. [1,2,3]
"""
self.files = []
self.labels = []

for label in ['0','1']:
folder = os.path.join(root, label)
for f in os.listdir(folder):
if f.endswith('.pt') and any(f".sub{s}" in f for s in subjects):
self.files.append(os.path.join(folder, f))
self.labels.append(int(label))

def __len__(self):
return len(self.files)

def __getitem__(self, idx):
x = torch.load(self.files[idx])
y = self.labels[idx]
x = x.float() # Convert to float32
return x, y

batch_size=64
max_epochs = 100


all_subjects = [1,2,3,4,5,6,7,9,11]
for i,sub in enumerate(all_subjects):
sub_train = [f".sub{x}" for x in all_subjects if x!=sub]
sub_valid = [f".sub{sub}"]
print(sub_train, sub_valid)
train_dataset = torchvision.datasets.DatasetFolder(root="../datasets/downstream/PhysioNetP300", loader=torch.load, extensions=sub_train)
valid_dataset = torchvision.datasets.DatasetFolder(root="../datasets/downstream/PhysioNetP300", loader=torch.load, extensions=sub_valid)
for val_sub in (all_subjects):

train_subs = [s for s in all_subjects if s != val_sub] # all except the current
train_dataset = P300SubjectDataset(root="../datasets/downstream/PhysioNetP300", subjects=train_subs)
val_dataset = P300SubjectDataset(root="../datasets/downstream/PhysioNetP300", subjects=[val_sub])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=8, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, num_workers=8, shuffle=False)

valid_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=8, shuffle=False)
steps_per_epoch = math.ceil(len(train_loader))

# init model
Expand All @@ -204,7 +226,7 @@ def configure_optimizers(self):
trainer = pl.Trainer(accelerator='cuda',
max_epochs=max_epochs,
callbacks=callbacks,
logger=[pl_loggers.TensorBoardLogger('./logs/', name="BENDR_PhysioP300_tb", version=f"subject{sub}"),
logger=[pl_loggers.TensorBoardLogger('./logs/', name="BENDR_PhysioP300_tb", version=f"subject{val_sub}"),
pl_loggers.CSVLogger('./logs/', name="BENDR_PhysioP300_csv")])

trainer.fit(model, train_loader, valid_loader, ckpt_path='last')
trainer.fit(model, train_loader, valid_loader, ckpt_path='last')