From 63a88fc05840b9ccc51421fe158ea324b32cc33e Mon Sep 17 00:00:00 2001 From: Felix Andreas Date: Sun, 4 Apr 2021 16:56:02 +0200 Subject: [PATCH] draw positive quads at the top, negative quads below --- apace/plot.py | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/apace/plot.py b/apace/plot.py index a204cd4..b306c15 100644 --- a/apace/plot.py +++ b/apace/plot.py @@ -31,7 +31,7 @@ class Color: YELLOW = "#FBBF24" GREEN = "#10B981" BLUE = "#3B82F6" - ORANGE = "F97316" + ORANGE = "#F97316" PURPLE = "#8B5CF6" CYAN = "#06B6D4" WHITE = "white" @@ -71,11 +71,11 @@ def draw_elements( y_min, y_max = ax.get_ylim() rect_height = 0.05 * (y_max - y_min) if location == "top": - y0 = y_max = y_max + rect_height + y0_base = y_max = y_max + rect_height else: - y0 = y_min - rect_height + y0_base = y_min - rect_height y_min -= 3 * rect_height - plt.hlines(y0, x_min, x_max, color="black", linewidth=1) + plt.hlines(y0_base, x_min, x_max, color="black", linewidth=1) ax.set_ylim(y_min, y_max) sign = -1 @@ -93,28 +93,34 @@ def draw_elements( except KeyError: continue - y0_local = y0 if isinstance(element, Dipole) and element.angle < 0: - y0_local += rect_height / 4 + y0 = y0_base + height = rect_height / 2 + elif isinstance(element, Quadrupole): + y0 = y0_base + (np.sign(element.k1) - 1) * rect_height / 4 + height = rect_height / 2 + else: + y0 = y0_base - rect_height / 2 + height = rect_height ax.add_patch( plt.Rectangle( - (max(start, x_min), y0_local - rect_height / 2), + (max(start, x_min), y0), min(end, x_max) - max(start, x_min), - rect_height, + height, facecolor=color, clip_on=False, zorder=10, ) ) - if labels and type(element) in {Dipole, Quadrupole}: + if labels and isinstance(element, (Dipole, Quadrupole)): sign = -sign ax.annotate( element.name, - xy=((start + end) / 2, y0 + sign * rect_height), + xy=((start + end) / 2, y0_base + sign * rect_height), fontsize=FONT_SIZE, ha="center", - va="center", + va="top" if sign < 0 else "bottom", annotation_clip=False, zorder=11, ) @@ -227,24 +233,18 @@ def _twiss_plot_section( y_max=None, annotate_elements=True, annotate_lattices=True, - line_style="solid", - line_width=1.3, ref_twiss=None, - scales={"eta_x": 10}, overwrite=False, + **kwargs, ): if overwrite: ax.clear() if ref_twiss: plot_twiss( - ax, - ref_twiss, - line_style="dashed", - line_width=2.5, - alpha=0.5, + ax, ref_twiss, line_style="dashed", line_width=2.5, alpha=0.5, **kwargs ) - plot_twiss(ax, twiss, line_style=line_style, line_width=line_width, scales=scales) + plot_twiss(ax, twiss, **kwargs) x_min = max(x_min, 0) x_max = min(x_max, twiss.lattice.length) if y_min is None: @@ -328,6 +328,7 @@ def __init__( y_max=y_max, annotate_elements=False, scales=scales, + twiss_functions=self.twiss_functions, ) if sections: @@ -356,6 +357,7 @@ def __init__( y_max=y_max, annotate_elements=True, scales=scales, + twiss_functions=self.twiss_functions, ) handles, labels = self.fig.axes[0].get_legend_handles_labels()