diff --git a/pynest/nest/raster_plot.py b/pynest/nest/raster_plot.py index dba13882ca..f19461f6de 100644 --- a/pynest/nest/raster_plot.py +++ b/pynest/nest/raster_plot.py @@ -208,7 +208,7 @@ def _from_memory(detec): return ev["times"], ev["senders"] -def _make_plot(ts, ts1, node_ids, neurons, hist=True, hist_binwidth=5.0, grayscale=False, title=None, xlabel=None): +def _make_plot(ts, ts1, node_ids, neurons, hist=True, hist_binwidth=5.0, grayscale=False, title=None, xlabel=None, ax=None): """Generic plotting routine. Constructs a raster plot along with an optional histogram (common part in @@ -234,55 +234,84 @@ def _make_plot(ts, ts1, node_ids, neurons, hist=True, hist_binwidth=5.0, graysca Plot title xlabel : str, optional Label for x-axis + ax : matplotlib.axes.Axes, optional + The axes object to draw the plot on. If None, a new figure + and axes will be created. """ - import matplotlib.pyplot as plt - - plt.figure() + # --- Axis Management --- + # If no axis is provided, create a new figure and axes. + # This block handles the creation of axes for both raster and histogram. + if ax is None: + import matplotlib.pyplot as plt + + fig = plt.figure() + if hist: + # Manually define positions for raster plot and histogram + ax_raster = fig.add_axes([0.1, 0.32, 0.85, 0.6]) + ax_hist = fig.add_axes([0.1, 0.1, 0.85, 0.2], sharex=ax_raster) + # Hide x-tick labels on the raster plot to avoid overlap + plt.setp(ax_raster.get_xticklabels(), visible=False) + else: + ax_raster = fig.add_subplot(111) + ax_hist = None + # If an axis is provided, use it for the raster plot. + # The histogram will not be plotted in this case. + else: + ax_raster = ax + ax_hist = None + if hist: + import warnings + warnings.warn("Histogram is disabled when an external axis is provided. Set hist=False to silence this warning.") + # --- Color settings --- if grayscale: color_marker = ".k" color_bar = "gray" else: color_marker = "." color_bar = "blue" - color_edge = "black" + # --- Label settings --- if xlabel is None: xlabel = "Time (ms)" - ylabel = "Neuron ID" - if hist: - ax1 = plt.axes([0.1, 0.3, 0.85, 0.6]) - plotid = plt.plot(ts1, node_ids, color_marker) - plt.ylabel(ylabel) - plt.xticks([]) - xlim = plt.xlim() + # --- Plotting --- + # Raster plot + plotid = ax_raster.plot(ts1, node_ids, color_marker) + ax_raster.set_ylabel(ylabel) - plt.axes([0.1, 0.1, 0.85, 0.17]) + # Set title on the main raster plot + if title is None: + ax_raster.set_title("Raster plot") + else: + ax_raster.set_title(title) + + # Histogram + if hist and ax_hist is not None: t_bins = numpy.arange(numpy.amin(ts), numpy.amax(ts), float(hist_binwidth)) n, _ = _histogram(ts, bins=t_bins) num_neurons = len(numpy.unique(neurons)) - heights = 1000 * n / (hist_binwidth * num_neurons) - - plt.bar(t_bins, heights, width=hist_binwidth, color=color_bar, edgecolor=color_edge) - plt.yticks([int(x) for x in numpy.linspace(0.0, int(max(heights) * 1.1) + 5, 4)]) - plt.ylabel("Rate (spks/s)") - plt.xlabel(xlabel) - plt.xlim(xlim) - plt.axes(ax1) - else: - plotid = plt.plot(ts1, node_ids, color_marker) - plt.xlabel(xlabel) - plt.ylabel(ylabel) - - if title is None: - plt.title("Raster plot") + + # Avoid division by zero if no neurons are provided + if num_neurons > 0: + heights = 1000 * n / (hist_binwidth * num_neurons) + else: + heights = numpy.zeros_like(n) + + # The number of bins is one less than the number of bin edges + ax_hist.bar(t_bins[:-1], heights, width=hist_binwidth, color=color_bar, edgecolor=color_edge, align='edge') + + if heights.any() and max(heights) > 0: + ax_hist.set_yticks([int(x) for x in numpy.linspace(0.0, int(max(heights) * 1.1) + 5, 4)]) + + ax_hist.set_ylabel("Rate (spks/s)") + ax_hist.set_xlabel(xlabel) + ax_raster.set_xlim(ax_hist.get_xlim()) # Ensure x-limits match else: - plt.title(title) - - plt.draw() + # If no histogram, set the x-label on the raster plot itself + ax_raster.set_xlabel(xlabel) return plotid