diff --git a/mesa/examples/basic/boltzmann_wealth_model/app.py b/mesa/examples/basic/boltzmann_wealth_model/app.py index 464708536ca..1a81c9f01b3 100644 --- a/mesa/examples/basic/boltzmann_wealth_model/app.py +++ b/mesa/examples/basic/boltzmann_wealth_model/app.py @@ -14,8 +14,9 @@ def agent_portrayal(agent): return AgentPortrayalStyle( - color=agent.wealth - ) # we are using a colormap to translate wealth to color + color=agent.wealth, + tooltip={"Agent ID": agent.unique_id, "Wealth": agent.wealth}, + ) model_params = { @@ -41,7 +42,7 @@ def post_process(chart): """Post-process the Altair chart to add a colorbar legend.""" chart = chart.encode( color=alt.Color( - "color:N", + "original_color:Q", scale=alt.Scale(scheme="viridis", domain=[0, 10]), legend=alt.Legend( title="Wealth", @@ -63,12 +64,12 @@ def post_process(chart): renderer = SpaceRenderer(model, backend="altair") # Can customize the grid appearance. renderer.draw_structure(grid_color="black", grid_dash=[6, 2], grid_opacity=0.3) -renderer.draw_agents(agent_portrayal=agent_portrayal, cmap="viridis", vmin=0, vmax=10) - +renderer.draw_agents(agent_portrayal=agent_portrayal) # The post_process function is used to modify the Altair chart after it has been created. # It can be used to add legends, colorbars, or other visual elements. renderer.post_process = post_process + # Creates a line plot component from the model's "Gini" datacollector. GiniPlot = make_plot_component("Gini") diff --git a/mesa/visualization/backends/altair_backend.py b/mesa/visualization/backends/altair_backend.py index 288c3ec5dd0..0b314d6fc81 100644 --- a/mesa/visualization/backends/altair_backend.py +++ b/mesa/visualization/backends/altair_backend.py @@ -1,4 +1,9 @@ -# noqa: D100 +"""Altair-based renderer for Mesa spaces. + +This module provides an Altair-based renderer for visualizing Mesa model spaces, +agents, and property layers with interactive charting capabilities. +""" + import warnings from collections.abc import Callable from dataclasses import fields @@ -75,6 +80,7 @@ def collect_agent_data( "stroke": [], # Stroke color "strokeWidth": [], "filled": [], + "tooltip": [], } # Import here to avoid circular import issues @@ -133,6 +139,7 @@ def collect_agent_data( linewidths=dict_data.pop( "linewidths", style_fields.get("linewidths") ), + tooltip=dict_data.pop("tooltip", None), ) if dict_data: ignored_keys = list(dict_data.keys()) @@ -188,6 +195,7 @@ def collect_agent_data( # FIXME: Make filled user-controllable filled_value = True arguments["filled"].append(filled_value) + arguments["tooltip"].append(aps.tooltip) final_data = {} for k, v in arguments.items(): @@ -221,79 +229,83 @@ def draw_agents( if arguments["loc"].size == 0: return None - # To get a continuous scale for color the domain should be between [0, 1] - # that's why changing the the domain of strokeWidth beforehand. - stroke_width = [data / 10 for data in arguments["strokeWidth"]] - - # Agent data preparation - df_data = { - "x": arguments["loc"][:, 0], - "y": arguments["loc"][:, 1], - "size": arguments["size"], - "shape": arguments["shape"], - "opacity": arguments["opacity"], - "strokeWidth": stroke_width, - "original_color": arguments["color"], - "is_filled": arguments["filled"], - "original_stroke": arguments["stroke"], - } - df = pd.DataFrame(df_data) - - # To ensure distinct shapes according to agent portrayal - unique_shape_names_in_data = df["shape"].unique().tolist() - - fill_colors = [] - stroke_colors = [] - for i in range(len(df)): - filled = df["is_filled"][i] - main_color = df["original_color"][i] - stroke_spec = ( - df["original_stroke"][i] - if isinstance(df["original_stroke"][i], str) - else None - ) - if filled: - fill_colors.append(main_color) - stroke_colors.append(stroke_spec) + # Prepare a list of dictionaries, which is a robust way to create a DataFrame + records = [] + for i in range(len(arguments["loc"])): + record = { + "x": arguments["loc"][i][0], + "y": arguments["loc"][i][1], + "size": arguments["size"][i], + "shape": arguments["shape"][i], + "opacity": arguments["opacity"][i], + "strokeWidth": arguments["strokeWidth"][i] + / 10, # Scale for continuous domain + "original_color": arguments["color"][i], + } + # Add tooltip data if available + tooltip = arguments["tooltip"][i] + if tooltip: + record.update(tooltip) + + # Determine fill and stroke colors + if arguments["filled"][i]: + record["viz_fill_color"] = arguments["color"][i] + record["viz_stroke_color"] = ( + arguments["stroke"][i] + if isinstance(arguments["stroke"][i], str) + else None + ) else: - fill_colors.append(None) - stroke_colors.append(main_color) - df["viz_fill_color"] = fill_colors - df["viz_stroke_color"] = stroke_colors - - # Extract additional parameters from kwargs - # FIXME: Add more parameters to kwargs - title = kwargs.pop("title", "") - xlabel = kwargs.pop("xlabel", "") - ylabel = kwargs.pop("ylabel", "") + record["viz_fill_color"] = None + record["viz_stroke_color"] = arguments["color"][i] - # Tooltip list for interactivity - # FIXME: Add more fields to tooltip (preferably from agent_portrayal) - tooltip_list = ["x", "y"] + records.append(record) - # Handle custom colormapping - cmap = kwargs.pop("cmap", "viridis") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) + df = pd.DataFrame(records) - color_is_numeric = np.issubdtype(df["original_color"].dtype, np.number) - if color_is_numeric: - color_min = vmin if vmin is not None else df["original_color"].min() - color_max = vmax if vmax is not None else df["original_color"].max() + # Ensure all columns that should be numeric are, handling potential Nones + numeric_cols = ["x", "y", "size", "opacity", "strokeWidth"] + for col in numeric_cols: + if col in df.columns: + df[col] = pd.to_numeric(df[col], errors="coerce") - fill_encoding = alt.Fill( - "original_color:Q", - scale=alt.Scale(scheme=cmap, domain=[color_min, color_max]), + # Handle color numeric conversion safely + if "original_color" in df.columns: + color_values = arguments["color"] + color_is_numeric = all( + isinstance(x, int | float | np.number) or x is None + for x in color_values ) - else: - fill_encoding = alt.Fill( - "viz_fill_color:N", - scale=None, - title="Color", + if color_is_numeric: + df["original_color"] = pd.to_numeric( + df["original_color"], errors="coerce" + ) + + # Get tooltip keys from the first valid record + tooltip_list = ["x", "y"] + if any(t is not None for t in arguments["tooltip"]): + first_valid_tooltip = next( + (t for t in arguments["tooltip"] if t is not None), None ) + if first_valid_tooltip is not None: + tooltip_list.extend(first_valid_tooltip.keys()) + + # Extract additional parameters from kwargs + title = kwargs.pop("title", "") + xlabel = kwargs.pop("xlabel", "") + ylabel = kwargs.pop("ylabel", "") + # FIXME: Add more parameters to kwargs + + color_is_numeric = pd.api.types.is_numeric_dtype(df["original_color"]) + fill_encoding = ( + alt.Fill("original_color:Q") + if color_is_numeric + else alt.Fill("viz_fill_color:N", scale=None, title="Color") + ) # Determine space dimensions xmin, xmax, ymin, ymax = self.space_drawer.get_viz_limits() + unique_shape_names_in_data = df["shape"].dropna().unique().tolist() chart = ( alt.Chart(df) diff --git a/mesa/visualization/backends/matplotlib_backend.py b/mesa/visualization/backends/matplotlib_backend.py index 1ae407b1683..47d73ddd749 100644 --- a/mesa/visualization/backends/matplotlib_backend.py +++ b/mesa/visualization/backends/matplotlib_backend.py @@ -27,7 +27,6 @@ OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid HexGrid = HexSingleGrid | HexMultiGrid | mesa.discrete_space.HexGrid - CORRECTION_FACTOR_MARKER_ZOOM = 0.01 @@ -145,6 +144,10 @@ def collect_agent_data(self, space, agent_portrayal, default_size=None): ) else: aps = portray_input + if aps.tooltip is not None: + raise ValueError( + "The 'tooltip' attribute in AgentPortrayalStyle is only supported by the Altair backend." + ) # Set defaults if not provided if aps.x is None and aps.y is None: aps.x, aps.y = self._get_agent_pos(agent, space) diff --git a/mesa/visualization/components/portrayal_components.py b/mesa/visualization/components/portrayal_components.py index d15871a12d5..45d50f1d296 100644 --- a/mesa/visualization/components/portrayal_components.py +++ b/mesa/visualization/components/portrayal_components.py @@ -27,6 +27,19 @@ class AgentPortrayalStyle: x, y are determined automatically according to the agent's type (normal/CellAgent) and position in the space if not manually declared. + Attributes: + x (float | None): The x-coordinate of the agent. + y (float | None): The y-coordinate of the agent. + color (ColorLike | None): The color of the agent. + marker (str | None): The marker shape for the agent. + size (int | float | None): The size of the agent marker. + zorder (int | None): The z-order for drawing the agent. + alpha (float | None): The opacity of the agent. + edgecolors (str | tuple | None): The color of the marker's edge. + linewidths (float | int | None): The width of the marker's edge. + tooltip (dict | None): A dictionary of data to display on hover. + Note: This feature is only available with the Altair backend. + Example: >>> def agent_portrayal(agent): >>> return AgentPortrayalStyle( @@ -55,6 +68,7 @@ class AgentPortrayalStyle: alpha: float | None = 1.0 edgecolors: str | tuple | None = None linewidths: float | int | None = 1.0 + tooltip: dict | None = None def update(self, *updates_fields: tuple[str, Any]): """Updates attributes from variable (field_name, new_value) tuple arguments. diff --git a/tests/test_backends.py b/tests/test_backends.py index d4c15d36185..018e051b6e4 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -248,6 +248,7 @@ def test_altair_backend_draw_agents(): "color": np.array(["red", "blue"]), "filled": np.array([True, True]), "stroke": np.array(["black", "black"]), + "tooltip": np.array([None, None]), } ab.space_drawer.get_viz_limits = MagicMock(return_value=(0, 10, 0, 10)) assert ab.draw_agents(arguments) is not None