Skip to content
Open
195 changes: 195 additions & 0 deletions mmai25_hackathon/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from torch.utils.data import Dataset, Sampler
from torch_geometric.data import DataLoader

from .load_data.cxr import load_chest_xray_image, load_mimic_cxr_metadata
from .load_data.ecg import load_ecg_record, load_mimic_iv_ecg_record_list
from .load_data.echo import load_echo_dicom, load_mimic_iv_echo_record_list

__all__ = ["BaseDataset", "BaseDataLoader", "BaseSampler"]


Expand Down Expand Up @@ -101,13 +105,190 @@ class CXRDataset(BaseDataset):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.records = load_mimic_cxr_metadata(args.data_path)
self.subject_ids = self.records["subject_id"].tolist()

def __len__(self):
return len(self.records)

def __getitem__(self, sample_ID: int):
record_idx = self.records[self.records.subject_id == sample_ID]
samples = []
for idx in record_idx:
path = idx["cxr_path"]
image = load_chest_xray_image(path)
item = {"image": image, "subject_id": record_idx["subject_id"]}
samples.append(item)
return samples

def modality(self) -> str:
"""Return the modality of the dataset."""
return "CXR"


class ECGDataset(BaseDataset):
"""Example subclass for an ECG dataset."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Loading the ECG data (records contains patient id, hea path) in df frame
self.records = load_mimic_iv_ecg_record_list(args.data_path)
self.subject_ids = self.records["subject_id"].tolist()

def __len__(self):
return len(self.records)

def __getitem__(self, sample_ID: int):
"""Return samples for one sampleID from the dataset."""
# record_idx = self.records[idx]
# signals, fields = load_ecg_record(record_idx["hea_path"])
# return {"signals": signals, "fields": fields, "subject_id": record_idx["subject_id"]}
record_idx = self.records[self.records.subject_id == sample_ID]
samples = []
for idx in record_idx:
signals, fields = load_ecg_record(idx["hea_path"])
item = {"signals": signals, "fields": fields, "subject_id": record_idx["subject_id"]}
samples.append(item)
return samples

def __repr__(self) -> str:
"""Return a string representation of the dataset."""
return f"{self.__class__.__name__}({self.extra_repr()})"

def extra_repr(self) -> str:
"""Return any extra information about the dataset."""
return f"sample_size={len(self.subject_ids)}"

def modality(self) -> str:
"""Return the modality of the dataset."""
return "ECG"

def __add__(self, other):
"""
Combine with another dataset. Assume other is a single sample.

Override in subclasses to implement multimodal aggregation.

Args:
other: Another dataset to combine with this one.

Initial Idea:
Use `__add__` to align and merge heterogeneous modalities into a single
dataset, keeping shared IDs synchronized.
Note: This is not mandatory; treat it as a sketch you can refine or replace.
"""
self.records = self.records.merge(other.records, on="subject_id", suffixes=("", "_other"), how="outer")
self.subject_ids = self.records["subject_id"].tolist()

# TODO: takes a single sample from other, find corresponding sample in this dataset?
# i.e. find idx where sample_id matches from other and call get_item on all of those indices?

return self


class EchoDataset(BaseDataset):
"""Example subclass for an ECHO dataset."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.records = load_mimic_iv_echo_record_list(args.data_path)
self.subject_ids = self.records["subject_id"].tolist()

def __len__(self) -> int:
"""Return the number of samples in the dataset."""
return len(self.records)

def __getitem__(self, sample_ID: int):
"""Return a single sample from the dataset."""
# record = self.records.iloc[idx]
# sample_path = record["echo_path"]
# frames, meta = load_echo_dicom(sample_path)
# return {"frames": frames, "metadata": meta, "subject_id": record["subject_id"]}
record_idx = self.records[self.records.subject_id == sample_ID]
samples = []
for idx in record_idx:
sample_path = idx["echo_path"]
frames, meta = load_echo_dicom(sample_path)
item = {"frames": frames, "metadata": meta, "subject_id": idx["subject_id"]}
samples.append(item)
return samples

