Skip to content

Commit

Permalink
Update var_by_distance functionalities (#929)
Browse files Browse the repository at this point in the history
  • Loading branch information
LLehner authored Dec 11, 2024
1 parent 7d3761f commit fb6ca2d
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 138 deletions.
158 changes: 96 additions & 62 deletions src/squidpy/pl/_var_by_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ def var_by_distance(
var: str | list[str],
anchor_key: str | list[str],
design_matrix_key: str = "design_matrix",
color: str | None = None,
stack_vars: bool = False,
covariate: str | None = None,
order: int = 5,
show_scatter: bool = True,
color: str | None = None,
line_palette: str | Sequence[str] | Cycler | None = None,
scatter_palette: str | Sequence[str] | Cycler | None = "viridis",
dpi: int | None = None,
Expand All @@ -51,20 +52,22 @@ def var_by_distance(
Parameters
----------
%(adata)s
design_matrix_key
Name of the design matrix, previously computed with :func:`squidpy.tl.var_by_distance`, to use.
var
Variables to plot on y-axis.
anchor_key
Anchor point column from which distances are taken.
color
Variables to plot on color palette.
design_matrix_key
Name of the design matrix, previously computed with :func:`squidpy.tl.var_by_distance`, to use.
stack_vars
Whether to show multiple variables on the same plot. Only works if 'covariate' is not specified.
covariate
A covariate for which separate regression lines are plotted for each category.
order
Order of the polynomial fit for :func:`seaborn.regplot`.
show_scatter
Whether to show a scatter plot underlying the regression line.
color
Variables in `adata.obs` to plot if 'show_scatter==True'.
line_palette
Categorical color palette used in case a covariate is specified.
scatter_palette
Expand Down Expand Up @@ -95,16 +98,18 @@ def var_by_distance(
scatterplot_kwargs = dict(scatterplot_kwargs)

# if several variables are plotted, make a panel grid
if isinstance(var, list):
if isinstance(var, list) and not stack_vars:
fig, grid = _panel_grid(
hspace=0.25,
wspace=0.75 / rcParams["figure.figsize"][0] + 0.02,
ncols=4,
num_panels=len(var),
)
axs = []
else:
elif isinstance(var, list) and stack_vars:
var = var
elif isinstance(var, str):
var = [var]
axs = []

df = adata.obsm[design_matrix_key] # get design matrix

Expand All @@ -119,88 +124,117 @@ def var_by_distance(
else:
raise ValueError(f"Variable {name} not found in `adata.var` or `adata.obs`.")

# iterate over the variables to plot
for i, v in enumerate(var):
if len(var) > 1:
ax = plt.subplot(grid[i])
axs.append(ax)
else:
# if a single variable and no grid, then one ax object suffices
fig, ax = plt.subplots(1, 1, figsize=figsize)

# if no covariate is specified, 'sns.regplot' will take the values of all observations
if covariate is None:
if covariate is None and stack_vars:
fig, ax = plt.subplots(1, 1, figsize=figsize)
if isinstance(line_palette, str) or line_palette is None:
line_palette = sns.color_palette("bright", len(var))
for i, v in enumerate(var):
sns.regplot(
data=df,
x=anchor_key,
y=v,
label=v,
order=order,
color=line_palette,
scatter=show_scatter,
color=line_palette[i],
scatter=False,
ax=ax,
line_kws=regplot_kwargs,
)
ax.legend(title=None)
ax.set(ylabel="var")
if title is not None:
ax.set(title=title)
if axis_label is None:
ax.set(xlabel=f"distance to {anchor_key}")
else:
# make a categorical color palette if none was specified and there are several regplots to be plotted
if isinstance(line_palette, str) or line_palette is None:
_set_default_colors_for_categorical_obs(adata, covariate)
line_palette = adata.uns[covariate + "_colors"]
covariate_instances = df[covariate].unique()

# iterate over all covariate values and make 'sns.regplot' for each
for i, co in enumerate(covariate_instances):
ax.set(xlabel=axis_label)

else:
# iterate over the variables to plot
for i, v in enumerate(var):
if len(var) > 1 and not stack_vars:
ax = plt.subplot(grid[i])
axs.append(ax)
else:
# if a single variable and no covariate, then one ax object suffices
fig, ax = plt.subplots(1, 1, figsize=figsize)

# if no covariate is specified, 'sns.regplot' will take the values of all observations
if covariate is None and not stack_vars:
sns.regplot(
data=df.loc[df[covariate] == co],
data=df,
x=anchor_key,
y=v,
order=order,
color=line_palette[i],
color=line_palette,
scatter=show_scatter,
ax=ax,
label=co,
line_kws=regplot_kwargs,
)
label_colors, _ = ax.get_legend_handles_labels()
ax.legend(label_colors, covariate_instances)
# add scatter plot if specified
if show_scatter:
if color is None:
plt.scatter(data=df, x=anchor_key, y=v, color="grey", **scatterplot_kwargs)
# if variable to plot on color palette is categorical, make categorical color palette
elif df[color].dtype.name == "category":
unique_colors = df[color].unique()
cNorm = colors.Normalize(vmin=0, vmax=len(unique_colors))
scalarMap = cm.ScalarMappable(norm=cNorm, cmap=scatter_palette)
for i in range(len(unique_colors)):
elif covariate is not None and not stack_vars:
# make a categorical color palette if none was specified and there are several regplots to be plotted
if isinstance(line_palette, str) or line_palette is None:
_set_default_colors_for_categorical_obs(adata, covariate)
line_palette = adata.uns[covariate + "_colors"]
covariate_instances = df[covariate].unique()

# iterate over all covariate values and make 'sns.regplot' for each
for i, co in enumerate(covariate_instances):
sns.regplot(
data=df.loc[df[covariate] == co],
x=anchor_key,
y=v,
order=order,
color=line_palette[i],
scatter=show_scatter,
ax=ax,
label=co,
line_kws=regplot_kwargs,
)
label_colors, _ = ax.get_legend_handles_labels()
ax.legend(label_colors, covariate_instances)
else:
raise ValueError("Can't stack variables and plot covariate at the same time.")

# add scatter plot if specified
if show_scatter:
if color is None:
plt.scatter(data=df, x=anchor_key, y=v, color="grey", **scatterplot_kwargs)
# if variable to plot on color palette is categorical, make categorical color palette
elif df[color].dtype.name == "category":
unique_colors = df[color].unique()
cNorm = colors.Normalize(vmin=0, vmax=len(unique_colors))
scalarMap = cm.ScalarMappable(norm=cNorm, cmap=scatter_palette)
for i in range(len(unique_colors)):
plt.scatter(
data=df.loc[df[color] == unique_colors[i]],
x=anchor_key,
y=v,
color=scalarMap.to_rgba(i),
**scatterplot_kwargs,
)
# if variable to plot on color palette is not categorical
else:
plt.scatter(
data=df.loc[df[color] == unique_colors[i]],
data=df,
x=anchor_key,
y=v,
color=scalarMap.to_rgba(i),
c=color,
cmap=scatter_palette,
**scatterplot_kwargs,
)
# if variable to plot on color palette is not categorical
if title is not None:
ax.set(title=title)
if axis_label is None:
ax.set(xlabel=f"distance to {anchor_key}")
else:
plt.scatter(
data=df,
x=anchor_key,
y=v,
c=color,
cmap=scatter_palette,
**scatterplot_kwargs,
)
if title is not None:
ax.set(title=title)
if axis_label is None:
ax.set(xlabel=f"distance to {anchor_key}")
else:
ax.set(xlabel=axis_label)
ax.set(xlabel=axis_label)

# remove line palette if it was made earlier in function
if f"{covariate}_colors" in adata.uns:
del line_palette

axs = axs if len(var) > 1 else ax
axs = axs if len(var) and not stack_vars > 1 else ax

if save is not None:
save_fig(fig, path=save, transparent=False, dpi=dpi)
Expand Down
Loading

0 comments on commit fb6ca2d

Please sign in to comment.