diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8b462e5..44a6f6f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,64 +14,62 @@ on: jobs: test: - name: Test (${{ matrix.os }}, python version ${{ matrix.python-version }}) - runs-on: ${{ matrix.os }} + env: + DROPBOX_APP_KEY: ${{ secrets.DROPBOX_APP_KEY }} + DROPBOX_APP_SECRET: ${{ secrets.DROPBOX_APP_SECRET }} + DROPBOX_REFRESH_TOKEN: ${{ secrets.DROPBOX_REFRESH_TOKEN }} + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} strategy: matrix: - # os: [ubuntu-latest, windows-latest] - os: [ubuntu-latest] - python-version: ["3.10", "3.11", "3.12"] # list of Python versions to test - include: - - os: ubuntu-latest - path: ~/.cache/pip - # - os: windows-latest - # path: ~\AppData\Local\pip\Cache - + python-version: ["3.10", "3.11", "3.12"] + runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - name: Set up Python using Miniconda - uses: conda-incubator/setup-miniconda@v3 - with: - auto-update-conda: true - python-version: ${{ matrix.python-version }} - miniconda-version: latest - - name: Cache pip dependencies - id: cache_pip - uses: actions/cache@v4 + - name: Set up Python + uses: actions/setup-python@v5 with: - path: ${{ matrix.path }} - key: ${{ runner.os }}-python${{ matrix.python-version }}-pip-20250908-${{ hashFiles('**/pyproject.toml', '**/requirements*.txt') }} - restore-keys: | - ${{ runner.os }}-python${{ matrix.python-version }}-pip-20250908- + python-version: ${{ matrix.python-version }} + cache: "pip" + cache-dependency-path: | + pyproject.toml + requirements*.txt - - name: Install from pyproject (single list) + - name: Install dependencies env: PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu run: | python -m pip install --upgrade pip - python -m pip install -e . + python -m pip install -e .[dev] # PyG compiled extensions (need the wheel index) python -m pip install \ -f https://data.pyg.org/whl/torch-2.6.0+cpu.html \ torch_scatter==2.1.2 torch_sparse==0.6.18 torch_cluster==1.6.3 torch_spline_conv==1.2.2 - shell: bash -l {0} - # - name: Run tests with thread limits - # id: run_tests - # run: | - # export OMP_NUM_THREADS=1 - # export MKL_NUM_THREADS=1 - # export NUMEXPR_NUM_THREADS=1 - # pytest --cov=mmai25_hackathon - # shell: bash -l {0} + - name: Download datasets from Dropbox (optional) + if: ${{ env.DROPBOX_REFRESH_TOKEN != '' }} + continue-on-error: true + run: | + python -m tests.dropbox_download \ + "/MMAI25Hackathon" \ + "MMAI25Hackathon" \ + --app-key "$DROPBOX_APP_KEY" \ + --app-secret "$DROPBOX_APP_SECRET" \ + --refresh-token "$DROPBOX_REFRESH_TOKEN" \ + --unzip + + - name: Run tests + env: + OMP_NUM_THREADS: "1" + MKL_NUM_THREADS: "1" + NUMEXPR_NUM_THREADS: "1" + run: pytest --cov=mmai25_hackathon - # - name: Determine coverage - # run: | - # coverage xml - # shell: bash -l {0} + - name: Generate coverage XML + run: coverage xml - # - name: Report coverage - # uses: codecov/codecov-action@v4 - # with: - # token: ${{ secrets.CODECOV_TOKEN }} + - name: Upload coverage to Codecov (optional) + if: ${{ env.CODECOV_TOKEN != '' }} + uses: codecov/codecov-action@v4 + with: + token: ${{ env.CODECOV_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 09583ed..8a19213 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,48 +2,29 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v6.0.0 hooks: - - id: check-added-large-files - args: ["--maxkb=300"] - - id: fix-byte-order-marker - - id: check-case-conflict - - id: check-merge-conflict + - id: check-yaml - id: end-of-file-fixer - - id: forbid-new-submodules - - id: mixed-line-ending - id: trailing-whitespace - - id: debug-statements - - id: check-yaml - - id: requirements-txt-fixer - - repo: https://github.com/pycqa/flake8.git - rev: 6.1.0 - hooks: - - id: flake8 - args: [ --config=setup.cfg ] + - repo: https://github.com/psf/black - rev: 23.11.0 + rev: 25.1.0 hooks: - id: black - language_version: python3 - additional_dependencies: [ 'click==8.0.4' ] + args: ["--line-length=120"] + - repo: https://github.com/pycqa/isort - rev: 5.11.2 + rev: 6.0.1 hooks: - id: isort - name: isort - entry: python -m isort - args: [ --settings-path, ./pyproject.toml ] - language: system - types: [ python ] -# - repo: https://github.com/astral-sh/ruff-pre-commit -# # Ruff version. -# rev: v0.12.11 -# hooks: -# # Run the linter. -# - id: ruff-check -# args: [--fix] -# # Run the formatter. -# - id: ruff-format + args: ["--profile=black", "--line-length=120"] + + - repo: https://github.com/PyCQA/flake8 + rev: 7.3.0 + hooks: + - id: flake8 + additional_dependencies: [] + args: ["--max-line-length=120"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.7.1 + rev: v1.18.1 hooks: - id: mypy diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f08bfe4 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 mmai-hackathon + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 3baa76e..d1fc8c1 100644 --- a/README.md +++ b/README.md @@ -16,127 +16,86 @@ This repository provides the base source code for the MultimodalAI'25 workshop H ## Installation -### Prerequisite +The steps below are linear and work with `venv`, `conda`, or `uv`. Pick one method and follow it end‑to‑end. -Before installing other dependencies, install pykale with all optional dependencies (full extras) from git: +### 1) Clone and create an environment ```bash -pip install "git+https://github.com/pykale/pykale@main[full]" -``` - -You can set up your development environment using one of the following methods: `venv`, `conda`, or `uv`. +git clone https://github.com/pykale/mmai-hackathon.git +cd mmai-hackathon -### Main Installation Steps +# conda (recommended) +conda create -n mmai-hackathon python=3.11 -y +conda activate mmai-hackathon -1. **Clone the repository:** - ```bash - git clone https://github.com/pykale/mmai-hackathon.git - cd mmai-hackathon - ``` +# venv (alternative) +# python3 -m venv .venv && source .venv/bin/activate -2. **Set up a virtual environment (recommended):** +# uv (alternative) +# uv venv .venv && source .venv/bin/activate +``` - ```bash - python3 -m venv .venv - source .venv/bin/activate - ``` +### 2) Install dependencies (with tests) -3. **Install dependencies:** +```bash - ```bash - pip install --upgrade pip - # Install pykale with all optional dependencies (full extras) from git first - pip install "git+https://github.com/pykale/pykale@main[full]" - pip install -e . - ``` +# Recommended for development and testing (includes pytest, coverage, linters) +pip install -e .[dev] -#### Installing torch-geometric (pyg) and its extensions +# If you only need runtime dependencies (not recommended for contributors): +# pip install -e . +``` -To install torch-geometric (`pyg`) and its required extensions (such as `torch-scatter`, `torch-sparse`, etc.), use the following command with the appropriate URL for your PyTorch and CUDA version: +If you use features that depend on PyG (graph loaders, SMILES), install torch‑geometric wheels that match your Torch/CUDA. +The snippet below detects your installed Torch and CUDA, constructs the correct find‑links URL, and installs the wheels: ```bash -pip install torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.6.0+cpu.html +# Inspect Torch / CUDA (optional) +python - <<'PYINFO' +import torch +print('Torch:', torch.__version__) +print('CUDA version:', torch.version.cuda) +print('CUDA available:', torch.cuda.is_available()) +PYINFO + +# Install PyG wheels matching your Torch/CUDA +PYG_INDEX=$(python - <<'PYG' +import torch +torch_ver = torch.__version__.split('+')[0] +cuda = torch.version.cuda +if cuda: + cu_tag = f"cu{cuda.replace('.', '')}" +else: + cu_tag = 'cpu' +print(f"https://data.pyg.org/whl/torch-{torch_ver}+{cu_tag}.html") +PYG +) +echo "Using PyG wheel index: $PYG_INDEX" +pip install torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv -f "$PYG_INDEX" ``` -Replace the URL with the one matching your PyTorch and CUDA version. For more details and the latest URLs, see the official torch-geometric installation guide: https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html - ---- - -You can also use the following environment-specific guides: - -### Using conda (Anaconda/Miniconda) - -1. **Create and activate a conda environment:** - - ```bash - conda create -n mmai-hackathon python=3.10 - conda activate mmai-hackathon - ``` - -2. **Install dependencies:** - - ```bash - pip install -e . - ``` - -### Using uv (Ultra-fast Python package manager) - -Assuming `uv` is already installed: - -1. **Create and activate a uv virtual environment:** +More details: https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html - ```bash - uv venv .venv - source .venv/bin/activate - ``` +### 3) (Optional) Pre‑commit hooks -2. **Install dependencies:** - - ```bash - uv pip install -e . - ``` - ---- - -1. **Clone the repository:** - - ```bash - git clone https://github.com/pykale/mmai-hackathon.git - cd mmai-hackathon - ``` - -2. **Set up a virtual environment (recommended):** - - ```bash - python3 -m venv .venv - source .venv/bin/activate - ``` - -3. **Install dependencies:** - - ```bash - pip install --upgrade pip - pip install -e . - ``` - -4. **(Optional) Install pre-commit hooks:** - - ```bash - pre-commit install - ``` +```bash +pre-commit install +``` -5. **Run tests:** +### 4) Run tests - ```bash - pytest - ``` +```bash +pytest +``` ## Notes - The project restricts Python versions to 3.10–3.12 as specified in `.python-version` and `pyproject.toml`. - For more information about the dependencies, see `pyproject.toml`. +Tip: Integration tests optionally use real data. In CI, datasets are downloaded with `python -m tests.dropbox_download "/MMAI25Hackathon" "MMAI25Hackathon" --unzip` when a Dropbox token is configured. + ## Authors - Shuo Zhou () diff --git a/mmai25_hackathon/dataset.py b/mmai25_hackathon/dataset.py index 5beec6a..25b4ebe 100644 --- a/mmai25_hackathon/dataset.py +++ b/mmai25_hackathon/dataset.py @@ -7,26 +7,34 @@ We provided two base classes, but feel free to modify them as needed. Classes: - - BaseDataset: Template for custom datasets, supports multimodal aggregation. - - BaseDataLoader: Alias for torch_geometric.data.DataLoader for graph/non-graph batching. + BaseDataset: Template for custom datasets, supports multimodal aggregation. + BaseDataLoader: Template for custom dataloaders based on torch_geometric.data.DataLoader for graph/non-graph batching. + BaseSampler: Template for custom samplers, e.g., for multimodal sampling. """ -from torch.utils.data import Dataset -from torch_geometric.data import DataLoader as PyGDataLoader +from torch.utils.data import Dataset, Sampler +from torch_geometric.data import DataLoader + +__all__ = ["BaseDataset", "BaseDataLoader", "BaseSampler"] class BaseDataset(Dataset): """ - Base dataset class for creating custom datasets. - - The arguments and methods defined here can be customized as needed. + Template base class for building datasets. - The goal is to have easy to extend dataset class for various modalities that - can also be combined to obtain multimodal datasets + Subclasses must implement `__len__` and `__getitem__`. Optionally override `extra_repr()` + and `__add__()` (for multimodal aggregation) if needed. `prepare_data()` can be used + as a class method to handle data downloading, preprocessing, and splitting if necessary. Args: *args: Positional arguments for dataset initialization. **kwargs: Keyword arguments for dataset initialization. + + Initial Idea: + Support composing modality-specific datasets via the `+` operator, e.g., + `mm_ds = ecg_ds + image_ds [+ text_ds]`. Subclasses implementing `__add__` + should align samples (by index/ID) and return a combined dataset. + Note: This is not a strict requirement, just a starting idea you can adapt or improve. """ def __init__(self, *args, **kwargs): @@ -51,41 +59,75 @@ def extra_repr(self) -> str: def __add__(self, other): """ - Aggregate data with heterogenous modalities. - - Note: - This is an optional idea that we imagined, but feel free to ignore it - if there are any better ways you may thought of to better integrate different modalities. - - For example, we may have: - ```python - dataset1 = ECGDataset(...) - dataset2 = ImageDataset(...) - ... - datasetN = TextDataset(...) - ``` - - One way we imagine to combine them is by using the `+` operator, - such that all we need to do is: - ```python - multimodal_dataset = dataset1 + dataset2 + ... + datasetN - # If we only have dataset1 and datasetN, we can simply do - bimodal_dataset = dataset1 + datasetN - ``` + Combine with another dataset. + + Override in subclasses to implement multimodal aggregation. + + Args: + other: Another dataset to combine with this one. + + Initial Idea: + Use `__add__` to align and merge heterogeneous modalities into a single + dataset, keeping shared IDs synchronized. + Note: This is not mandatory; treat it as a sketch you can refine or replace. """ raise NotImplementedError("Subclasses may implement __add__ method if needed.") + @classmethod + def prepare_data(cls, *args, **kwargs): + """ + Prepare data for the dataset. Possible use case: + 1. Downloading data from a remote source. + 2. Preprocessing raw data into a format suitable for the dataset. + 3. Any other setup tasks required before the dataset can be used. An example + could be dataset subsetting to train/val/test splits. + 4. Returns the dataset object given the prepared data and available splits. + + You may skip this method if you feel that it is not necessary for your ideal use case. + + Args: + *args: Positional arguments for data preparation. + **kwargs: Keyword arguments for data preparation. + + Returns: + Union[BaseDataset, Dict[str, BaseDataset]]: The prepared dataset or a dictionary + of datasets for different splits (e.g., train, val, test). + """ + raise NotImplementedError("Subclasses may implement prepare_data class method if needed.") + -class BaseDataLoader(PyGDataLoader): +class BaseDataLoader(DataLoader): """ - A base dataloader directly inheriting from `torch_geometric.data.DataLoader` without any - modification. This is to ensure that both graph and non-graph data can be handled seamlessly. + DataLoader for graph and non-graph data. + + Directly inherits from `torch_geometric.data.DataLoader`. Use it like + `torch.utils.data.DataLoader`. + + Args: + dataset (BaseDataset): The dataset from which to load data. + batch_size (int): How many samples per batch to load. Default: 1. + shuffle (bool): Whether to reshuffle the data at every epoch. Default: False. + follow_batch (list): Creates assignment batch vectors for each key in the list. Default: None. + exclude_keys (list): Keys to exclude. Default: None. + **kwargs: Additional arguments forwarded to `torch.utils.data.DataLoader`. + + Initial Idea: + A future `MultimodalDataLoader` can accept a tuple of modality datasets and yield + batches like `{"ecg": ..., "image": ...}`. Missing modalities are simply absent + in that batch, keeping iteration simple and robust. + Note: This is not a hard requirement. Consider it a future-facing idea you can evolve. + """ + + +class BaseSampler(Sampler): + """ + Base sampler to extend for custom sampling strategies. Args: - dataset (BaseDataset): The dataset from which to load the data. - batch_size (int, optional): How many samples per batch to load. Default: 1 - shuffle (bool, optional): If set to True, the data will be reshuffled at every epoch. Default: False - follow_batch (List[str], optional): Creates assignment batch vectors for each key in the list. Default: None - exclude_keys (List[str], optional): Will exclude each key in the list. Default: None - **kwargs (optional): Additional arguments of torch.utils.data.DataLoader. + data_source (Sized): The dataset to sample from. + + Initial Idea: + A `MultimodalSampler` can coordinate indices across modality datasets to ensure + balanced or paired sampling before passing to `BaseDataLoader`. + Note: This is optional and meant as a design hint, not a constraint. """ diff --git a/mmai25_hackathon/load_data/cxr.py b/mmai25_hackathon/load_data/cxr.py index 3bd7e7e..a0c1f05 100644 --- a/mmai25_hackathon/load_data/cxr.py +++ b/mmai25_hackathon/load_data/cxr.py @@ -1,90 +1,188 @@ -import glob +""" +Chest x-ray (CXR) loading utilities for MIMIC-CXR. + +Functions: +load_mimic_cxr_metadata(cxr_path, filter_rows=None) + Scans the dataset directory for a metadata CSV, loads it, optionally filters rows, + and adds a `cxr_path` column pointing to each JPEG file (by DICOM ID). Returns a + `pd.DataFrame`. Raises `FileNotFoundError` if the dataset/CSV is missing and + `KeyError` if no suitable DICOM ID column is found. + +load_chest_xray_image(path, to_gray=True) + Opens a chest x-ray image with PIL. Converts to grayscale ("L") when `to_gray=True`, + otherwise returns RGB. Returns a `PIL.Image.Image`. Raises `FileNotFoundError` if the + image path does not exist. + +Preview CLI: +`python -m mmai25_hackathon.load_data.cxr --data-path /path/to/mimic-cxr-jpg-...` +Loads metadata, prints a preview, and opens a sample image. +""" + +import logging import os from pathlib import Path +from typing import Dict, Optional, Sequence, Union import pandas as pd from PIL import Image - -# ---- Configure your dataset root ---- -DATA_PATH = r"your_data_path_here" -CXR_DIR = "mimic-cxr-jpg-chest-radiographs-with-structured-labels-2.1.0/files" -FILES_PATH = os.path.join(DATA_PATH, CXR_DIR) - - -# ----------------------------- -# 1) Build paths from metadata -# ----------------------------- -def get_cxr_paths(base_path: str, csv_path: str | None = None): - """Return DataFrame with a resolved `path` to each JPG.""" - base = Path(base_path) - if not base.exists(): - raise FileNotFoundError(f"Base path not found: {base}") - - # Auto-find metadata CSV if not provided - if csv_path is None: - candidates = [] - for pat in [ - os.path.join(DATA_PATH, "**", "*metadata*.csv"), - os.path.join(DATA_PATH, "**", "mimic-cxr*-metadata*.csv"), - os.path.join(DATA_PATH, "**", "mimic-cxr-2.0.0-metadata.csv"), - ]: - candidates.extend(glob.glob(pat, recursive=True)) - if not candidates: - raise FileNotFoundError("Could not auto-find a metadata CSV. Please pass csv_path explicitly.") - csv_path = min(candidates, key=len) - - # Read CSV - df = pd.read_csv(csv_path) - - # Detect dicom id column - col_candidates = [c for c in df.columns] - lower_map = {c.lower(): c for c in col_candidates} - id_col = None - for key in ("dicom_id", "dicom", "image_id"): - if key in lower_map: - id_col = lower_map[key] - break - if id_col is None: - raise KeyError("No suitable ID column found. Expected one of: 'dicom_id', 'dicom', 'image_id'.") - - # Scan all JPGs once - jpg_map = {p.stem: p for p in base.rglob("*.jpg")} - - # Map id -> jpg path - df["path"] = df[id_col].astype(str).str.strip().map(lambda x: str(jpg_map.get(x, ""))) - - # Keep only matches - before = len(df) - df = df[df["path"] != ""].copy() - after = len(df) - print(f"Matched {after}/{before} rows using ID column '{id_col}'.") - - return df - - -# ----------------------------- -# 2) Load a single image -# ----------------------------- -def load_cxr_image(path: str, to_gray: bool = True): - """Load a single CXR image as PIL Image (grayscale by default).""" - if not path or not os.path.exists(path): +from sklearn.utils._param_validation import validate_params + +from .tabular import read_tabular + +__all__ = ["load_mimic_cxr_metadata", "load_chest_xray_image"] + +METADATA_PATTERNS = ("*metadata*.csv", "*mimic-cxr*-metadata*.csv", "mimic-cxr-2.0.0-metadata.csv") +DICOM_ID_COLUMN_CANDIDATES = ("dicom_id", "dicom", "image_id") + + +@validate_params({"cxr_path": [Path, str], "filter_rows": [None, dict]}, prefer_skip_nested_validation=True) +def load_mimic_cxr_metadata( + cxr_path: Union[Path, str], filter_rows: Optional[Dict[str, Union[Sequence, pd.Index]]] = None +) -> pd.DataFrame: + """ + Loads the MIMIC CXR metadata and maps available DICOM IDs to their corresponding image file paths. + + High-level steps: + - Validate dataset root, ensure `files/` exists and find a metadata CSV by pattern. + - Load metadata via `read_tabular`, optionally applying `filter_rows`; normalise columns to lower case. + - Identify a DICOM ID column, compute absolute `cxr_path` by rglob under `files/`. + - Keep only rows that successfully map to an existing `.jpg`. + - Return the filtered DataFrame. + + Args: + cxr_path (Union[Path, str]): The root directory of the MIMIC CXR dataset. + filter_rows (dict, optional): A dictionary to filter rows in the metadata DataFrame. + Keys are column names and values are the values to filter by. Default: None. + + Returns: + pd.DataFrame: A DataFrame containing the metadata with an additional column `cxr_path` + that provides the full path to each image file. Rows without a corresponding image file are excluded. + + Raises: + FileNotFoundError: If the specified `cxr_path` does not exist or if the metadata CSV cannot be found. + KeyError: If no suitable DICOM ID column is found in the metadata CSV. + + Examples: + >>> df_metadata = load_mimic_cxr_metadata("path/to/mimic-cxr") + >>> print(df_metadata.head()[["subject_id", "cxr_path"]]) + subject_id cxr_path + 0 101 mimic-iv/mimic-cxr-jpg-chest-radiographs-with-... + 1 101 mimic-iv/mimic-cxr-jpg-chest-radiographs-with-... + 2 101 mimic-iv/mimic-cxr-jpg-chest-radiographs-with-... + 3 101 mimic-iv/mimic-cxr-jpg-chest-radiographs-with-... + 4 101 mimic-iv/mimic-cxr-jpg-chest-radiographs-with-... + """ + if isinstance(cxr_path, str): + cxr_path = Path(cxr_path) + + if not cxr_path.exists(): + raise FileNotFoundError(f"MIMIC CXR path not found: {cxr_path}") + + if not (cxr_path / "files").exists(): + raise FileNotFoundError(f"Expected 'files' subdirectory not found under: {cxr_path}") + + # find metadata csv given patterns + metadata_path = None + for pat in METADATA_PATTERNS: + for subpath in cxr_path.rglob(pat): + # Stop if multiple found, take the first + if metadata_path is None: + metadata_path = subpath + break + + if metadata_path is None: + raise FileNotFoundError(f"Metadata could not be found in {cxr_path} given patterns: {METADATA_PATTERNS}") + + logger = logging.getLogger(f"{__name__}.load_mimic_cxr_metadata") + logger.info("Loading CXR metadata from: %s", metadata_path) + df_metadata = read_tabular(metadata_path, filter_rows=filter_rows) + df_metadata.columns = df_metadata.columns.str.lower() + dicom_id_col = df_metadata.columns.intersection(DICOM_ID_COLUMN_CANDIDATES) + + if len(dicom_id_col) == 0: + raise KeyError(f"No suitable DICOM ID column found. Expected one of: {DICOM_ID_COLUMN_CANDIDATES}") + dicom_id_col = dicom_id_col[0] # take the first match + + logger.info("Using DICOM ID column: %s", dicom_id_col) + logger.info("Found %d metadata entries in: %s", len(df_metadata), metadata_path) + logger.info("Mapping DICOM IDs to image files under: %s", cxr_path / "files") + + # image path column + df_metadata["cxr_path"] = ( + df_metadata[dicom_id_col] + .astype(str) + .str.strip() + .map(lambda x: str(next((cxr_path / "files").rglob(f"{x}.jpg"), ""))) + ) + + df_metadata = df_metadata[df_metadata["cxr_path"] != ""].copy() + logger.info("Mapped %d metadata entries to existing image files.", len(df_metadata)) + + return df_metadata + + +@validate_params({"path": [Path, str], "to_gray": ["boolean"]}, prefer_skip_nested_validation=True) +def load_chest_xray_image(path: Union[str, Path], to_gray: bool = True) -> Image.Image: + """ + Loads a chest X-ray image from the specified path. + + High-level steps: + - Validate the image path exists. + - Open image with PIL and convert to `L` (grayscale) or `RGB` based on `to_gray`. + - Return the converted `Image`. + + Args: + path (Union[str, Path]): The file path to the chest X-ray image. + to_gray (bool): If True, convert the image to grayscale. Default is True. + + Returns: + Image.Image: The loaded chest X-ray image. + + Raises: + FileNotFoundError: If the specified image file does not exist. + + Examples: + >>> image = load_chest_xray_image("path/to/image.jpg", to_gray=True) + >>> image.show() + """ + if isinstance(path, Path): + path = str(path) + + if not os.path.exists(path): raise FileNotFoundError(f"Image not found: {path}") - img = Image.open(path) - return img.convert("L") if to_gray else img.convert("RGB") + logger = logging.getLogger(f"{__name__}.load_chest_xray_image") + logger.info("Loading image: %s", path) -# --------- -# Example -# --------- -if __name__ == "__main__": - # Option A: Let the function auto-discover your metadata CSV under DATA_PATH - df = get_cxr_paths(FILES_PATH) + img = Image.open(path) + img = img.convert("L") if to_gray else img.convert("RGB") + logger.info("Loaded image size: %s, mode: %s", img.size, img.mode) + + return img - # Option B: Provide the exact CSV path if you know it - # csv_file = os.path.join(DATA_PATH, "mimic-cxr-2.0.0-metadata.csv") - # df = get_cxr_paths(FILES_PATH, csv_file) - print(df.head()) - if not df.empty: - img = load_cxr_image(df.iloc[0]["path"]) # PIL Image - # img.show() # uncomment to preview +if __name__ == "__main__": + import argparse + + # Example script: + # python -m mmai25_hackathon.load_data.cxr --data-path mimic-iv/mimic-cxr-jpg-chest-radiographs-with-structured-labels-2.1.0 + parser = argparse.ArgumentParser(description="Load MIMIC CXR metadata and images.") + parser.add_argument( + "--data-path", + type=str, + help="Path to the MIMIC CXR dataset directory.", + default="MMAI25Hackathon/mimic-iv/mimic-cxr-jpg-chest-radiographs-with-structured-labels-2.1.0", + ) + args = parser.parse_args() + + print("Loading MIMIC CXR metadata...") + metadata = load_mimic_cxr_metadata(args.data_path) + print(metadata.head()[["subject_id", "cxr_path"]]) + + # Example of loading an image + if not metadata.empty: + print() + print(f"Loading first chest x-ray image from: {metadata.iloc[0]['cxr_path']}") + example_path = metadata.iloc[0]["cxr_path"] + image = load_chest_xray_image(example_path) + image.show() diff --git a/mmai25_hackathon/load_data/ecg.py b/mmai25_hackathon/load_data/ecg.py index 4ec33e8..d36f987 100644 --- a/mmai25_hackathon/load_data/ecg.py +++ b/mmai25_hackathon/load_data/ecg.py @@ -1,72 +1,181 @@ +""" +MIMIC-IV Electrocardiogram (ECG) loading utilities. + +Functions: +load_mimic_iv_ecg_record_list(ecg_path, filter_rows=None) + Loads `record_list.csv`, verifies dataset layout (expects a `files/` subdirectory), constructs absolute + `ecg_path` for each record from the CSV `path` column, derives the corresponding `.hea` and `.dat` paths, + filters rows if provided, and keeps only rows where both files exist. Returns a `pd.DataFrame` with the + added columns: `ecg_path`, `hea_path`, and `dat_path`. + +load_ecg_record(hea_path) + Reads an ECG record with `wfdb.rdsamp` using the provided `.hea` file path (the stem is passed to WFDB). + Returns `(signals, metadata)` where `signals` is a `np.ndarray` shaped `(T, L)` (time samples x leads), + and `metadata` is a dict of WFDB header fields (e.g., `fs`, `n_sig`). + +Preview CLI: +`python -m mmai25_hackathon.load_data.ecg --data-path /path/to/mimic-iv-ecg-...` +Prints a preview of the record list (including `hea_path` and `dat_path`), then loads one example record +to report the array shape and selected metadata (sampling frequency and number of leads). +""" + +import logging import os from pathlib import Path +from typing import Any, Dict, Optional, Sequence, Tuple, Union +import numpy as np import pandas as pd +import wfdb +from sklearn.utils._param_validation import validate_params -# ---- Configure your dataset root ---- -DATA_PATH = r"your_data_path_here" -ECG_DIR = "mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0" -FILES_PATH = os.path.join(DATA_PATH, ECG_DIR) +from .tabular import read_tabular +__all__ = ["load_mimic_iv_ecg_record_list", "load_ecg_record"] -# ----------------------------- -# 1) Build paths from record_list -# ----------------------------- -def get_ecg_paths(base_path: str, csv_path: str): + +@validate_params({"ecg_path": [str, Path], "filter_rows": [None, dict]}, prefer_skip_nested_validation=True) +def load_mimic_iv_ecg_record_list( + ecg_path: Union[str, Path], filter_rows: Optional[Dict[str, Union[Sequence, pd.Index]]] = None +) -> pd.DataFrame: """ - Return a DataFrame with resolved ECG file paths from record_list.csv. + Load the MIMIC-IV-ECG `record_list.csv` file as a DataFrame and maps available `.dat` and `.hea` files + to their respective paths. The path must contain a `files` subdirectory with the ECG files. + + High-level steps: + - Validate dataset root, ensure `files/` and `record_list.csv` exist. + - Load CSV via `read_tabular`, optionally applying `filter_rows`. + - Strip `path`, resolve absolute `ecg_path`, and derive `hea_path`/`dat_path`. + - Keep only rows for which both `.hea` and `.dat` exist. + - Return the filtered DataFrame. + + Args: + ecg_path (Union[str, Path]): The root directory of the MIMIC-IV-ECG dataset. + filter_rows (dict, optional): A dictionary to filter rows in the metadata DataFrame. + Keys are column names and values are the values to filter by. Default: None. + + Returns: + pd.DataFrame: A DataFrame containing the contents of `record_list.csv` with additional columns: + - `ecg_path`: Full path to the original ECG file. + - `hea_path`: Full path to the corresponding `.hea` file. + - `dat_path`: Full path to the corresponding `.dat` file. + Only rows with both `.hea` and `.dat` files present are included. + + Raises: + FileNotFoundError: If the specified `ecg_path` does not exist, `files` subdirectory is missing, + or if the `record_list.csv` file cannot be found. + + Examples: + >>> df = load_mimic_iv_ecg_record_list("path/to/mimic-iv-ecg") + >>> print(df.head()) + subject_id hea_path dat_path + 0 101 mimic-iv/mimic-iv-ecg-diagnostic-electrocardio... mimic-iv/mimic-iv-ecg-diagnostic-electrocardio... + 1 101 mimic-iv/mimic-iv-ecg-diagnostic-electrocardio... mimic-iv/mimic-iv-ecg-diagnostic-electrocardio... + 2 101 mimic-iv/mimic-iv-ecg-diagnostic-electrocardio... mimic-iv/mimic-iv-ecg-diagnostic-electrocardio... + 3 102 mimic-iv/mimic-iv-ecg-diagnostic-electrocardio... mimic-iv/mimic-iv-ecg-diagnostic-electrocardio... + 4 102 mimic-iv/mimic-iv-ecg-diagnostic-electrocardio... mimic-iv/mimic-iv-ecg-diagnostic-electrocardio... """ - base = Path(base_path) - if not base.exists(): - raise FileNotFoundError(f"Base path not found: {base}") + if isinstance(ecg_path, str): + ecg_path = Path(ecg_path) + + if not ecg_path.exists(): + raise FileNotFoundError(f"MIMIC-IV-ECG base path not found: {ecg_path}") + + if not (ecg_path / "files").exists(): + raise FileNotFoundError(f"Expected 'files' subdirectory not found under: {ecg_path}") + + records_path = ecg_path / "record_list.csv" + if not records_path.exists(): + raise FileNotFoundError(f"'record_list.csv' not found in: {ecg_path}") - df = pd.read_csv(csv_path) - if "path" not in df.columns: - raise KeyError("CSV must contain a 'path' column.") + df_records = read_tabular(records_path, filter_rows=filter_rows) - abs_heas, abs_dats = [], [] - for rel in df["path"].astype(str): - abs_heas.append(str(base / f"{rel}.hea")) - abs_dats.append(str(base / f"{rel}.dat")) + logger = logging.getLogger(f"{__name__}.load_mimic_iv_ecg_record_list") + logger.info("Loaded %d records from %s", len(df_records), records_path) + logger.info("Mapping ECG file paths under: %s", ecg_path / "files") - df["hea_path"] = abs_heas - df["dat_path"] = abs_dats + df_records["path"] = df_records["path"].astype(str).str.strip() + df_records["ecg_path"] = df_records["path"].map(lambda rel_path: str(ecg_path / rel_path)) + df_records["hea_path"] = df_records["ecg_path"].map(lambda x: str(Path(x).with_suffix(".hea"))) + df_records["dat_path"] = df_records["ecg_path"].map(lambda x: str(Path(x).with_suffix(".dat"))) - before = len(df) - df = df[df["hea_path"].map(os.path.exists) & df["dat_path"].map(os.path.exists)].copy() - after = len(df) - print(f"Matched {after}/{before} records with both .hea and .dat present.") + existing_hea = df_records["hea_path"].map(os.path.exists) + existing_dat = df_records["dat_path"].map(os.path.exists) - return df + # Only collect records with both .hea and .dat present + available_files = existing_hea & existing_dat + logger.info("Found %d records with both .hea and .dat files present.", available_files.sum()) + return df_records[available_files].copy() + + +@validate_params({"hea_path": [str, Path]}, prefer_skip_nested_validation=True) +def load_ecg_record(hea_path: Union[str, Path]) -> Tuple[np.ndarray, Dict[str, Any]]: + """ + Load an ECG record given a .hea file path using wfdb. + + High-level steps: + - Coerce `hea_path` to `Path` and validate the file exists. + - Call `wfdb.rdsamp` with the stem (path without suffix). + - Return the sampled `signals` and `fields` metadata. + + Args: + hea_path (Union[str, Path]): The path to the .hea file. + + Returns: + Tuple[np.ndarray, Dict[str, Any]]: A tuple containing: + - signals (np.ndarray): The ECG signal data with shape (signal_length, num_leads). + - fields (Dict[str, Any]): A dictionary of metadata fields from the .hea file. + + Raises: + FileNotFoundError: If the specified .hea file does not exist. + + Examples: + >>> signals, fields = load_ecg_wfdb("path/to/record.hea") + >>> print(signals.shape) # (signal_length, num_leads) e.g. + (5000, 12) + >>> print(fields["fs"]) # Sampling frequency + 500 + """ + if isinstance(hea_path, str): + hea_path = Path(hea_path) + if not hea_path.exists(): + raise FileNotFoundError(f"ECG .hea path not found: {hea_path}") -# ----------------------------- -# 2) Load a single ECG record -# ----------------------------- -def load_ecg_record(hea_path: str): - """Load an ECG record given a .hea file path using wfdb.""" - if not hea_path or not os.path.exists(hea_path): - raise FileNotFoundError(f".hea not found: {hea_path}") + logger = logging.getLogger(f"{__name__}.load_ecg_record") + logger.info("Loading ECG record from: %s", hea_path) + signals, fields = wfdb.rdsamp(hea_path.with_suffix("").as_posix()) - try: - import wfdb - except ImportError as e: - raise ImportError("Install wfdb with: pip install wfdb") from e + logger.info("Loaded ECG signals with shape: %s", signals.shape) + logger.info("Metadata fields: %s", list(fields.keys())) - rec = os.path.splitext(hea_path)[0] # drop extension - signals, fields = wfdb.rdsamp(rec) return signals, fields -# --------- -# Example -# --------- if __name__ == "__main__": - csv_file = os.path.join(FILES_PATH, "record_list.csv") - df = get_ecg_paths(FILES_PATH, csv_file) - print(df.head()) - - if not df.empty: - sig, meta = load_ecg_record(df.iloc[0]["hea_path"]) - print("Signals shape:", sig.shape) - print("Sampling freq:", meta.get("fs")) + import argparse + + # Example script: + # python -m mmai25_hackathon.load_data.ecg --data-path mimic-iv/mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0 + parser = argparse.ArgumentParser(description="Load MIMIC-IV-ECG metadata and records.") + parser.add_argument( + "--data-path", + type=str, + help="Path to the MIMIC-IV-ECG dataset root (should contain 'files' subdirectory).", + default="MMAI25Hackathon/mimic-iv/mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0", + ) + args = parser.parse_args() + + print("Loading MIMIC-IV-ECG record list...") + records = load_mimic_iv_ecg_record_list(args.data_path) + print(records.head()[["subject_id", "hea_path", "dat_path"]]) + + # Example of loading a record + if not records.empty: + print() + print(f"Loading first ECG record from: {records.iloc[0]['hea_path']}") + signals, fields = load_ecg_record(records.iloc[0]["hea_path"]) + print(f"Loaded ECG signals with shape: {signals.shape}") + print(f"Metadata fields: {list(fields.keys())}") + print(f"Sampling frequency: {fields.get('fs', 'N/A')} Hz") + print(f"Number of leads: {fields.get('n_sig', 'N/A')}") diff --git a/mmai25_hackathon/load_data/echo.py b/mmai25_hackathon/load_data/echo.py index 7297d31..8f2b691 100644 --- a/mmai25_hackathon/load_data/echo.py +++ b/mmai25_hackathon/load_data/echo.py @@ -1,113 +1,182 @@ +""" +MIMIC-IV Echocardiogram (ECHO) loading utilities. + +Functions: +load_mimic_iv_echo_record_list(echo_path, filter_rows=None) + Loads `echo-record-list.csv`, verifies dataset layout (expects `files/`), constructs absolute + `echo_path` for each DICOM from `dicom_filepath`, filters rows if provided, and drops paths that + do not exist. Returns a `pd.DataFrame`. + +load_echo_dicom(path) + Reads an ECHO DICOM (cine or single-frame) with pydicom. Returns `(frames, metadata)` where + `frames` is `np.ndarray` shaped `(T, H, W)` (or `(1, H, W)`), rescaled via `RescaleSlope` and + `RescaleIntercept`, and `metadata` is a dict of DICOM keywords to values. + +Preview CLI: +`python -m mmai25_hackathon.load_data.echo --data-path /path/to/mimic-iv-echo-...` +Prints a preview of the record list and loads one example DICOM to report shape and selected metadata. +""" + +import logging import os from pathlib import Path +from typing import Any, Dict, Optional, Sequence, Tuple, Union +import numpy as np import pandas as pd +from pydicom import dcmread +from sklearn.utils._param_validation import validate_params -# ---- Configure your dataset root ---- -DATA_PATH = r"your_data_path_here" -ECHO_DIR = "mimic-iv-echo-0.1.physionet.org" -FILES_PATH = os.path.join(DATA_PATH, ECHO_DIR) +from .tabular import read_tabular +__all__ = ["load_mimic_iv_echo_record_list", "load_echo_dicom"] -# ----------------------------- -# 1) Build paths from echo-record-list.csv -# ----------------------------- -def get_echo_paths(base_path: str, csv_path: str): + +@validate_params({"echo_path": [str, Path], "filter_rows": [None, dict]}, prefer_skip_nested_validation=True) +def load_mimic_iv_echo_record_list( + echo_path: Union[str, Path], filter_rows: Optional[Dict[str, Union[Sequence, pd.Index]]] = None +) -> pd.DataFrame: """ - Return a DataFrame with resolved ECHO DICOM paths. + Load the MIMIC-IV-ECHO `echo-record-list.csv` file as a DataFrame and maps DICOM file paths. + The path must contain a `files` subdirectory with the DICOM files. + + High-level steps: + - Validate dataset root, ensure `files/` and `echo-record-list.csv` exist. + - Load CSV via `read_tabular`, optionally applying `filter_rows`. + - Strip `dicom_filepath`, resolve absolute `echo_path` under the root. + - Keep only rows whose `echo_path` exists on disk. + - Return the filtered DataFrame. + + Args: + echo_path (str): The root directory of the MIMIC-IV-ECHO dataset. + filter_rows (dict, optional): A dictionary to filter rows in the DataFrame. + Keys are column names and values are the values to filter by. Default: None. + + Returns: + pd.DataFrame: A DataFrame containing the contents of `echo-record-list.csv`. + + Raises: + FileNotFoundError: If the specified `echo_path` does not exist or if the CSV file cannot be found. + + Examples: + >>> df = load_mimic_iv_echo_record_list("path/to/mimic-iv-echo") + >>> print(df.head()) + subject_id study_id acquisition_datetime dicom_filepath echo_path + 0 101 133 03/10/2204 13:14 files/p100/p101/s133/133.dcm mimic-iv/mimic-iv-echo-0.1.physionet.org/files... + 1 101 231 03/10/2204 13:17 files/p100/p101/s231/231.dcm mimic-iv/mimic-iv-echo-0.1.physionet.org/files... + 2 101 378 03/10/2204 13:18 files/p100/p101/s378/378.dcm mimic-iv/mimic-iv-echo-0.1.physionet.org/files... + 3 102 484 03/10/2204 13:18 files/p100/p102/s484/484.dcm mimic-iv/mimic-iv-echo-0.1.physionet.org/files... + 4 102 548 03/10/2204 13:18 files/p100/p102/s548/548.dcm mimic-iv/mimic-iv-echo-0.1.physionet.org/files... """ - base = Path(base_path) - if not base.exists(): - raise FileNotFoundError(f"Base path not found: {base}") + if isinstance(echo_path, str): + echo_path = Path(echo_path) + + if not echo_path.exists(): + raise FileNotFoundError(f"MIMIC-IV-ECHO path not found: {echo_path}") - df = pd.read_csv(csv_path) - if "dicom_filepath" not in df.columns: - raise KeyError("CSV must contain a 'dicom_filepath' column (e.g., files/.../133.dcm)") + if not (echo_path / "files").exists(): + raise FileNotFoundError(f"Expected 'files' subdirectory not found under: {echo_path}") - df["dicom_filepath"] = df["dicom_filepath"].astype(str).str.strip() - df["dcm_path"] = df["dicom_filepath"].map(lambda rp: str(base / rp)) + records_path = echo_path / "echo-record-list.csv" + if not records_path.exists(): + raise FileNotFoundError(f"'echo-record-list.csv' not found in: {echo_path}") - before = len(df) - df = df[df["dcm_path"].map(os.path.exists)].copy() - after = len(df) - print(f"Matched {after}/{before} DICOM files present on disk.") - return df + logger = logging.getLogger(f"{__name__}.load_mimic_iv_echo_record_list") + logger.info("Loading ECHO record list from: %s", records_path) + df_records = read_tabular(records_path, filter_rows=filter_rows) + logger.info("Loaded %d records from %s", len(df_records), records_path) + logger.info("Mapping ECHO DICOM file paths under: %s", echo_path / "files") -# ----------------------------- -# 2) Load a single ECHO DICOM -# ----------------------------- -def load_echo_dicom(dcm_path: str): + df_records["dicom_filepath"] = df_records["dicom_filepath"].astype(str).str.strip() + df_records["echo_path"] = df_records["dicom_filepath"].map(lambda rel_path: str(echo_path / rel_path)) + + existing_files = df_records["echo_path"].map(os.path.exists) + logger.info("Found %d records with existing DICOM files.", existing_files.sum()) + + return df_records[existing_files].copy() + + +@validate_params({"path": [str, Path]}, prefer_skip_nested_validation=True) +def load_echo_dicom(path: Union[str, Path]) -> Tuple[np.ndarray, Dict[str, Any]]: """ Load an ECHO DICOM (supports multi-frame cine) using pydicom. - Returns: - frames: np.ndarray with shape (T, H, W) for multi-frame, or (1, H, W) for single image - meta: dict with handy fields (Rows, Columns, NumberOfFrames, FrameTime, CineRate, etc.) + High-level steps: + - Coerce `path` to `Path` and validate it exists. + - Read DICOM via `pydicom.dcmread` and extract `pixel_array`. + - If 2D, expand dims to shape `(1, H, W)`. + - Apply rescale using `RescaleSlope` and `RescaleIntercept` if present. + - Collect metadata from DICOM elements into a dictionary. + - Return `(frames, metadata)`. - Requires: pip install pydicom + Args: + path (Union[str, Path]): The path to the ECHO DICOM file. + + Returns: + Tuple[np.ndarray, Dict[str, Any]]: A tuple containing: + - frames: np.ndarray with shape (T, H, W) for multi-frame, or (1, H, W) for single image + - metadata: metadata from the DICOM file as a dictionary (e.g., Rows, Columns, NumberOfFrames, FrameTime, CineRate, etc.) + + Examples: + >>> frames, meta = load_echo_dicom("path/to/echo.dcm") + >>> print("Frames shape:", frames.shape) + Frames shape: (58, 708, 1016, 3) + >>> print("Meta:", {k: meta[k] for k in ("NumberOfFrames", "Rows", "Columns", "FrameTime", "CineRate")}) + Meta: {'NumberOfFrames': '58', 'Rows': 708, 'Columns': 1016, 'FrameTime': '33.6842', 'CineRate': '30'} """ - if not dcm_path or not os.path.exists(dcm_path): - raise FileNotFoundError(f"DICOM not found: {dcm_path}") - - try: - import pydicom - except ImportError as e: - raise ImportError("pydicom is required. Install with: pip install pydicom") from e - - ds = pydicom.dcmread(dcm_path) - arr = ds.pixel_array - - if arr.ndim == 2: - arr = arr[None, ...] - - slope = float(getattr(ds, "RescaleSlope", 1.0)) - inter = float(getattr(ds, "RescaleIntercept", 0.0)) - arr = (arr * slope + inter).astype(arr.dtype) - - meta = { - "Rows": int(getattr(ds, "Rows", arr.shape[-2])), - "Columns": int(getattr(ds, "Columns", arr.shape[-1])), - "NumberOfFrames": int(getattr(ds, "NumberOfFrames", arr.shape[0])), - "FrameTime_ms": float(getattr(ds, "FrameTime", 0.0)) if hasattr(ds, "FrameTime") else None, - "CineRate": int(getattr(ds, "CineRate", 0)) if hasattr(ds, "CineRate") else None, - "PhotometricInterpretation": getattr(ds, "PhotometricInterpretation", None), - "BitsAllocated": int(getattr(ds, "BitsAllocated", 0)) if hasattr(ds, "BitsAllocated") else None, - "PixelRepresentation": int(getattr(ds, "PixelRepresentation", 0)) - if hasattr(ds, "PixelRepresentation") - else None, - "StudyInstanceUID": getattr(ds, "StudyInstanceUID", None), - "SeriesInstanceUID": getattr(ds, "SeriesInstanceUID", None), - "SOPInstanceUID": getattr(ds, "SOPInstanceUID", None), - "Modality": getattr(ds, "Modality", None), - "SeriesDescription": getattr(ds, "SeriesDescription", None), - } - return arr, meta - - -# --------- -# Example -# --------- + if isinstance(path, str): + path = Path(path) + + if not path.exists(): + raise FileNotFoundError(f"ECHO DICOM not found: {path}") + + logger = logging.getLogger(f"{__name__}.load_echo_dicom") + logger.info("Loading ECHO DICOM: %s", path) + + echo = dcmread(path) + frames = echo.pixel_array # shape (num_frames, height, width) or (height, width) + if frames.ndim == 2: + frames = np.expand_dims(frames, 0) # single image to (1, H, W) + + logger.info("Loaded frames with shape: %s", frames.shape) + logger.info("Adjusting pixel values using RescaleSlope and RescaleIntercept if present.") + intercept = float(getattr(echo, "RescaleIntercept", 0.0)) + slope = float(getattr(echo, "RescaleSlope", 1.0)) + frames = frames * slope + intercept + + metadata = {elem.keyword: elem.value for elem in echo if elem.keyword} + logger.info("Extracted %d metadata fields from DICOM.", len(metadata)) + + return frames, metadata + + if __name__ == "__main__": - # Point to your echo-record-list.csv - csv_file = os.path.join(FILES_PATH, "echo-record-list.csv") - - df = get_echo_paths(FILES_PATH, csv_file) - print(df.head()) - - if not df.empty: - frames, meta = load_echo_dicom(df.iloc[0]["dcm_path"]) - print("Frames shape:", frames.shape) - print( - "Meta:", - { - k: meta[k] - for k in ( - "NumberOfFrames", - "Rows", - "Columns", - "FrameTime_ms", - "CineRate", - ) - }, - ) + import argparse + + # Example script: + # python -m mmai25_hackathon.load_data.echo --data-path mimic-iv/mimic-iv-echo-0.1.physionet.org + parser = argparse.ArgumentParser(description="Load MIMIC-IV-ECHO metadata and DICOM files.") + parser.add_argument( + "--data-path", + type=str, + help="Path to the MIMIC-IV-ECHO dataset directory (containing 'files' subdirectory).", + default="MMAI25Hackathon/mimic-iv/mimic-iv-echo-0.1.physionet.org", + ) + args = parser.parse_args() + + print("Loading MIMIC-IV-ECHO record list...") + records = load_mimic_iv_echo_record_list(args.data_path) + print(records.head()[["subject_id", "study_id", "echo_path"]]) + + # Example of loading a DICOM file + if not records.empty: + print() + print(f"Loading first ECHO DICOM from: {records.iloc[0]['echo_path']}") + example_path = records.iloc[0]["echo_path"] + frames, meta = load_echo_dicom(example_path) + meta_filtered = { + k: meta[k] for k in ("NumberOfFrames", "Rows", "Columns", "FrameTime", "CineRate") if k in meta + } + print(f"Loaded frames shape: {frames.shape}") + print(f"Metadata sample: {meta_filtered}") diff --git a/mmai25_hackathon/load_data/ehr.py b/mmai25_hackathon/load_data/ehr.py index 8fc45e3..66fbafb 100644 --- a/mmai25_hackathon/load_data/ehr.py +++ b/mmai25_hackathon/load_data/ehr.py @@ -1,195 +1,272 @@ +""" +MIMIC-IV Electronic Health Record (EHR) loading and merging utilities. + +Functions: +load_mimic_iv_ehr( + ehr_path, module='hosp'|'icu'|'both', tables=None, index_cols=None, + subset_cols=None, filter_rows=None, merge=True, join='inner', raise_errors=True +) + Discovers and loads CSV tables from the selected MIMIC-IV module(s). Optionally selects + per-table columns, filters rows, and merges tables by overlapping key columns (via + `merge_multiple_dataframes`). Returns a dict of DataFrames when `merge=False` or a single + merged DataFrame when `merge=True`. Raises `FileNotFoundError` if the dataset/subfolders are + missing and `ValueError` for invalid table names or disjoint merge components. + +Notes: +- Column selection and row filtering are delegated to `read_tabular`. +- `index_cols` are used as merge keys only (the DataFrame index is not set by this helper). + +Preview CLI: +`python -m mmai25_hackathon.load_data.ehr --data-path /path/to/mimic-iv-3.1` +Loads a small example (e.g., ICU stays + admissions), merges on `subject_id, hadm_id`, and prints a preview. +""" + +import logging from pathlib import Path -from typing import Dict, Literal, Optional, Tuple +from typing import Dict, List, Literal, Optional, Sequence, Union import pandas as pd +from sklearn.utils._param_validation import StrOptions, validate_params -# ---- Configure your dataset root ---- -DATA_PATH = r"your_data_path_here" -HOSP_DIR = "mimic-iv-3.1/hosp" -ICU_DIR = "mimic-iv-3.1/icu" - -# -------------------------------------- -# 1) File maps for HOSP and ICU -# -------------------------------------- -HOSP_TABLES = [ - "admissions.csv", - "diagnoses_icd.csv", - "drgcodes.csv", - "emar.csv", - "emar_detail.csv", - "hcpcsevents.csv", - "labevents.csv", - "microbiologyevents.csv", - "omr.csv", - "patients.csv", - "pharmacy.csv", - "poe.csv", - "poe_detail.csv", - "prescriptions.csv", - "procedures_icd.csv", - "provider.csv", - "services.csv", - "transfers.csv", - "d_hcpcs.csv", - "d_icd_diagnoses.csv", - "d_icd_procedures.csv", - "d_labitems.csv", -] - -ICU_TABLES = [ - "caregiver.csv", - "chartevents.csv", - "d_items.csv", - "datetimeevents.csv", - "icustays.csv", - "ingredientevents.csv", - "inputevents.csv", - "outputevents.csv", - "procedureevents.csv", -] - - -# -------------------------------------- -# 2) Helper to read CSV (with filtering) -# -------------------------------------- -def _read_csv( - filepath: Path, - keep_cols: Optional[Tuple[str, ...]] = None, - dtypes: Optional[Dict[str, str]] = None, - filters: Optional[Dict[str, object]] = None, -) -> pd.DataFrame: - """ - Wrapper around pd.read_csv with optional column selection and filtering. +from .tabular import merge_multiple_dataframes, read_tabular + +__all__ = ["load_mimic_iv_ehr"] + +MIMIC_IV_EHR_AVAILABLE_TABLES = { + "hosp": ( + "admissions", + "diagnoses_icd", + "drgcodes", + "emar", + "emar_detail", + "hcpcsevents", + "labevents", + "microbiologyevents", + "omr", + "patients", + "pharmacy", + "poe", + "poe_detail", + "prescriptions", + "procedures_icd", + "provider", + "services", + "transfers", + "d_hcpcs", + "d_icd_diagnoses", + "d_icd_procedures", + "d_labitems", + ), + "icu": ( + "caregiver", + "chartevents", + "d_items", + "datetimeevents", + "icustays", + "ingredientevents", + "inputevents", + "outputevents", + "procedureevents", + ), +} + + +@validate_params( + { + "ehr_path": [str, Path], + "module": [StrOptions({"hosp", "icu", "both"})], + "tables": [None, "array-like"], + "index_cols": [None, list, str], + "subset_cols": [None, dict], + "filter_rows": [None, dict], + "merge": ["boolean"], + "join": [StrOptions({"inner", "outer", "left", "right"})], + "raise_errors": ["boolean"], + }, + prefer_skip_nested_validation=True, +) +def load_mimic_iv_ehr( + ehr_path: Union[str, Path], + module: Literal["hosp", "icu", "both"] = "hosp", + tables: Optional[Sequence[str]] = None, + index_cols: Optional[Union[List[str], str]] = None, + subset_cols: Optional[Dict[str, Sequence[str]]] = None, + filter_rows: Optional[Dict[str, Union[Sequence, pd.Index]]] = None, + merge: bool = True, + join: Literal["inner", "outer", "left", "right"] = "inner", + raise_errors: bool = True, +) -> Union[Dict[str, pd.DataFrame], pd.DataFrame]: """ - if not filepath.exists(): - raise FileNotFoundError(filepath) + Query, load, and aggregate MIMIC-IV EHR data from specified module(s) and tables. - df = pd.read_csv(filepath, usecols=keep_cols, dtype=dtypes) + High-level steps: + - Coerce `ehr_path` to `Path` and validate root; select modules (`hosp`, `icu`, or both) and check subfolders (optional). + - Discover available `tables` under selected modules (or validate requested names). + - Load each table via `read_tabular` with optional `subset_cols`, `index_cols`, and `filter_rows`. + - When `merge=False`, return a dict of DataFrames keyed by table name. + - When `merge=True`, call `merge_multiple_dataframes` on loaded tables; if multiple disjoint components remain, raise. + - Return the merged DataFrame. - if filters: - for col, allowed in filters.items(): - df = df[df[col].isin(allowed)] + Args: + ehr_path (Union[str, Path]): Path to the root folder containing `hosp` and/or `icu` subfolders. + module (Literal['hosp', 'icu', 'both']): Module(s) to load data from. The 'hosp' module contains hospital-wide + data, while the 'icu' module contains intensive care unit-specific data. Default: 'hosp'. + tables (Optional[Sequence[str]]): Specific sequences of tables to load. If None, all available tables in the selected + module(s) will be loaded. Default: None. + index_cols (Optional[List[str]]): Columns to use as keys for merging tables. If None and merge=True, will try to do + naive concatenation. Default: None. + subset_cols (Optional[Dict[str, List[str]]]): Per-table column selection. If provided, only these columns + will be loaded from each table. Default: None. + filter_rows (Optional[Dict[str, Union[Sequence, pd.Index]]]): A dictionary to filter rows in the DataFrame. + Keys are column names and values are the values to filter by. Default: None. + merge (bool): Whether to merge the loaded tables into components based on shared keys. Default: True. + join (str): Merge strategy to use when merging tables given merge=True. Options include 'inner', 'outer', 'left', and 'right'. + Default: 'inner'. + raise_errors (bool): If True, will raise an error for the following criterias: + 1. `modules` not found in `ehr_path`. Will fetch existing ones if False. + 2. `subset_cols` or `index_cols` provided but none of the specified columns are found in the DataFrame. + 3. `filter_rows` provided but none of the specified values are found in the DataFrame. - return df + Returns: + Union[Dict[str, pd.DataFrame], pd.DataFrame]: If merge is False, returns a dictionary of DataFrames + keyed by table names. If merge is True, returns a single merged DataFrame. + Raises: + FileNotFoundError: If `ehr_path` does not exist or if the specified `modules` subfolder is not found + and `raise_errors` is True. + ValueError: If no available tables are found for the specified `modules`, if any of the requested `tables` + are not available, or if merging results in multiple components with exclusive keys. -# -------------------------------------- -# 3) Loader -# -------------------------------------- -def get_tabular_mimic( - base_path: str, - domain: Literal["hosp", "icu", "both"] = "hosp", - tables: Optional[Tuple[str, ...]] = None, - keep_cols: Optional[Dict[str, Tuple[str, ...]]] = None, - dtypes: Optional[Dict[str, Dict[str, str]]] = None, - filters: Optional[Dict[str, Dict[str, object]]] = None, -) -> Dict[str, pd.DataFrame]: - """ - Load tabular data from MIMIC-IV-3.1. - - Parameters - ---------- - base_path : str - Root folder containing `hosp` and/or `icu` subfolders. - domain : {'hosp','icu','both'} - Which domain to load from. - tables : tuple of str | None - Which CSV files to load. If None, load all in the domain(s). - keep_cols, dtypes : dict - Per-table settings for column selection and dtypes. - filters : dict - Row filters, e.g., {"icustays": {"subject_id": [123]}}. - - Returns - ------- - dict[str, pd.DataFrame] - Dict of {table_name: DataFrame}. - """ - base = Path(base_path) - if not base.exists(): - raise FileNotFoundError(f"Base path not found: {base}") - - selected_domains = [] - if domain in ["hosp", "both"]: - selected_domains.append((HOSP_DIR, HOSP_TABLES)) - if domain in ["icu", "both"]: - selected_domains.append((ICU_DIR, ICU_TABLES)) - - out = {} - for dom, file_list in selected_domains: - dom_path = base / dom - for fname in file_list: - table_name = fname.replace(".csv", "") - if tables and table_name not in tables: - continue - - filepath = dom_path / fname - cols = keep_cols.get(table_name) if keep_cols else None - dtmap = dtypes.get(table_name) if dtypes else None - fltrs = filters.get(table_name) if filters else None - - df = _read_csv(filepath, keep_cols=cols, dtypes=dtmap, filters=fltrs) - print(f"Loaded {table_name}: {len(df)} rows, {len(df.columns)} columns.") - out[table_name] = df - - return out - - -# -------------------------------------- -# 4) Merge helper -# -------------------------------------- -def merge_tables( - frames: Dict[str, pd.DataFrame], - how: str = "inner", - on: Optional[Tuple[str, ...]] = None, - plan: Optional[Tuple[Tuple[str, str, Tuple[str, ...]], ...]] = None, -) -> pd.DataFrame: - """ - Merge multiple tables. - - Parameters - ---------- - frames : dict[str, pd.DataFrame] - how : str - Merge type. - on : tuple of str - Keys for merge if same for all. - plan : sequence of (left, right, keys) - Explicit multi-step merge plan. - - Returns - ------- - pd.DataFrame + Examples: + >>> # Load specific tables from both modules and merge them on 'subject_id' and 'hadm_id' + >>> df = load_mimic_iv_ehr( + ... ehr_path="path/to/mimic-iv-3.1", + ... module="both", + ... tables=["icustays", "admissions"], + ... index_cols=["subject_id", "hadm_id"], + ... subset_cols={"icustays": ["first_careunit"], "admissions": ["admittime"]}, + ... merge=True, + ... join="inner", + ... ) + >>> print(df.head()) + subject_id hadm_id admittime first_careunit + 0 101 1 24/02/2196 14:38 Neuro Stepdown + 1 101 2 17/09/2153 17:08 Neuro Surgical Intensive Care Unit (Neuro SICU) + 2 101 3 18/08/2134 02:02 Neuro Intermediate + 3 102 4 13/11/2111 23:39 Trauma SICU (TSICU) + 4 102 5 04/08/2113 18:46 Trauma SICU (TSICU) + 5 103 6 12/12/2132 01:43 Trauma SICU (TSICU) """ - if plan: - df = frames[plan[0][0]] - for left, right, keys in plan: - df = df.merge(frames[right], how=how, on=keys) - return df - else: - # Simple reduce-style merge - keys = on or ("subject_id",) - dfs = list(frames.values()) - df = dfs[0] - for other in dfs[1:]: - df = df.merge(other, how=how, on=keys) - return df - - -# --------- -# Example -# --------- + if isinstance(ehr_path, str): + ehr_path = Path(ehr_path) + + if not ehr_path.exists(): + raise FileNotFoundError(f"MIMIC-IV EHR path not found: '{ehr_path}'") + + # Check if hosp and/or icu directories exist, expect to be validated + # later will add sklearn params validation + selected_modules = ["hosp", "icu"] if module == "both" else [module] + + # need to check availability of selected modules + # sklearn param validation doesn't support this + for mod in selected_modules: + if not (ehr_path / mod).exists() and raise_errors: + raise FileNotFoundError(f"Expected subfolder '{mod}' not found in {ehr_path}") + + # generate dictionary of available tables to load given selected modules, tables, and paths + available_tables = {} + for mod in selected_modules: + if tables is None: + for table in MIMIC_IV_EHR_AVAILABLE_TABLES[mod]: + path = ehr_path / mod / f"{table}.csv" + if path.exists(): + available_tables[table] = path + continue + + for table in tables: + path = ehr_path / mod / f"{table}.csv" + if table in MIMIC_IV_EHR_AVAILABLE_TABLES[mod] and path.exists(): + available_tables[table] = path + + logger = logging.getLogger(f"{__name__}.load_mimic_iv_ehr") + logger.info("Selected modules: %s", selected_modules) + logger.info("Available tables to load: %s", list(available_tables.keys())) + + # Validate we have at least one table to load + if len(available_tables) == 0: + raise ValueError(f"No available tables found for modules: {selected_modules}") + + # Check available tables if any missing + if tables is not None: + missing_tables = set(tables) - set(available_tables.keys()) + if missing_tables: + raise ValueError(f"The following requested tables are not available: {missing_tables}") + + # Load tables + logger.info("Loading tables from: %s", ehr_path) + dfs = { + table: read_tabular( + path, + (subset_cols or {}).get(table, None), + index_cols, + filter_rows, + raise_errors=raise_errors, + ) + for table, path in available_tables.items() + } + + if not merge: + return dfs + + logging.info("Merging tables on keys: %s using '%s' join", index_cols, join) + aggregated_dfs = merge_multiple_dataframes( + list(dfs.values()), dfs_name=list(dfs.keys()), index_cols=index_cols, join=join + ) + + if len(aggregated_dfs) != 1: + # Find exclusive keys between aggregated dataframes + all_keys = [set(keys) for keys, _ in aggregated_dfs] + exclusive_keys = set().union(*all_keys) - set().intersection(*all_keys) + raise ValueError( + f"Merging resulted in multiple components with exclusive keys: {exclusive_keys}. " + "Consider using a different set of index_cols or setting merge=False when " + "loading tables with exclusive/disjoint keys." + ) + + _, merged_df = aggregated_dfs[0] + logger.info("Merged DataFrame shape: %s", merged_df.shape) + + return merged_df + + if __name__ == "__main__": - # Example: load ICU stays + admissions - data = get_tabular_mimic( - DATA_PATH, - domain="both", - tables=("icustays", "admissions"), - keep_cols={ - "icustays": ("subject_id", "hadm_id", "stay_id"), - "admissions": ("subject_id", "hadm_id"), + import argparse + + # Example script given the relative path to folder mimic-iv + # containing mimic-iv-3.1 that has hosp and icu subfolders: + # python -m mmai25_hackathon.load_data.ehr --data-path mimic-iv/mimic-iv-3.1 + parser = argparse.ArgumentParser(description="Fetch MIMIC-IV EHR data example.") + parser.add_argument( + "--data-path", + type=str, + help="Path to the MIMIC-IV EHR root directory (mimic-iv-3.1).", + default="MMAI25Hackathon/mimic-iv/mimic-iv-3.1", + ) + args = parser.parse_args() + + print("Loading MIMIC-IV EHR data example...") + dfs_new = load_mimic_iv_ehr( + ehr_path=args.data_path, + module="both", + tables=["icustays", "admissions"], + index_cols=["subject_id", "hadm_id"], + subset_cols={ + "icustays": ["first_careunit"], + "admissions": ["admittime"], }, + filter_rows={"subject_id": [101]}, + merge=True, + join="inner", ) - merged = merge_tables(data, how="inner", on=("subject_id", "hadm_id")) - print(merged.head()) + print(dfs_new.head()) diff --git a/mmai25_hackathon/load_data/molecule.py b/mmai25_hackathon/load_data/molecule.py index 5d3477f..c7f1998 100644 --- a/mmai25_hackathon/load_data/molecule.py +++ b/mmai25_hackathon/load_data/molecule.py @@ -1,29 +1,60 @@ """ -Molecular data utilities for handling SMILES strings and graph conversion. +Molecular (SMILES) loading and graph conversion utilities. Functions: - - fetch_smiles_from_dataframe: Extract SMILES strings from DataFrame or CSV. - - smiles_to_graph: Convert SMILES to a molecular graph (uses PyG's from_smiles). - -Note: - - The function smiles_to_graph is a wrapper for PyG's native from_smiles implementation, provided for clarity in this hackathon context. +fetch_smiles_from_dataframe(df, smiles_col, index_col=None) + Fetches SMILES strings from a DataFrame or CSV. Uses `read_tabular` when a path is provided. Optionally sets an index, + and returns a one-column DataFrame named `"smiles"` (index reset if `index_col` is None). + +smiles_to_graph(smiles, with_hydrogen=False, kekulize=False) + Converts a SMILES string into a PyTorch Geometric `Data` object via `torch_geometric.utils.smiles.from_smiles`. + Returns a graph `Data` with typical keys: `x` (node features), `edge_index` (COO connectivity), and `edge_attr`. + Flags `with_hydrogen` and `kekulize` are forwarded to the underlying conversion. + +Preview CLI: +`python -m mmai25_hackathon.load_data.molecule --data-path /path/to/dataset.csv` +Reads the CSV, prints a small preview of the SMILES column, and converts the first few entries to graphs, printing each +graph’s summary (e.g., number of nodes/edges and feature sizes). """ -from typing import Union +import logging +from typing import Dict, Optional, Sequence, Union import pandas as pd +from sklearn.utils._param_validation import validate_params from torch_geometric.data import Data from torch_geometric.utils.smiles import from_smiles +from .tabular import read_tabular + +__all__ = ["fetch_smiles_from_dataframe", "smiles_to_graph"] -def fetch_smiles_from_dataframe(df: Union[pd.DataFrame, str], smiles_col: str, index_col: str = None) -> pd.DataFrame: + +@validate_params( + {"df": [pd.DataFrame, str], "smiles_col": [str], "index_col": [None, str], "filter_rows": [None, dict]}, + prefer_skip_nested_validation=True, +) +def fetch_smiles_from_dataframe( + df: Union[pd.DataFrame, str], + smiles_col: str, + index_col: str = None, + filter_rows: Optional[Dict[str, Union[Sequence, pd.Index]]] = None, +) -> pd.DataFrame: """ Fetches SMILES strings from a DataFrame or CSV file. Will read the CSV if a path is provided. + High-level steps: + - If `df` is a path, load via `read_tabular` selecting `smiles_col` and optional `index_col`; apply `filter_rows`. + - If `df` is a DataFrame and `filter_rows` is provided, apply row filters where columns exist. + - Validate `smiles_col` exists; optionally set DataFrame index. + - Return a one-column DataFrame named `"smiles"` (index preserved if set). + Args: df (Union[pd.DataFrame, str]): DataFrame or path to CSV file. smiles_col (str): Column name for SMILES representations. index_col (str, optional): Column to set as index. Default: None. + filter_rows (dict, optional): A dictionary to filter rows in the DataFrame. + Keys are column names and values are the values to filter by. Default: None. Returns: pd.DataFrame: A single column DataFrame containing the SMILES strings with name `"smiles"`. @@ -44,21 +75,34 @@ def fetch_smiles_from_dataframe(df: Union[pd.DataFrame, str], smiles_col: str, i 3 CC(=O)O """ if isinstance(df, str): - df = pd.read_csv(df) + df = read_tabular(df, subset_cols=smiles_col, index_cols=index_col, filter_rows=filter_rows) + else: + for col, valid_vals in (filter_rows or {}).items(): + if col in df.columns: + df = df[df[col].isin(valid_vals)] if smiles_col not in df.columns: raise ValueError(f"Column '{smiles_col}' not found in DataFrame.") + logger = logging.getLogger(f"{__name__}.fetch_smiles_from_dataframe") + if index_col is not None: df = df.set_index(index_col) + logger.info("Setting index column to '%s'.", index_col) - return df[smiles_col].to_frame("smiles") + logger.info("Fetched %d SMILES strings from column '%s'.", len(df), smiles_col) + return df[smiles_col].to_frame("smiles").reset_index(drop=index_col is None) +@validate_params({"smiles": [str], "with_hydrogen": [bool], "kekulize": [bool]}, prefer_skip_nested_validation=True) def smiles_to_graph(smiles: str, with_hydrogen: bool = False, kekulize: bool = False) -> Data: """ Converts a SMILES string to a molecular graph representation. + High-level steps: + - Forward to `torch_geometric.utils.smiles.from_smiles` with `with_hydrogen` and `kekulize` flags. + - Return the resulting `torch_geometric.data.Data` graph. + Args: smiles (str): The SMILES string to convert. with_hydrogen (bool): Store hydrogens in the graph if True. Default: False @@ -79,20 +123,26 @@ def smiles_to_graph(smiles: str, with_hydrogen: bool = False, kekulize: bool = F >>> print(graph) Data(x=[3, 9], edge_index=[2, 4], edge_attr=[4, 3], smiles='CCO') """ + logger = logging.getLogger(f"{__name__}.smiles_to_graph") + logger.info("Converting SMILES to graph: %s", smiles) return from_smiles(smiles, with_hydrogen, kekulize) if __name__ == "__main__": import argparse - # Example script: python -m mmai25_hackathon.load_data.molecule dataset.csv - + # Example script: python -m mmai25_hackathon.load_data.molecule --data-path MMAI25Hackathon/molecule-protein-interaction/dataset.csv parser = argparse.ArgumentParser(description="Process SMILES strings.") - parser.add_argument("csv_path", type=str, help="Path to the CSV file containing SMILES strings.") + parser.add_argument( + "--data-path", + type=str, + help="Path to the CSV file containing SMILES strings.", + default="MMAI25Hackathon/molecule-protein-interaction/dataset.csv", + ) args = parser.parse_args() # Take from Peizhen's csv file for DrugBAN training - df = fetch_smiles_from_dataframe(args.csv_path, smiles_col="SMILES") + df = fetch_smiles_from_dataframe(args.data_path, smiles_col="SMILES") for i, smiles in enumerate(df["smiles"].head(5), 1): graph = smiles_to_graph(smiles) print(i, graph) diff --git a/mmai25_hackathon/load_data/omics.py b/mmai25_hackathon/load_data/omics.py deleted file mode 100644 index c0d54da..0000000 --- a/mmai25_hackathon/load_data/omics.py +++ /dev/null @@ -1,361 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, List, Literal, Optional - -import numpy as np -import pandas as pd -import torch -import torch.nn.functional as F -from torch_geometric.data import Data -from torch_sparse import SparseTensor - - -# ====================================================================== -# load_multiomics: build per-modality sample-graph(s) from CSV features -# ====================================================================== -def load_multiomics( - *, - feature_csvs: List[str], - labels_csv: Optional[str] = None, - featname_csvs: Optional[List[str]] = None, - mode: Literal["train", "val", "test"] = "train", - num_classes: Optional[int] = None, - pre_transform: Optional[Any] = None, - target_pre_transform: Optional[Any] = None, - edge_per_node: int = 10, - metric: Literal["cosine"] = "cosine", - eps: float = 1e-8, - equal_weight: bool = False, - ref: Optional[Dict[str, Any]] = None, - rng: Optional[np.random.Generator] = None, -) -> Dict[str, Any]: - """ - Build one PYG `Data` per omics modality by: - 1) Loading feature matrices (CSV; delimiter=",") into numpy arrays. - 2) (Optional) Applying `pre_transform` to each X (e.g., standardization). - 3) Computing pairwise similarities (currently cosine) between samples. - 4) Fitting a *global* similarity cutoff so each node keeps ~k neighbors. - 5) Thresholding, symmetrizing (max), adding self-loops, row-normalizing. - 6) Packaging `edge_index`, `edge_weight`, dense `x`, and `SparseTensor adj_t`. - - Labels (optional): - - If `labels_csv` is provided, loads (N,) or (N,C) integer labels. - - If 1D labels and `num_classes` is set, converts to one-hot (float). - - If `mode=="train"`, computes `train_sample_weight` per sample. - - Parameters - ---------- - feature_csvs : list[str] - Paths to feature CSVs. Each is expected to be shape (N, F_m) after reading via - `np.loadtxt(..., delimiter=",")`. (Make sure your file really is comma-separated.) - labels_csv : str | None - Optional path to labels CSV (comma-separated). Shape (N,) or (N, C). - If (N,1) it will be squeezed to (N,). - featname_csvs : list[str] | None - Optional per-modality feature-name CSVs (no header). If omitted, names become - `m{mi}_f{j}`. - mode : {"train","val","test"} - Drives whether to *fit* thresholds (train) or *reuse* from `ref` if present. - num_classes : int | None - If labels are 1D ints and this is provided, labels are one-hot encoded. - pre_transform : callable | None - Function applied to each X (numpy). Signature: `X -> X_transformed`. - target_pre_transform : callable | None - Function applied to `y` (numpy). Useful for label re-mapping. - edge_per_node : int - Desired avg number of neighbors per node (k) used to fit a global cutoff. - metric : {"cosine"} - Similarity metric. Currently only "cosine" is implemented. - eps : float - Numerical epsilon for normalization / division safeguards. - equal_weight : bool - If True, assigns uniform sample weights (1/N). Otherwise class-frequency based. - (See NOTE in `_sample_weight` docstring.) - ref : dict | None - If provided, may contain pre-fitted thresholds: `{"sim_thresholds": [thr_m0, ...]}`. - When `mode!="train"` we try to reuse `ref["sim_thresholds"][mi]` for each modality. - If missing, we refit on the fly. - rng : np.random.Generator | None - Reserved for future stochastic steps (currently unused). Defaults to a fixed seed. - - Returns - ------- - out : dict - { - "data_list": [Data_m0, Data_m1, ...], # one PYG Data per modality - "fit_": {"sim_thresholds": [thr_m0, thr_m1, ...]} # fitted/used thresholds - } - - Graph construction details - -------------------------- - - Similarity: S = cosine(X_i, X_j). Diagonal is 1. - - Fit cutoff: - * Mask diag to -inf, flatten off-diagonals, pick global (k*N)-th largest. - * This yields a *single* threshold per modality. - - Adjacency: - * Mask: M = (S >= thr), zero diag, A0 = S * M - * Symmetrize: A = max(A0, A0^T) - * Add self-loops: A <- A + I - * Row-normalize: A[i,:] /= sum(A[i,:]) - - Saved on each `Data` object - --------------------------- - x : torch.FloatTensor [N, F_m] - y : torch.FloatTensor [N] or [N, C] or None - mode : str - feat_names : np.ndarray[object] # feature names list for modality m - edge_index : torch.LongTensor [2, E] - edge_weight : torch.FloatTensor [E] - adj_t : torch_sparse.SparseTensor (N x N) - train_sample_weight : torch.FloatTensor [N] (only when mode=="train") - - Notes & gotchas - --------------- - - All modalities must share the same sample count N (checked). - - Labels (if provided) must also have length N. - - `np.loadtxt` is strict; malformed CSVs (headers/strings) will fail. Use pandas - to read & convert to numeric if your inputs are not purely numeric. - - The global cutoff approach keeps *on average* k neighbors per node; some nodes - may have more/less depending on the similarity distribution and symmetrization. - """ - if rng is None: - rng = np.random.default_rng(12345) - - # ----------------------------- - # (1) Load labels, if present - # ----------------------------- - y_np = None - if labels_csv is not None: - y_np = _load_labels(labels_csv) - - # ----------------------------------------------- - # (2) Load features + optional pre_transform/name - # ----------------------------------------------- - X_list_np, featnames_all = [], [] - for mi, fpath in enumerate(feature_csvs): - # Expect pure numeric CSV with comma delimiter - Xi = np.loadtxt(fpath, delimiter=",") - if pre_transform is not None: - Xi = pre_transform(Xi) # e.g., StandardScaler().fit_transform(Xi) - X_list_np.append(Xi) - - # Feature names - if featname_csvs and mi < len(featname_csvs) and featname_csvs[mi]: - names = pd.read_csv(featname_csvs[mi], header=None).iloc[:, 0].astype(str).tolist() - else: - names = [f"m{mi}_f{j}" for j in range(Xi.shape[1])] - featnames_all.append(names) - - # Consistency: N must be same across modalities - nset = {X.shape[0] for X in X_list_np} - if len(nset) != 1: - raise ValueError(f"All modalities must have same sample count; got {nset}") - N = next(iter(nset)) - - # Labels length must match N - if y_np is not None and y_np.shape[0] != N: - raise ValueError(f"labels length ({y_np.shape[0]}) != samples ({N})") - - # Optional label pre-transform (e.g., relabeling, one-hot already, etc.) - if y_np is not None and target_pre_transform is not None: - y_np = target_pre_transform(y_np) - - # Convert labels to torch; one-hot if requested - if y_np is not None: - y_t = torch.as_tensor(y_np) - if y_t.ndim == 1 and (num_classes is not None): - y_t = F.one_hot(y_t.to(torch.long), num_classes=num_classes).float() - else: - y_t = None - - # --------------------------------------------- - # (3) Build per-modality adjacency + Data objs - # --------------------------------------------- - sim_thresholds: List[float] = [] - data_list: List[Data] = [] - - allow_fit = mode == "train" - applied_ref = ref is not None and "sim_thresholds" in ref and isinstance(ref["sim_thresholds"], list) - - for mi, X_np in enumerate(X_list_np): - X = torch.as_tensor(X_np, dtype=torch.float32) - - # ---- Fit or reuse threshold for this modality ---- - if allow_fit: - S = _pairwise_cosine(X, eps=eps) - S_no_diag = S.clone() - S_no_diag.fill_diagonal_(float("-inf")) - thr = _fit_global_cutoff(S_no_diag, edge_per_node) - else: - if applied_ref and ref is None: - raise ValueError("ref with 'sim_thresholds' required when mode != 'train'") - elif applied_ref and ref is not None: - sim_thresholds = ref["sim_thresholds"] - if mi < len(sim_thresholds): - thr = float(sim_thresholds[mi]) - else: - # Fallback: compute from current data (useful for standalone val/test) - S = _pairwise_cosine(X, eps=eps) - S_no_diag = S.clone() - S_no_diag.fill_diagonal_(float("-inf")) - thr = _fit_global_cutoff(S_no_diag, edge_per_node) - - sim_thresholds.append(thr) - - # ---- Build row-normalized adjacency ---- - S = _pairwise_cosine(X, eps=eps) - M = (S >= thr).float() - M.fill_diagonal_(0.0) - A = _symmetrize_max(S * M) # undirected - A = _add_I_and_row_normalize(A, eps=eps) # add self-loops + normalize rows - - # Convert to (edge_index, edge_weight) - Asp = A.to_sparse() - ei = Asp.indices() # [2, E] - ew = Asp.values() # [E] - - # ---- Sample weights (train only) ---- - if y_t is not None and mode == "train": - if y_t.ndim == 2: # one-hot - labels_train = torch.argmax(y_t, dim=1).cpu().numpy() - n_cls = y_t.shape[1] - else: # integer - labels_train = y_t.to(torch.long).cpu().numpy() - n_cls = int(labels_train.max()) + 1 if labels_train.size > 0 else (num_classes or 0) - sw = _sample_weight(labels_train, n_cls, equal_weight) - train_sample_weight = torch.as_tensor(sw, dtype=torch.float32) - else: - train_sample_weight = None - - # ---- Package PYG Data object for this modality ---- - data = Data( - x=X, - y=y_t, - mode=mode, - feat_names=np.asarray(featnames_all[mi], dtype=object), - edge_index=ei, - edge_weight=ew, - adj_t=SparseTensor(row=ei[0], col=ei[1], value=ew, sparse_sizes=(N, N)), - ) - if train_sample_weight is not None: - data.train_sample_weight = train_sample_weight - - data_list.append(data) - - return {"data_list": data_list, "fit_": {"sim_thresholds": sim_thresholds}} - - -# ========================= -# ---- Helper functions ---- -# ========================= -def _load_labels(path: str) -> np.ndarray: - """ - Load labels from a comma-separated CSV using `np.loadtxt`. - - Returns - ------- - y : np.ndarray - - If shape is (N,1), squeezed to (N,) and cast to int. - - If shape is (N,C), returned as-is (e.g., multi-label or already one-hot). - - Tip - --- - If your label file has headers or non-numeric tokens, use pandas to read and - then extract a numeric array instead of `np.loadtxt`. - """ - y = np.loadtxt(path, delimiter=",") - if y.ndim == 2 and y.shape[1] == 1: - y = y.reshape(-1) - return y.astype(int) if y.ndim == 1 else y - - -def _pairwise_cosine(X: torch.Tensor, *, eps: float = 1e-8) -> torch.Tensor: - """ - Dense cosine similarity between rows of X (N x F) -> (N x N) with diag=1. - - Implementation - -------------- - - L2-normalize rows, then return Xn @ Xn^T. - """ - Xn = F.normalize(X, p=2, dim=1, eps=eps) - return Xn @ Xn.T - - -def _fit_global_cutoff(S_no_diag: torch.Tensor, k: int) -> float: - """ - Choose a *single* similarity cutoff so each node keeps ~k neighbors on average. - - Parameters - ---------- - S_no_diag : torch.Tensor [N, N] - Similarity matrix with diagonal pre-set to -inf. - k : int - Target avg neighbors per node. Internally clipped to [1, N-1]. - - Method - ------ - - Flatten off-diagonal entries, take the global (k*N)-th largest. - Implemented as kthvalue on the negated array. - - Returns - ------- - thr : float - Scalar similarity threshold. - """ - N = S_no_diag.shape[0] - k_eff = max(min(k, max(N - 1, 1)), 1) # 1 <= k_eff <= N-1 - flat = S_no_diag.reshape(-1) - valid = flat.isfinite() - flat = flat[valid] - if flat.numel() == 0: - return float("inf") # degenerate case (N<=1) - pos = max(min(k_eff * N - 1, flat.numel() - 1), 0) - val = torch.kthvalue(-flat, k=pos + 1).values - return float(-val.item()) - - -def _symmetrize_max(A: torch.Tensor) -> torch.Tensor: - """Symmetrize adjacency by elementwise maximum (undirected graph).""" - return torch.maximum(A, A.T) - - -def _add_I_and_row_normalize(A: torch.Tensor, *, eps: float = 1e-8) -> torch.Tensor: - """ - Add self-loops and row-normalize so each row sums to ~1. - - Returns - ------- - A_norm : torch.Tensor [N, N] - Row-stochastic adjacency. - """ - N = A.shape[0] - A = A + torch.eye(N, dtype=A.dtype, device=A.device) - row_sum = A.sum(dim=1, keepdim=True).clamp_min(eps) - return A / row_sum - - -def _sample_weight(labels: np.ndarray, num_classes: int, equal_weight: bool) -> np.ndarray: - """ - Per-sample weight vector for (imbalanced) classification. - - Behavior - -------- - - If `equal_weight=True`: uniform weights 1/N. - - Else: weights are **proportional to class frequency** (NOT inverse): - w_i = count[label_i] / sum(count) - This gives larger weights to majority classes. - - NOTE - ---- - If what you want is *inverse* frequency (common choice to rebalance): - inv = 1.0 / np.maximum(count, 1) - w = inv[labels] - w = w / w.sum() - Replace the body accordingly if desired. - """ - if labels.size == 0: - return np.array([], dtype=np.float32) - if equal_weight: - return np.ones(len(labels), dtype=np.float32) / len(labels) - count = np.bincount(labels, minlength=num_classes if num_classes else labels.max() + 1) - return (count[labels] / np.sum(count)).astype(np.float32) diff --git a/mmai25_hackathon/load_data/protein.py b/mmai25_hackathon/load_data/protein.py index 0f8770f..9a6b008 100644 --- a/mmai25_hackathon/load_data/protein.py +++ b/mmai25_hackathon/load_data/protein.py @@ -1,15 +1,32 @@ """ -Protein sequence utilities for reading and integer encoding. +Protein sequence loading and integer-encoding utilities. Functions: - - fetch_protein_sequences_from_dataframe: Extract protein sequences from DataFrame or CSV. - - protein_sequence_to_integer_encoding: Convert a sequence to integer-encoded array. +fetch_protein_sequences_from_dataframe(df, prot_seq_col, index_col=None) + Fetches protein sequences from a DataFrame or CSV (uses `read_tabular` if a path is given). Optionally sets an index + and returns a one-column DataFrame named "protein_sequence" (index reset if `index_col` is None). + +protein_sequence_to_integer_encoding(sequence, max_length=1200) + Converts an amino-acid sequence to a fixed-length integer array using an A–Z alphabet (excluding 'J'); 0 is reserved + for padding/unknown. Returns a NumPy array of shape `(max_length,)`. + +Preview CLI: +`python -m mmai25_hackathon.load_data.protein --data-path /path/to/proteins.csv` +Reads the CSV, prints a small preview of the protein sequences, and encodes the first few entries, printing each array’s +shape and the count of unknown (0) tokens. """ -from typing import Union +import logging +from numbers import Integral +from typing import Dict, Optional, Sequence, Union import numpy as np import pandas as pd +from sklearn.utils._param_validation import Interval, validate_params + +from .tabular import read_tabular + +__all__ = ["fetch_protein_sequences_from_dataframe", "protein_sequence_to_integer_encoding"] # Generate character set for protein sequences between A-Z (except J) CHARPROTSET = [chr(i) for i in range(ord("A"), ord("Z") + 1) if chr(i) != "J"] @@ -17,16 +34,31 @@ CHARPROTSET = {letter: idx for idx, letter in enumerate(CHARPROTSET, 1)} +@validate_params( + {"df": [pd.DataFrame, str], "prot_seq_col": [str], "index_col": [None, str], "filter_rows": [None, dict]}, + prefer_skip_nested_validation=True, +) def fetch_protein_sequences_from_dataframe( - df: Union[pd.DataFrame, str], prot_seq_col: str, index_col: str = None + df: Union[pd.DataFrame, str], + prot_seq_col: str, + index_col: str = None, + filter_rows: Optional[Dict[str, Union[Sequence, pd.Index]]] = None, ) -> pd.DataFrame: """ Fetches protein sequences from a DataFrame or CSV file. Will read the CSV if a path is provided. + High-level steps: + - If `df` is a path, load via `read_tabular` selecting `prot_seq_col` and optional `index_col`; apply `filter_rows`. + - If `df` is a DataFrame and `filter_rows` is provided, apply row filters where columns exist. + - Validate `prot_seq_col` exists; optionally set DataFrame index. + - Return a one-column DataFrame named `"protein_sequence"` (index preserved if set). + Args: df (Union[pd.DataFrame, str]): DataFrame or path to CSV file. prot_seq_col (str): Column name for protein sequences. index_col (str, optional): Column to set as index. Default: None. + filter_rows (dict, optional): A dictionary to filter rows in the DataFrame. + Keys are column names and values are the values to filter by. Default: None. Returns: pd.DataFrame: A single column DataFrame containing the protein sequences with name `"protein_sequence"`. @@ -47,7 +79,11 @@ def fetch_protein_sequences_from_dataframe( 3 TTPSYVAFTDTER """ if isinstance(df, str): - df = pd.read_csv(df) + df = read_tabular(df, subset_cols=prot_seq_col, index_cols=index_col, filter_rows=filter_rows) + else: + for col, valid_vals in (filter_rows or {}).items(): + if col in df.columns: + df = df[df[col].isin(valid_vals)] if prot_seq_col not in df.columns: raise ValueError(f"Column '{prot_seq_col}' not found in DataFrame.") @@ -55,13 +91,23 @@ def fetch_protein_sequences_from_dataframe( if index_col is not None: df = df.set_index(index_col) - return df[prot_seq_col].to_frame("protein_sequence") + logger = logging.getLogger(f"{__name__}.fetch_protein_sequences_from_dataframe") + logger.info("Fetched %d protein sequences from column '%s'.", len(df), prot_seq_col) + return df[prot_seq_col].to_frame("protein_sequence").reset_index(drop=index_col is None) +@validate_params( + {"sequence": [str], "max_length": [Interval(Integral, 1, None, closed="left")]}, prefer_skip_nested_validation=True +) def protein_sequence_to_integer_encoding(sequence: str, max_length: int = 1200) -> np.ndarray: """ Converts a protein sequence into an integer-encoded representation. + High-level steps: + - Allocate a zero-initialised array of length `max_length` (dtype `uint64`). + - For each character up to `max_length`, map A–Z (excluding 'J') using a lookup (unknown→0). + - Return the encoded array. + Args: sequence (str): The protein sequence to encode. max_length (int): The maximum length of the output array. @@ -83,20 +129,29 @@ def protein_sequence_to_integer_encoding(sequence: str, max_length: int = 1200) for i, char in enumerate(sequence[:max_length]): # If character is not in CHARPROTSET, it will be skipped and assumed to be unknown encoded_sequence[i] = CHARPROTSET.get(char, 0) + + logger = logging.getLogger(f"{__name__}.protein_sequence_to_integer_encoding") + logger.info("Encoded sequence: %s", encoded_sequence) + logger.info("Original sequence length: %d, Encoded length: %d", len(sequence), len(encoded_sequence)) + logger.info("Unknown characters (encoded as 0) count: %d", np.sum(encoded_sequence == 0)) return encoded_sequence if __name__ == "__main__": import argparse - # Example script: python -m mmai25_hackathon.load_data.protein dataset.csv - + # Example script: python -m mmai25_hackathon.load_data.protein --data-path MMAI25Hackathon/molecule-protein-interaction/dataset.csv parser = argparse.ArgumentParser(description="Process protein sequences.") - parser.add_argument("csv_path", type=str, help="Path to the CSV file containing protein sequences.") + parser.add_argument( + "--data-path", + type=str, + help="Path to the CSV file containing protein sequences.", + default="MMAI25Hackathon/molecule-protein-interaction/dataset.csv", + ) args = parser.parse_args() # Take from Peizhen's csv file for DrugBAN training - df = fetch_protein_sequences_from_dataframe(args.csv_path, prot_seq_col="Protein") + df = fetch_protein_sequences_from_dataframe(args.data_path, prot_seq_col="Protein") for i, prot_seq in enumerate(df["protein_sequence"].head(5), 1): integer_encoding = protein_sequence_to_integer_encoding(prot_seq) print(i, integer_encoding) diff --git a/mmai25_hackathon/load_data/supervised_labels.py b/mmai25_hackathon/load_data/supervised_labels.py index e535b53..0424382 100644 --- a/mmai25_hackathon/load_data/supervised_labels.py +++ b/mmai25_hackathon/load_data/supervised_labels.py @@ -1,36 +1,64 @@ """ -Labels handling utilities for supervised learning. - -This module provides functions to fetch supervision labels from CSV files or pandas DataFrames, supporting both single-column and multi-column labels for regression or classification tasks. It also includes utilities for one-hot encoding categorical labels. +Supervised labels loading and encoding utilities. Functions: - - fetch_supervised_labels_from_dataframe: Fetch labels from a DataFrame or CSV file, supporting index columns and single/multi-column labels. - - one_hot_encode_labels: One-hot encode categorical labels in a DataFrame, supporting single or multiple columns. - -Examples: - >>> df = pd.DataFrame({"id": [1, 2, 3], "label": [0, 1, 0]}) - >>> labels = fetch_supervised_labels_from_dataframe(df, label_col="label", index_col="id") - >>> one_hot_labels = one_hot_encode_labels(labels) +fetch_supervised_labels_from_dataframe(df, label_col, index_col=None) + Fetches labels from a DataFrame or CSV (uses `read_tabular` when a path is provided). Supports single- or multi-column + labels for classification/regression. Optionally sets an index. Returns a DataFrame named "label" for a single column + or the original column names for multiple columns. + +one_hot_encode_labels(labels, columns="label") + One-hot encodes categorical label columns using `pandas.get_dummies`. Supports single or multiple columns and returns + a `pd.DataFrame` with `float32` dtypes. + +Preview CLI: +`python -m mmai25_hackathon.load_data.supervised_labels --data-path /path/to/labels.csv` +Reads the CSV (expects a label column named `Y` in this demo), prints the first five labels, then prints a preview of +one-hot–encoded labels. """ -from typing import Sequence, Union +from typing import Dict, Optional, Sequence, Union import numpy as np import pandas as pd +from sklearn.utils._param_validation import validate_params + +from .tabular import read_tabular +__all__ = ["fetch_supervised_labels_from_dataframe", "one_hot_encode_labels"] + +@validate_params( + { + "df": [pd.DataFrame, str], + "label_col": [str, "array-like"], + "index_col": [None, str], + "filter_rows": [None, dict], + }, + prefer_skip_nested_validation=True, +) def fetch_supervised_labels_from_dataframe( df: Union[pd.DataFrame, str], label_col: Union[str, Sequence[str]], index_col: str = None, + filter_rows: Optional[Dict[str, Union[Sequence, pd.Index]]] = None, ) -> pd.DataFrame: """ Fetches supervision labels from a DataFrame or CSV file. Will read the CSV if a path is provided. + High-level steps: + - If `df` is a path, load via `read_tabular` selecting `label_col` and optional `index_col`; apply `filter_rows`. + - If `df` is a DataFrame and `filter_rows` is provided, apply row filters where columns exist. + - Validate that the requested label column(s) are present. + - If a single column, optionally set index and return DataFrame named `"label"`. + - If multiple columns, return the DataFrame as-is. + Args: df (Union[pd.DataFrame, str]): DataFrame or path to CSV file. label_col (Union[str, Sequence[str]]): Column name or sequence of column names for labels. index_col (str, optional): Column to set as index. Default: None. + filter_rows (dict, optional): A dictionary to filter rows in the DataFrame. + Keys are column names and values are the values to filter by. Default: None. Returns: pd.DataFrame: A DataFrame containing the labels with name `"label"` if a single column is provided, @@ -49,24 +77,35 @@ def fetch_supervised_labels_from_dataframe( 3 0 """ if isinstance(df, str): - df = pd.read_csv(df) + df = read_tabular(df, subset_cols=label_col, index_cols=index_col, filter_rows=filter_rows) + else: + # Apply filter_rows to provided DataFrame for consistency + for col, valid_vals in (filter_rows or {}).items(): + if col in df.columns: + df = df[df[col].isin(valid_vals)] if label_col not in df.columns: raise ValueError(f"Column '{label_col}' not found in DataFrame.") + if isinstance(label_col, (list, tuple)) and len(label_col) > 1: + return df + if index_col is not None: df = df.set_index(index_col) - if isinstance(label_col, Sequence) and len(label_col) > 1: - return df[list(label_col)] - - return df[label_col].to_frame("label") + return df[label_col].to_frame("label").reset_index(drop=index_col is None) +@validate_params({"labels": [pd.DataFrame], "columns": [str, "array-like"]}, prefer_skip_nested_validation=True) def one_hot_encode_labels(labels: pd.DataFrame, columns: Union[Sequence[str], str] = "label") -> pd.DataFrame: """ One-hot encodes categorical labels in a DataFrame. + High-level steps: + - Coerce `columns` to a list of column names. + - Call `pandas.get_dummies` with `dtype=np.float32` on the specified columns. + - Return the one-hot encoded DataFrame. + Args: labels (pd.DataFrame): DataFrame containing the labels to be one-hot encoded. columns (Union[Sequence[str], str]): Column name or sequence of column names to be one-hot encoded. Default: "label". @@ -92,14 +131,18 @@ def one_hot_encode_labels(labels: pd.DataFrame, columns: Union[Sequence[str], st if __name__ == "__main__": import argparse - # Example script: python -m mmai25_hackathon.load_data.supervised_labels dataset.csv - + # Example script: python -m mmai25_hackathon.load_data.supervised_labels --data-path MMAI25Hackathon/molecule-protein-interaction/dataset.csv parser = argparse.ArgumentParser(description="Process supervision labels for regression/classification.") - parser.add_argument("csv_path", type=str, help="Path to the CSV file containing supervision labels.") + parser.add_argument( + "--data-path", + type=str, + help="Path to the CSV file containing supervision labels.", + default="MMAI25Hackathon/molecule-protein-interaction/dataset.csv", + ) args = parser.parse_args() # Take from Peizhen's csv file for DrugBAN training - df = fetch_supervised_labels_from_dataframe(args.csv_path, label_col="Y") + df = fetch_supervised_labels_from_dataframe(args.data_path, label_col="Y") for i, label in enumerate(df["label"].head(5), 1): print(i, label) diff --git a/mmai25_hackathon/load_data/tabular.py b/mmai25_hackathon/load_data/tabular.py index 7005cdb..d64a245 100644 --- a/mmai25_hackathon/load_data/tabular.py +++ b/mmai25_hackathon/load_data/tabular.py @@ -1,67 +1,182 @@ """ -Tabular data utilities for reading, merging, and graph conversion. +Tabular utilities for loading a CSV and merging multiple DataFrames by overlapping key columns. Functions: - - read_tabular: Load a single CSV file into a DataFrame. - - merge_multiple_dataframes: Merge DataFrames by join keys, handling column collisions with suffixes. - - tabular_to_graph: Convert a DataFrame to a graph using row/sample similarity. +read_tabular(path, subset_cols=None, index_cols=None, filter_rows=None, sep=",", raise_errors=True) + Thin wrapper around `pandas.read_csv`. Optionally selects key columns first (order-preserving), + keeps only requested columns, and filters rows via a `{column: allowed_values}` mapping. + Note: does **not** set the DataFrame index; `index_cols` are treated as key columns only. + +merge_multiple_dataframes(dfs, dfs_name=None, index_cols=None, join="outer") + Greedily merges a sequence of DataFrames into connected components based on overlapping key columns. + First merges frames that share the same subset of keys, then merges groups whose key sets overlap. + Returns a list of `(keys_tuple, merged_df)`. Name collisions get suffixes from `dfs_name` or `_df{i}`. + +Preview CLI: +`python -m mmai25_hackathon.load_data.tabular --data-path BASE_PATH --index-cols ... --subset-cols ... --join outer` +Recursively loads `*.csv`, then groups/merges and prints a preview for each component. """ +import logging from pathlib import Path -from typing import Dict, FrozenSet, List, Optional, Sequence, Tuple, Union +from typing import Dict, List, Literal, Optional, Sequence, Tuple, Union -import numpy as np import pandas as pd -import torch -from sklearn.metrics import pairwise_kernels -from torch_geometric.data import Data - -from ..utils import find_global_cutoff, symmetrize_matrix - - -def read_tabular(path: Union[str, Path], sep: str = ",") -> pd.DataFrame: +from sklearn.utils._param_validation import StrOptions, validate_params + +__all__ = ["read_tabular", "merge_multiple_dataframes"] + + +@validate_params( + { + "path": [Path, str], + "subset_cols": [None, list, str], + "index_cols": [None, list, str], + "filter_rows": [None, dict], + "sep": [str], + "raise_errors": ["boolean"], + }, + prefer_skip_nested_validation=True, +) +def read_tabular( + path: Union[str, Path], + subset_cols: Optional[Union[List[str], str]] = None, + index_cols: Optional[Union[List[str], str]] = None, + filter_rows: Optional[Dict[str, Union[Sequence, pd.Index]]] = None, + sep: str = ",", + raise_errors: bool = True, +) -> pd.DataFrame: """ Reads a single tabular text file into a DataFrame. + If `subset_cols` and/or `index_cols` are provided, will select only those columns + that exist in the DataFrame. The order will be `index_cols` followed by `subset_cols`. + + High-level steps: + - Read CSV via `pandas.read_csv` with separator `sep`. + - Normalise `subset_cols`/`index_cols` to lists; compute intersections with available columns. + - When `raise_errors` and none of the requested columns exist, raise `ValueError`. + - Order columns with `index_cols` first then `subset_cols`; filter rows per `filter_rows` where possible. + - Return the resulting DataFrame. + Args: path (Union[str, Path]): Path to the tabular text file. + subset_cols (Optional[Union[List[str], str]]): If provided, will select only these columns if they exist in the DataFrame. + Default: None. + index_cols (Optional[Union[List[str], str]]): If provided, will select these columns as the index of the DataFrame. + Default: None. + filter_rows (Optional[Dict[str, Union[Sequence, pd.Index]]]): If provided, will filter the rows + based on the specified column values. The keys are column names and the values are sequences + of acceptable values for filtering. Will be ignored if not found in the DataFrame. Default: None. sep (str): Value separator in the tabular text file. Default: ",". + raise_errors (bool): If True, will raise an error if none of the specified `subset_cols` + or `index_cols` are found in the DataFrame. Default: True. Returns: - pd.DataFrame: The loaded DataFrame. + pd.DataFrame: The loaded DataFrame with complete or partial columns selected. + + Raises: + ValueError: If `raise_errors` is True and none of the specified `subset_cols` + or `index_cols` are found in the DataFrame. Examples: >>> df = read_tabular("data.csv") >>> print(df.head()) """ # Just provide thin wrapper around pd.read_csv for function availability. - return pd.read_csv(path, sep=sep) + logger = logging.getLogger(f"{__name__}.read_tabular") + logger.info("Reading tabular data from: %s", path) + + df = pd.read_csv(path, sep=sep) + + if isinstance(subset_cols, str): + subset_cols = [subset_cols] + + if isinstance(index_cols, str): + index_cols = [index_cols] + + selected_index_cols = pd.Index([]) + if index_cols is not None: + # preserves the order of index_cols as provided when sort=False + logger.info("Selecting index columns: %s", index_cols) + selected_index_cols = df.columns.intersection(index_cols, sort=False) + logger.info("Found index columns in DataFrame: %s", selected_index_cols.to_list()) + + selected_subset_cols = df.columns.difference(selected_index_cols, sort=False) + if subset_cols is not None: + logger.info("Selecting subset columns: %s", subset_cols) + selected_subset_cols = df.columns.intersection(subset_cols, sort=False) + logger.info("Found subset columns in DataFrame: %s", selected_subset_cols.to_list()) + + if raise_errors and subset_cols is not None and len(selected_subset_cols) == 0: + raise ValueError(f"No valid subset_cols found in DataFrame for: {subset_cols}") + + if raise_errors and index_cols is not None and len(selected_index_cols) == 0: + raise ValueError(f"No valid index_cols found in DataFrame for: {index_cols}") + + # Reorder the dataframe to have index_cols first then subset_cols + logger.info( + "Final selected columns: %s", + selected_index_cols.to_list() + selected_subset_cols.to_list(), + ) + selected_cols = selected_index_cols.union(selected_subset_cols, sort=False) + + df = df if len(selected_cols) == 0 else df[df.columns.intersection(selected_cols, sort=False)] + logger.info("Loaded DataFrame shape: %s", df.shape) + + logger.info("Applying row filters: %s", filter_rows) + for col, valid_vals in (filter_rows or {}).items(): + if col in df.columns: + logger.info( + "Filtering rows on column '%s' with %d valid values.", + col, + len(valid_vals), + ) + df = df[df[col].isin(valid_vals)] + + return df +@validate_params( + { + "dfs": ["array-like"], + "dfs_name": [None, "array-like"], + "index_cols": [None, list, str], + "join": [StrOptions({"outer", "inner", "left", "right"})], + }, + prefer_skip_nested_validation=True, +) def merge_multiple_dataframes( dfs: Sequence[pd.DataFrame], dfs_name: Optional[Sequence[str]] = None, index_cols: Optional[List[str]] = None, - join: str = "outer", + join: Literal["outer", "inner", "left", "right"] = "outer", ) -> List[Tuple[Tuple[str, ...], pd.DataFrame]]: """ Merge a sequence of DataFrames by shared keys until disjoint components remain. - - If `index_cols` is None/empty, return a single component with all frames concatenated - column-wise: [((), concat_df)]. - - Otherwise: (1) merge frames that share the same subset of `index_cols`, then - (2) greedily merge groups whose key sets overlap (prefer larger overlaps, then - smaller combined size). Column collisions get suffixes from `dfs_name` - (or `_df{i}` if not provided). + High-level steps: + - Validate `join` option and `dfs_name` length. + - If `dfs` is empty, return []. If `index_cols` is falsy, concatenate columns and return single component. + - Group frames by the exact subset of provided keys they contain; if none, return []. + - Merge within each group using suffixes determined by `dfs_name`. + - Greedily merge groups whose key sets overlap until no overlaps remain. + - Return components as `(sorted_keys_tuple, DataFrame)` pairs. Args: - dfs: Input DataFrames. - dfs_name: Optional names (same length as `dfs`) used to derive merge suffixes. - index_cols: Candidate join keys. - join: One of {"outer", "inner", "left", "right"}. + dfs (Sequence[pd.DataFrame]): Sequences of dataframes to merge. + dfs_name (Optional[Sequence[str]]): Optional names (same length as `dfs`) used to derive merge suffixes. + index_cols (Optional[List[str]]): Key columns to use for merging. If None/empty, + will concatenate all frames. Default: None. + subset_cols (Optional[List[str]]): If provided, will pre select these columns from each dataframe + if any of them exist in the dataframe. Default: None. + join (Literal["outer", "inner", "left", "right"]): Dataframe merging strategy. Default: "outer". Returns: - List of components as (sorted_keys_tuple, merged_dataframe). + List[Tuple[Tuple[str, ...], pd.DataFrame]]: A list of tuples where each tuple contains: + - A tuple of key column names used for merging that component. + - The merged DataFrame for that component. + - Accounts for column collisions by adding suffixes and disjoining index columns. Examples: >>> df1 = pd.DataFrame({"id": [1, 2], "a": [10, 20]}) @@ -77,10 +192,6 @@ def merge_multiple_dataframes( >>> # The first component merges df1 & df2 on 'id'; non-key collisions would get >>> # suffixes '_X' and '_Y'. The second component is just df3 keyed by 'site'. """ - valid_joins = {"outer", "inner", "left", "right"} - if join not in valid_joins: - raise ValueError(f"`join` must be one of {valid_joins}, got {join!r}") - if dfs_name is not None and len(dfs_name) != len(dfs): raise ValueError( f"Length of `dfs_name` must match length of `dfs`. Found {len(dfs_name)} and {len(dfs)} respectively." @@ -89,26 +200,31 @@ def merge_multiple_dataframes( if not dfs: return [] + logger = logging.getLogger(f"{__name__}.merge_multiple_dataframes") + # Concatenate-only mode if not index_cols: + logger.info("No index_cols provided; concatenating all DataFrames column-wise.") return [((), pd.concat(list(dfs), axis="columns", join=join))] # Prepare suffix labels labels = [f"_{name}" for name in (dfs_name or [f"df{i}" for i in range(len(dfs))])] # Bucket frames by the exact subset of keys they actually contain - frames_by_subset: Dict[Tuple[str, ...], List[Tuple[pd.DataFrame, str]]] = {} + logger.info("Merging DataFrames by overlapping keys: %s", index_cols) + df_by_subset = {} # type: ignore[var-annotated] for df, label in zip(dfs, labels): subset = tuple(col for col in index_cols if col in df.columns) if subset: - frames_by_subset.setdefault(subset, []).append((df, label)) + df_by_subset.setdefault(subset, []).append((df, label)) - if not frames_by_subset: + if not df_by_subset: return [] # Merge within each exact key-subset first - groups: List[Tuple[FrozenSet[str], pd.DataFrame, str]] = [] # (keys, df, last_suffix) - for subset, items in frames_by_subset.items(): + logger.info("Found %d groups by exact key subsets.", len(df_by_subset)) + groups = [] # (keys, df, last_suffix) + for subset, items in df_by_subset.items(): (merged_df, left_suffix), *rest = items for df, right_suffix in rest: merged_df = merged_df.merge( @@ -121,6 +237,7 @@ def merge_multiple_dataframes( groups.append((frozenset(subset), merged_df, left_suffix)) # Greedy pairwise merging across groups until no overlaps remain + logger.info("Merging %d groups by overlapping keys.", len(groups)) while True: best = None best_score = None # (-overlap_size, combined_cells) @@ -148,107 +265,62 @@ def merge_multiple_dataframes( groups[i] = (keys_i | keys_j, merged_df, sfx_j) del groups[j] - # Materialize as (sorted_keys_tuple, DataFrame) - components: List[Tuple[Tuple[str, ...], pd.DataFrame]] = [ - (tuple(sorted(keys)), df) for (keys, df, _sfx) in sorted(groups, key=lambda g: tuple(sorted(g[0]))) + # Materialize as (sorted_keys_tuple, DataFrame), sorted by keys for consistency + merged_key_df_pairs = [ + (tuple(sorted(keys)), df) for (keys, df, _) in sorted(groups, key=lambda g: tuple(sorted(g[0]))) ] - return components - - -def tabular_to_graph( - df: pd.DataFrame, edge_per_node: int = 10, metric: str = "cosine", threshold: Optional[float] = None -) -> Data: - """ - Convert tabular data from DataFrame into a graph representation. The nodes are derived from the rows - or samples in the DataFrame, and edges are created based on similarity between the rows. - Args: - df (pd.DataFrame): DataFrame containing the features. Assumes the subsets (i.e., train/test/val) - separation is done separately and `df` contains all samples. - edge_per_node (int): Number of edges to create per node based on similarity. Default: 10. - metric (str): Metric to compute similarity between rows. Default: "cosine". - threshold (Optional[float]): Predefined thresholds for edge creation. If None, it will be estimated - from the data given `edge_per_node`. Default: None. + logging.info("Final merged components: %d", len(merged_key_df_pairs)) + for keys, df in merged_key_df_pairs: + logging.info("Merged keys: %s, shape: %s", keys, df.shape) - Returns: - torch_geometric.data.Data: The resulting graph representation containing: - - `x`: Node feature matrix with shape [num_nodes, num_node_features]. - - `edge_index`: Graph connectivity in COO format with shape [2, num_edges]. - - `edge_attr`: Edge feature matrix with shape [num_edges, num_edge_features]. - - `num_nodes`: Number of nodes in the graph. - - `num_edges`: Number of edges in the graph. - - `feature_names`: List of feature names corresponding to columns in `df`. - - `metric`: The metric used for similarity computation. - - `threshold`: The threshold used for edge creation. - """ - # Assumes all columns can be casted to float32 - features = df.to_numpy(dtype=np.float32) - similarity_matrix = pairwise_kernels(features, metric=metric) - if threshold is None: - threshold = find_global_cutoff(similarity_matrix, edge_per_node) - - # We do not need to remove self-loops as they will be removed in GNN layers if needed - adjacency_matrix = symmetrize_matrix(similarity_matrix >= threshold, method="maximum") - # Get edges where we have shape [2, num_edges] - edge_index = np.vstack(np.nonzero(adjacency_matrix)).astype(np.int64) - # We use the sparsified similarity value as the edge weights - edge_weight = similarity_matrix[adjacency_matrix] - - # Cast to torch tensors - x = torch.from_numpy(features) - edge_index = torch.from_numpy(edge_index) - edge_weight = torch.from_numpy(edge_weight) - - return Data( - x, - edge_index, - edge_weight=edge_weight, - feature_names=df.columns.to_list(), - metric=metric, - threshold=threshold, - ) + return merged_key_df_pairs if __name__ == "__main__": - # separate function for loading tabular data - # 1. read_tabular (only single csv) - # 2. merge multiple dataframes (optimize the query greedily) - # 3. concat multiple dataframes - import argparse - # Example script (assuming folder mimic-iv-3.1 is in the current directory) - # python -m mmai25_hackathon.load_data.tabular mimic-iv-3.1 --index-cols subject_id hadm_id charttime --join outer + # Example script (assuming folder mimic-iv/mimic-iv-3.1 is in the current directory) + # python -m mmai25_hackathon.load_data.tabular --data-path mimic-iv/mimic-iv-3.1 --index-cols subject_id hadm_id --subset-cols language --join outer # NOTE: Expect increase in row count given we are doing outer join and each dataframes may or will have # different relational structures (i.e., admissions to icustays in MIMIC-IV has one-to-many # relationship w.r.t. subject_id and hadm_id) parser = argparse.ArgumentParser(description="Read and aggregate tabular CSV files.") - parser.add_argument("base_path", help="Base path for the CSV files.") - parser.add_argument("--index-cols", nargs="+", default=None, help="Columns to use as index.") + parser.add_argument( + "--data-path", + help="Data path for the CSV files.", + default="MMAI25Hackathon/mimic-iv/mimic-iv-3.1", + ) + parser.add_argument( + "--index-cols", + nargs="+", + default=["subject_id", "hadm_id"], + help="Columns to use as index.", + ) + parser.add_argument("--subset-cols", nargs="+", default=["language"], help="Columns to subset.") parser.add_argument("--join", default="outer", help="Join type for merging DataFrames.") args = parser.parse_args() # Recursive glob for CSV files - csv_files = list(Path(args.base_path).rglob("*.csv")) + csv_files = list(Path(args.data_path).rglob("*.csv")) # Load multiple dataframes - dfs = [pd.read_csv(f) for f in csv_files] - df_names = [f.stem for f in csv_files] + dfs = [ + read_tabular( + f, + index_cols=args.index_cols, + subset_cols=args.subset_cols, + raise_errors=False, + ) + for f in csv_files + ] + dfs_name = [f.stem for f in csv_files] # Merge dataframes - components = merge_multiple_dataframes(dfs, dfs_name=df_names, index_cols=args.index_cols, join=args.join) + components = merge_multiple_dataframes(dfs, dfs_name=dfs_name, index_cols=args.index_cols, join=args.join) for keys, comp_df in components: print(f"Component keys: {keys}") print(comp_df.head()) print() - - # Quick test for tabular_to_graph - from sklearn.datasets import make_classification - - # Create a synthetic dataset - X, _ = make_classification(n_samples=100, n_features=20, random_state=42) - df = pd.DataFrame(X, columns=[f"feature_{i}" for i in range(X.shape[1])]) - graph = tabular_to_graph(df, edge_per_node=5, metric="cosine") - print(graph) diff --git a/mmai25_hackathon/load_data/text.py b/mmai25_hackathon/load_data/text.py index 4c26fa6..19d8bfd 100644 --- a/mmai25_hackathon/load_data/text.py +++ b/mmai25_hackathon/load_data/text.py @@ -1,157 +1,230 @@ -import os +""" +MIMIC-IV clinical notes (free-text) loading utilities. + +Functions: +load_mimic_iv_notes(note_path, subset='radiology', include_detail=False, subset_cols=None) + Loads the selected notes CSV (`radiology.csv` or `discharge.csv`), verifies required ID columns + (`note_id`, `subject_id`), optionally merges `_detail.csv` when `include_detail=True`, + applies optional `subset_cols`, strips/filters empty `text`, and returns a `pd.DataFrame` indexed by + [`note_id`, `subject_id`]. + +extract_text_from_note(note, include_metadata=False) + Extracts the `text` field from a single note `pd.Series`. When `include_metadata=True`, returns + `(text, metadata_dict)` where `metadata_dict` is the note’s fields excluding `text`. + +Preview CLI: +`python -m mmai25_hackathon.load_data.text --data-path /path/to/mimic-iv-note-.../note --subset radiology --note-id 12345678` +Prints a preview of the loaded notes (columns like `note_id`, `subject_id`, `hadm_id`, `note_type`, `text`) and then +retrieves the note matching the provided `note_id`, printing its full text and selected metadata (e.g., `subject_id`, +`hadm_id`, `note_type`). +""" + +import logging from pathlib import Path -from typing import Literal, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union import pandas as pd +from sklearn.utils._param_validation import StrOptions, validate_params -# ---- Configure your dataset root ---- -DATA_PATH = r"your_data_path_here" -TEXT_DIR = "mimic-iv-note-deidentified-free-text-clinical-notes-2.2/note" -NOTE_PATH = os.path.join(DATA_PATH, TEXT_DIR) +from .tabular import merge_multiple_dataframes, read_tabular +__all__ = ["load_mimic_iv_notes", "extract_text_from_note"] -# ----------------------------- -# 1) Load notes (radiology or discharge) -# ----------------------------- -def get_text_notes( - base_note_path: str, +REQUIRED_ID_COLS = ["note_id", "subject_id"] + + +@validate_params( + { + "note_path": [str, Path], + "subset": [StrOptions({"radiology", "discharge"})], + "include_detail": ["boolean"], + "subset_cols": [None, list], + "filter_rows": [None, dict], + }, + prefer_skip_nested_validation=True, +) +def load_mimic_iv_notes( + note_path: Union[str, Path], subset: Literal["radiology", "discharge"] = "radiology", include_detail: bool = False, - keep_cols: Optional[Tuple[str, ...]] = ( - "note_id", - "subject_id", - "hadm_id", - "note_type", - "note_seq", - "charttime", - "storetime", - "text", - ), + subset_cols: Optional[List[str]] = ["hadm_id", "note_type", "note_seq", "charttime", "storetime", "text"], + filter_rows: Optional[Dict[str, Union[Sequence, pd.Index]]] = None, ) -> pd.DataFrame: """ - Load free-text clinical notes. - - Parameters - ---------- - base_note_path : str - Folder that contains the 4 CSVs (radiology.csv, radiology_detail.csv, discharge.csv, discharge_detail.csv). - subset : {'radiology','discharge'} - Which note family to load. - include_detail : bool - If True, left-join the corresponding *_detail.csv on ['note_id','subject_id'] and add detail columns. - keep_cols : tuple of str | None - Columns to keep from the main notes CSV. If None, keep all columns. - - Returns - ------- - pd.DataFrame - Notes DataFrame. If include_detail=True, extra columns from *_detail are merged (field_name, field_value, field_ordinal). + Load de-identified free-text clinical notes for a selected MIMIC-IV subset and + optionally merge the corresponding detail CSV. + + High-level steps: + - Validate the input directory exists; resolve ``note_path`` to ``Path``. + - Load ``.csv`` via ``read_tabular`` with ``subset_cols`` plus required IDs; apply ``filter_rows`` when provided. + - Ensure required ID columns (``note_id``, ``subject_id``) are present. + - When ``include_detail`` is True, load ``_detail.csv`` (also applying ``filter_rows``), validate IDs, and left‑merge on IDs. + - If ``text`` exists, ``str.strip`` and drop empty rows; otherwise log a warning. + - Return the resulting ``pd.DataFrame``. + + Args: + note_path (Union[str, Path]): Directory containing the notes CSV files + (for example: ``.../mimic-iv-note-.../note``). + subset (Literal['radiology', 'discharge']): Which note subset to load. Default: ``'radiology'``. + include_detail (bool): If True, left-join ``_detail.csv`` on ``['note_id', 'subject_id']``. + Default: False. + subset_cols (Optional[List[str]]): Columns to load from the main notes CSV in addition to the + required ID columns. Defaults to a small set including ``'text'``. + filter_rows (dict, optional): A dictionary to filter rows in the DataFrame. + Keys are column names and values are the values to filter by. Default: None. + + Returns: + pd.DataFrame: Notes for the requested subset. When ``text`` exists, values are trimmed and + empty rows removed; when ``include_detail=True``, columns from the detail CSV may be present. + + Raises: + FileNotFoundError: If ``note_path`` or the main CSV (``.csv``) is missing, or if + ``include_detail=True`` and ``_detail.csv`` is missing. + KeyError: If the required ID columns ``['note_id', 'subject_id']`` are absent from the main + or (when requested) detail CSV. + + Examples: + >>> from mmai25_hackathon.load_data.text import load_mimic_iv_notes + >>> base = "MMAI25Hackathon/mimic-iv/mimic-iv-note-deidentified-free-text-clinical-notes-2.2/note" + >>> df = load_mimic_iv_notes(base, subset="radiology", include_detail=True) + >>> df.head()[["note_id", "subject_id", "note_type", "text"]] + note_id subject_id hadm_id note_type text + 0 1 101 1 DS EXAMINATION: CHEST (PA AND LAT)INDICATION: ___... + 1 2 101 2 DS EXAMINATION: LIVER OR GALLBLADDER US (SINGLE O... + 2 3 101 3 DS INDICATION: ___ HCV cirrhosis c/b ascites, hiv... + 3 4 102 4 DS EXAMINATION: Ultrasound-guided paracentesis.IN... + 4 5 102 5 DS EXAMINATION: Ultrasound-guided paracentesis.IN... """ - base = Path(base_note_path) - if not base.exists(): - raise FileNotFoundError(f"Notes folder not found: {base}") + if isinstance(note_path, str): + note_path = Path(note_path) - main_name = f"{subset}.csv" - detail_name = f"{subset}_detail.csv" + if not note_path.exists(): + raise FileNotFoundError(f"Notes folder not found: {note_path}") - main_csv = base / main_name - detail_csv = base / detail_name + subset_path = note_path / f"{subset}.csv" - if not main_csv.exists(): - raise FileNotFoundError(f"Missing main notes CSV: {main_csv}") + if not subset_path.exists(): + raise FileNotFoundError(f"Missing main notes CSV: {subset_path}") - df = pd.read_csv(main_csv) + logger = logging.getLogger(f"{__name__}.load_mimic_iv_notes") + logger.info("Loading notes from: %s", subset_path) + df_notes = read_tabular(subset_path, subset_cols=subset_cols, index_cols=REQUIRED_ID_COLS, filter_rows=filter_rows) + logger.info("Loaded %d notes from: %s", len(df_notes), subset_path) - # Ensure required ID columns exist - required_ids = {"note_id", "subject_id"} - if not required_ids.issubset(set(df.columns)): - raise KeyError(f"{main_name} must contain columns: {required_ids}") + id_cols_available = df_notes.columns.intersection(REQUIRED_ID_COLS).to_list() + if len(id_cols_available) < len(REQUIRED_ID_COLS): + raise KeyError(f"{subset_path} must contain columns: {REQUIRED_ID_COLS}. Found: {id_cols_available}") - # Keep requested columns if provided and present - if keep_cols is not None: - cols_present = [c for c in keep_cols if c in df.columns] - # Guarantee IDs stay even if not in keep_cols - for col in ["note_id", "subject_id"]: - if col not in cols_present and col in df.columns: - cols_present.insert(0, col) - df = df[cols_present].copy() + detail_path = note_path / f"{subset}_detail.csv" + if include_detail and not detail_path.exists(): + raise FileNotFoundError(f"Missing detail notes CSV: {detail_path}") - # Optionally join detail if include_detail: - if not detail_csv.exists(): - raise FileNotFoundError(f"Requested detail join but missing: {detail_csv}") - det = pd.read_csv(detail_csv) - # minimal check - if not required_ids.issubset(set(det.columns)): - raise KeyError(f"{detail_name} must contain columns: {required_ids}") - # Typical detail columns: field_name, field_value, field_ordinal - df = df.merge(det, how="left", on=["note_id", "subject_id"]) - - # Report simple stats - total = len(df) - has_text = "text" in df.columns - if has_text: - nonempty = (df["text"].astype(str).str.strip() != "").sum() - print(f"Loaded {total} {subset} notes ({nonempty} with non-empty text).") + logger.info("Including detail from: %s", detail_path) + df_detail = read_tabular(detail_path, index_cols=REQUIRED_ID_COLS, filter_rows=filter_rows) + logger.info("Loaded %d detail rows from: %s", len(df_detail), detail_path) + id_cols_available = df_detail.columns.intersection(REQUIRED_ID_COLS).to_list() + if len(id_cols_available) < len(REQUIRED_ID_COLS): + raise KeyError(f"{detail_path} must contain columns: {REQUIRED_ID_COLS}. Found: {id_cols_available}") + logger.info("Merging detail into main notes on: %s", REQUIRED_ID_COLS) + df_notes = merge_multiple_dataframes((df_notes, df_detail), ("notes", "detail"), REQUIRED_ID_COLS, "left") + # Unpack df_notes from list of (paired_keys, df_notes) + df_notes = df_notes[0] + _, df_notes = df_notes + + text_included = "text" in df_notes.columns + if text_included: + df_notes["text"] = df_notes["text"].astype(str).str.strip() + available_texts = df_notes["text"] != "" + df_notes = df_notes[available_texts].copy() + logger.info("After filtering, %d notes have non-empty text.", len(df_notes)) else: - print(f"Loaded {total} {subset} note rows (no 'text' column in selection).") + logger.warning("The loaded notes do not include a 'text' column.") - return df + return df_notes -# ----------------------------- -# 2) Fetch text for a given note_id -# ----------------------------- -def load_text_note(df: pd.DataFrame, note_id: int, return_meta: bool = False): +@validate_params({"note": [pd.Series], "include_metadata": ["boolean"]}, prefer_skip_nested_validation=True) +def extract_text_from_note(note: pd.Series, include_metadata: bool = False) -> Union[str, Tuple[str, Dict[str, Any]]]: """ - Retrieve the free-text for a given note_id from a DataFrame returned by get_text_notes(). - - Parameters - ---------- - df : pd.DataFrame - DataFrame returned by get_text_notes(). - note_id : int - Note identifier to look up. - return_meta : bool - If True, return (text, row_dict). Otherwise return text only. - - Returns - ------- - text : str | (str, dict) - The note text; optionally with the note's metadata as a dictionary. + Extracts the text from a note Series, optionally returning metadata. + + High-level steps: + - Validate that the input Series contains a ``'text'`` field; otherwise, raise ``KeyError``. + - When ``include_metadata`` is False, return only the note text. + - When ``include_metadata`` is True, return a tuple of ``(text, metadata_dict)`` where + ``metadata_dict`` is the note’s fields excluding ``'text'``. + + Args: + note (pd.Series): A pandas Series representing a note, expected to contain a 'text' column. + include_metadata (bool): If True, return a tuple of (text, metadata_dict). Default is False. + + Returns: + Union[str, Tuple[str, Dict[str, Any]]]: The note text; optionally with the note's metadata as a dictionary. + + Raises: + KeyError: If the 'text' column is not present in the note Series. + + Examples: + >>> note = pd.Series( + ... {"note_id": 1, "subject_id": 101, "text": "Patient is stable.", "note_type": "Discharge summary"} + ... ) + >>> extract_text_from_note(note) + 'Patient is stable.' + >>> extract_text_from_note(note, include_metadata=True) + ('Patient is stable.', {'note_id': 1, 'subject_id': 101, 'note_type': 'Discharge summary'}) """ - if "note_id" not in df.columns: - raise KeyError("DataFrame must contain 'note_id' column.") - rows = df[df["note_id"] == note_id] - if rows.empty: - raise KeyError(f"note_id {note_id} not found.") - - row = rows.iloc[0] - if "text" not in df.columns: - raise KeyError("The DataFrame does not include a 'text' column. Re-load with keep_cols including 'text'.") - txt = str(row["text"]) - if return_meta: - meta = row.to_dict() - return txt, meta - return txt - - -# --------- -# Example -# --------- -if __name__ == "__main__": - # Radiology notes - radi_df = get_text_notes(NOTE_PATH, subset="radiology", include_detail=False) - print(radi_df.head(2)) + if "text" not in note: + raise KeyError("The note does not include a 'text' column.") + + logger = logging.getLogger(f"{__name__}.extract_text_from_note") + logger.info("Extracting text from note with ID: %s", note.get("note_id", "unknown")) - if not radi_df.empty: - sample_text = load_text_note(radi_df, note_id=int(radi_df.iloc[0]["note_id"])) - print("Radiology sample text (truncated):", sample_text[:200], "...") + text = note["text"] + if not include_metadata: + return text - # Discharge notes - disc_df = get_text_notes(NOTE_PATH, subset="discharge", include_detail=True) # include detail join - print(disc_df.head(2)) + logger.info("Including metadata in the output.") - if not disc_df.empty: - sample_text = load_text_note(disc_df, note_id=int(disc_df.iloc[0]["note_id"])) - print("Discharge sample text (truncated):", sample_text[:200], "...") + metadata = note.drop("text").to_dict() + return text, metadata + + +if __name__ == "__main__": + import argparse + + # Example script: + # python -m mmai25_hackathon.load_data.text --data-path mimic-iv/mimic-iv-note-deidentified-free-text-clinical-notes-2.2/note --subset radiology --note-id 1 + parser = argparse.ArgumentParser(description="Load MIMIC-IV free-text clinical notes.") + parser.add_argument( + "--data-path", + type=str, + help="Path to the MIMIC-IV notes directory (containing CSV files).", + default="MMAI25Hackathon/mimic-iv/mimic-iv-note-deidentified-free-text-clinical-notes-2.2/note", + ) + parser.add_argument( + "--subset", + type=str, + choices=["radiology", "discharge"], + help="Which note subset to load (radiology or discharge).", + default="radiology", + ) + parser.add_argument("--note-id", type=int, help="The note_id of the note to retrieve.", default=1) + args = parser.parse_args() + + print(f"Loading {args.subset} notes from: {args.data_path}") + data = load_mimic_iv_notes(args.data_path, subset=args.subset, include_detail=True) + print(data.head()[["note_id", "subject_id", "hadm_id", "note_type", "text"]]) + print() + + print(f"Retrieving text for note_id={args.note_id}") + try: + text, metadata = extract_text_from_note( + data.loc[data["note_id"] == args.note_id].squeeze(), include_metadata=True + ) + print("Note text:") + print(text) + print("Metadata:") + print(metadata) + except KeyError as e: + print(e) diff --git a/mmai25_hackathon/prep/__init__.py b/mmai25_hackathon/prep/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/mmai25_hackathon/prep/omics.py b/mmai25_hackathon/prep/omics.py deleted file mode 100644 index 2b57ec4..0000000 --- a/mmai25_hackathon/prep/omics.py +++ /dev/null @@ -1,508 +0,0 @@ -""" -Multi-omics preprocessing (linear pipeline) -=========================================== - -What this script does ---------------------- -Given: - 1) A *label table* (CSV/TSV) with sample IDs as the first column (index) - and at least one column containing class labels. - 2) One or more *omics tables* (CSV/TSV), each with features as rows and - samples as columns (common in bio datasets). - -It will: - - Load & filter labels (drop NA; optionally keep a subset of label values). - - Load omics, transpose to [samples x features], align to the label index. - - Drop labels that are missing from **all** modalities. - - Clean missing values per modality (drop columns with >=10% NA, then mean-impute). - - (Optional) Min–max normalize features per modality independently. - - (Optional) Remove features below a variance threshold (per modality). - - Map string labels → integers 0..K-1. - - Save: - save_path/label.csv (mapped integer labels; see NOTE below) - save_path/{name}_names.csv (feature names for each modality) - save_path/{name}_feat.csv (feature matrix per modality) - -IMPORTANT shape assumptions ---------------------------- -- Label table: index = sample IDs (unique), column `label_column_name` holds labels. -- Each omics file: index = feature IDs, columns = sample IDs BEFORE transpose. - After loading we call `.transpose()` so final shape is [samples x features]. - -File format notes ------------------ -- All readers use `index_col=0`. Your first column must be the sample (for labels) - or feature (for omics) index. -- `sep` defaults to tab (`"\t"`). Change `sep` if you have commas. - -Saving notes (be careful!) --------------------------- -- This script currently writes labels as a Series *without* index or header - (see `labels.to_csv(..., index=False, header=False)`). That **discards sample IDs**. - Keep this if your downstream code expects a plain vector. If you need the sample - IDs preserved, change to `index=True, header=True` and update downstream. - -Quick start (example) ---------------------- -data_paths = [ - ("/path/to/mRNA.tsv", "mRNA"), - ("/path/to/methylation.tsv", "methylation"), - ("/path/to/miRNA.tsv", "miRNA"), -] -run_preprocessing_pipeline( - label_path="/path/to/labels.tsv", - data_paths=data_paths, - save_path="./BRCA/processed/", - label_column_name="PAM50Call_RNAseq", - label_column_values=None, - clean_missing=True, - normalize=True, - var_threshold=[None, 1e-6, 1e-6] # one per modality (or None to skip) -) -""" - -import os -from typing import Dict, List, Optional, Tuple - -import pandas as pd - -# ----------------------------- -# Small, reusable helper funcs -# ----------------------------- - - -def load_label_table( - label_path: str, - sep: str = "\t", - label_column_name: Optional[str] = None, - label_column_values: Optional[List[str]] = None, -) -> pd.Series: - """ - Load label table and return a 1D Series of labels indexed by sample ID. - - Parameters - ---------- - label_path : str - Path to the label file (CSV/TSV). The first column must be sample IDs - (will be used as the index). - sep : str, default="\t" - Field separator used in the file. - label_column_name : str - Name of the column that contains class labels. Required. - label_column_values : list[str] | None - If provided, keep only rows whose label ∈ this set (useful to focus on - selected classes, e.g., five PAM50 subtypes). - - Returns - ------- - labels : pd.Series - 1D Series (index = sample IDs, values = label strings), sorted by index. - - Behavior - -------- - - Drops rows with missing labels (NA in `label_column_name`). - - Optionally filters labels to a user-specified subset. - - Prints class counts for sanity-checking. - - Common pitfalls - --------------- - - If your sample IDs are not in the first column, set `index_col=0` will be wrong. - Fix the file or change the reader accordingly. - """ - label_df = pd.read_csv(label_path, sep=sep, index_col=0) - if label_column_name is None: - raise ValueError("label_column_name must be provided for label filtering.") - - # Keep only samples with non-null labels - label_df = label_df[label_df[label_column_name].notnull()] - labels = label_df[label_column_name] - - print(f"Labels size after removing missing values: {labels.shape}") - - # (Optional) filter label values to a given subset - if label_column_values is not None: - labels = labels[labels.isin(label_column_values)] - print(f"Labels size after value filtering: {labels.shape}") - - # Sort by index to keep stable order - labels = labels.sort_index(axis=0) - print(f"\nInput class labels:\n{labels.value_counts(dropna=True)}\n") - return labels - - -def load_single_omics(path: str, name: str, sep: str = "\t") -> Tuple[str, pd.DataFrame]: - """ - Load one omics modality. - - Steps - ----- - - Reads a matrix with `index_col=0` (features as rows, samples as columns). - - Transposes to shape [samples x features] (machine-learning friendly). - - Assigns index/column names for readability. - - Parameters - ---------- - path : str - Path to the omics file (CSV/TSV). First column must be feature IDs. - name : str - Short identifier for this modality (e.g., "mRNA", "methylation"). - sep : str - Field separator. - - Returns - ------- - (name, df) : tuple[str, pd.DataFrame] - Name and the processed DataFrame (index = sample IDs, columns = feature IDs). - """ - df = pd.read_csv(path, sep=sep, index_col=0).transpose() - df.index.names = ["sample"] - df.columns.names = ["feature"] - - # Sanity checks (non-fatal; change to asserts if you want strictness) - # - Ensure no duplicate sample IDs - if not df.index.is_unique: - print(f"[WARN] Duplicate sample IDs found in {name}. Consider deduplication.") - - # - Ensure no duplicate feature names - if not df.columns.is_unique: - print(f"[WARN] Duplicate feature names found in {name}. Consider deduplication.") - - return name, df - - -def remove_missing_labels(df: pd.DataFrame, valid_sample_index: pd.Index) -> pd.DataFrame: - """ - Keep only rows (samples) present in `valid_sample_index`, then sort rows. - - Why this matters - ---------------- - Ensures every sample in each modality has a corresponding label entry and - aligns row order for consistent saving/merging downstream. - """ - df = df[df.index.isin(valid_sample_index)] - return df.sort_index(axis=0) - - -def clean_missing_values(df: pd.DataFrame) -> pd.DataFrame: - """ - Handle missing values per modality. - - Procedure - --------- - 1) Drop columns (features) with ≥10% missing values (keep columns with >=90% non-NA). - 2) Fill any remaining NA by column mean (simple imputation). - 3) Error if any NA remains (guards against silent failures). - - Tunables - -------- - - Adjust the 10% threshold by changing `min_non_na` if needed. - """ - min_non_na = int(len(df.index) * 0.9) # keep columns with >=90% observed - df = df.dropna(axis=1, thresh=min_non_na) - df = df.fillna(df.mean()) - - if df.isna().any().any(): - raise ValueError("The modality contains missing values. Please handle them before proceeding.") - return df - - -def normalize_minmax(df: pd.DataFrame) -> pd.DataFrame: - """ - Min–max normalize each feature independently: (x - min) / (max - min). - - Notes - ----- - - If a column is constant (max == min), denominator = 0. - We temporarily replace the 0 denom with NA and then fill NA with 0.0, - effectively leaving that feature as all zeros (safe default). - - When to use - ----------- - - Useful when features have different scales. If your downstream method - is scale-invariant (e.g., tree models), normalization may be optional. - """ - col_min = df.min() - col_max = df.max() - denom = (col_max - col_min).replace(0, pd.NA) - df_norm = (df - col_min) / denom - # Replace NA/inf from constant columns with 0.0 - df_norm = df_norm.fillna(0.0) - return df_norm - - -def remove_low_variance_features(df: pd.DataFrame, threshold: Optional[float]) -> pd.DataFrame: - """ - Optionally drop columns with variance < threshold. - - Parameters - ---------- - threshold : float | None - If None, do nothing. Otherwise, retain only features whose variance >= threshold. - Typical tiny thresholds: 1e-8, 1e-6, etc. - - Tip - --- - - Use per-modality thresholds to reflect different numeric scales before normalization. - """ - if threshold is None: - return df - return df.loc[:, df.var() >= threshold] - - -def print_omics_shapes(omics: List[Tuple[str, pd.DataFrame]], title: str): - """ - Utility to log modality shapes for quick inspection. - """ - print(f"\nOmic modality shape ({title}):") - for name, df in omics: - print(f" - {name} shape: {df.shape}") - print() - - -def check_label_indices_availability( - labels: pd.Series, omics: List[Tuple[str, pd.DataFrame]] -) -> Tuple[bool, List[str]]: - """ - Check whether each sample in labels is present in at least one modality. - - Returns - ------- - is_all_available : bool - True if every labeled sample is present in ≥1 modality. - labels_to_remove : list[str] - Labeled sample IDs that do not appear in any modality (to be dropped). - - Why this matters - ---------------- - Keeps the label vector consistent with the available data. Otherwise you'd - carry labels for samples you never actually feed to a model. - """ - combined_indices = set() - for _, df in omics: - combined_indices.update(df.index) - - not_found = [sid for sid in labels.index if sid not in combined_indices] - print(f"Number of labels not in combined omics indices: {len(not_found)}") - return (len(not_found) == 0), not_found - - -def map_labels_to_int(labels: pd.Series) -> Tuple[pd.Series, Dict[str, int]]: - """ - Map unique label strings to integers 0..K-1 (order = first appearance order). - - Returns - ------- - mapped : pd.Series - Same index as `labels`, values are ints in [0, K-1], name="Class". - mapping : dict[str, int] - Dictionary of {original_label -> int_code} for reproducibility/logging. - - Notes - ----- - - If you need a deterministic order (e.g., alphabetical), change the - enumeration order to `for i, label in enumerate(sorted(unique))`. - """ - unique = labels.unique() - mapping = {label: i for i, label in enumerate(unique)} - mapped = labels.map(mapping).rename("Class") - print(f"Mapped class labels: {mapping}") - print(f"\nMapped class labels:\n{mapped.value_counts(dropna=False)}\n") - return mapped, mapping - - -def save_processed( - omics: List[Tuple[str, pd.DataFrame]], - labels: Optional[pd.Series], - data_paths: List[Tuple[str, str]], - label_path: Optional[str], - save_label: bool = True, - save_path: Optional[str] = None, -): - """ - Save processed outputs to disk. - - Parameters - ---------- - omics : list[(name, df)] - Each df is [samples x features], aligned to the final label index. - labels : pd.Series | None - Integer-mapped labels (index = samples). If None or `save_label=False`, - the label file will not be saved. - data_paths : list[(orig_path, name)] - Original (path, name) pairs, used here only to keep naming consistent. - label_path : str | None - Kept for parity with the class interface (not used for writing). - save_label : bool - Whether to write `label.csv`. - save_path : str | None - Directory to save files. Defaults to current directory. - - Writes - ------ - - label.csv - CURRENTLY written **without** index/header to keep compatibility with - some pipelines that expect a raw vector. Change to include index if needed. - - {name}_names.csv - 1D list of feature names (as they appear in the df.columns) *excluding* - the 1st column if you later decide to prepend something. Right now it - simply dumps the columns[1:], mirroring your original behavior. - (See NOTE below.) - - {name}_feat.csv - The full feature matrix [samples x features], written without index/header. - - NOTE about `{name}_names.csv` - ----------------------------- - Your original code uses `names = list(df.columns[1:])`. That drops the first - feature name. Keep as-is if this is intentional for downstream compatibility. - If not intentional, change to `names = list(df.columns)`. - - Tip - --- - Add versioning to `save_path` (e.g., include a timestamp or config hash). - """ - print("Saving the processed data...") - if save_path is None: - save_path = "./" - - # ---- Save labels (optional) ---- - if save_label and labels is not None and label_path is not None: - label_out = os.path.join(save_path, "label.csv") - os.makedirs(os.path.dirname(label_out), exist_ok=True) - # WARNING: This drops sample IDs and the column name. - labels.to_csv(label_out, index=False, header=False) - - # ---- Save each modality ---- - for (orig_path, name), (_, df) in zip(data_paths, omics): - name_out_path = os.path.join(save_path, f"{name}_names.csv") - feat_out_path = os.path.join(save_path, f"{name}_feat.csv") - os.makedirs(os.path.dirname(name_out_path), exist_ok=True) - - # NOTE: Preserve your original slicing behavior. - names = list(df.columns[1:]) - name_df = pd.DataFrame(names) - name_df.to_csv(name_out_path, index=False, header=False) - - # Save feature matrix (no index/header) for compact downstream loading - df.to_csv(feat_out_path, index=False, header=False) - - -# ----------------------------- -# Linear pipeline -# ----------------------------- - - -def run_preprocessing_pipeline( - label_path: str, - data_paths: List[Tuple[str, str]], - save_path: str, - label_column_name: str = "PAM50Call_RNAseq", - label_column_values: Optional[List[str]] = None, - clean_missing: bool = True, - normalize: bool = True, - var_threshold: Optional[List[Optional[float]]] = None, -): - """ - Run the full preprocessing pipeline as a linear script. - - Parameters - ---------- - label_path : str - Path to the label file (CSV/TSV). Must have sample IDs in the first column. - data_paths : list[(path, name)] - List of (omics_file_path, modality_name) tuples. - Each omics file must have features as rows and samples as columns (pre-transpose). - save_path : str - Directory to write processed outputs. - label_column_name : str - Column in the label file that contains class labels to predict. - label_column_values : list[str] | None - If provided, restrict to these label classes (others are dropped). - clean_missing : bool - If True, drop high-NA features and mean-impute remaining NA per modality. - normalize : bool - If True, min–max normalize features per modality. - var_threshold : list[float|None] | None - Per-modality variance thresholds. Must match `len(data_paths)` if provided. - Use None to skip variance filtering for a modality. - - Output files - ------------ - - save_path/label.csv - - save_path/{name}_names.csv - - save_path/{name}_feat.csv - - Logging - ------- - Prints shapes and class counts at each key step for traceability. - - Notes - ----- - - The order of saved rows follows the sorted label index. - - If you later join modalities, ensure consistent row order across files. - """ - # ----- User-configurable inputs (same semantics as your class version) ----- - sep = "\t" # adjust if using CSV with commas - - # ----------------- Step 1: Load labels ----------------- - labels = load_label_table( - label_path, - sep=sep, - label_column_name=label_column_name, - label_column_values=label_column_values, - ) - - # ----------------- Step 2: Load omics ------------------ - omics: List[Tuple[str, pd.DataFrame]] = [] - for path, name in data_paths: - omics.append(load_single_omics(path, name, sep=sep)) - print_omics_shapes(omics, "Raw/Input omic modalities") - - # ----------------- Step 3: Remove samples w/ missing labels ----------------- - # Align each modality to the (possibly filtered) label index - omics = [(name, remove_missing_labels(df, labels.index)) for name, df in omics] - print_omics_shapes(omics, "After missing label removal") - - # ----------------- Step 4: Drop labels not present in ANY modality ---------- - is_ok, labels_to_remove = check_label_indices_availability(labels, omics) - if labels_to_remove: - # If a sample has a label but appears in none of the modalities, drop it. - labels = labels.drop(labels_to_remove) - print(f"Are all labels available in omic modalities? {is_ok}\n") - - # After dropping such labels, re-align all modalities again - omics = [(name, remove_missing_labels(df, labels.index)) for name, df in omics] - - # ----------------- Step 5: Clean missing values (per modality) -------------- - if clean_missing: - omics = [(name, clean_missing_values(df)) for name, df in omics] - print_omics_shapes(omics, "After missing value removal") - - # ----------------- Step 6: Normalize (per modality) ------------------------- - if normalize: - omics = [(name, normalize_minmax(df)) for name, df in omics] - - # ----------------- Step 7: Remove low-variance features --------------------- - if var_threshold is not None: - if len(var_threshold) != len(omics): - raise ValueError("var_threshold must be a list with the same length as data_paths.") - pruned = [] - for (name, df), thr in zip(omics, var_threshold): - pruned.append((name, remove_low_variance_features(df, thr))) - omics = pruned - print_omics_shapes(omics, "After low-variance feature removal") - - # ----------------- Step 8: Map labels to integers --------------------------- - labels_mapped, mapping = map_labels_to_int(labels) - # Consider saving `mapping` as JSON alongside outputs if reproducibility is critical. - - # ----------------- Step 9: Save outputs ------------------------------------ - save_processed( - omics=omics, - labels=labels_mapped, - data_paths=data_paths, - label_path=label_path, - save_label=True, - save_path=save_path, - ) - - print("Done.") diff --git a/mmai25_hackathon/utils.py b/mmai25_hackathon/utils.py deleted file mode 100644 index 66a9c4e..0000000 --- a/mmai25_hackathon/utils.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -General utilities for different use cases. - -Functions: - - find_global_cutoff: Find a cutoff for similarity matrices to ensure k neighbors per row. - - symmetrize_matrix: Symmetrize a square matrix using various aggregation methods. -""" - -import numpy as np -from sklearn.utils._param_validation import Integral, Interval, StrOptions, validate_params -from sklearn.utils.validation import check_symmetric - -VALID_SYMMETRIZATIONS = {"average", "maximum", "minimum", "lower", "upper"} - - -@validate_params( - { - "similarity_matrix": ["array-like"], - "k": [Interval(Integral, 1, None, closed="left")], - }, - prefer_skip_nested_validation=True, -) -def find_global_cutoff(similarity_matrix: np.ndarray, k: int) -> float: - """ - Find an approximately correct global cutoff such that each row has at least `k` - neighbors above the cutoff. - - Args: - similarity_matrix (np.ndarray): A 2D numpy array representing the similarity matrix. - k (int): The minimum number of neighbors each row should have above the cutoff. - If k is outside [1, num_rows - 1], k will be clipped to this range. - - Returns: - float: The global cutoff value. - - Raises: - ValueError: If the similarity matrix is not symmetric. - - Examples: - >>> from sklearn.datasets import make_swiss_roll - >>> from sklearn.metrics import pairwise_kernels - >>> X, _ = make_swiss_roll(n_samples=100, noise=0.1, random_state=42) - >>> similarity_matrix = pairwise_kernels(X, metric="cosine") - >>> cutoff = find_global_cutoff(similarity_matrix, k=5) - >>> print(round(cutoff, 4)) - 0.9676 - """ - # Validate the similarity matrix is symmetric - similarity_matrix = check_symmetric(similarity_matrix, raise_exception=True) - # Get number of samples and clip k to valid range - num_samples = len(similarity_matrix) - - # No valid edges can be formed - if num_samples < 2: - return np.inf - - k = np.clip(k, 1, num_samples - 1) - - # Expect flattened array of shape (num_samples * (num_samples - 1) / 2,) - # which contains all pairwise similarities without duplicates - triu_sim_matrix = similarity_matrix[np.triu_indices(num_samples, k=1)] - finite_sim_matrix = triu_sim_matrix[np.isfinite(triu_sim_matrix)] - - # Fetch targeted number of undirected edges - # Given initial value (num_samples * k) // 2 - # We will clip it to be at least 1 and at most len(finite_sim_matrix) - # Subtracting 1 to convert it to a zero-based index - edge_target = np.clip((num_samples * k) // 2, 1, len(finite_sim_matrix)) - - # Get the cutoff value - index = len(finite_sim_matrix) - edge_target - return np.partition(finite_sim_matrix, index)[index] - - -@validate_params( - {"matrix": ["array-like"], "method": [StrOptions(VALID_SYMMETRIZATIONS)]}, - prefer_skip_nested_validation=True, -) -def symmetrize_matrix(matrix: np.ndarray, method: str = "average") -> np.ndarray: - """ - Symmetrizes a square matrix using the specified method. - - Args: - matrix (np.ndarray): A square numpy array to be symmetrized. - method (str): The method to use for symmetrization. Options are: - - "sum": matrix + matrix.T - - "average": (matrix + matrix.T) / 2 - - "maximum": np.maximum(matrix, matrix.T) - - "minimum": np.minimum(matrix, matrix.T) - - "lower": np.tril(matrix) + np.tril(matrix, -1).T - - "upper": np.triu(matrix) + np.triu(matrix, 1).T - - Returns: - np.ndarray: The symmetrized matrix. - - Raises: - ValueError: If the input matrix is not square or if an invalid method is provided. - - Examples: - >>> mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - >>> symmetrize_matrix(mat, method="average") - array([[1. , 3. , 5. ], - [3. , 5. , 7. ], - [5. , 7. , 9. ]]) - """ - if matrix.shape[0] != matrix.shape[1]: - raise ValueError("Input matrix must be square.") - - if method == "maximum": - return np.maximum(matrix, matrix.mT) - if method == "minimum": - return np.minimum(matrix, matrix.mT) - if method == "lower": - return np.tril(matrix) + np.tril(matrix, -1).mT - if method == "upper": - return np.triu(matrix) + np.triu(matrix, 1).mT - - matrix = matrix + matrix.mT - - # default to average - return 0.5 * matrix if method == "average" else matrix - - -# Quick run of the function -if __name__ == "__main__": - # Run with python -m mmai25_hackathon.utils - - from sklearn.datasets import make_swiss_roll - from sklearn.metrics import pairwise_kernels - - # Create a synthetic dataset - X, _ = make_swiss_roll(n_samples=100, noise=0.1, random_state=42) - - # Compute the similarity matrix and symmetrize it - similarity_matrix = pairwise_kernels(X, metric="cosine") - similarity_matrix = symmetrize_matrix(similarity_matrix) - - # Find the global cutoff - cutoff = find_global_cutoff(similarity_matrix, k=5) - print("Global cutoff:", round(cutoff, 4)) diff --git a/pyproject.toml b/pyproject.toml index c4fcbd9..2705c2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,13 @@ +############################################# +# Build system # +############################################# +[build-system] +requires = ["setuptools>=66", "wheel"] +build-backend = "setuptools.build_meta" + +############################################# +# Project metadata # +############################################# [project] name = "mmai-hackathon" version = "0.1.0" @@ -5,13 +15,15 @@ description = "Base source code for MMAI workshop Hackathon." readme = "README.md" requires-python = ">=3.10, <3.13" dependencies = [ - "black>=25.1.0", - "isort>=6.0.1", - "pre-commit>=4.3.0", - "pykale[full]@git+https://github.com/pykale/pykale@main", - "pytest>=8.4.1", - "pytest-cov>=6.2.1", + "numpy>=2.0.0", + "pandas>=2.3.2", + "pillow>=11.3.0", + "pydicom>=3.0.1", + "rdkit>=2025.3.6", + "scikit-learn>=1.6.1", + "torch>=2.6.0", "torch-geometric>=2.6.0", + "wfdb>=4.3.0", ] authors = [ { name = "Shuo Zhou", email = "shuo.zhou@sheffield.ac.uk" }, @@ -21,39 +33,45 @@ authors = [ { name = "L. M. Riza Rizky", email = "l.m.rizky@sheffield.ac.uk" }, ] -[tool.isort] -known_first_party = [ - "mmai25_hackathon", - "tests", +[project.optional-dependencies] +dev = [ + "black>=25.1.0", + "dropbox>=12.0.2", + "isort>=6.0.1", + "flake8>=7.0.0", + "mypy>=1.10.0", + "pre-commit>=4.3.0", + "pytest>=8.4.1", + "pytest-cov>=6.2.1", ] -profile = "black" -line_length = 120 -force_sort_within_sections = "False" -order_by_type = "False" - -[tool.mypy] -disable_error_code = ["assignment", "call-overload", "attr-defined"] +############################################# +# Tooling configuration # +############################################# [tool.black] # https://github.com/psf/black line-length = 120 target-version = ["py311"] -[tool.ruff] -line-length = 119 - -[tool.ruff.lint] -extend-select = ["I"] - -[tool.ruff.format] -docstring-code-format = true -docstring-code-line-length = "dynamic" +[tool.isort] +known_first_party = ["mmai25_hackathon", "tests"] +profile = "black" +line_length = 120 +force_sort_within_sections = false +order_by_type = false -[tool.uv.sources] -pykale = { git = "https://github.com/pykale/pykale", rev = "main" } +[tool.mypy] +disable_error_code = ["assignment", "call-overload", "attr-defined"] [tool.pytest.ini_options] log_cli = true log_cli_level = "INFO" log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" log_cli_date_format = "%Y-%m-%d %H:%M:%S" +addopts = "--cov=mmai25_hackathon --cov-report term-missing --cov-report html" + +[tool.setuptools] +py-modules = ["mmai25_hackathon"] + +[tool.uv.sources] +pykale = { git = "https://github.com/pykale/pykale", rev = "main" } diff --git a/tests/dropbox_download.py b/tests/dropbox_download.py new file mode 100644 index 0000000..9bffa08 --- /dev/null +++ b/tests/dropbox_download.py @@ -0,0 +1,74 @@ +"""Download a Dropbox folder for CI integration tests. + +This small CLI downloads a Dropbox folder as a zip to ``.zip`` and +optionally extracts it beside the destination folder, removing the zip afterwards. + +Usage examples +-------------- +- Using environment secret (preferred in CI): + ``python -m tests.dropbox_download "/MMAI25Hackathon" "MMAI25Hackathon" --unzip`` + where ``DROPBOX_ACCESS_TOKEN`` (or ``DROPBOX_TOKEN``) is set in the environment. +- Providing the token explicitly: + ``python -m tests.dropbox_download "/remote/path" "local_dir" --access_token --unzip`` +""" + +import argparse +import os +import zipfile + +import dropbox + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Download a folder from Dropbox.") + parser.add_argument("dropbox_folder", type=str, help="Path to the Dropbox folder to download.") + parser.add_argument("local_folder", type=str, help="Local directory to save downloaded files.") + parser.add_argument("--unzip", action="store_true", help="Unzip and remove the zip afterwards.") + + # Auth via refresh token (recommended for headless/CI) + parser.add_argument( + "--app-key", + dest="app_key", + default=os.getenv("DROPBOX_APP_KEY"), + help="Dropbox app key (or set env DROPBOX_APP_KEY).", + ) + parser.add_argument( + "--app-secret", + dest="app_secret", + default=os.getenv("DROPBOX_APP_SECRET"), + help="Dropbox app secret (omit if your refresh token was issued via PKCE).", + ) + parser.add_argument( + "--refresh-token", + dest="refresh_token", + default=os.getenv("DROPBOX_REFRESH_TOKEN"), + help="Dropbox OAuth2 refresh token (or set env DROPBOX_REFRESH_TOKEN).", + ) + + args = parser.parse_args() + + # Basic validation + if not args.app_key: + parser.error("Missing --app-key (or env DROPBOX_APP_KEY).") + if not args.refresh_token: + parser.error("Missing --refresh-token (or env DROPBOX_REFRESH_TOKEN).") + # app_secret is optional if your refresh token was created with PKCE + + # Create the Dropbox client using refresh token auth + dbx = dropbox.Dropbox( + oauth2_refresh_token=args.refresh_token, + app_key=args.app_key, + app_secret=args.app_secret, # may be None if using PKCE + ) + + # Ensure local folder exists + os.makedirs(args.local_folder, exist_ok=True) + zip_path = args.local_folder + ".zip" + dbx.files_download_zip_to_file(zip_path, args.dropbox_folder) + + if args.unzip: + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(os.path.join(args.local_folder, "..")) + + print(f"Unzipped to {args.local_folder}") + os.remove(zip_path) + print(f"Removed zip file {zip_path}") diff --git a/tests/load_data/test_cxr.py b/tests/load_data/test_cxr.py new file mode 100644 index 0000000..739486c --- /dev/null +++ b/tests/load_data/test_cxr.py @@ -0,0 +1,139 @@ +"""Tests for MIMIC-IV Chest X-ray (CXR) loading utilities. + +This suite validates two public APIs: + +- ``load_mimic_cxr_metadata(path, ...)``: parses the CXR metadata CSV, verifies + the expected dataset layout (``files/`` subfolder), resolves absolute image + paths, applies optional row filtering, and returns a ``pd.DataFrame`` with a + ``cxr_path`` column pointing to existing ``.jpg`` files. +- ``load_chest_xray_image(path, to_gray=True)``: opens a chest X-ray image at + the given path, optionally converting to grayscale. + +Prerequisite +------------ +The tests assume the real dataset is available during CI under the fixed path: +``${PWD}/MMAI25Hackathon/mimic-iv/mimic-cxr-jpg-chest-radiographs-with-structured-labels-2.1.0``. +If that directory or its ``files/`` subfolder is missing, the tests are skipped +rather than failing. +""" + +import logging +from pathlib import Path + +import pandas as pd +import pytest + +from mmai25_hackathon.load_data.cxr import ( + load_chest_xray_image, + load_mimic_cxr_metadata, +) + +# Fixed dataset path (fetched during CI) +CXR_ROOT = Path.cwd() / "MMAI25Hackathon" / "mimic-iv" / "mimic-cxr-jpg-chest-radiographs-with-structured-labels-2.1.0" + + +@pytest.fixture(scope="module") +def cxr_root() -> Path: + """Return the dataset root or skip the module if missing.""" + if not CXR_ROOT.exists(): + pytest.skip(f"Dataset root not found: {CXR_ROOT}") + files_dir = CXR_ROOT / "files" + if not files_dir.exists(): + pytest.skip(f"Dataset 'files' subdir not found under: {CXR_ROOT}") + return CXR_ROOT + + +@pytest.fixture(scope="module") +def metadata_df(cxr_root: Path) -> pd.DataFrame: + """Load metadata once for the module to speed up tests.""" + return load_mimic_cxr_metadata(cxr_root) + + +@pytest.mark.parametrize("use_str_path", [True, False]) +def test_metadata_and_image_loading(caplog: pytest.LogCaptureFixture, cxr_root: Path, use_str_path: bool): + # Ensure INFO logs from the loader are captured if emitted + caplog.set_level(logging.INFO) + + path_arg = str(cxr_root) if use_str_path else cxr_root + df = load_mimic_cxr_metadata(path_arg) + + # Basic metadata checks + assert isinstance(df, pd.DataFrame), "Expected a DataFrame from load_mimic_cxr_metadata" + assert not df.empty, "Metadata DataFrame is unexpectedly empty" + assert "cxr_path" in df.columns, "Expected column 'cxr_path' to be present" + + # If the implementation logs mapping info, it should be visible here. + # We don't require it, but if present we assert the message contains 'Mapped'. + if caplog.records: + assert any("Mapped" in rec.getMessage() for rec in caplog.records), "Expected mapping log message" + + # Sample one image and check both grayscale and RGB branches + sample_path = Path(str(df.iloc[0]["cxr_path"])) # type: ignore[index] + assert sample_path.is_absolute(), "cxr_path should be absolute" + assert sample_path.exists(), f"Image path does not exist: {sample_path}" + + img_gray = load_chest_xray_image(sample_path) + assert img_gray.mode == "L" + + img_rgb = load_chest_xray_image(sample_path, to_gray=False) + assert img_rgb.mode == "RGB" + + +def test_paths_are_absolute_and_exist_on_head(metadata_df: pd.DataFrame): + head_paths = metadata_df["cxr_path"].astype(str).head(10).tolist() + + for p in head_paths: + pth = Path(p) + assert pth.is_absolute(), f"Path is not absolute: {pth}" + assert pth.exists(), f"Resolved path does not exist: {pth}" + assert pth.suffix.lower() == ".jpg" + + +def test_filter_rows_train_subset_is_consistent(cxr_root: Path, metadata_df: pd.DataFrame): + # Use an ID that exists in the dataset. 101 is commonly present in the public splits. + train_df = load_mimic_cxr_metadata(cxr_root, filter_rows={"subject_id": [101]}) + + assert not train_df.empty, "Filtered subject_id is unexpectedly empty" + assert set(train_df["subject_id"].unique()) == {101}, "Filtered subject_id has unexpected values" + + # The filtered set must be a subset of the unfiltered rows with subject_id==101 + all_train = metadata_df[metadata_df["subject_id"] == 101] + assert set(train_df["cxr_path"]).issubset(set(all_train["cxr_path"])), "Filtered rows mismatch" + + +def test_loading_nonexistent_image_raises(cxr_root: Path): + missing = cxr_root / "files" / "__definitely_not_here__.jpg" + with pytest.raises(FileNotFoundError): + load_chest_xray_image(missing) + + +def test_invalid_metadata_path_raises(cxr_root: Path): + with pytest.raises(FileNotFoundError): + load_mimic_cxr_metadata(cxr_root / "nonexistent_dir") + + +def test_missing_dicom_id_column_raises(tmp_path: Path): + """If required DICOM ID columns are absent, the loader must raise ``KeyError``.""" + df = pd.DataFrame({"subject_id": [1, 2], "study_id": [10, 20], "other_column": ["A", "B"]}) + (tmp_path / "files").mkdir() + (tmp_path / "metadata.csv").write_text(df.to_csv(index=False)) + + with pytest.raises(KeyError): + load_mimic_cxr_metadata(tmp_path, filter_rows={"subject_id": [1]}) + + +def test_files_folder_missing_raises(tmp_path: Path): + """If ``files/`` is missing, the loader must raise ``FileNotFoundError``.""" + df = pd.DataFrame({"subject_id": [1, 2], "study_id": [10, 20], "dicom_id": ["img1", "img2"]}) + (tmp_path / "metadata.csv").write_text(df.to_csv(index=False)) + + with pytest.raises(FileNotFoundError): + load_mimic_cxr_metadata(tmp_path, filter_rows={"subject_id": [1]}) + + +def test_metadata_not_found_raises(tmp_path: Path): + """If no metadata CSV is present, the loader must raise ``FileNotFoundError``.""" + (tmp_path / "files").mkdir() + + with pytest.raises(FileNotFoundError): + load_mimic_cxr_metadata(tmp_path, filter_rows={"subject_id": [1]}) diff --git a/tests/load_data/test_ecg.py b/tests/load_data/test_ecg.py new file mode 100644 index 0000000..74bebe7 --- /dev/null +++ b/tests/load_data/test_ecg.py @@ -0,0 +1,237 @@ +"""Tests for MIMIC-IV Electrocardiogram (ECG) loading utilities. + +This suite validates the public APIs in ``mmai25_hackathon.load_data.ecg``: + +- ``load_mimic_iv_ecg_record_list(ecg_path, ...)``: parses ``record_list.csv``, verifies + the expected dataset layout (``files/`` subfolder), resolves absolute ``.hea``/``.dat`` paths, + applies optional row filtering, and returns a ``pd.DataFrame`` containing only rows with both files. +- ``load_ecg_record(hea_path)``: reads an ECG record via ``wfdb.rdsamp`` given a ``.hea`` path. + +Prerequisite +------------ +The tests assume the real dataset may be available under the fixed path: +``${PWD}/MMAI25Hackathon/mimic-iv/mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0``. +If that directory, its ``files/`` subfolder, or ``record_list.csv`` is missing, the integration +tests are skipped. Unit-level behavior and error handling are still validated via temporary data. +""" + +import logging +from pathlib import Path +from typing import Dict, Tuple + +import numpy as np +import pandas as pd +import pytest + +# Ensure wfdb is available; otherwise, skip this module's tests at collection time +pytest.importorskip("wfdb") + +from mmai25_hackathon.load_data.ecg import load_ecg_record, load_mimic_iv_ecg_record_list # noqa: E402 + +# Fixed dataset path (if available locally or fetched during CI) +ECG_ROOT = Path.cwd() / "MMAI25Hackathon" / "mimic-iv" / "mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0" + + +@pytest.fixture(scope="module") +def ecg_root() -> Path: + """Return the dataset root or skip the module-level integration tests if missing.""" + if not ECG_ROOT.exists(): + pytest.skip(f"Dataset root not found: {ECG_ROOT}") + files_dir = ECG_ROOT / "files" + if not files_dir.exists(): + pytest.skip(f"Dataset 'files' subdir not found under: {ECG_ROOT}") + records_csv = ECG_ROOT / "record_list.csv" + if not records_csv.exists(): + pytest.skip(f"'record_list.csv' not found under: {ECG_ROOT}") + return ECG_ROOT + + +@pytest.fixture(scope="module") +def record_df(ecg_root: Path) -> pd.DataFrame: + """Load the record list once for the module to speed up tests.""" + return load_mimic_iv_ecg_record_list(ecg_root) + + +@pytest.mark.parametrize("use_str_path", [True, False]) +def test_record_list_and_signal_loading(caplog: pytest.LogCaptureFixture, ecg_root: Path, use_str_path: bool) -> None: + # Capture loader INFO logs if emitted + caplog.set_level(logging.INFO) + + path_arg = str(ecg_root) if use_str_path else ecg_root + df = load_mimic_iv_ecg_record_list(path_arg) + + # Basic metadata checks + assert isinstance(df, pd.DataFrame), f"Expected a DataFrame from load_mimic_iv_ecg_record_list, got {type(df)!r}" + assert not df.empty, ( + "Record list DataFrame is unexpectedly empty; check that 'record_list.csv' contains rows and that" + " corresponding '.hea' and '.dat' files exist." + ) + for col in ("ecg_path", "hea_path", "dat_path"): + assert col in df.columns, f"Expected column '{col}' to be present; available columns: {list(df.columns)}" + + # Optional assertion on log messages if available + if caplog.records: + assert any( + "Mapping ECG file paths" in rec.getMessage() or "Found" in rec.getMessage() for rec in caplog.records + ), "Expected mapping or discovery log message" + + # Sample one record, ensure paths are absolute/exist, and load signals + sample_hea = Path(str(df.iloc[0]["hea_path"])) # type: ignore[index] + sample_dat = Path(str(df.iloc[0]["dat_path"])) # type: ignore[index] + assert sample_hea.is_absolute(), f"hea_path should be absolute, got: {sample_hea}" + assert sample_dat.is_absolute(), f"dat_path should be absolute, got: {sample_dat}" + assert sample_hea.exists(), f"Expected .hea file to exist, missing: {sample_hea}" + assert sample_dat.exists(), f"Expected .dat file to exist, missing: {sample_dat}" + assert ( + sample_hea.suffix.lower() == ".hea" + ), f"hea_path must end with .hea, got suffix: {sample_hea.suffix} (path={sample_hea})" + assert ( + sample_dat.suffix.lower() == ".dat" + ), f"dat_path must end with .dat, got suffix: {sample_dat.suffix} (path={sample_dat})" + + signals, fields = load_ecg_record(sample_hea) + assert isinstance(signals, np.ndarray), f"signals should be np.ndarray, got {type(signals)!r}" + assert signals.ndim == 2, f"signals should be 2D (T, L), got shape {signals.shape}" + assert signals.shape[1] > 0, f"signals should have >0 leads, got shape {signals.shape}" + assert isinstance(fields, dict), f"fields should be dict, got {type(fields)!r}" + # WFDB commonly provides sampling frequency under 'fs' + if "fs" in fields: + assert float(fields["fs"]) > 0, f"Sampling frequency 'fs' must be > 0, got {fields['fs']!r}" + + +def test_paths_are_absolute_and_exist_on_head(record_df: pd.DataFrame) -> None: + head_hea = record_df["hea_path"].astype(str).head(10).tolist() + head_dat = record_df["dat_path"].astype(str).head(10).tolist() + + for hp, dp in zip(head_hea, head_dat): + hp = Path(hp) + dp = Path(dp) + assert hp.is_absolute(), f"hea_path is not absolute: {hp}" + assert dp.is_absolute(), f"dat_path is not absolute: {dp}" + assert hp.exists(), f"Resolved .hea does not exist: {hp}" + assert dp.exists(), f"Resolved .dat does not exist: {dp}" + assert hp.suffix.lower() == ".hea", f"hea_path must end with .hea, got {hp.suffix} (path={hp})" + assert dp.suffix.lower() == ".dat", f"dat_path must end with .dat, got {dp.suffix} (path={dp})" + + +def test_filter_rows_subject_subset_is_consistent(ecg_root: Path, record_df: pd.DataFrame) -> None: + # Choose a subject_id present in the loaded record list to make the test robust + some_subject = int(record_df["subject_id"].iloc[0]) + filtered = load_mimic_iv_ecg_record_list(ecg_root, filter_rows={"subject_id": [some_subject]}) + + assert ( + not filtered.empty + ), f"Filtered subject_id is unexpectedly empty; subject={some_subject} not found in record_list.csv" + assert set(filtered["subject_id"].unique()) == { + some_subject + }, f"Filtered subject_id has unexpected values; expected only {some_subject}, got {set(filtered['subject_id'].unique())}" + + # The filtered set must be a subset of the unfiltered rows for that subject + all_rows = record_df[record_df["subject_id"] == some_subject] + assert set(filtered["hea_path"]).issubset( + set(all_rows["hea_path"]) + ), "Filtered rows mismatch for 'hea_path'; filtered set must be a subset of unfiltered rows for the same subject" + assert set(filtered["dat_path"]).issubset( + set(all_rows["dat_path"]) + ), "Filtered rows mismatch for 'dat_path'; filtered set must be a subset of unfiltered rows for the same subject" + + +def test_loading_nonexistent_hea_raises() -> None: + with pytest.raises(FileNotFoundError): + load_ecg_record(Path("/definitely/not/here.hea")) + + +def test_invalid_ecg_base_path_raises(tmp_path: Path) -> None: + with pytest.raises(FileNotFoundError): + load_mimic_iv_ecg_record_list(tmp_path / "missing_root") + + +def test_files_folder_missing_raises(tmp_path: Path) -> None: + (tmp_path / "record_list.csv").write_text(pd.DataFrame({"path": ["p100/p101/s133/133"]}).to_csv(index=False)) + with pytest.raises(FileNotFoundError): + load_mimic_iv_ecg_record_list(tmp_path) + + +def test_record_list_not_found_raises(tmp_path: Path) -> None: + (tmp_path / "files").mkdir() + with pytest.raises(FileNotFoundError): + load_mimic_iv_ecg_record_list(tmp_path) + + +def test_mapping_and_filtering_with_temp_csv(tmp_path: Path) -> None: + # Create folder layout: /files and a synthetic record_list.csv + files_dir = tmp_path / "files" + files_dir.mkdir() + + # Synthetic CSV with whitespace in 'path' to exercise .str.strip() + df = pd.DataFrame( + { + "subject_id": [101, 101, 102], + "study_id": [133, 999, 200], + # one valid pair (both .hea/.dat), one missing .dat, one missing .hea + "path": [ + " files/p100/p101/s133/133 ", + " files/p100/p101/s999/999 ", + " files/p200/p201/s200/200 ", + ], + } + ) + (tmp_path / "record_list.csv").write_text(df.to_csv(index=False)) + + # Create files for the first entry only + rec1 = files_dir / "p100" / "p101" / "s133" + rec1.mkdir(parents=True) + (rec1 / "133.hea").write_text("dummy header") + (rec1 / "133.dat").write_bytes(b"\x00\x01\x02") + + # Second: create only .hea + rec2 = files_dir / "p100" / "p101" / "s999" + rec2.mkdir(parents=True) + (rec2 / "999.hea").write_text("dummy header") + + # Third: create only .dat + rec3 = files_dir / "p200" / "p201" / "s200" + rec3.mkdir(parents=True) + (rec3 / "200.dat").write_bytes(b"\x00\x01\x02") + + out = load_mimic_iv_ecg_record_list(tmp_path, filter_rows={"subject_id": [101, 102]}) + + # Only the first row should remain (both files present) + assert len(out) == 1, f"Expected only one mapped row with both .hea/.dat present; got {len(out)} rows." + row = out.iloc[0] + assert int(row["subject_id"]) == 101, f"Unexpected subject_id in mapped row: {row['subject_id']!r}" + assert Path(row["hea_path"]).is_absolute(), f"hea_path should be absolute, got: {row['hea_path']}" + assert Path(row["dat_path"]).is_absolute(), f"dat_path should be absolute, got: {row['dat_path']}" + assert Path(row["hea_path"]).exists(), f"Mapped .hea missing on disk: {row['hea_path']}" + assert Path(row["dat_path"]).exists(), f"Mapped .dat missing on disk: {row['dat_path']}" + assert Path(row["hea_path"]).suffix == ".hea", f"hea_path must end with .hea: {row['hea_path']}" + assert Path(row["dat_path"]).suffix == ".dat", f"dat_path must end with .dat: {row['dat_path']}" + + +@pytest.mark.parametrize("as_str", [True, False]) +def test_load_ecg_record_invokes_wfdb_with_stem(monkeypatch: pytest.MonkeyPatch, tmp_path: Path, as_str: bool) -> None: + # Prepare a dummy .hea file to pass the existence check + hea_file = tmp_path / "example" / "rec.hea" + hea_file.parent.mkdir(parents=True) + hea_file.write_text("header") + + captured_arg: Dict[str, str] = {} + + def fake_rdsamp(arg: str) -> Tuple[np.ndarray, Dict[str, float]]: # type: ignore[override] + captured_arg["arg"] = arg + return np.zeros((10, 3), dtype=float), {"fs": 500.0} + + # Monkeypatch wfdb.rdsamp to avoid requiring a valid WFDB record on disk + import wfdb as _wfdb + + monkeypatch.setattr(_wfdb, "rdsamp", fake_rdsamp, raising=True) + + arg_in = str(hea_file) if as_str else hea_file + sig, meta = load_ecg_record(arg_in) + + # Expect the stem path (without suffix) to be passed to rdsamp + assert captured_arg.get("arg") and ( + captured_arg["arg"].endswith("/example/rec") or captured_arg["arg"].endswith("\\example\\rec") + ), f"wfdb.rdsamp should receive the stem path; got arg={captured_arg!r}" + assert sig.shape == (10, 3), f"Unexpected dummy signal shape, expected (10,3), got {sig.shape}" + assert meta["fs"] == 500.0, f"Unexpected dummy fs, expected 500.0, got {meta.get('fs')!r}" diff --git a/tests/load_data/test_echo.py b/tests/load_data/test_echo.py new file mode 100644 index 0000000..649c3cb --- /dev/null +++ b/tests/load_data/test_echo.py @@ -0,0 +1,167 @@ +"""Tests for MIMIC-IV Echocardiogram (ECHO) loading utilities. + +This suite validates the public APIs in ``mmai25_hackathon.load_data.echo``: + +- ``load_mimic_iv_echo_record_list(echo_path, ...)``: parses ``echo-record-list.csv``, verifies + the expected dataset layout (``files/`` subfolder), resolves absolute DICOM paths from + ``dicom_filepath``, applies optional row filtering, and returns a ``pd.DataFrame`` containing only + rows with existing files. +- ``load_echo_dicom(path)``: reads an ECHO DICOM via ``pydicom.dcmread`` and returns frames (T,H,W) and metadata. + +Prerequisite +------------ +The tests assume the real dataset may be available under the fixed path: +``${PWD}/MMAI25Hackathon/mimic-iv/mimic-iv-echo-0.1.physionet.org``. +If that directory, its ``files/`` subfolder, or ``echo-record-list.csv`` is missing, the integration +tests are skipped. Unit-level behavior and error handling are still validated via temporary data. +""" + +import logging +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd +import pytest + +# Ensure pydicom is available; otherwise, skip this module's tests at collection time +pytest.importorskip("pydicom") + +from mmai25_hackathon.load_data.echo import load_echo_dicom, load_mimic_iv_echo_record_list # noqa: E402 + +ECHO_ROOT = Path.cwd() / "MMAI25Hackathon" / "mimic-iv" / "mimic-iv-echo-0.1.physionet.org" + + +@pytest.fixture(scope="module") +def echo_root() -> Path: + if not ECHO_ROOT.exists(): + pytest.skip(f"Dataset root not found: {ECHO_ROOT}") + files_dir = ECHO_ROOT / "files" + if not files_dir.exists(): + pytest.skip(f"Dataset 'files' subdir not found: {ECHO_ROOT}") + records_csv = ECHO_ROOT / "echo-record-list.csv" + if not records_csv.exists(): + pytest.skip(f"'echo-record-list.csv' not found under: {ECHO_ROOT}") + return ECHO_ROOT + + +@pytest.fixture(scope="module") +def echo_df(echo_root: Path) -> pd.DataFrame: + return load_mimic_iv_echo_record_list(echo_root) + + +@pytest.mark.parametrize("use_str_path", [True, False]) +def test_record_list_mapping_and_paths(caplog: pytest.LogCaptureFixture, echo_root: Path, use_str_path: bool) -> None: + caplog.set_level(logging.INFO) + + arg = str(echo_root) if use_str_path else echo_root + df = load_mimic_iv_echo_record_list(arg) + + assert isinstance(df, pd.DataFrame), f"Expected DataFrame, got {type(df)!r}" + assert not df.empty, "ECHO record list is empty; ensure 'echo-record-list.csv' has rows and DICOM files exist." + assert "echo_path" in df.columns, f"Missing 'echo_path' column; columns: {list(df.columns)}" + + # Inspect one record + p = Path(str(df.iloc[0]["echo_path"])) # type: ignore[index] + assert p.is_absolute(), f"echo_path should be absolute, got: {p}" + assert p.exists(), f"Resolved DICOM does not exist: {p}" + assert p.suffix.lower() == ".dcm", f"Expected .dcm suffix, got {p.suffix} (path={p})" + + # Optional logging checks + if caplog.records: + assert any( + "Mapping ECHO DICOM" in r.getMessage() or "Found" in r.getMessage() for r in caplog.records + ), "Expected mapping/discovery log messages in INFO logs" + + +def test_invalid_echo_base_path_raises(tmp_path: Path) -> None: + with pytest.raises(FileNotFoundError): + load_mimic_iv_echo_record_list(tmp_path / "missing") + + +def test_files_folder_missing_raises(tmp_path: Path) -> None: + (tmp_path / "echo-record-list.csv").write_text( + pd.DataFrame({"dicom_filepath": ["files/p100/p101/s133/133.dcm"]}).to_csv(index=False) + ) + with pytest.raises(FileNotFoundError): + load_mimic_iv_echo_record_list(tmp_path) + + +def test_record_list_not_found_raises(tmp_path: Path) -> None: + (tmp_path / "files").mkdir() + with pytest.raises(FileNotFoundError): + load_mimic_iv_echo_record_list(tmp_path) + + +def test_mapping_and_filtering_with_temp_csv(tmp_path: Path) -> None: + files_dir = tmp_path / "files" + files_dir.mkdir() + + df = pd.DataFrame( + { + "subject_id": [101, 102, 102], + "study_id": [133, 200, 201], + "dicom_filepath": [ + " files/p100/p101/s133/133.dcm ", # valid + " files/p200/p201/s200/200.dcm ", # missing + " files/p200/p201/s201/201.dcm ", # valid + ], + } + ) + (tmp_path / "echo-record-list.csv").write_text(df.to_csv(index=False)) + + # Create only 133.dcm and 201.dcm + (files_dir / "p100" / "p101" / "s133").mkdir(parents=True) + (files_dir / "p100" / "p101" / "s133" / "133.dcm").write_bytes(b"DICOM") + (files_dir / "p200" / "p201" / "s201").mkdir(parents=True) + (files_dir / "p200" / "p201" / "s201" / "201.dcm").write_bytes(b"DICOM") + + out = load_mimic_iv_echo_record_list(tmp_path, filter_rows={"subject_id": [101, 102]}) + assert len(out) == 2, f"Expected two existing DICOMs, got {len(out)}" + assert set(out["study_id"]) == {133, 201}, f"Unexpected study_ids mapped: {set(out['study_id'])}" + for p in out["echo_path"].astype(str): + pth = Path(p) + assert pth.is_absolute(), f"echo_path should be absolute, got {pth}" + assert pth.exists(), f"Mapped DICOM missing on disk: {pth}" + + +def test_loading_nonexistent_dicom_raises(tmp_path: Path) -> None: + with pytest.raises(FileNotFoundError): + load_echo_dicom(tmp_path / "not_here.dcm") + + +@pytest.mark.parametrize("as_str", [True, False]) +def test_load_echo_dicom_with_monkeypatch(monkeypatch: pytest.MonkeyPatch, tmp_path: Path, as_str: bool) -> None: + # Create a dummy file to pass existence check + dcm = tmp_path / "a" / "b.dcm" + dcm.parent.mkdir(parents=True) + dcm.write_bytes(b"DICOM") + + class FakeDicom: + # Start with single-frame (H,W) to exercise branch + def __init__(self) -> None: + self.pixel_array = np.ones((8, 6), dtype=np.float32) + self.RescaleSlope = 2.0 + self.RescaleIntercept = -1.0 + + def __iter__(self): + # Simulate iteration over DICOM elements with keyword/value + class E: + def __init__(self, k: str, v: Any) -> None: + self.keyword, self.value = k, v + + yield E("Rows", 8) + yield E("Columns", 6) + yield E("NumberOfFrames", 1) + + # Monkeypatch the symbol actually used by the loader (echo.dcmread) + from mmai25_hackathon.load_data import echo as echo_mod + + monkeypatch.setattr(echo_mod, "dcmread", lambda p: FakeDicom(), raising=True) + + path_arg = str(dcm) if as_str else dcm + frames, meta = load_echo_dicom(path_arg) + assert frames.shape == (1, 8, 6), f"Expected (1,H,W) after expand, got {frames.shape}" + # Check rescale applied: 1*2-1 = 1 + assert float(frames.mean()) == pytest.approx(1.0), f"Unexpected rescaled mean: {frames.mean()}" + assert {"Rows", "Columns"}.issubset(meta.keys()), f"Missing expected metadata keys: {meta.keys()}" diff --git a/tests/load_data/test_ehr.py b/tests/load_data/test_ehr.py new file mode 100644 index 0000000..372c88e --- /dev/null +++ b/tests/load_data/test_ehr.py @@ -0,0 +1,174 @@ +"""Tests for MIMIC-IV Electronic Health Record (EHR) utilities. + +This suite validates the public API in ``mmai25_hackathon.load_data.ehr``: + +- ``load_mimic_iv_ehr(ehr_path, ...)``: discovers available tables for selected module(s), + loads CSVs with optional column selection and row filtering, and merges tables by overlapping keys. + +Prerequisite +------------ +No external dataset required. The tests use synthetic CSVs under temporary directories to validate +behavior, including error handling for missing modules/tables and merge semantics. +""" + +from pathlib import Path + +import pandas as pd +import pytest + +from mmai25_hackathon.load_data.ehr import load_mimic_iv_ehr + + +def _write_csv(path: Path, df: pd.DataFrame) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(df.to_csv(index=False)) + + +def test_invalid_ehr_base_path_raises(tmp_path: Path) -> None: + with pytest.raises(FileNotFoundError): + load_mimic_iv_ehr(tmp_path / "missing") + + +def test_missing_selected_module_subfolder_raises(tmp_path: Path) -> None: + # Create only 'hosp' + (tmp_path / "hosp").mkdir() + with pytest.raises(FileNotFoundError): + load_mimic_iv_ehr(tmp_path, module="icu") + + +def test_no_available_tables_raises(tmp_path: Path) -> None: + (tmp_path / "hosp").mkdir() + with pytest.raises(ValueError): + load_mimic_iv_ehr(tmp_path, module="hosp") + + +def test_requested_missing_tables_raises(tmp_path: Path) -> None: + # Create one table only + hosp = tmp_path / "hosp" + hosp.mkdir() + _write_csv(hosp / "admissions.csv", pd.DataFrame({"subject_id": [1], "hadm_id": [10], "admittime": ["x"]})) + + with pytest.raises(ValueError): + load_mimic_iv_ehr( + tmp_path, module="hosp", tables=["admissions", "transfers"], index_cols=["subject_id", "hadm_id"] + ) + + +@pytest.mark.parametrize("as_str", [True, False]) +def test_merge_success_on_shared_keys(tmp_path: Path, as_str: bool) -> None: + hosp = tmp_path / "hosp" + icu = tmp_path / "icu" + hosp.mkdir() + icu.mkdir() + + _write_csv( + hosp / "admissions.csv", + pd.DataFrame( + { + "subject_id": [101, 102], + "hadm_id": [1, 2], + "admittime": ["t1", "t2"], + } + ), + ) + _write_csv( + icu / "icustays.csv", + pd.DataFrame( + { + "subject_id": [101, 102], + "hadm_id": [1, 2], + "first_careunit": ["A", "B"], + } + ), + ) + + root_arg = str(tmp_path) if as_str else tmp_path + df = load_mimic_iv_ehr( + root_arg, + module="both", + tables=["admissions", "icustays"], + index_cols=["subject_id", "hadm_id"], + subset_cols={"admissions": ["admittime"], "icustays": ["first_careunit"]}, + filter_rows={"subject_id": [101]}, + merge=True, + join="inner", + ) + + assert not df.empty, "Merged EHR DataFrame is unexpectedly empty" + assert set(["subject_id", "hadm_id", "admittime", "first_careunit"]).issubset( + df.columns + ), f"Missing expected columns after merge; got: {list(df.columns)}" + assert set(df["subject_id"]) == {101}, f"Filter_rows not applied as expected: {set(df['subject_id'])}" + + +def test_merge_multiple_components_raises(tmp_path: Path) -> None: + hosp = tmp_path / "hosp" + icu = tmp_path / "icu" + hosp.mkdir() + icu.mkdir() + + _write_csv(hosp / "patients.csv", pd.DataFrame({"subject_id": [1, 2], "gender": ["M", "F"]})) + _write_csv(icu / "caregiver.csv", pd.DataFrame({"icustay_id": [10, 20], "role": ["x", "y"]})) + + with pytest.raises(ValueError): + load_mimic_iv_ehr( + tmp_path, + module="both", + tables=["patients", "caregiver"], + index_cols=["subject_id", "icustay_id"], + merge=True, + ) + + +def test_merge_false_returns_dict(tmp_path: Path) -> None: + hosp = tmp_path / "hosp" + hosp.mkdir() + _write_csv(hosp / "admissions.csv", pd.DataFrame({"subject_id": [1], "hadm_id": [10], "admittime": ["t"]})) + + dfs = load_mimic_iv_ehr(tmp_path, module="hosp", tables=["admissions"], merge=False) + assert isinstance(dfs, dict), f"Expected dict of DataFrames, got {type(dfs)!r}" + assert set(dfs.keys()) == {"admissions"}, f"Unexpected keys in result: {set(dfs.keys())}" + assert not dfs["admissions"].empty, "admissions DataFrame unexpectedly empty" + + +def test_autodiscover_tables_when_none(tmp_path: Path) -> None: + hosp = tmp_path / "hosp" + hosp.mkdir() + _write_csv(hosp / "admissions.csv", pd.DataFrame({"subject_id": [7], "hadm_id": [70], "admittime": ["t"]})) + + # With tables=None and module='hosp', admissions should be discovered (covers available_tables[table]=path branch) + df = load_mimic_iv_ehr(tmp_path, module="hosp", tables=None, index_cols=["subject_id", "hadm_id"], merge=True) + assert isinstance(df, pd.DataFrame) and not df.empty, "Autodiscovered admissions table should load successfully" + + +# Optional real dataset integration +EHR_ROOT = Path.cwd() / "MMAI25Hackathon" / "mimic-iv" / "mimic-iv-3.1" + + +@pytest.fixture(scope="module") +def ehr_root() -> Path: + if not EHR_ROOT.exists(): + pytest.skip(f"EHR root not found: {EHR_ROOT}") + return EHR_ROOT + + +def test_integration_load_and_merge_real_ehr_if_available(ehr_root: Path) -> None: + # Try a small, common pair of tables + hosp_adm = ehr_root / "hosp" / "admissions.csv" + icu_stays = ehr_root / "icu" / "icustays.csv" + if not hosp_adm.exists() or not icu_stays.exists(): + pytest.skip("Required EHR tables (admissions or icustays) not found; skipping integration test") + + df = load_mimic_iv_ehr( + ehr_root, + module="both", + tables=["admissions", "icustays"], + index_cols=["subject_id", "hadm_id"], + subset_cols={"admissions": ["admittime"], "icustays": ["first_careunit"]}, + merge=True, + join="inner", + ) + # We can't guarantee non-emptiness in all subsets, but can assert type and columns when present + assert isinstance(df, pd.DataFrame), "Expected merged DataFrame" + if not df.empty: + assert {"subject_id", "hadm_id"}.issubset(df.columns), f"Merged keys missing in columns: {df.columns}" diff --git a/tests/load_data/test_molecule.py b/tests/load_data/test_molecule.py new file mode 100644 index 0000000..9dfd97e --- /dev/null +++ b/tests/load_data/test_molecule.py @@ -0,0 +1,91 @@ +"""Tests for molecular SMILES utilities. + +This suite validates the public APIs in ``mmai25_hackathon.load_data.molecule``: + +- ``fetch_smiles_from_dataframe(df_or_path, ...)``: extracts a SMILES column from a DataFrame or CSV path, + optionally setting an index, and returns a single-column DataFrame named ``smiles``. +- ``smiles_to_graph(smiles, ...)``: converts a SMILES string to a PyG ``Data`` graph; flags forwarded. + +Prerequisite +------------ +Optional real-data integration uses: +``${PWD}/MMAI25Hackathon/molecule-protein-interaction/dataset.csv``. +If unavailable, integration tests are skipped; unit tests still validate core behavior and conversion via monkeypatching. +""" + +from pathlib import Path + +import pandas as pd +import pytest + +# Ensure torch_geometric is available; otherwise, skip this module's tests +pytest.importorskip("torch_geometric") +from torch_geometric.data import Data # noqa: E402 + +from mmai25_hackathon.load_data.molecule import fetch_smiles_from_dataframe, smiles_to_graph # noqa: E402 + +# Optional real dataset path for integration-style checks +MOLECULE_DATASET_CSV = Path.cwd() / "MMAI25Hackathon" / "molecule-protein-interaction" / "dataset.csv" + + +@pytest.fixture(scope="module") +def molecule_csv() -> Path: + if not MOLECULE_DATASET_CSV.exists(): + pytest.skip(f"Molecule dataset CSV not found: {MOLECULE_DATASET_CSV}") + return MOLECULE_DATASET_CSV + + +def test_fetch_smiles_from_dataframe_and_csv(tmp_path: Path) -> None: + df = pd.DataFrame({"id": [1, 2], "SMILES": ["CCO", "C1=CC=CC=C1"], "extra": [0, 1]}) + csv = tmp_path / "molecules.csv" + csv.write_text(df.to_csv(index=False)) + + # From DataFrame with index + out_df = fetch_smiles_from_dataframe(df, smiles_col="SMILES", index_col="id") + assert list(out_df.columns) == ["id", "smiles"], f"Expected columns ['id', 'smiles'], got {list(out_df.columns)}" + assert out_df.shape[0] == 2, f"Unexpected number of rows: {out_df.shape}" + + # From CSV + out_csv = fetch_smiles_from_dataframe(str(csv), smiles_col="SMILES") + assert list(out_csv.columns) == ["smiles"], f"Expected columns ['smiles'], got {list(out_csv.columns)}" + + # Missing column error + with pytest.raises(ValueError): + fetch_smiles_from_dataframe(df, smiles_col="missing") + + +def test_smiles_to_graph_monkeypatched(monkeypatch: pytest.MonkeyPatch) -> None: + # Monkeypatch torch_geometric.utils.smiles.from_smiles to return a tiny Data object + from mmai25_hackathon.load_data import molecule as mol_mod + + def fake_from_smiles(s: str, with_h: bool, kek: bool) -> Data: # type: ignore[override] + return Data(x=None, edge_index=None, edge_attr=None, smiles=s, with_h=with_h, kek=kek) + + monkeypatch.setattr(mol_mod, "from_smiles", fake_from_smiles, raising=True) + + g = smiles_to_graph("CCO", with_hydrogen=True, kekulize=False) + assert isinstance(g, Data), f"Expected Data, got {type(g)!r}" + assert getattr(g, "smiles", None) == "CCO", f"SMILES not propagated: {getattr(g, 'smiles', None)!r}" + assert ( + getattr(g, "with_h", None) is True and getattr(g, "kek", None) is False + ), f"Flags not forwarded correctly: with_h={getattr(g, 'with_h', None)}, kek={getattr(g, 'kek', None)}" + + +def test_fetch_smiles_from_real_dataset(molecule_csv: Path) -> None: + out = fetch_smiles_from_dataframe(str(molecule_csv), smiles_col="SMILES") + assert not out.empty, f"No SMILES rows loaded from {molecule_csv}" + assert list(out.columns) == ["smiles"], f"Expected single column 'smiles', got {list(out.columns)}" + first = out.iloc[0]["smiles"] + assert isinstance(first, str) and len(first) > 0, f"First SMILES is invalid: {first!r}" + + +def test_fetch_smiles_dataframe_filter_rows() -> None: + df = pd.DataFrame({"id": [1, 2, 3], "SMILES": ["CCO", "CCC", "CCN"]}) + out = fetch_smiles_from_dataframe( + df, + smiles_col="SMILES", + index_col="id", + filter_rows={"id": [2]}, + ) + assert out.shape[0] == 1, f"Expected 1 row after filtering, got shape {out.shape}" + assert 2 in out["id"].values, f"Expected remaining id to be 2, got {out['id'].tolist()}" diff --git a/tests/load_data/test_protein.py b/tests/load_data/test_protein.py new file mode 100644 index 0000000..3a74efa --- /dev/null +++ b/tests/load_data/test_protein.py @@ -0,0 +1,90 @@ +"""Tests for protein sequence utilities. + +This suite validates the public APIs in ``mmai25_hackathon.load_data.protein``: + +- ``fetch_protein_sequences_from_dataframe(df_or_path, ...)``: extracts a protein sequence column from a DataFrame + or CSV path, optionally setting an index, and returns a single-column DataFrame named ``protein_sequence``. +- ``protein_sequence_to_integer_encoding(sequence, ...)``: integer-encodes an amino-acid sequence to fixed length. + +Prerequisite +------------ +Optional real-data integration uses: +``${PWD}/MMAI25Hackathon/molecule-protein-interaction/dataset.csv``. +If unavailable, integration tests are skipped; unit tests still validate core behavior. +""" + +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +from mmai25_hackathon.load_data.protein import ( + fetch_protein_sequences_from_dataframe, + protein_sequence_to_integer_encoding, +) + +# Optional real dataset path for integration-style checks +PROTEIN_DATASET_CSV = Path.cwd() / "MMAI25Hackathon" / "molecule-protein-interaction" / "dataset.csv" + + +@pytest.fixture(scope="module") +def protein_csv() -> Path: + if not PROTEIN_DATASET_CSV.exists(): + pytest.skip(f"Protein dataset CSV not found: {PROTEIN_DATASET_CSV}") + return PROTEIN_DATASET_CSV + + +def test_fetch_protein_sequences_from_dataframe_and_csv(tmp_path: Path) -> None: + df = pd.DataFrame({"id": [1, 2], "Protein": ["MKTAYI", "GAVLIL"], "extra": [0, 1]}) + csv = tmp_path / "proteins.csv" + csv.write_text(df.to_csv(index=False)) + + out_df = fetch_protein_sequences_from_dataframe(df, prot_seq_col="Protein", index_col="id") + assert list(out_df.columns) == [ + "id", + "protein_sequence", + ], f"Expected columns ['id', 'protein_sequence'], got {list(out_df.columns)}" + assert out_df.shape[0] == 2, f"Unexpected number of rows: {out_df.shape}" + + out_csv = fetch_protein_sequences_from_dataframe(str(csv), prot_seq_col="Protein") + assert list(out_csv.columns) == [ + "protein_sequence" + ], f"Expected single column 'protein_sequence', got {list(out_csv.columns)}" + + with pytest.raises(ValueError): + fetch_protein_sequences_from_dataframe(df, prot_seq_col="missing") + + +def test_protein_sequence_to_integer_encoding_properties() -> None: + seq = "MKTAY?" # '?' should map to 0 (unknown) + enc = protein_sequence_to_integer_encoding(seq, max_length=5) + + assert ( + isinstance(enc, np.ndarray) and enc.ndim == 1 + ), f"Expected 1D numpy array, got type={type(enc)!r}, shape={getattr(enc, 'shape', None)}" + assert enc.dtype == np.uint64, f"Expected dtype uint64, got {enc.dtype}" + assert len(enc) == 5, f"Expected length 5 (truncation), got {len(enc)}" + assert (enc == 0).sum() >= 0, "Unknown characters should be encoded as 0" + + +def test_fetch_proteins_from_real_dataset(protein_csv: Path) -> None: + out = fetch_protein_sequences_from_dataframe(str(protein_csv), prot_seq_col="Protein") + assert not out.empty, f"No Protein rows loaded from {protein_csv}" + assert list(out.columns) == [ + "protein_sequence" + ], f"Expected single column 'protein_sequence', got {list(out.columns)}" + first = out.iloc[0]["protein_sequence"] + assert isinstance(first, str) and len(first) > 0, f"First protein sequence is invalid: {first!r}" + + +def test_fetch_protein_sequences_dataframe_filter_rows() -> None: + df = pd.DataFrame({"id": [1, 2, 3], "Protein": ["MKT", "GAV", "TTT"]}) + out = fetch_protein_sequences_from_dataframe( + df, + prot_seq_col="Protein", + index_col="id", + filter_rows={"id": [3]}, + ) + assert out.shape[0] == 1, f"Expected 1 row after filtering, got shape {out.shape}" + assert 3 in out["id"].values, f"Expected remaining id to be 3, got {out['id'].tolist()}" diff --git a/tests/load_data/test_supervised_labels.py b/tests/load_data/test_supervised_labels.py new file mode 100644 index 0000000..2793e81 --- /dev/null +++ b/tests/load_data/test_supervised_labels.py @@ -0,0 +1,97 @@ +"""Tests for supervised labels utilities. + +This suite validates the public APIs in ``mmai25_hackathon.load_data.supervised_labels``: + +- ``fetch_supervised_labels_from_dataframe(df_or_path, ...)``: extracts label column(s) from a DataFrame or CSV path, + optionally setting an index, and returns a DataFrame named ``label`` (single column) or original names (multi). +- ``one_hot_encode_labels(labels, ...)``: one-hot encodes categorical label columns to ``float32`` dtypes. + +Prerequisite +------------ +Optional real-data integration uses: +``${PWD}/MMAI25Hackathon/molecule-protein-interaction/dataset.csv``. +If unavailable, integration tests are skipped; unit tests still validate core behavior. +""" + +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +from mmai25_hackathon.load_data.supervised_labels import ( + fetch_supervised_labels_from_dataframe, + one_hot_encode_labels, +) + +# Optional real dataset path for integration-style checks +LABELS_DATASET_CSV = Path.cwd() / "MMAI25Hackathon" / "molecule-protein-interaction" / "dataset.csv" + + +@pytest.fixture(scope="module") +def labels_csv() -> Path: + if not LABELS_DATASET_CSV.exists(): + pytest.skip(f"Labels dataset CSV not found: {LABELS_DATASET_CSV}") + return LABELS_DATASET_CSV + + +def test_fetch_supervised_labels_single_column(tmp_path: Path) -> None: + df = pd.DataFrame({"id": [1, 2], "Y": [0, 1], "extra": ["a", "b"]}) + csv = tmp_path / "labels.csv" + csv.write_text(df.to_csv(index=False)) + + out_df = fetch_supervised_labels_from_dataframe(df, label_col="Y", index_col="id") + assert list(out_df.columns) == ["id", "label"], f"Expected columns ['id', 'label'], got {list(out_df.columns)}" + assert out_df.shape[0] == 2, f"Unexpected number of rows: {out_df.shape}" + + out_csv = fetch_supervised_labels_from_dataframe(str(csv), label_col="Y") + assert list(out_csv.columns) == ["label"], f"Expected single column 'label', got {list(out_csv.columns)}" + + with pytest.raises(ValueError): + fetch_supervised_labels_from_dataframe(df, label_col="missing") + + +def test_one_hot_encode_labels_single_and_multi_columns() -> None: + df = pd.DataFrame({"label": ["cat", "dog", "cat", "mouse"]}) + oh = one_hot_encode_labels(df) + assert { + "label_cat", + "label_dog", + "label_mouse", + }.issubset(oh.columns), f"Missing expected one-hot columns: {list(oh.columns)}" + assert all(oh.dtypes == np.float32), f"One-hot dtypes must be float32, got {oh.dtypes}" + + df2 = pd.DataFrame({"a": ["x", "y"], "b": ["u", "v"]}) + oh2 = one_hot_encode_labels(df2, columns=["a", "b"]) + assert any(c.startswith("a_") for c in oh2.columns), f"Expected one-hot columns for 'a': {list(oh2.columns)}" + assert any(c.startswith("b_") for c in oh2.columns), f"Expected one-hot columns for 'b': {list(oh2.columns)}" + + +def test_fetch_supervised_labels_dataframe_filter_rows() -> None: + df = pd.DataFrame({"id": [1, 2, 3], "Y": [0, 1, 0]}) + out = fetch_supervised_labels_from_dataframe( + df, + label_col="Y", + index_col="id", + filter_rows={"id": [1, 3]}, + ) + assert out.shape[0] == 2, f"Expected 2 rows after filtering, got shape {out.shape}" + assert set(out["id"].values) == {1, 3}, f"Unexpected remaining ids: {out['id'].tolist()}" + + +def test_fetch_supervised_labels_with_filter_rows(tmp_path: Path) -> None: + # Create a small CSV with an ID column to filter on + df = pd.DataFrame({"id": [1, 2, 3], "Y": [0, 1, 0]}) + csv = tmp_path / "labels.csv" + csv.write_text(df.to_csv(index=False)) + + # Apply a filter to keep only id==2; include index_col so filter can apply + out = fetch_supervised_labels_from_dataframe(str(csv), label_col="Y", index_col="id", filter_rows={"id": [2]}) + assert out.shape[0] == 1, f"Expected 1 row after filtering, got shape {out.shape}" + assert 2 in out["id"].values, f"Expected remaining id to be 2, got {out['id'].tolist()}" + + +def test_fetch_labels_from_real_dataset(labels_csv: Path) -> None: + out = fetch_supervised_labels_from_dataframe(str(labels_csv), label_col="Y") + assert not out.empty, f"No label rows loaded from {labels_csv}" + assert list(out.columns) == ["label"], f"Expected single column 'label', got {list(out.columns)}" diff --git a/tests/load_data/test_tabular.py b/tests/load_data/test_tabular.py new file mode 100644 index 0000000..feb13a7 --- /dev/null +++ b/tests/load_data/test_tabular.py @@ -0,0 +1,148 @@ +"""Tests for tabular utilities: ``read_tabular`` and ``merge_multiple_dataframes``. + +This suite validates the public APIs in ``mmai25_hackathon.load_data.tabular``: + +- ``read_tabular(path, ...)``: loads a CSV, optionally selects/indexes columns, and applies row filtering. +- ``merge_multiple_dataframes(dfs, ...)``: merges frames by overlapping key columns, or concatenates when no keys. + +Prerequisite +------------ +Optional real-data integration uses: +``${PWD}/MMAI25Hackathon/mimic-iv/mimic-iv-3.1``. +If unavailable, the integration test is skipped; unit tests still validate selection, filtering, +merge grouping, suffix behavior, and error handling using synthetic CSVs and DataFrames. +""" + +from pathlib import Path + +import pandas as pd +import pytest + +from mmai25_hackathon.load_data.tabular import merge_multiple_dataframes, read_tabular + + +def test_read_tabular_selects_and_filters(tmp_path: Path) -> None: + df = pd.DataFrame({"id": [1, 2, 3], "a": [10, 20, 30], "b": [0.1, 0.2, 0.3]}) + p = tmp_path / "data.csv" + p.write_text(df.to_csv(index=False)) + + # Select subset/index cols in order and filter rows + out = read_tabular(p, subset_cols=["b", "missing"], index_cols="id", filter_rows={"id": [1, 3]}) + assert list(out.columns) == ["id", "b"], f"Unexpected columns: {list(out.columns)}" + assert out["id"].tolist() == [1, 3], f"Row filter not applied as expected: {out['id'].tolist()}" + + +def test_read_tabular_raises_on_invalid_selection_when_requested(tmp_path: Path) -> None: + df = pd.DataFrame({"id": [1], "x": [5]}) + p = tmp_path / "d.csv" + p.write_text(df.to_csv(index=False)) + + with pytest.raises(ValueError): + read_tabular(p, subset_cols=["missing"], raise_errors=True) + + with pytest.raises(ValueError): + read_tabular(p, index_cols=["missing"], raise_errors=True) + + +def test_merge_multiple_dataframes_concat_when_no_keys() -> None: + df1 = pd.DataFrame({"a": [1, 2]}) + df2 = pd.DataFrame({"b": [10, 20]}) + comps = merge_multiple_dataframes([df1, df2], index_cols=None) + assert ( + len(comps) == 1 and comps[0][0] == () + ), f"Expected single concat component, got {[(k, d.shape) for k, d in comps]}" + assert list(comps[0][1].columns) == ["a", "b"], f"Unexpected columns after concat: {list(comps[0][1].columns)}" + + +def test_merge_multiple_dataframes_by_overlap_keys() -> None: + df1 = pd.DataFrame({"id": [1, 2], "a": [10, 20]}) + df2 = pd.DataFrame({"id": [1, 2], "b": [0.1, 0.2]}) + df3 = pd.DataFrame({"site": ["A", "B"], "c": [5, 6]}) + comps = merge_multiple_dataframes( + [df1, df2, df3], dfs_name=["X", "Y", "Z"], index_cols=["id", "site"], join="inner" + ) + # Should produce two components keyed by 'id' and 'site' + keys_list = [keys for keys, _ in comps] + assert ("id",) in keys_list and ("site",) in keys_list, f"Unexpected key components: {keys_list}" + + +def test_merge_multiple_dataframes_invalid_join() -> None: + with pytest.raises(ValueError): + merge_multiple_dataframes([pd.DataFrame()], index_cols=["id"], join="bad") + + +def test_merge_multiple_dataframes_empty_input() -> None: + assert merge_multiple_dataframes([]) == [], "Expected empty list for empty input" + + +def test_merge_multiple_dataframes_labels_length_mismatch() -> None: + with pytest.raises(ValueError): + merge_multiple_dataframes([pd.DataFrame(), pd.DataFrame()], dfs_name=["a"], index_cols=["id"]) + + +def test_merge_no_subsets_returns_empty() -> None: + # Provide index_cols that none of the DataFrames contain to cover df_by_subset empty branch + df1 = pd.DataFrame({"x": [1]}) + df2 = pd.DataFrame({"y": [2]}) + comps = merge_multiple_dataframes([df1, df2], index_cols=["id"]) # no 'id' in frames + assert comps == [], f"Expected empty components when no frames share the provided keys; got {comps}" + + +def test_merge_greedy_overlap_path() -> None: + # Three groups: ('id',), ('id','site'), and ('site') to exercise greedy overlap selection and merge + df_id = pd.DataFrame({"id": [1, 2], "a": [10, 20]}) + df_id_site = pd.DataFrame({"id": [1, 2], "site": ["A", "B"], "b": [0.1, 0.2]}) + df_site = pd.DataFrame({"site": ["A", "B"], "c": [5, 6]}) + + comps = merge_multiple_dataframes( + [df_id, df_id_site, df_site], + dfs_name=["X", "Y", "Z"], + index_cols=["id", "site"], + join="inner", + ) + # All three can merge into a single component via greedy merging + assert len(comps) == 1, f"Expected a single merged component, got {[(k, d.shape) for k, d in comps]}" + + +# Optional real EHR root for integration-style checks +EHR_ROOT = Path.cwd() / "MMAI25Hackathon" / "mimic-iv" / "mimic-iv-3.1" + + +@pytest.fixture(scope="module") +def ehr_root() -> Path: + if not EHR_ROOT.exists(): + pytest.skip(f"EHR root not found: {EHR_ROOT}") + return EHR_ROOT + + +def test_read_and_merge_real_ehr_if_available(ehr_root: Path) -> None: + hosp_adm = ehr_root / "hosp" / "admissions.csv" + icu_stays = ehr_root / "icu" / "icustays.csv" + if not hosp_adm.exists() or not icu_stays.exists(): + pytest.skip("Required EHR tables (admissions or icustays) not found; skipping integration test") + + # Load minimal subsets with expected keys + adm_df = read_tabular( + hosp_adm, + subset_cols=["admittime"], + index_cols=["subject_id", "hadm_id"], + raise_errors=False, + ) + stays_df = read_tabular( + icu_stays, + subset_cols=["first_careunit"], + index_cols=["subject_id", "hadm_id"], + raise_errors=False, + ) + + comps = merge_multiple_dataframes( + [adm_df, stays_df], + dfs_name=["admissions", "icustays"], + index_cols=["subject_id", "hadm_id"], + join="inner", + ) + key_sets = [keys for keys, _ in comps] + assert ("hadm_id", "subject_id") in key_sets or ( + "subject_id", + "hadm_id", + ) in key_sets, f"Expected a merged component on subject_id/hadm_id, got keys: {key_sets}" diff --git a/tests/load_data/test_text.py b/tests/load_data/test_text.py new file mode 100644 index 0000000..f8b3360 --- /dev/null +++ b/tests/load_data/test_text.py @@ -0,0 +1,153 @@ +"""Tests for clinical notes (text) loading utilities. + +This suite validates the public APIs in ``mmai25_hackathon.load_data.text``: + +- ``load_mimic_iv_notes(note_path, ...)``: loads a selected notes CSV (e.g., radiology), verifies required ID columns, + optionally merges detail CSV, strips/filters empty ``text``, and returns a ``pd.DataFrame``. +- ``extract_text_from_note(note, ...)``: extracts the ``text`` field and optionally returns metadata. + +Prerequisite +------------ +Optional real-data integration uses: +``${PWD}/MMAI25Hackathon/mimic-iv/mimic-iv-note-deidentified-free-text-clinical-notes-2.2/note``. +If unavailable, integration tests are skipped; unit tests still validate core behavior with synthetic CSVs. +""" + +import logging +from pathlib import Path + +import pandas as pd +import pytest + +from mmai25_hackathon.load_data.text import extract_text_from_note, load_mimic_iv_notes + +# Optional real dataset path for integration-style checks +TEXT_ROOT = ( + Path.cwd() / "MMAI25Hackathon" / "mimic-iv" / "mimic-iv-note-deidentified-free-text-clinical-notes-2.2" / "note" +) + + +@pytest.fixture(scope="module") +def real_notes_root() -> Path: + if not TEXT_ROOT.exists(): + pytest.skip(f"Notes root not found: {TEXT_ROOT}") + return TEXT_ROOT + + +def _w(path: Path, df: pd.DataFrame) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(df.to_csv(index=False)) + + +@pytest.mark.parametrize("as_str", [True, False]) +def test_invalid_notes_path_raises(tmp_path: Path, as_str: bool) -> None: + arg = str(tmp_path / "missing") if as_str else (tmp_path / "missing") + with pytest.raises(FileNotFoundError): + load_mimic_iv_notes(arg, subset="radiology") + + +def test_missing_main_csv_raises(tmp_path: Path) -> None: + with pytest.raises(FileNotFoundError): + load_mimic_iv_notes(tmp_path, subset="radiology") + + +def test_required_id_columns_missing_raises(tmp_path: Path) -> None: + # Missing note_id + _w(tmp_path / "radiology.csv", pd.DataFrame({"subject_id": [1], "text": ["t"]})) + with pytest.raises(KeyError): + load_mimic_iv_notes(tmp_path, subset="radiology") + + +def test_missing_detail_when_requested_raises(tmp_path: Path) -> None: + _w( + tmp_path / "radiology.csv", + pd.DataFrame({"note_id": [1], "subject_id": [101], "text": ["hello"]}), + ) + with pytest.raises(FileNotFoundError): + load_mimic_iv_notes(tmp_path, subset="radiology", include_detail=True) + + +def test_load_notes_filters_empty_text_and_merges_detail(tmp_path: Path) -> None: + # Main notes: include one empty/whitespace-only text to exercise filtering + _w( + tmp_path / "radiology.csv", + pd.DataFrame( + { + "note_id": [1, 2], + "subject_id": [101, 102], + "text": [" hello ", " "], + "note_type": ["r", "r"], + } + ), + ) + # Detail includes an extra column to verify merge and absence of suffix when no collision + _w( + tmp_path / "radiology_detail.csv", + pd.DataFrame({"note_id": [1, 2], "subject_id": [101, 102], "detail_field": ["A", "B"]}), + ) + + df = load_mimic_iv_notes(tmp_path, subset="radiology", include_detail=True) + assert not df.empty, "Notes DataFrame is unexpectedly empty" + assert set(["note_id", "subject_id"]).issubset(df.columns), f"Missing ID columns: {list(df.columns)}" + assert "detail_field" in df.columns, f"Merged detail column missing: columns={list(df.columns)}" + + # After filtering empty text, only note_id==1 should remain + assert set(df["note_id"]) == {1}, f"Expected only note_id=1 after filtering, got {set(df['note_id'])}" + assert ( + df.loc[df["note_id"] == 1, "text"].iloc[0] == "hello" + ), f"Text should be stripped; got {df.loc[df['note_id']==1, 'text'].iloc[0]!r}" + + +def test_detail_missing_required_id_columns_raises(tmp_path: Path) -> None: + _w( + tmp_path / "radiology.csv", + pd.DataFrame({"note_id": [1], "subject_id": [101], "text": ["x"]}), + ) + # Detail exists but missing 'subject_id' + _w( + tmp_path / "radiology_detail.csv", + pd.DataFrame({"note_id": [1], "extra": ["A"]}), + ) + with pytest.raises(KeyError): + load_mimic_iv_notes(tmp_path, subset="radiology", include_detail=True) + + +def test_load_notes_without_text_column_warns(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None: + _w( + tmp_path / "radiology.csv", + pd.DataFrame({"note_id": [1], "subject_id": [101], "note_type": ["r"]}), + ) + caplog.set_level(logging.WARNING) + df = load_mimic_iv_notes(tmp_path, subset="radiology", include_detail=False, subset_cols=["note_type"]) + assert "text" not in df.columns, f"'text' unexpectedly present: {list(df.columns)}" + assert any( + "do not include a 'text'" in r.getMessage() for r in caplog.records + ), "Expected warning about missing 'text' column" + + +def test_extract_text_from_note_success_and_metadata() -> None: + note = pd.Series({"note_id": 1, "subject_id": 101, "text": "Patient stable.", "note_type": "Discharge"}) + t = extract_text_from_note(note) + assert t == "Patient stable.", f"Unexpected text extracted: {t!r}" + + t2, meta = extract_text_from_note(note, include_metadata=True) + assert t2 == "Patient stable.", f"Unexpected text extracted: {t2!r}" + assert meta.get("note_id") == 1 and meta.get("subject_id") == 101, f"Unexpected metadata returned: {meta}" + + +def test_extract_text_from_note_missing_text_raises() -> None: + with pytest.raises(KeyError): + extract_text_from_note(pd.Series({"note_id": 1, "subject_id": 101})) + + +@pytest.mark.parametrize("as_str", [True, False]) +def test_integration_load_notes_if_available(real_notes_root: Path, as_str: bool) -> None: + arg = str(real_notes_root) if as_str else real_notes_root + df = load_mimic_iv_notes(arg, subset="radiology", include_detail=False) + # It's acceptable for the subset to be empty if the sample doesn't include radiology, + # but the loader should return a DataFrame with columns present. + assert isinstance(df, pd.DataFrame), "Expected DataFrame from load_mimic_iv_notes" + if not df.empty: + assert {"note_id", "subject_id"}.issubset(df.columns), f"Missing required ID columns: {df.columns}" + if "text" in df.columns: + assert (df["text"].astype(str).str.len() > 0).any(), "All texts are empty after loading" diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..d11fd80 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,98 @@ +"""Tests for base dataset/dataloader/sampler utilities. + +This suite validates the public classes in ``mmai25_hackathon.dataset``: + +- ``BaseDataset``: abstract interface; ``__len__``/``__getitem__`` must be implemented; ``__repr__`` + uses ``extra_repr``; ``prepare_data`` is optional. +- ``BaseDataLoader``: batches simple PyG graphs and yields ``Batch`` objects. +- ``BaseSampler``: remains abstract for iteration (``__iter__`` raises ``NotImplementedError``). + +Prerequisite +------------ +Graph batching test requires ``torch`` and ``torch_geometric``; if unavailable, the test is skipped. +""" + +from __future__ import annotations + +import pytest + +from mmai25_hackathon.dataset import BaseDataLoader, BaseDataset, BaseSampler + + +def test_base_dataset_instantiation_raises() -> None: + with pytest.raises(NotImplementedError): + BaseDataset() + + +def test_incomplete_subclass_abstract_methods_raise() -> None: + class Incomplete(BaseDataset): + def __init__(self) -> None: + # Override to avoid BaseDataset.__init__ raising, but keep abstract methods unimplemented + pass + + ds = Incomplete() + with pytest.raises(NotImplementedError): + _ = len(ds) + with pytest.raises(NotImplementedError): + _ = ds[0] + with pytest.raises(NotImplementedError): + _ = ds + ds + with pytest.raises(NotImplementedError): + ds.prepare_data() + + +def test_complete_subclass_repr_and_len() -> None: + class ToyDataset(BaseDataset): + def __init__(self, items: list[int]) -> None: + self._items = list(items) + + def __len__(self) -> int: + return len(self._items) + + def __getitem__(self, idx: int) -> int: + return self._items[idx] + + ds = ToyDataset([1, 2, 3]) + assert len(ds) == 3, f"Expected length 3, got {len(ds)}" + assert repr(ds) == "ToyDataset(sample_size=3)", f"Unexpected repr: {repr(ds)!r}" + + +def test_base_dataloader_batches_graphs() -> None: + # Skip if torch_geometric/torch unavailable + pytest.importorskip("torch_geometric") + torch = pytest.importorskip("torch") + from torch_geometric.data import Batch, Data + + class GraphDataset: + def __init__(self) -> None: + self._graphs = [ + Data(x=torch.randn(3, 2), edge_index=torch.empty(2, 0, dtype=torch.long)), + Data(x=torch.randn(2, 2), edge_index=torch.empty(2, 0, dtype=torch.long)), + Data(x=torch.randn(4, 2), edge_index=torch.empty(2, 0, dtype=torch.long)), + Data(x=torch.randn(1, 2), edge_index=torch.empty(2, 0, dtype=torch.long)), + Data(x=torch.randn(2, 2), edge_index=torch.empty(2, 0, dtype=torch.long)), + ] + + def __len__(self) -> int: + return len(self._graphs) + + def __getitem__(self, idx: int) -> Data: # type: ignore[name-defined] + return self._graphs[idx] + + ds = GraphDataset() + loader = BaseDataLoader(ds, batch_size=2, shuffle=False) + + total_graphs = 0 + for batch in loader: + assert isinstance(batch, Batch), f"Loader should return Batch, got {type(batch)!r}" + assert getattr(batch, "num_graphs", 0) > 0, "Batch missing num_graphs or is zero" + total_graphs += int(batch.num_graphs) + + assert total_graphs == len(ds), f"Total graphs {total_graphs} != dataset size {len(ds)}" + + +def test_base_sampler_iteration_not_implemented() -> None: + # BaseSampler inherits torch.utils.data.Sampler: __iter__ is abstract and should raise + s = BaseSampler(data_source=range(5)) # type: ignore[call-arg] + with pytest.raises(NotImplementedError): + _ = next(iter(s))