-
Notifications
You must be signed in to change notification settings - Fork 38
Add Sankey diagram visualization functions #989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8c6df3e
d59b90d
dcbdaf9
3a9c5f8
1474da9
f672487
542a9fc
14abe6e
5900d21
2326aa4
c9e6b01
cae40f6
abcb9ba
3698734
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,180 @@ | ||||||
| from __future__ import annotations | ||||||
|
|
||||||
| from typing import TYPE_CHECKING | ||||||
|
|
||||||
| import holoviews as hv | ||||||
| import numpy as np | ||||||
| import pandas as pd | ||||||
| from holoviews import opts | ||||||
|
|
||||||
| if TYPE_CHECKING: | ||||||
| from ehrdata import EHRData | ||||||
|
|
||||||
|
|
||||||
| def plot_sankey( | ||||||
| edata: EHRData, | ||||||
| *, | ||||||
| columns: list[str], | ||||||
| show: bool = False, | ||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| **kwargs, | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please look at the parameters of the survival analysis plots? (I'll make another PR very soon but one of them is already updated). We should have parameters like height, width etc. Consistency is very very important. |
||||||
| ) -> hv.Sankey: | ||||||
| """Create a Sankey diagram showing relationships across observation columns. | ||||||
|
|
||||||
| Please call :func:`holoviews.extension` with ``"matplotlib"`` or ``"bokeh"`` before using this function to select the backend. | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I think we should set a default backend and these functions will error if none is set. |
||||||
|
|
||||||
| Args: | ||||||
| edata : Central data object containing observation data | ||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| columns : Column names from edata.obs to visualize | ||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| show: If True, display the plot immediately. If False, only return the plot object without displaying. | ||||||
| **kwargs: Additional styling options passed to `holoviews.opts.Sankey`. See HoloViews Sankey documentation for full list of options. | ||||||
|
|
||||||
| Returns: | ||||||
| holoviews.Sankey | ||||||
|
|
||||||
| Examples: | ||||||
| >>> import ehrdata as ed | ||||||
| >>> edata = ed.dt.diabetes_130_fairlearn(columns_obs_only=["gender", "race"]) | ||||||
| >>> ep.pl.plot_sankey(edata, columns=["gender", "race"]) | ||||||
|
|
||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| """ | ||||||
| df = edata.obs[columns] | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a series and not a Pandas DataFrame or am I wrong? |
||||||
|
|
||||||
| labels = [] | ||||||
| for col in columns: | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks a bit inefficient. |
||||||
| labels.extend([f"{col}: {val}" for val in df[col].unique()]) | ||||||
| labels = list(dict.fromkeys(labels)) # keep order & unique | ||||||
|
|
||||||
| # Build links between consecutive columns | ||||||
| sources, targets, values = [], [], [] | ||||||
| source_levels, target_levels = [], [] | ||||||
| for i in range(len(columns) - 1): | ||||||
| col_from, col_to = columns[i], columns[i + 1] | ||||||
| flows = df.groupby([col_from, col_to]).size().reset_index(name="count") | ||||||
| for _, row in flows.iterrows(): | ||||||
| source = f"{col_from}: {row[col_from]}" | ||||||
| target = f"{col_to}: {row[col_to]}" | ||||||
| sources.append(source) | ||||||
| targets.append(target) | ||||||
| values.append(row["count"]) | ||||||
| source_levels.append(col_from) | ||||||
| target_levels.append(col_to) | ||||||
|
|
||||||
| sankey_df = pd.DataFrame( | ||||||
| { | ||||||
| "source": sources, | ||||||
| "target": targets, | ||||||
| "value": values, | ||||||
| "source_level": source_levels, | ||||||
| "target_level": target_levels, | ||||||
| } | ||||||
| ) | ||||||
|
|
||||||
| sankey = hv.Sankey(sankey_df, kdims=["source", "target"], vdims=["value"]) | ||||||
| default_opts = {"label_position": "right", "show_values": True, "title": f"Flow across: {columns}"} | ||||||
|
|
||||||
| default_opts.update(kwargs) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this have consequences beyond this plot? It should be locally scoped. |
||||||
|
|
||||||
| sankey = sankey.opts(opts.Sankey(**default_opts)) | ||||||
|
|
||||||
| if show: | ||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| from IPython.display import display | ||||||
|
|
||||||
| display(sankey) | ||||||
|
|
||||||
| return sankey | ||||||
|
|
||||||
|
|
||||||
| def plot_sankey_time( | ||||||
| edata: EHRData, | ||||||
| *, | ||||||
| columns: list[str], | ||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| layer: str, | ||||||
| state_labels: dict[int, str] | None = None, | ||||||
| show: bool = False, | ||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| **kwargs, | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See above. |
||||||
| ) -> hv.Sankey: | ||||||
| """Create a Sankey diagram showing patient state transitions over time. | ||||||
|
|
||||||
| This function visualizes how patients transition between different states | ||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| (e.g., disease severity, treatment status) across consecutive time points. | ||||||
| Each node represents a state at a specific time point, and flows show the | ||||||
| number of patients transitioning between states. | ||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
|
||||||
| Please call :func:`holoviews.extension` with ``"matplotlib"`` or ``"bokeh"`` before using this function to select the backend. | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See above. |
||||||
|
|
||||||
| Args: | ||||||
| edata: Central data object containing observation data | ||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| columns: Column names from edata.var_names to visualize | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
This also reads a bit weird |
||||||
| layer: Name of the layer in `edata.layers` containing the feature data to visualize. | ||||||
| state_labels: Mapping from numeric state values to readable labels. If None, state values | ||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| will be displayed as strings of their numeric codes (e.g., "0", "1", "2"). Default: None | ||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| show: If True, display the plot immediately. If False, only return the plot object without displaying. | ||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| **kwargs: Additional styling options passed to `holoviews.opts.Sankey`. See HoloViews Sankey documentation for full list of options. | ||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
|
||||||
| Returns: | ||||||
| holoviews.Sankey | ||||||
|
|
||||||
| Examples: | ||||||
| >>> import numpy as np | ||||||
| >>> import pandas as pd | ||||||
| >>> import ehrdata as ed | ||||||
| >>> | ||||||
| >>> layer = np.array( | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This example is too complex. Can we make this work with blobs? If not, can we adapt the blobs function so that you can use it here, please? |
||||||
| ... [ | ||||||
| ... [[1, 0, 1], [0, 1, 0]], # patient 1: treatment, disease_flare | ||||||
| ... [[0, 1, 1], [1, 0, 0]], # patient 2: treatment, disease_flare | ||||||
| ... [[1, 1, 0], [0, 0, 1]], # patient 3: treatment, disease_flare | ||||||
| ... ] | ||||||
| ... ) | ||||||
| >>> | ||||||
| >>> edata = ed.EHRData( | ||||||
| ... layers={"layer_1": layer}, | ||||||
| ... obs=pd.DataFrame(index=["patient_1", "patient_2", "patient_3"]), | ||||||
| ... var=pd.DataFrame(index=["treatment", "disease_flare"]), | ||||||
| ... tem=pd.DataFrame(index=["visit_0", "visit_1", "visit_2"]), | ||||||
| ... ) | ||||||
| >>> | ||||||
| >>> plot_sankey_time(edata, columns=["disease_flare"], layer="layer_1", state_labels={0: "no flare", 1: "flare"}) | ||||||
|
|
||||||
|
|
||||||
| """ | ||||||
| flare_data = edata[:, edata.var_names.isin(columns), :].layers[layer][:, 0, :] | ||||||
|
|
||||||
| time_steps = edata.tem.index.tolist() | ||||||
|
|
||||||
| if state_labels is None: | ||||||
| unique_states = np.unique(flare_data) | ||||||
| unique_states = unique_states[~np.isnan(unique_states)] | ||||||
| state_labels = {int(state): str(state) for state in unique_states} | ||||||
|
|
||||||
| state_values = sorted(state_labels.keys()) | ||||||
| state_names = [state_labels[val] for val in state_values] | ||||||
|
|
||||||
| sources, targets, values = [], [], [] | ||||||
| for t in range(len(time_steps) - 1): | ||||||
| for s_from_idx, s_from_val in enumerate(state_values): | ||||||
| for s_to_idx, s_to_val in enumerate(state_values): | ||||||
| count = np.sum((flare_data[:, t] == s_from_val) & (flare_data[:, t + 1] == s_to_val)) | ||||||
| if count > 0: | ||||||
| source_label = f"{state_names[s_from_idx]} ({time_steps[t]})" | ||||||
| target_label = f"{state_names[s_to_idx]} ({time_steps[t + 1]})" | ||||||
| sources.append(source_label) | ||||||
| targets.append(target_label) | ||||||
| values.append(int(count)) | ||||||
|
|
||||||
| sankey_df = pd.DataFrame({"source": sources, "target": targets, "value": values}) | ||||||
|
|
||||||
| sankey = hv.Sankey(sankey_df, kdims=["source", "target"], vdims=["value"]) | ||||||
|
|
||||||
| default_opts = {"label_position": "right", "show_values": True, "title": f"Patient flows: {columns[0]} over time"} | ||||||
|
|
||||||
| default_opts.update(kwargs) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Worried about this - see above. |
||||||
|
|
||||||
| sankey = sankey.opts(opts.Sankey(**default_opts)) | ||||||
|
|
||||||
| if show: | ||||||
sueoglu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| from IPython.display import display | ||||||
|
|
||||||
| display(sankey) | ||||||
|
|
||||||
| return sankey | ||||||
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,157 @@ | ||
| from pathlib import Path | ||
|
|
||
| import ehrdata as ed | ||
| import holoviews as hv | ||
| import numpy as np | ||
| import pandas as pd | ||
| import pytest | ||
|
|
||
| import ehrapy as ep | ||
|
|
||
| CURRENT_DIR = Path(__file__).parent | ||
| _TEST_IMAGE_PATH = f"{CURRENT_DIR}/_images" | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def ehr_3d_mini(): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally with blobs as well - see above. |
||
| layer = np.array( | ||
| [ | ||
| [[0, 1, 2, 1, 2], [1, 2, 1, 2, 0]], | ||
| [[1, 2, 0, 2, 1], [2, 1, 2, 1, 2]], | ||
| [[2, 0, 1, 2, 0], [2, 0, 1, 1, 2]], | ||
| [[1, 2, 1, 0, 1], [0, 2, 1, 2, 0]], | ||
| [[0, 2, 1, 2, 2], [1, 2, 1, 0, 2]], | ||
| ] | ||
| ) | ||
|
|
||
| edata = ed.EHRData( | ||
| layers={"layer_1": layer}, | ||
| obs=pd.DataFrame(index=["patient 1", "patient 2", "patient 3", "patient 4", "patient 5"]), | ||
| var=pd.DataFrame(index=["treatment", "disease_flare"]), | ||
| tem=pd.DataFrame(index=["visit_0", "visit_1", "visit_2", "visit_3", "visit_4"]), | ||
| ) | ||
| return edata | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def diabetes_130_fairlearn_sample_100(): | ||
| edata = ed.dt.diabetes_130_fairlearn( | ||
| columns_obs_only=[ | ||
| "race", | ||
| "gender", | ||
| ] | ||
| )[:100] | ||
|
|
||
| return edata | ||
|
|
||
|
|
||
| def test_sankey_plot(diabetes_130_fairlearn_sample_100, check_same_image): | ||
| hv.extension("matplotlib") | ||
| edata = diabetes_130_fairlearn_sample_100.copy() | ||
|
|
||
| sankey = ep.pl.plot_sankey(edata, columns=["gender", "race"]) | ||
| fig = hv.render(sankey, backend="matplotlib") | ||
|
|
||
| check_same_image( | ||
| fig=fig, | ||
| base_path=f"{_TEST_IMAGE_PATH}/sankey", | ||
| tol=2e-1, | ||
| ) | ||
|
|
||
|
|
||
| def test_sankey_time_plot(ehr_3d_mini, check_same_image): | ||
| hv.extension("matplotlib") | ||
| edata = ehr_3d_mini | ||
| sankey_time = ep.pl.plot_sankey_time( | ||
| edata, | ||
| columns=["disease_flare"], | ||
| layer="layer_1", | ||
| state_labels={0: "no flare", 1: "mid flare", 2: "severe flare"}, | ||
| ) | ||
|
|
||
| fig = hv.render(sankey_time, backend="matplotlib") | ||
|
|
||
| check_same_image( | ||
| fig=fig, | ||
| base_path=f"{_TEST_IMAGE_PATH}/sankey_time", | ||
| tol=2e-1, | ||
| ) | ||
|
|
||
|
|
||
| def test_sankey_bokeh_plot(diabetes_130_fairlearn_sample_100): | ||
| hv.extension("bokeh") | ||
| edata = diabetes_130_fairlearn_sample_100.copy() | ||
|
|
||
| sankey = ep.pl.plot_sankey(edata, columns=["gender", "race"]) | ||
|
|
||
| assert isinstance(sankey, hv.Sankey) | ||
|
|
||
| data = sankey.data | ||
| required_columns = ["source", "target", "value"] | ||
| for col in required_columns: | ||
| assert col in data.columns | ||
|
|
||
| assert len(data) > 0 # at least one flow | ||
| assert (data["value"] > 0).all() # flow values positive | ||
| assert data["value"].sum() == len(edata.obs) # total flow must match total obs | ||
|
|
||
| # each flow matches the expected count | ||
| for _, row in data.iterrows(): | ||
| gender_value = row["source"].split(": ")[1] | ||
| race_value = row["target"].split(": ")[1] | ||
| flow_value = row["value"] | ||
|
|
||
| expected_count = len(edata.obs[(edata.obs["gender"] == gender_value) & (edata.obs["race"] == race_value)]) | ||
|
|
||
| assert flow_value == expected_count | ||
|
|
||
| for source in data["source"].unique(): | ||
| assert source.startswith("gender:") # sources have the correct prefix | ||
| for target in data["target"].unique(): | ||
| assert target.startswith("race:") # targets have the correct prefix | ||
|
|
||
|
|
||
| def test_sankey_time_bokeh_plot(ehr_3d_mini): | ||
| hv.extension("bokeh") | ||
| edata = ehr_3d_mini | ||
| sankey = ep.pl.plot_sankey_time( | ||
| edata, | ||
| columns=["disease_flare"], | ||
| layer="layer_1", | ||
| state_labels={0: "no flare", 1: "mid flare", 2: "severe flare"}, | ||
| ) | ||
| assert isinstance(sankey, hv.Sankey) | ||
|
|
||
| data = sankey.data | ||
| required_columns = ["source", "target", "value"] | ||
| for col in required_columns: | ||
| assert col in data.columns | ||
|
|
||
| assert len(data) > 0 | ||
| assert (data["value"] > 0).all() | ||
|
|
||
| # check that sources and targets contain state labels | ||
| state_labels = ["no flare", "mid flare", "severe flare"] | ||
| for source in data["source"].unique(): | ||
| assert any(label in source for label in state_labels) | ||
| assert "(" in source and ")" in source | ||
| for target in data["target"].unique(): | ||
| assert any(label in target for label in state_labels) | ||
| assert "(" in target and ")" in target | ||
|
|
||
| # check conservation of patients across time points | ||
| time_steps = edata.tem.index.tolist() | ||
| first_time = time_steps[0] | ||
| last_time = time_steps[-1] | ||
|
|
||
| outflow_first = data[data["source"].str.contains(f"\\({first_time}\\)", regex=True)]["value"].sum() | ||
| inflow_last = data[data["target"].str.contains(f"\\({last_time}\\)", regex=True)]["value"].sum() | ||
|
|
||
| assert outflow_first == inflow_last | ||
|
|
||
| # total flow equals number of transitions | ||
| n_patients = edata.n_obs | ||
| n_transitions = len(time_steps) - 1 | ||
| expected_total_flow = n_patients * n_transitions | ||
|
|
||
| assert data["value"].sum() == expected_total_flow | ||
Uh oh!
There was an error while loading. Please reload this page.