Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions mmai25_hackathon/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from torch.utils.data import Dataset, Sampler
from torch_geometric.data import DataLoader
from load_data.echo import load_echo_dicom, load_mimic_iv_echo_record_list

__all__ = ["BaseDataset", "BaseDataLoader", "BaseSampler"]

Expand Down Expand Up @@ -110,6 +111,53 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


class EchoDataset(BaseDataset):
"""Example subclass for an ECHO dataset."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.records = load_mimic_iv_echo_record_list(args.data_path)
self.subject_ids = self.records['subject_id'].tolist()

def __len__(self) -> int:
"""Return the number of samples in the dataset."""
return len(self.records)

def __getitem__(self, idx: int):
"""Return a single sample from the dataset."""
record = self.records[idx]
# Load and return the ECHO data for the given record
#print(f"Loading first ECHO DICOM from: {records.iloc[0]['echo_path']}")
sample_path = record["echo_path"]
frames, meta = load_echo_dicom(sample_path)
#meta_filtered = {
# k: meta[k] for k in ("NumberOfFrames", "Rows", "Columns", "FrameTime", "CineRate") if k in meta
#}
return {'frames': frames, 'metadata': meta, 'subject_id': record['subject_id']}

def extra_repr(self) -> str:
"""Return any extra information about the dataset."""
return f"sample_size={len(self)}, subjects={len(set(self.subject_ids))}"

def __add__(self, other):
"""
Combine with another dataset.

Override in subclasses to implement multimodal aggregation.

Args:
other: Another dataset to combine with this one.

Initial Idea:
Use `__add__` to align and merge heterogeneous modalities into a single
dataset, keeping shared IDs synchronized.
Note: This is not mandatory; treat it as a sketch you can refine or replace.
"""
self.records = self.records.merge(other.records, on='subject_id', suffixes=('', '_other'), how='outer')
self.subject_ids = self.records['subject_id'].tolist()
return self


class BaseDataLoader(DataLoader):
"""
DataLoader for graph and non-graph data.
Expand Down
Loading