Skip to content
Open
1 change: 1 addition & 0 deletions ehrapy/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
missing_values_heatmap,
missing_values_matrix,
)
from ehrapy.plot._sankey import plot_sankey, plot_sankey_time
from ehrapy.plot._scanpy_pl_api import * # noqa: F403
from ehrapy.plot._survival_analysis import cox_ph_forestplot, kaplan_meier, ols
from ehrapy.plot.causal_inference._dowhy import causal_effect
Expand Down
180 changes: 180 additions & 0 deletions ehrapy/plot/_sankey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import holoviews as hv
import numpy as np
import pandas as pd
from holoviews import opts

if TYPE_CHECKING:
from ehrdata import EHRData


def plot_sankey(
edata: EHRData,
*,
columns: list[str],
show: bool = False,
**kwargs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please look at the parameters of the survival analysis plots? (I'll make another PR very soon but one of them is already updated). We should have parameters like height, width etc. Consistency is very very important.

) -> hv.Sankey:
"""Create a Sankey diagram showing relationships across observation columns.

Please call :func:`holoviews.extension` with ``"matplotlib"`` or ``"bokeh"`` before using this function to select the backend.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Please call :func:`holoviews.extension` with ``"matplotlib"`` or ``"bokeh"`` before using this function to select the backend.

I think we should set a default backend and these functions will error if none is set.


Args:
edata : Central data object containing observation data
columns : Column names from edata.obs to visualize
show: If True, display the plot immediately. If False, only return the plot object without displaying.
**kwargs: Additional styling options passed to `holoviews.opts.Sankey`. See HoloViews Sankey documentation for full list of options.

Returns:
holoviews.Sankey

Examples:
>>> import ehrdata as ed
>>> edata = ed.dt.diabetes_130_fairlearn(columns_obs_only=["gender", "race"])
>>> ep.pl.plot_sankey(edata, columns=["gender", "race"])

"""
df = edata.obs[columns]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a series and not a Pandas DataFrame or am I wrong?


labels = []
for col in columns:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit inefficient.

labels.extend([f"{col}: {val}" for val in df[col].unique()])
labels = list(dict.fromkeys(labels)) # keep order & unique

# Build links between consecutive columns
sources, targets, values = [], [], []
source_levels, target_levels = [], []
for i in range(len(columns) - 1):
col_from, col_to = columns[i], columns[i + 1]
flows = df.groupby([col_from, col_to]).size().reset_index(name="count")
for _, row in flows.iterrows():
source = f"{col_from}: {row[col_from]}"
target = f"{col_to}: {row[col_to]}"
sources.append(source)
targets.append(target)
values.append(row["count"])
source_levels.append(col_from)
target_levels.append(col_to)

sankey_df = pd.DataFrame(
{
"source": sources,
"target": targets,
"value": values,
"source_level": source_levels,
"target_level": target_levels,
}
)

sankey = hv.Sankey(sankey_df, kdims=["source", "target"], vdims=["value"])
default_opts = {"label_position": "right", "show_values": True, "title": f"Flow across: {columns}"}

default_opts.update(kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have consequences beyond this plot? It should be locally scoped.


sankey = sankey.opts(opts.Sankey(**default_opts))

if show:
from IPython.display import display

display(sankey)

return sankey


def plot_sankey_time(
edata: EHRData,
*,
columns: list[str],
layer: str,
state_labels: dict[int, str] | None = None,
show: bool = False,
**kwargs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above.

) -> hv.Sankey:
"""Create a Sankey diagram showing patient state transitions over time.

This function visualizes how patients transition between different states
(e.g., disease severity, treatment status) across consecutive time points.
Each node represents a state at a specific time point, and flows show the
number of patients transitioning between states.

Please call :func:`holoviews.extension` with ``"matplotlib"`` or ``"bokeh"`` before using this function to select the backend.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above.


Args:
edata: Central data object containing observation data
columns: Column names from edata.var_names to visualize
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
columns: Column names from edata.var_names to visualize
columns: Column names from `edata.var_names` to visualize

This also reads a bit weird

layer: Name of the layer in `edata.layers` containing the feature data to visualize.
state_labels: Mapping from numeric state values to readable labels. If None, state values
will be displayed as strings of their numeric codes (e.g., "0", "1", "2"). Default: None
show: If True, display the plot immediately. If False, only return the plot object without displaying.
**kwargs: Additional styling options passed to `holoviews.opts.Sankey`. See HoloViews Sankey documentation for full list of options.

Returns:
holoviews.Sankey

Examples:
>>> import numpy as np
>>> import pandas as pd
>>> import ehrdata as ed
>>>
>>> layer = np.array(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example is too complex. Can we make this work with blobs? If not, can we adapt the blobs function so that you can use it here, please?

... [
... [[1, 0, 1], [0, 1, 0]], # patient 1: treatment, disease_flare
... [[0, 1, 1], [1, 0, 0]], # patient 2: treatment, disease_flare
... [[1, 1, 0], [0, 0, 1]], # patient 3: treatment, disease_flare
... ]
... )
>>>
>>> edata = ed.EHRData(
... layers={"layer_1": layer},
... obs=pd.DataFrame(index=["patient_1", "patient_2", "patient_3"]),
... var=pd.DataFrame(index=["treatment", "disease_flare"]),
... tem=pd.DataFrame(index=["visit_0", "visit_1", "visit_2"]),
... )
>>>
>>> plot_sankey_time(edata, columns=["disease_flare"], layer="layer_1", state_labels={0: "no flare", 1: "flare"})


"""
flare_data = edata[:, edata.var_names.isin(columns), :].layers[layer][:, 0, :]

time_steps = edata.tem.index.tolist()

if state_labels is None:
unique_states = np.unique(flare_data)
unique_states = unique_states[~np.isnan(unique_states)]
state_labels = {int(state): str(state) for state in unique_states}

state_values = sorted(state_labels.keys())
state_names = [state_labels[val] for val in state_values]

sources, targets, values = [], [], []
for t in range(len(time_steps) - 1):
for s_from_idx, s_from_val in enumerate(state_values):
for s_to_idx, s_to_val in enumerate(state_values):
count = np.sum((flare_data[:, t] == s_from_val) & (flare_data[:, t + 1] == s_to_val))
if count > 0:
source_label = f"{state_names[s_from_idx]} ({time_steps[t]})"
target_label = f"{state_names[s_to_idx]} ({time_steps[t + 1]})"
sources.append(source_label)
targets.append(target_label)
values.append(int(count))

sankey_df = pd.DataFrame({"source": sources, "target": targets, "value": values})

sankey = hv.Sankey(sankey_df, kdims=["source", "target"], vdims=["value"])

default_opts = {"label_position": "right", "show_values": True, "title": f"Patient flows: {columns[0]} over time"}

default_opts.update(kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worried about this - see above.


sankey = sankey.opts(opts.Sankey(**default_opts))

if show:
from IPython.display import display

display(sankey)

return sankey
775 changes: 775 additions & 0 deletions tests/_scripts/sankey_bokeh_expected.ipynb

Large diffs are not rendered by default.

696 changes: 696 additions & 0 deletions tests/_scripts/sankey_matplotlib_expected.ipynb

Large diffs are not rendered by default.

Binary file added tests/plot/_images/sankey_expected.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/plot/_images/sankey_time_expected.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
157 changes: 157 additions & 0 deletions tests/plot/test_sankey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from pathlib import Path

import ehrdata as ed
import holoviews as hv
import numpy as np
import pandas as pd
import pytest

import ehrapy as ep

CURRENT_DIR = Path(__file__).parent
_TEST_IMAGE_PATH = f"{CURRENT_DIR}/_images"


@pytest.fixture
def ehr_3d_mini():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally with blobs as well - see above.

layer = np.array(
[
[[0, 1, 2, 1, 2], [1, 2, 1, 2, 0]],
[[1, 2, 0, 2, 1], [2, 1, 2, 1, 2]],
[[2, 0, 1, 2, 0], [2, 0, 1, 1, 2]],
[[1, 2, 1, 0, 1], [0, 2, 1, 2, 0]],
[[0, 2, 1, 2, 2], [1, 2, 1, 0, 2]],
]
)

edata = ed.EHRData(
layers={"layer_1": layer},
obs=pd.DataFrame(index=["patient 1", "patient 2", "patient 3", "patient 4", "patient 5"]),
var=pd.DataFrame(index=["treatment", "disease_flare"]),
tem=pd.DataFrame(index=["visit_0", "visit_1", "visit_2", "visit_3", "visit_4"]),
)
return edata


@pytest.fixture
def diabetes_130_fairlearn_sample_100():
edata = ed.dt.diabetes_130_fairlearn(
columns_obs_only=[
"race",
"gender",
]
)[:100]

return edata


def test_sankey_plot(diabetes_130_fairlearn_sample_100, check_same_image):
hv.extension("matplotlib")
edata = diabetes_130_fairlearn_sample_100.copy()

sankey = ep.pl.plot_sankey(edata, columns=["gender", "race"])
fig = hv.render(sankey, backend="matplotlib")

check_same_image(
fig=fig,
base_path=f"{_TEST_IMAGE_PATH}/sankey",
tol=2e-1,
)


def test_sankey_time_plot(ehr_3d_mini, check_same_image):
hv.extension("matplotlib")
edata = ehr_3d_mini
sankey_time = ep.pl.plot_sankey_time(
edata,
columns=["disease_flare"],
layer="layer_1",
state_labels={0: "no flare", 1: "mid flare", 2: "severe flare"},
)

fig = hv.render(sankey_time, backend="matplotlib")

check_same_image(
fig=fig,
base_path=f"{_TEST_IMAGE_PATH}/sankey_time",
tol=2e-1,
)


def test_sankey_bokeh_plot(diabetes_130_fairlearn_sample_100):
hv.extension("bokeh")
edata = diabetes_130_fairlearn_sample_100.copy()

sankey = ep.pl.plot_sankey(edata, columns=["gender", "race"])

assert isinstance(sankey, hv.Sankey)

data = sankey.data
required_columns = ["source", "target", "value"]
for col in required_columns:
assert col in data.columns

assert len(data) > 0 # at least one flow
assert (data["value"] > 0).all() # flow values positive
assert data["value"].sum() == len(edata.obs) # total flow must match total obs

# each flow matches the expected count
for _, row in data.iterrows():
gender_value = row["source"].split(": ")[1]
race_value = row["target"].split(": ")[1]
flow_value = row["value"]

expected_count = len(edata.obs[(edata.obs["gender"] == gender_value) & (edata.obs["race"] == race_value)])

assert flow_value == expected_count

for source in data["source"].unique():
assert source.startswith("gender:") # sources have the correct prefix
for target in data["target"].unique():
assert target.startswith("race:") # targets have the correct prefix


def test_sankey_time_bokeh_plot(ehr_3d_mini):
hv.extension("bokeh")
edata = ehr_3d_mini
sankey = ep.pl.plot_sankey_time(
edata,
columns=["disease_flare"],
layer="layer_1",
state_labels={0: "no flare", 1: "mid flare", 2: "severe flare"},
)
assert isinstance(sankey, hv.Sankey)

data = sankey.data
required_columns = ["source", "target", "value"]
for col in required_columns:
assert col in data.columns

assert len(data) > 0
assert (data["value"] > 0).all()

# check that sources and targets contain state labels
state_labels = ["no flare", "mid flare", "severe flare"]
for source in data["source"].unique():
assert any(label in source for label in state_labels)
assert "(" in source and ")" in source
for target in data["target"].unique():
assert any(label in target for label in state_labels)
assert "(" in target and ")" in target

# check conservation of patients across time points
time_steps = edata.tem.index.tolist()
first_time = time_steps[0]
last_time = time_steps[-1]

outflow_first = data[data["source"].str.contains(f"\\({first_time}\\)", regex=True)]["value"].sum()
inflow_last = data[data["target"].str.contains(f"\\({last_time}\\)", regex=True)]["value"].sum()

assert outflow_first == inflow_last

# total flow equals number of transitions
n_patients = edata.n_obs
n_transitions = len(time_steps) - 1
expected_total_flow = n_patients * n_transitions

assert data["value"].sum() == expected_total_flow
Loading