def extra_repr(self) -> str:
"""Return any extra information about the dataset."""
return f"sample_size={len(self)}, subjects={len(set(self.subject_ids))}"

def modality(self) -> str:
"""Return the modality of the dataset."""
return "echo"

def __add__(self, other):
"""
Combine with another dataset.

Override in subclasses to implement multimodal aggregation.

Args:
other: Another dataset to combine with this one.

Initial Idea:
Use `__add__` to align and merge heterogeneous modalities into a single
dataset, keeping shared IDs synchronized.
Note: This is not mandatory; treat it as a sketch you can refine or replace.
"""
self.records = self.records.merge(other.records, on="subject_id", suffixes=("", "_other"), how="outer")
self.subject_ids = self.records["subject_id"].tolist()
return self


class MultimodalDataset(BaseDataset):
"""Example subclass for a multimodal dataset."""

def __init__(self, datasets: list[BaseDataset], *args, **kwargs):
super().__init__(*args, **kwargs)
self.datasets = datasets
# _dataset = datasets[0]
# if not isinstance(_dataset, BaseDataset):
# raise ValueError("All elements in datasets must be instances of BaseDataset.")
# if len(datasets) > 1:
# for ds in datasets[1:]:
# if not isinstance(ds, BaseDataset):
# raise ValueError("All elements in datasets must be instances of BaseDataset.")
# _dataset.__add__(ds)
# self.dataset = _dataset

# get union of all subject IDs in each dataset
self.subject_ids = list(set().union(*(ds.subject_ids for ds in datasets)))
print(f"MultimodalDataset initialized with {len(self.subject_ids)} unique subjects.")

def __len__(self) -> int:
"""Return the number of samples in the dataset."""
return len(self.dataset)

def __getitem__(self, idx: int):
"""Return a single sample from the dataset."""
subject_ID = self.subject_ids[idx]
results = {}
for dataset in self.datasets:
items = dataset.__getitem__(subject_ID)
results[dataset.modality()] = items
return results

# get dictionaries for each dataset
# results = {}
# for ds in self.datasets:
# dict_result = ds.__getitem__(idx) # TODO: assumes idx is same for each sample. replace idx with sample ID
# results[ds.modality()] = dict_result

# or primary dataset and __add__ in others
# primary_ds = self.datasets[0]
# item = primary_ds.__getitem__(idx)
# for ds in self.datasets[1:]:
# items = ds.__add__(item)
# return results

def extra_repr(self) -> str:
"""Return any extra information about the dataset."""
return self.dataset.extra_repr()


class BaseDataLoader(DataLoader):
Expand All @@ -132,6 +313,20 @@ class BaseDataLoader(DataLoader):
Note: This is not a hard requirement. Consider it a future-facing idea you can evolve.
"""

def __init__(
self,
dataset: BaseDataset,
batch_size: int = 1,
shuffle: bool = False,
follow_batch: list = None,
exclude_keys: list = None,
**kwargs,
):
super().__init__(dataset, batch_size, shuffle, follow_batch, exclude_keys, **kwargs)

# collate_fn=lambda data_list: Batch.from_data_list(
# data_list, follow_batch),


class MultimodalDataLoader(BaseDataLoader):
"""Example dataloader for handling multiple data modalities."""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def __len__(self) -> int:
def __getitem__(self, idx: int) -> Data: # type: ignore[name-defined]
return self._graphs[idx]

ds = GraphDataset()
loader = BaseDataLoader(ds, batch_size=2, shuffle=False)
ds = GraphDataset() # type: ignore[arg-type]
loader = BaseDataLoader(ds, batch_size=2, shuffle=False) # type: ignore[arg-type]

total_graphs = 0
for batch in loader:
Expand Down