diff --git a/stim1p/ui/dmd_axis_item.py b/stim1p/ui/dmd_axis_item.py new file mode 100644 index 0000000..da07dda --- /dev/null +++ b/stim1p/ui/dmd_axis_item.py @@ -0,0 +1,51 @@ +"""Axis item helpers for DMD visualisations.""" + +from __future__ import annotations + +from typing import Protocol, Sequence + +import numpy as np +import pyqtgraph as pg + + +class AxisScaleProvider(Protocol): + """Provide conversion factors for axis tick labels.""" + + def axis_unit_scale_for_orientation(self, orientation: str) -> float | None: + """Return micrometre-per-unit scale for the given axis orientation.""" + + +class MicrometreAxisItem(pg.AxisItem): + """Axis that renders tick labels in micrometres when calibration is available.""" + + def __init__(self, orientation: str, scale_provider: AxisScaleProvider): + super().__init__(orientation=orientation) + self._scale_provider = scale_provider + + def tickStrings( + self, values: Sequence[float], scale: float, spacing: float + ) -> list[str]: + if self.logMode: + return super().tickStrings(values, scale, spacing) + + per_unit = self._scale_provider.axis_unit_scale_for_orientation(self.orientation) + if per_unit is None or not np.isfinite(per_unit) or per_unit == 0.0: + return super().tickStrings(values, scale, spacing) + + spacing_um = abs(spacing * per_unit) + effective_spacing = max(spacing_um, 1e-9) + places = max(0, int(np.ceil(-np.log10(effective_spacing)))) + places = min(places, 6) + + strings: list[str] = [] + for value in values: + val_um = float(value) * per_unit + if abs(val_um) < 1e-9: + val_um = 0.0 + if abs(val_um) < 1e-3 or abs(val_um) >= 1e4: + label = f"{val_um:g}" + else: + label = f"{val_um:.{places}f}" + strings.append(label) + + return strings diff --git a/stim1p/ui/dmd_dialogs.py b/stim1p/ui/dmd_dialogs.py new file mode 100644 index 0000000..3309f8a --- /dev/null +++ b/stim1p/ui/dmd_dialogs.py @@ -0,0 +1,224 @@ +"""Dialog helpers for the DMD stimulation widget.""" + +from __future__ import annotations + +from PySide6.QtWidgets import ( + QCheckBox, + QDialog, + QDialogButtonBox, + QDoubleSpinBox, + QFormLayout, + QHBoxLayout, + QLabel, + QPushButton, + QSizePolicy, + QSpinBox, + QVBoxLayout, + QWidget, +) + + +class CalibrationDialog(QDialog): + """Collect user inputs required to build a calibration.""" + + def __init__( + self, + parent: QWidget | None = None, + *, + default_mirrors: tuple[int, int] = (100, 100), + default_pixel_size: float = 1.0, + default_invert_x: bool = False, + default_invert_y: bool = False, + ): + super().__init__(parent) + self.setWindowTitle("Calibrate DMD") + layout = QFormLayout(self) + + self._mirror_size = QSpinBox(self) + self._mirror_size.setRange(1, 8192) + default_avg = 0.5 * (float(default_mirrors[0]) + float(default_mirrors[1])) + default_size = max(1, min(8192, int(round(default_avg)))) + self._mirror_size.setValue(default_size) + layout.addRow("Square size (mirrors)", self._mirror_size) + + self._pixel_size = QDoubleSpinBox(self) + self._pixel_size.setRange(1e-6, 10_000.0) + self._pixel_size.setDecimals(6) + clamped_size = max( + self._pixel_size.minimum(), + min(self._pixel_size.maximum(), float(default_pixel_size)), + ) + self._pixel_size.setValue(clamped_size) + layout.addRow("Camera pixel size (µm)", self._pixel_size) + + self._invert_x = QCheckBox(self) + self._invert_x.setChecked(bool(default_invert_x)) + self._invert_x.setText("Flip DMD X axis (X→X−x)") + layout.addRow(self._invert_x) + + self._invert_y = QCheckBox(self) + self._invert_y.setChecked(bool(default_invert_y)) + self._invert_y.setText("Flip DMD Y axis (Y→Y−y)") + layout.addRow(self._invert_y) + + buttons = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel, + parent=self, + ) + buttons.accepted.connect(self.accept) + buttons.rejected.connect(self.reject) + layout.addRow(buttons) + + def values(self) -> tuple[int, float, bool, bool]: + return ( + self._mirror_size.value(), + self._pixel_size.value(), + self._invert_x.isChecked(), + self._invert_y.isChecked(), + ) + + +class CalibrationPreparationDialog(QDialog): + """Ask the user whether to display a calibration frame before proceeding.""" + + def __init__( + self, + parent: QWidget | None = None, + *, + default_square_size: int = 100, + can_send: bool = False, + max_square_size: int | None = None, + ): + super().__init__(parent) + self.setWindowTitle("Prepare Calibration") + self._chosen_action: str | None = "skip" + + layout = QFormLayout(self) + + message = QLabel(self) + message.setText( + "Send a bright square to the DMD before selecting\n" + "the calibration image?" + ) + message.setWordWrap(True) + message.setSizePolicy( + QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred + ) + layout.addRow(message) + + self._square_size = QSpinBox(self) + self._square_size.setRange(1, 8192) + if max_square_size is not None: + self._square_size.setMaximum(max(1, int(max_square_size))) + self._square_size.setValue(max(1, int(default_square_size))) + layout.addRow("Square size (mirrors)", self._square_size) + + button_box = QDialogButtonBox(parent=self) + send_button = QPushButton("Send to DMD", self) + send_button.setEnabled(can_send) + if not can_send: + send_button.setToolTip("Connect to the DMD to send a calibration frame.") + skip_button = QPushButton("Continue without sending", self) + skip_button.setDefault(True) + cancel_button = button_box.addButton( + QDialogButtonBox.StandardButton.Cancel + ) + button_box.addButton(send_button, QDialogButtonBox.ButtonRole.AcceptRole) + button_box.addButton(skip_button, QDialogButtonBox.ButtonRole.AcceptRole) + layout.addRow(button_box) + + send_button.clicked.connect(self._on_send_clicked) + skip_button.clicked.connect(self._on_skip_clicked) + cancel_button.clicked.connect(self.reject) + + def _on_send_clicked(self) -> None: + sender = self.sender() + if sender is None or not sender.isEnabled(): + return + self._chosen_action = "send" + self.accept() + + def _on_skip_clicked(self) -> None: + self._chosen_action = "skip" + self.accept() + + def chosen_action(self) -> str | None: + return self._chosen_action + + def square_size(self) -> int: + return int(self._square_size.value()) + + +class CyclePatternsDialog(QDialog): + """Collect parameters required to generate cycle/repeat table entries.""" + + def __init__( + self, + parent: QWidget | None = None, + *, + default_first_time: int = 0, + default_cycle_count: int = 1, + default_repeat_count: int = 1, + default_repeat_gap: int = 100, + default_cycle_gap: int = 250, + default_duration: int = 100, + ): + super().__init__(parent) + self.setWindowTitle("Cycle patterns") + main_layout = QVBoxLayout(self) + form = QFormLayout() + form.setFieldGrowthPolicy(QFormLayout.FieldGrowthPolicy.ExpandingFieldsGrow) + + self._cycle_count = QSpinBox(self) + self._cycle_count.setRange(1, 1_000_000) + self._cycle_count.setValue(max(1, int(default_cycle_count))) + form.addRow("Number of cycles", self._cycle_count) + + self._repeat_count = QSpinBox(self) + self._repeat_count.setRange(1, 1_000_000) + self._repeat_count.setValue(max(1, int(default_repeat_count))) + form.addRow("Repetitions per pattern", self._repeat_count) + + self._first_time = QSpinBox(self) + self._first_time.setRange(0, 3_600_000) + self._first_time.setSingleStep(10) + self._first_time.setValue(max(0, int(default_first_time))) + form.addRow("First pattern time (ms)", self._first_time) + + self._repeat_gap = QSpinBox(self) + self._repeat_gap.setRange(0, 3_600_000) + self._repeat_gap.setSingleStep(10) + self._repeat_gap.setValue(max(0, int(default_repeat_gap))) + form.addRow("Separation between repetitions (ms)", self._repeat_gap) + + self._cycle_gap = QSpinBox(self) + self._cycle_gap.setRange(0, 3_600_000) + self._cycle_gap.setSingleStep(10) + self._cycle_gap.setValue(max(0, int(default_cycle_gap))) + form.addRow("Additional gap between cycles (ms)", self._cycle_gap) + + self._duration = QSpinBox(self) + self._duration.setRange(1, 3_600_000) + self._duration.setSingleStep(10) + self._duration.setValue(max(1, int(default_duration))) + form.addRow("Duration for each entry (ms)", self._duration) + + main_layout.addLayout(form) + + buttons = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel, + parent=self, + ) + buttons.accepted.connect(self.accept) + buttons.rejected.connect(self.reject) + main_layout.addWidget(buttons) + + def values(self) -> dict[str, int | str]: + return { + "cycle_count": int(self._cycle_count.value()), + "repeat_count": int(self._repeat_count.value()), + "first_time_ms": int(self._first_time.value()), + "repeat_gap_ms": int(self._repeat_gap.value()), + "cycle_gap_ms": int(self._cycle_gap.value()), + "duration_ms": int(self._duration.value()), + } diff --git a/stim1p/ui/dmd_grid_overlay.py b/stim1p/ui/dmd_grid_overlay.py new file mode 100644 index 0000000..6f83e8e --- /dev/null +++ b/stim1p/ui/dmd_grid_overlay.py @@ -0,0 +1,49 @@ +"""Temporary overlay helpers for DMD previews.""" + +from __future__ import annotations + +from typing import Sequence + +import numpy as np +import pyqtgraph as pg +from PySide6.QtCore import Qt + + +class GridPreviewOverlay: + """Render a temporary preview of rectangles on top of the plot.""" + + def __init__(self, plot_item: pg.PlotItem): + self._plot_item = plot_item + self._items: list[pg.PlotCurveItem] = [] + self._pen = pg.mkPen(color=(0, 200, 255, 200), width=2, style=Qt.PenStyle.DashLine) + + def set_rectangles(self, rectangles: Sequence[np.ndarray]) -> None: + rectangles = [np.asarray(rect, dtype=float) for rect in rectangles] + required = len(rectangles) + while len(self._items) < required: + item = pg.PlotCurveItem(pen=self._pen) + item.setZValue(8_750) + item.hide() + self._plot_item.addItem(item) + self._items.append(item) + + for idx, rect in enumerate(rectangles): + item = self._items[idx] + if rect.ndim != 2 or rect.shape[1] != 2: + item.hide() + continue + closed = np.vstack([rect, rect[0]]) + item.setData(closed[:, 0], closed[:, 1]) + item.show() + + for idx in range(len(rectangles), len(self._items)): + self._items[idx].hide() + + def hide(self) -> None: + for item in self._items: + item.hide() + + def clear(self) -> None: + for item in self._items: + self._plot_item.removeItem(item) + self._items.clear() diff --git a/stim1p/ui/dmd_stim_widget.py b/stim1p/ui/dmd_stim_widget.py index 154a6d5..e2c012b 100644 --- a/stim1p/ui/dmd_stim_widget.py +++ b/stim1p/ui/dmd_stim_widget.py @@ -1,13 +1,9 @@ import os import glob import math -from pathlib import Path import numpy as np from PIL import Image from datetime import timedelta -from dataclasses import dataclass -from typing import Sequence -import h5py from PySide6.QtCore import ( QEvent, @@ -18,40 +14,17 @@ from PySide6.QtGui import QTransform from PySide6.QtWidgets import ( QAbstractItemView, - QDialog, - QDialogButtonBox, - QDoubleSpinBox, QFileDialog, - QFormLayout, - QCheckBox, + QHeaderView, QHBoxLayout, - QLabel, QMessageBox, - QPushButton, - QHeaderView, - QSpinBox, - QSizePolicy, - QTableWidgetItem, QTreeWidgetItem, QWidget, - QFrame, - QVBoxLayout, ) import pyqtgraph as pg -from ..logic.calibration import ( - DMDCalibration, - compute_calibration_from_square, -) -from ..logic.geometry import ( - AxisDefinition, - axis_micrometre_scale, - axis_micrometre_to_axis_pixels, - axis_micrometre_to_global, - axis_pixels_to_axis_micrometre, -) +from ..logic.calibration import DMDCalibration from ..logic.sequence import PatternSequence -from ..logic import saving from ..stim1p import Stim1P from .qt.DMD_stim_ui import Ui_widget_dmd_stim @@ -59,293 +32,26 @@ from .calibration_preferences import CalibrationPreferences from .grid_dialog import GridDialog, GridParameters from .capture_tools import ( - AxisCapture, InteractiveRectangleCapture, PolygonDrawingCapture, ) +from .dmd_axis_item import MicrometreAxisItem +from .dmd_grid_overlay import GridPreviewOverlay + +from .dmd_widget import ( + AxisControlsMixin, + AxisRedefinitionCache, + CalibrationWorkflowMixin, + PatternSequenceIOMixin, +) -_HDF5_FILE_FILTER = "HDF5 files (*.h5 *.hdf5);;All files (*)" - - -class _MicrometreAxisItem(pg.AxisItem): - """Axis that renders tick labels in micrometres when calibration is available.""" - - def __init__(self, orientation: str, widget): - super().__init__(orientation=orientation) - self._widget = widget - - def tickStrings(self, values, scale, spacing): - if self.logMode: - return super().tickStrings(values, scale, spacing) - per_unit = self._widget._axis_unit_scale_for_orientation(self.orientation) - if per_unit is None or not np.isfinite(per_unit) or per_unit == 0.0: - return super().tickStrings(values, scale, spacing) - spacing_um = abs(spacing * per_unit) - effective_spacing = max(spacing_um, 1e-9) - places = max(0, int(np.ceil(-np.log10(effective_spacing)))) - places = min(places, 6) - strings: list[str] = [] - for value in values: - val_um = float(value) * per_unit - if abs(val_um) < 1e-9: - val_um = 0.0 - if abs(val_um) < 1e-3 or abs(val_um) >= 1e4: - label = f"{val_um:g}" - else: - label = f"{val_um:.{places}f}" - strings.append(label) - return strings - - -@dataclass -class _AxisRedefinitionCache: - previous_origin: np.ndarray - previous_angle: float - new_origin: np.ndarray - new_angle: float - shapes: dict[QTreeWidgetItem, tuple[np.ndarray, str]] - behaviour: str | None = None - - -class _CalibrationDialog(QDialog): - """Collect user inputs required to build a calibration.""" - - def __init__( - self, - parent: QWidget | None = None, - *, - default_mirrors: tuple[int, int] = (100, 100), - default_pixel_size: float = 1.0, - default_invert_x: bool = False, - default_invert_y: bool = False, - ): - super().__init__(parent) - self.setWindowTitle("Calibrate DMD") - layout = QFormLayout(self) - - self._mirror_size = QSpinBox(self) - self._mirror_size.setRange(1, 8192) - default_avg = 0.5 * (float(default_mirrors[0]) + float(default_mirrors[1])) - default_size = max(1, min(8192, int(round(default_avg)))) - self._mirror_size.setValue(default_size) - layout.addRow("Square size (mirrors)", self._mirror_size) - - self._pixel_size = QDoubleSpinBox(self) - self._pixel_size.setRange(1e-6, 10_000.0) - self._pixel_size.setDecimals(6) - clamped_size = max(self._pixel_size.minimum(), min(self._pixel_size.maximum(), float(default_pixel_size))) - self._pixel_size.setValue(clamped_size) - layout.addRow("Camera pixel size (µm)", self._pixel_size) - - self._invert_x = QCheckBox(self) - self._invert_x.setChecked(bool(default_invert_x)) - self._invert_x.setText("Flip DMD X axis (X→X−x)") - layout.addRow(self._invert_x) - - self._invert_y = QCheckBox(self) - self._invert_y.setChecked(bool(default_invert_y)) - self._invert_y.setText("Flip DMD Y axis (Y→Y−y)") - layout.addRow(self._invert_y) - - buttons = QDialogButtonBox( - QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel, - parent=self, - ) - buttons.accepted.connect(self.accept) - buttons.rejected.connect(self.reject) - layout.addRow(buttons) - - def values(self) -> tuple[int, float, bool, bool]: - return ( - self._mirror_size.value(), - self._pixel_size.value(), - self._invert_x.isChecked(), - self._invert_y.isChecked(), - ) - - -class _CalibrationPreparationDialog(QDialog): - """Ask the user whether to display a calibration frame before proceeding.""" - - def __init__( - self, - parent: QWidget | None = None, - *, - default_square_size: int = 100, - can_send: bool = False, - max_square_size: int | None = None, - ): - super().__init__(parent) - self.setWindowTitle("Prepare Calibration") - self._chosen_action: str | None = "skip" - - layout = QFormLayout(self) - - message = QLabel(self) - message.setText( - "Send a bright square to the DMD before selecting\n" - "the calibration image?" - ) - message.setWordWrap(True) - message.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred) - layout.addRow(message) - - self._square_size = QSpinBox(self) - self._square_size.setRange(1, 8192) - if max_square_size is not None: - self._square_size.setMaximum(max(1, int(max_square_size))) - self._square_size.setValue(max(1, int(default_square_size))) - layout.addRow("Square size (mirrors)", self._square_size) - - button_box = QDialogButtonBox(parent=self) - send_button = QPushButton("Send to DMD", self) - send_button.setEnabled(can_send) - if not can_send: - send_button.setToolTip("Connect to the DMD to send a calibration frame.") - skip_button = QPushButton("Continue without sending", self) - skip_button.setDefault(True) - cancel_button = button_box.addButton( - QDialogButtonBox.StandardButton.Cancel - ) - button_box.addButton(send_button, QDialogButtonBox.ButtonRole.AcceptRole) - button_box.addButton(skip_button, QDialogButtonBox.ButtonRole.AcceptRole) - layout.addRow(button_box) - - send_button.clicked.connect(self._on_send_clicked) - skip_button.clicked.connect(self._on_skip_clicked) - cancel_button.clicked.connect(self.reject) - - def _on_send_clicked(self) -> None: - if not self.sender() or not self.sender().isEnabled(): - return - self._chosen_action = "send" - self.accept() - - def _on_skip_clicked(self) -> None: - self._chosen_action = "skip" - self.accept() - - def chosen_action(self) -> str | None: - return self._chosen_action - - def square_size(self) -> int: - return int(self._square_size.value()) - - -class _CyclePatternsDialog(QDialog): - """Collect parameters required to generate cycle/repeat table entries.""" - - def __init__( - self, - parent: QWidget | None = None, - *, - default_first_time: int = 0, - default_cycle_count: int = 1, - default_repeat_count: int = 1, - default_repeat_gap: int = 100, - default_cycle_gap: int = 250, - default_duration: int = 100, - ): - super().__init__(parent) - self.setWindowTitle("Cycle patterns") - main_layout = QVBoxLayout(self) - form = QFormLayout() - form.setFieldGrowthPolicy(QFormLayout.FieldGrowthPolicy.ExpandingFieldsGrow) - - self._cycle_count = QSpinBox(self) - self._cycle_count.setRange(1, 1_000_000) - self._cycle_count.setValue(max(1, int(default_cycle_count))) - form.addRow("Number of cycles", self._cycle_count) - - self._repeat_count = QSpinBox(self) - self._repeat_count.setRange(1, 1_000_000) - self._repeat_count.setValue(max(1, int(default_repeat_count))) - form.addRow("Repetitions per pattern", self._repeat_count) - - self._first_time = QSpinBox(self) - self._first_time.setRange(0, 3_600_000) - self._first_time.setSingleStep(10) - self._first_time.setValue(max(0, int(default_first_time))) - form.addRow("First pattern time (ms)", self._first_time) - - self._repeat_gap = QSpinBox(self) - self._repeat_gap.setRange(0, 3_600_000) - self._repeat_gap.setSingleStep(10) - self._repeat_gap.setValue(max(0, int(default_repeat_gap))) - form.addRow("Separation between repetitions (ms)", self._repeat_gap) - - self._cycle_gap = QSpinBox(self) - self._cycle_gap.setRange(0, 3_600_000) - self._cycle_gap.setSingleStep(10) - self._cycle_gap.setValue(max(0, int(default_cycle_gap))) - form.addRow("Additional gap between cycles (ms)", self._cycle_gap) - - self._duration = QSpinBox(self) - self._duration.setRange(1, 3_600_000) - self._duration.setSingleStep(10) - self._duration.setValue(max(1, int(default_duration))) - form.addRow("Duration for each entry (ms)", self._duration) - - main_layout.addLayout(form) - - buttons = QDialogButtonBox( - QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel, - parent=self, - ) - buttons.accepted.connect(self.accept) - buttons.rejected.connect(self.reject) - main_layout.addWidget(buttons) - - def values(self) -> dict[str, int | str]: - return { - "cycle_count": int(self._cycle_count.value()), - "repeat_count": int(self._repeat_count.value()), - "first_time_ms": int(self._first_time.value()), - "repeat_gap_ms": int(self._repeat_gap.value()), - "cycle_gap_ms": int(self._cycle_gap.value()), - "duration_ms": int(self._duration.value()), - } - -class _GridPreviewOverlay: - """Render a temporary preview of rectangles on top of the plot.""" - - def __init__(self, plot_item: pg.PlotItem): - self._plot_item = plot_item - self._items: list[pg.PlotCurveItem] = [] - self._pen = pg.mkPen(color=(0, 200, 255, 200), width=2, style=Qt.PenStyle.DashLine) - - def set_rectangles(self, rectangles: Sequence[np.ndarray]) -> None: - rectangles = [np.asarray(rect, dtype=float) for rect in rectangles] - required = len(rectangles) - while len(self._items) < required: - item = pg.PlotCurveItem(pen=self._pen) - item.setZValue(8_750) - item.hide() - self._plot_item.addItem(item) - self._items.append(item) - for idx, rect in enumerate(rectangles): - item = self._items[idx] - if rect.ndim != 2 or rect.shape[1] != 2: - item.hide() - continue - closed = np.vstack([rect, rect[0]]) - item.setData(closed[:, 0], closed[:, 1]) - item.show() - for idx in range(len(rectangles), len(self._items)): - self._items[idx].hide() - - def hide(self) -> None: - for item in self._items: - item.hide() - - def clear(self) -> None: - for item in self._items: - self._plot_item.removeItem(item) - self._items.clear() - - -class StimDMDWidget(QWidget): +class StimDMDWidget( + AxisControlsMixin, + CalibrationWorkflowMixin, + PatternSequenceIOMixin, + QWidget, +): """Coordinate DMD calibration, ROI editing, and pattern sequencing UI.""" _AXIS_MODE_MOVE = "move" @@ -370,12 +76,12 @@ def __init__(self, name="Stimulation DMD Widget", dmd=None, parent=None): self._axis_origin_camera = np.array([0.0, 0.0], dtype=float) self._axis_angle_rad = 0.0 self._axis_defined = False - self._axis_redefine_cache: _AxisRedefinitionCache | None = None + self._axis_redefine_cache: AxisRedefinitionCache | None = None # GraphicsLayoutWidget gives us fine control over plot + histogram layout. self._graphics_widget = pg.GraphicsLayoutWidget(parent=self) axis_items = { - "bottom": _MicrometreAxisItem("bottom", self), - "left": _MicrometreAxisItem("left", self), + "bottom": MicrometreAxisItem("bottom", self), + "left": MicrometreAxisItem("left", self), } self._plot_item = self._graphics_widget.addPlot(axisItems=axis_items) self._view_box = self._plot_item.getViewBox() @@ -424,7 +130,7 @@ def __init__(self, name="Stimulation DMD Widget", dmd=None, parent=None): self._axis_line_item.hide() self._axis_arrow_item.hide() self._axis_origin_item.hide() - self._grid_preview_overlay: _GridPreviewOverlay | None = None + self._grid_preview_overlay: GridPreviewOverlay | None = None self._grid_dialog: GridDialog | None = None # HistogramLUTWidget provides the contrast controls that users expect. @@ -873,264 +579,6 @@ def set_up(self): def _get_view_box(self) -> pg.ViewBox: return self._view_box - def _axis_definition(self) -> AxisDefinition: - origin = tuple(float(v) for v in self._axis_origin_camera.reshape(2)) - return AxisDefinition(origin_camera=origin, angle_rad=float(self._axis_angle_rad)) - - def _rotation_matrix(self, angle: float | None = None) -> np.ndarray: - angle = self._axis_angle_rad if angle is None else float(angle) - cos_a = float(np.cos(angle)) - sin_a = float(np.sin(angle)) - return np.array([[cos_a, -sin_a], [sin_a, cos_a]], dtype=float) - - def _camera_to_axis( - self, - points: np.ndarray, - *, - origin: np.ndarray | None = None, - angle: float | None = None, - ) -> np.ndarray: - """Convert camera pixel coordinates into the user-defined axis frame.""" - arr = np.asarray(points, dtype=float) - was_1d = arr.ndim == 1 - pts = np.atleast_2d(arr) - origin_vec = ( - self._axis_origin_camera - if origin is None - else np.asarray(origin, dtype=float) - ) - R = self._rotation_matrix(angle) - relative = pts - origin_vec - # Rotate into the axis frame and keep the input dimensionality. - result = (R.T @ relative.T).T - return result[0] if was_1d else result - - def _axis_to_camera( - self, - points: np.ndarray, - *, - origin: np.ndarray | None = None, - angle: float | None = None, - ) -> np.ndarray: - """Convert axis-aligned coordinates back to camera pixel indices.""" - arr = np.asarray(points, dtype=float) - was_1d = arr.ndim == 1 - pts = np.atleast_2d(arr) - origin_vec = ( - self._axis_origin_camera - if origin is None - else np.asarray(origin, dtype=float) - ) - R = self._rotation_matrix(angle) - # Rotate and translate back into camera coordinates. - result = (R @ pts.T).T + origin_vec - return result[0] if was_1d else result - - def _axis_origin_micrometre( - self, origin_camera: np.ndarray | None = None - ) -> np.ndarray: - if self._calibration is None: - raise RuntimeError("A calibration is required for micrometre conversion.") - origin_vec = ( - self._axis_origin_camera - if origin_camera is None - else np.asarray(origin_camera, dtype=float) - ) - mic = self._calibration.camera_to_micrometre(origin_vec.reshape(2, 1)).T[0] - return np.asarray(mic, dtype=float) - - def _axis_pixels_to_micrometres(self, points: np.ndarray) -> np.ndarray: - if self._calibration is None: - raise RuntimeError("A calibration is required for micrometre conversion.") - return axis_pixels_to_axis_micrometre( - points, self._axis_definition(), self._calibration - ) - - def _axis_micrometre_scale(self) -> tuple[float, float] | None: - if self._calibration is None: - return None - try: - scales = axis_micrometre_scale( - self._axis_definition(), self._calibration - ) - except Exception: - return None - scale_x = float(scales[0]) - scale_y = float(scales[1]) - if ( - not np.isfinite(scale_x) - or not np.isfinite(scale_y) - or scale_x <= 0.0 - or scale_y <= 0.0 - ): - return None - return scale_x, scale_y - - def _axis_unit_scale_for_orientation(self, orientation: str) -> float | None: - scales = self._axis_micrometre_scale() - if scales is None: - return None - orient = orientation.lower() - if orient in ("bottom", "top"): - return scales[0] - if orient in ("left", "right"): - return scales[1] - return None - - def _reproject_shapes_from_cache(self, cache: _AxisRedefinitionCache) -> None: - prev_origin = np.asarray(cache.previous_origin, dtype=float) - prev_angle = float(cache.previous_angle) - new_origin = np.asarray(cache.new_origin, dtype=float) - new_angle = float(cache.new_angle) - for item, (axis_points, shape_type) in cache.shapes.items(): - axis_pts = np.asarray(axis_points, dtype=float) - camera_pts = self._axis_to_camera(axis_pts, origin=prev_origin, angle=prev_angle) - axis_pts_new = self._camera_to_axis(camera_pts, origin=new_origin, angle=new_angle) - self.roi_manager.update_shape(item, shape_type, axis_pts_new) - - def _restore_shapes_from_cache(self, cache: _AxisRedefinitionCache) -> None: - for item, (axis_points, shape_type) in cache.shapes.items(): - self.roi_manager.update_shape(item, shape_type, axis_points) - - def _setup_axis_behaviour_controls(self) -> None: - combo = self.ui.comboBox_axis_behaviour - self._axis_behaviour_by_index = { - 0: self._AXIS_MODE_MOVE, - 1: self._AXIS_MODE_KEEP, - } - self._axis_behaviour_to_index = { - value: key for key, value in self._axis_behaviour_by_index.items() - } - for index, mode in self._axis_behaviour_by_index.items(): - combo.setItemText(index, self._AXIS_BEHAVIOUR_LABELS[mode]) - tooltip = ( - "Choose what happens to existing patterns when the axis is redefined.\n" - "A banner appears after redefining so you can switch behaviour for that change." - ) - combo.setToolTip(tooltip) - self.ui.label_axis_behaviour.setToolTip(tooltip) - stored_mode = self._preferences.axis_redefinition_mode() - index = self._axis_behaviour_to_index.get(stored_mode, 0) - combo.blockSignals(True) - combo.setCurrentIndex(index) - combo.blockSignals(False) - combo.currentIndexChanged.connect(self._on_axis_behaviour_combo_changed) - - def _setup_axis_feedback_banner(self) -> None: - frame = QFrame(self.ui.verticalLayoutWidget) - frame.setObjectName("axisBehaviourBanner") - frame.setFrameShape(QFrame.Shape.StyledPanel) - frame.setVisible(False) - frame.setAttribute(Qt.WidgetAttribute.WA_Hover, True) - - layout = QHBoxLayout(frame) - layout.setContentsMargins(8, 4, 8, 4) - layout.setSpacing(8) - - label = QLabel(frame) - layout.addWidget(label, 1) - - layout.addStretch(1) - - move_btn = QPushButton(self._AXIS_BEHAVIOUR_LABELS[self._AXIS_MODE_MOVE], frame) - keep_btn = QPushButton(self._AXIS_BEHAVIOUR_LABELS[self._AXIS_MODE_KEEP], frame) - layout.addWidget(move_btn, 0) - layout.addWidget(keep_btn, 0) - - self.ui.verticalLayout_controls.insertWidget(1, frame) - - timer = QTimer(self) - timer.setSingleShot(True) - timer.timeout.connect(self._hide_axis_feedback_banner) - - move_btn.clicked.connect(lambda: self._handle_axis_banner_choice(self._AXIS_MODE_MOVE)) - keep_btn.clicked.connect(lambda: self._handle_axis_banner_choice(self._AXIS_MODE_KEEP)) - frame.installEventFilter(self) - - self._axis_feedback_frame = frame - self._axis_feedback_label = label - self._axis_feedback_move_button = move_btn - self._axis_feedback_keep_button = keep_btn - self._axis_feedback_timer = timer - - def _axis_behaviour_from_index(self, index: int) -> str: - return self._axis_behaviour_by_index.get(index, self._AXIS_MODE_MOVE) - - def _axis_behaviour_label(self, behaviour: str) -> str: - return self._AXIS_BEHAVIOUR_LABELS.get(behaviour, behaviour) - - def _default_axis_behaviour(self) -> str: - return self._axis_behaviour_from_index(self.ui.comboBox_axis_behaviour.currentIndex()) - - def _update_axis_behaviour_combo(self, behaviour: str, *, update_preferences: bool) -> None: - index = self._axis_behaviour_to_index.get(behaviour) - if index is None: - return - combo = self.ui.comboBox_axis_behaviour - if combo.currentIndex() != index: - combo.blockSignals(True) - combo.setCurrentIndex(index) - combo.blockSignals(False) - if update_preferences: - self._preferences.set_axis_redefinition_mode(behaviour) - - def _show_axis_feedback_banner(self, cache: _AxisRedefinitionCache) -> None: - if not cache.shapes: - self._hide_axis_feedback_banner() - return - behaviour = cache.behaviour or self._default_axis_behaviour() - description = self._axis_behaviour_label(behaviour) - self._axis_feedback_label.setText(f'Axis updated; patterns set to "{description}". Change?') - self._refresh_axis_feedback_buttons(behaviour) - self._axis_feedback_frame.setVisible(True) - self._axis_feedback_timer.start(6000) - - def _hide_axis_feedback_banner(self) -> None: - self._axis_feedback_timer.stop() - self._axis_feedback_frame.setVisible(False) - - def _refresh_axis_feedback_buttons(self, behaviour: str) -> None: - move_active = behaviour == self._AXIS_MODE_MOVE - keep_active = behaviour == self._AXIS_MODE_KEEP - self._axis_feedback_move_button.setEnabled(not move_active) - self._axis_feedback_move_button.setDefault(move_active) - self._axis_feedback_keep_button.setEnabled(not keep_active) - self._axis_feedback_keep_button.setDefault(keep_active) - - def _handle_axis_banner_choice(self, behaviour: str) -> None: - cache = self._axis_redefine_cache - if cache is None: - return - if cache.behaviour == behaviour: - self._hide_axis_feedback_banner() - return - self._apply_axis_definition(cache, behaviour, fit_view=False) - self._update_axis_behaviour_combo(behaviour, update_preferences=True) - self._show_axis_feedback_banner(cache) - - def _on_axis_behaviour_combo_changed(self, index: int) -> None: - behaviour = self._axis_behaviour_from_index(index) - self._preferences.set_axis_redefinition_mode(behaviour) - if self._axis_redefine_cache is not None: - self._refresh_axis_feedback_buttons(self._axis_redefine_cache.behaviour or behaviour) - - def _update_axis_labels(self) -> None: - unit = "µm" if self._calibration is not None else "px" - axis_bottom = self._plot_item.getAxis("bottom") - axis_left = self._plot_item.getAxis("left") - axis_bottom.setLabel(f"X ({unit})") - axis_left.setLabel(f"Y ({unit})") - for axis in (axis_bottom, axis_left): - axis.picture = None - axis.update() - - def _micrometres_to_axis_pixels(self, points_um: np.ndarray) -> np.ndarray: - if self._calibration is None: - raise RuntimeError("A calibration is required for micrometre conversion.") - return axis_micrometre_to_axis_pixels( - points_um, self._axis_definition(), self._calibration - ) - def _setup_roi_properties_panel(self) -> None: stack = self.ui.stackedWidget_roi_properties stack.setCurrentWidget(self.ui.page_roi_placeholder) @@ -1344,36 +792,6 @@ def _update_image_transform(self) -> None: ) self._image_item.setTransform(transform) - def _image_axis_bounds(self) -> tuple[float, float, float, float]: - if self._current_image is None: - return (-50.0, 50.0, -50.0, 50.0) - height, width = self._current_image.shape[:2] - corners_camera = np.array( - [[0.0, 0.0], [float(width), 0.0], [float(width), float(height)], [0.0, float(height)]], - dtype=float, - ) - corners_axis = self._camera_to_axis(corners_camera) - min_x = float(np.min(corners_axis[:, 0])) - max_x = float(np.max(corners_axis[:, 0])) - min_y = float(np.min(corners_axis[:, 1])) - max_y = float(np.max(corners_axis[:, 1])) - return min_x, max_x, min_y, max_y - - def _update_axis_visuals(self) -> None: - show = self._axis_defined - for item in (self._axis_line_item, self._axis_arrow_item, self._axis_origin_item): - item.setVisible(show) - if not show: - return - min_x, max_x, min_y, max_y = self._image_axis_bounds() - span = max(max_x - min_x, max_y - min_y, 1.0) - origin_x, origin_y = 0.0, 0.0 - end_x, end_y = span * 0.25, 0.0 - self._axis_line_item.setData([origin_x, end_x], [origin_y, end_y]) - self._axis_arrow_item.setPos(end_x, end_y) - self._axis_arrow_item.setStyle(angle=0.0) - self._axis_origin_item.setData([origin_x], [origin_y]) - def _update_zoom_constraints(self, _width: int, _height: int) -> None: view_box = self._get_view_box() view_box.setLimits( @@ -1405,48 +823,6 @@ def _fit_view_to_image(self, *, use_axis: bool = True) -> None: self._update_zoom_constraints(int(span), int(span)) self._get_view_box().setRange(xRange=x_range, yRange=y_range, padding=0.0) - def _set_axis_state( - self, origin_camera: np.ndarray, angle_rad: float, defined: bool - ) -> None: - self._axis_origin_camera = np.asarray(origin_camera, dtype=float) - self._axis_angle_rad = float(angle_rad) - self._axis_defined = defined - self._update_image_transform() - self._update_axis_visuals() - self._update_listener_controls() - - def _apply_axis_definition( - self, - cache: _AxisRedefinitionCache, - behaviour: str, - *, - fit_view: bool, - ) -> None: - """Apply an axis redefinition using the supplied behaviour.""" - - self._axis_origin_camera = np.asarray(cache.new_origin, dtype=float) - self._axis_angle_rad = float(cache.new_angle) - self._axis_defined = True - self._update_image_transform() - - if cache.shapes: - if ( - behaviour == self._AXIS_MODE_MOVE - and cache.behaviour != self._AXIS_MODE_MOVE - ): - self._reproject_shapes_from_cache(cache) - elif ( - behaviour == self._AXIS_MODE_KEEP - and cache.behaviour not in (None, self._AXIS_MODE_KEEP) - ): - self._restore_shapes_from_cache(cache) - cache.behaviour = behaviour - - self._update_axis_visuals() - if fit_view: - self._fit_view_to_image() - self._update_listener_controls() - def _install_context_menu(self) -> None: try: menu = self._view_box.getMenu() @@ -1490,14 +866,6 @@ def _attempt_auto_connect(self) -> None: except Exception as exc: # noqa: BLE001 print(f"Automatic DMD connection failed: {exc}") - def _ensure_calibration_available(self) -> None: - stored_path = self.last_calibration_file_path() - if stored_path: - success, _ = self._load_calibration_from_path(stored_path) - if success: - return - self._calibrate_dmd() - def _resolve_pattern_parent(self) -> QTreeWidgetItem | None: tree = self.ui.treeWidget selected_items = tree.selectedItems() @@ -1577,9 +945,9 @@ def _draw_polygon_roi(self) -> None: array = np.array([[pt.x(), pt.y()] for pt in points], dtype=float) self._create_roi_item(parent_item, array, "polygon") - def _ensure_grid_preview_overlay(self) -> _GridPreviewOverlay: + def _ensure_grid_preview_overlay(self) -> GridPreviewOverlay: if self._grid_preview_overlay is None: - self._grid_preview_overlay = _GridPreviewOverlay(self._plot_item) + self._grid_preview_overlay = GridPreviewOverlay(self._plot_item) return self._grid_preview_overlay def _open_grid_dialog(self) -> None: @@ -1646,15 +1014,6 @@ def _reset_image_view(self) -> None: view_box.enableAutoRange(pg.ViewBox.XYAxes, enable=False) self._fit_view_to_image(use_axis=self._axis_defined) - def remember_calibration_file(self, path: str) -> None: - """Store the path to the most recently used calibration file.""" - self._last_calibration_file_path = path - self._preferences.set_last_calibration_file_path(path) - - def last_calibration_file_path(self) -> str: - """Return the last calibration file recorded for this session.""" - return self._last_calibration_file_path - def _apply_auto_levels_full(self) -> None: self._apply_histogram_levels(percentile=None) @@ -1814,384 +1173,6 @@ def _load_image(self, path: str | None = "") -> None: self._set_image(image, fit_to_view=True, auto_contrast=True) - def _calibrate_dmd(self): - action = self._prompt_calibration_action() - if action is None: - return - if action == "load": - self._load_calibration_from_dialog() - elif action == "define": - self._define_new_calibration() - - def _prompt_calibration_action(self) -> str | None: - prompt = QMessageBox(self) - prompt.setWindowTitle("Calibrate DMD") - prompt.setIcon(QMessageBox.Icon.Question) - prompt.setText("Choose how to obtain a DMD calibration.") - load_button = prompt.addButton( - "Load calibration file", QMessageBox.ButtonRole.ActionRole - ) - define_button = prompt.addButton( - "Define new calibration", QMessageBox.ButtonRole.ActionRole - ) - prompt.addButton(QMessageBox.StandardButton.Cancel) - if self._calibration is None: - prompt.setDefaultButton(define_button) - else: - prompt.setDefaultButton(load_button) - prompt.exec() - clicked = prompt.clickedButton() - if clicked is None: - return None - standard = prompt.standardButton(clicked) - if standard == QMessageBox.StandardButton.Cancel: - return None - if clicked is load_button: - return "load" - if clicked is define_button: - return "define" - return None - - def _load_calibration_from_dialog(self) -> None: - last_path = self.last_calibration_file_path() - initial = "" - if last_path: - candidate = Path(str(last_path)).expanduser() - if candidate.exists(): - initial = str(candidate) - else: - parent = candidate.parent - if parent.exists(): - initial = str(parent) - file_filter = "Calibration files (*.h5 *.hdf5);;All files (*)" - file_path, _ = QFileDialog.getOpenFileName( - self, - "Select calibration file", - initial, - file_filter, - ) - if not file_path: - return - success, error = self._load_calibration_from_path(file_path) - if not success: - QMessageBox.warning( - self, - "Calibration load error", - f"Unable to load calibration file:\n{error}", - ) - - def _prompt_calibration_preparation(self) -> tuple[str, int] | None: - mirror_counts = self._preferences.mirror_counts() - default_mirror = int( - max(1, round(0.5 * (float(mirror_counts[0]) + float(mirror_counts[1])))) - ) - dmd_shape: tuple[int, int] | None - try: - dmd_shape = self._stim.dmd_shape() - except Exception: # noqa: BLE001 - dmd_shape = None - - dialog = _CalibrationPreparationDialog( - self, - default_square_size=default_mirror, - can_send=self._stim.is_dmd_connected, - max_square_size=min(dmd_shape) if dmd_shape is not None else None, - ) - if dialog.exec() != QDialog.DialogCode.Accepted: - return None - action = dialog.chosen_action() - size = dialog.square_size() - if action is None: - return None - return action, size - - def _send_calibration_frame(self, square_size: int) -> bool: - if not self._stim.is_dmd_connected: - QMessageBox.information( - self, - "DMD disconnected", - "Connect to the DMD before sending a calibration frame.", - ) - return False - try: - self._stim.display_calibration_frame(square_size) - except Exception as exc: # noqa: BLE001 - QMessageBox.warning( - self, - "Calibration frame error", - str(exc), - ) - return False - return True - - def _define_new_calibration(self) -> None: - """Guide the user through loading a calibration image and storing it.""" - - preparation = self._prompt_calibration_preparation() - if preparation is None: - return - action, square_size = preparation - if action == "send": - self._send_calibration_frame(square_size) - - initial_dir = self.ui.lineEdit_image_folder_path.text().strip() - stored_image_path = self._preferences.last_calibration_image_path() - if stored_image_path: - stored_dir = os.path.dirname(stored_image_path) - if stored_dir: - initial_dir = stored_dir - file_filter = ( - "Image files (*.png *.jpg *.jpeg *.tif *.tiff *.gif);;All files (*)" - ) - file_path, _ = QFileDialog.getOpenFileName( - self, - "Select calibration image", - initial_dir if initial_dir else "", - file_filter, - ) - if not file_path: - return - try: - with Image.open(file_path) as pil_image: - calibration_image = np.array(pil_image) - except Exception as exc: - QMessageBox.warning( - self, - "Calibration image error", - f"Unable to load calibration image:\n{exc}", - ) - return - self._preferences.set_last_calibration_image_path(file_path) - selected_dir = os.path.dirname(file_path) - if selected_dir: - self.ui.lineEdit_image_folder_path.setText(selected_dir) - - previous_image = self._current_image - previous_view = self._capture_view_state() - selected_items = self.ui.treeWidget.selectedItems() - selected_item = selected_items[0] if selected_items else None - self.roi_manager.clear_visible_only() - - # Swap the display to the calibration image but remember the previous - # session so we can restore it if the workflow is cancelled mid-way. - self._set_image( - calibration_image, - fit_to_view=True, - apply_axis=False, - auto_contrast=True, - ) - - diagonal_points = self._prompt_calibration_diagonal() - if diagonal_points is None: - QMessageBox.information( - self, - "Calibration cancelled", - "No calibration diagonal was drawn. Calibration has been cancelled.", - ) - self._restore_after_calibration(previous_image, previous_view, selected_item) - return - - invert_defaults = self._preferences.axes_inverted() - default_invert_x = bool(invert_defaults[0]) - default_invert_y = bool(invert_defaults[1]) - default_mirrors = self._preferences.mirror_counts() - if action == "send": - default_mirrors = (int(square_size), int(square_size)) - dialog = _CalibrationDialog( - self, - default_mirrors=default_mirrors, - default_pixel_size=self._preferences.pixel_size(), - default_invert_x=default_invert_x, - default_invert_y=default_invert_y, - ) - if dialog.exec() != QDialog.DialogCode.Accepted: - self._restore_after_calibration(previous_image, previous_view, selected_item) - return - - square_mirrors, pixel_size, invert_x, invert_y = dialog.values() - self._preferences.set_mirror_counts(square_mirrors, square_mirrors) - self._preferences.set_pixel_size(pixel_size) - self._preferences.set_axes_inverted(invert_x, invert_y) - camera_shape = ( - int(calibration_image.shape[1]), - int(calibration_image.shape[0]), - ) - if self.dmd is not None and hasattr(self.dmd, "shape"): - try: - dmd_shape = tuple(int(v) for v in self.dmd.shape) - except Exception: - dmd_shape = (1024, 768) - else: - dmd_shape = (1024, 768) - - try: - calibration = compute_calibration_from_square( - diagonal_points, - square_mirrors, - pixel_size, - camera_shape=camera_shape, - dmd_shape=dmd_shape, - invert_x=invert_x, - invert_y=invert_y, - ) - except ValueError as exc: - QMessageBox.warning(self, "Calibration failed", str(exc)) - self._restore_after_calibration(previous_image, previous_view, selected_item) - return - - self.calibration = calibration - print( - "Updated DMD calibration: pixels/mirror=(%.3f, %.3f), µm/mirror=(%.3f, %.3f), rotation=%.2f°" - % ( - calibration.camera_pixels_per_mirror[0], - calibration.camera_pixels_per_mirror[1], - calibration.micrometers_per_mirror[0], - calibration.micrometers_per_mirror[1], - np.degrees(calibration.camera_rotation_rad), - ) - ) - self._prompt_save_calibration(calibration) - self._restore_after_calibration(previous_image, previous_view, selected_item) - - def _prompt_save_calibration(self, calibration: DMDCalibration) -> None: - response = QMessageBox.question( - self, - "Save calibration", - "Do you want to save this calibration to a file?", - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, - QMessageBox.StandardButton.Yes, - ) - if response != QMessageBox.StandardButton.Yes: - return - last_path = self.last_calibration_file_path() - initial = "" - if last_path: - candidate = Path(str(last_path)).expanduser() - if candidate.exists(): - initial = str(candidate) - else: - parent = candidate.parent - if parent.exists(): - initial = str(parent / candidate.name) - file_filter = "Calibration files (*.h5 *.hdf5);;All files (*)" - file_path, _ = QFileDialog.getSaveFileName( - self, - "Save calibration", - initial, - file_filter, - ) - if not file_path: - return - root, ext = os.path.splitext(file_path) - if not ext: - file_path = f"{file_path}.h5" - self._save_calibration_to_path(calibration, file_path) - - def _save_calibration_to_path( - self, calibration: DMDCalibration, file_path: str - ) -> bool: - path = Path(str(file_path)).expanduser() - try: - path.parent.mkdir(parents=True, exist_ok=True) - except Exception: - pass - try: - saving.save_calibration(str(path), calibration) - except Exception as exc: - QMessageBox.warning( - self, - "Save failed", - f"Unable to save calibration file:\n{exc}", - ) - return False - self.remember_calibration_file(str(path)) - print(f"Saved DMD calibration to {path}") - return True - - def _load_calibration_from_path(self, path_str: str) -> tuple[bool, str | None]: - """Load calibration from disk and activate it.""" - path = Path(str(path_str)).expanduser() - try: - calibration = saving.load_calibration(str(path)) - except Exception as exc: - message = f"{path}: {exc}" - print(f"Failed to load stored calibration from {message}") - return False, message - self.calibration = calibration - self.remember_calibration_file(str(path)) - try: - pixel_size = calibration.camera_pixel_size_um - except AttributeError: - pixel_size = None - if pixel_size is not None: - self._preferences.set_pixel_size(float(pixel_size)) - print(f"Active DMD calibration: {path}") - return True, None - - def _prompt_calibration_diagonal(self) -> np.ndarray | None: - """Capture the diagonal of the illuminated calibration square.""" - - prompt = QMessageBox(self) - prompt.setWindowTitle("Select calibration diagonal") - prompt.setIcon(QMessageBox.Icon.Information) - prompt.setText("Draw the diagonal of the illuminated calibration square.") - prompt.setInformativeText( - "Left-click and drag to draw the line. Right-click or press Esc to cancel." - ) - prompt.setStandardButtons( - QMessageBox.StandardButton.Cancel | QMessageBox.StandardButton.Ok - ) - if prompt.exec() != QMessageBox.StandardButton.Ok: - return None - - capture = AxisCapture(self._get_view_box(), self) - segment = capture.exec() - if segment is None: - return None - start, end = segment - start_xy = np.array([start.x(), start.y()], dtype=float) - end_xy = np.array([end.x(), end.y()], dtype=float) - if not np.all(np.isfinite(start_xy)) or not np.all(np.isfinite(end_xy)): - return None - if np.linalg.norm(end_xy - start_xy) < 1e-9: - return None - return np.vstack((start_xy, end_xy)) - - def _capture_view_state( - self, - ) -> tuple[tuple[float, float], tuple[float, float]] | None: - view = self._get_view_box() - try: - x_range, y_range = view.viewRange() - except Exception: - return None - return (tuple(x_range), tuple(y_range)) - - def _restore_after_calibration( - self, - previous_image: np.ndarray | None, - previous_view_range: tuple[tuple[float, float], tuple[float, float]] | None, - selected_item: QTreeWidgetItem | None, - ) -> None: - if previous_image is not None: - self._set_image(previous_image, auto_contrast=True) - if previous_view_range is not None: - x_range, y_range = previous_view_range - self._get_view_box().setRange( - xRange=x_range, - yRange=y_range, - padding=0.0, - ) - else: - self._image_item.clear() - self._current_levels = None - self._current_image = None - - if selected_item is not None: - self.roi_manager.show_for_item(selected_item) - else: - self.roi_manager.clear_visible_only() - def _change_folder(self): try: filename = self.ui.lineEdit_image_folder_path.text() @@ -2226,387 +1207,3 @@ def _show_grid(self): show = self.ui.pushButton_show_grid.isChecked() self._plot_item.showGrid(show, show) - def _define_axis(self): - button = self.ui.pushButton_define_axis - if not button.isEnabled(): - return - button.setChecked(True) - print( - "Axis tool: click to set origin, drag to direction, release to confirm. Right-click or Esc cancels." - ) - capture = AxisCapture(self._get_view_box(), self) - result = capture.exec() - button.setChecked(False) - if result is None: - return - origin_view, end_view = result - origin_axis = np.array([origin_view.x(), origin_view.y()], dtype=float) - end_axis = np.array([end_view.x(), end_view.y()], dtype=float) - vector_axis = end_axis - origin_axis - if np.linalg.norm(vector_axis) < 1e-6: - return - origin_camera = self._axis_to_camera(origin_axis) - direction_camera = self._rotation_matrix() @ vector_axis - angle_camera = float(np.arctan2(direction_camera[1], direction_camera[0])) - shapes_export = { - item: (np.asarray(points, dtype=float), shape_type) - for item, (points, shape_type) in self.roi_manager.export_shape_points().items() - } - cache = _AxisRedefinitionCache( - previous_origin=self._axis_origin_camera.copy(), - previous_angle=self._axis_angle_rad, - new_origin=np.asarray(origin_camera, dtype=float), - new_angle=angle_camera, - shapes=shapes_export, - ) - self._axis_redefine_cache = cache - behaviour = self._default_axis_behaviour() - self._apply_axis_definition(cache, behaviour, fit_view=True) - if cache.shapes: - self._show_axis_feedback_banner(cache) - else: - self._hide_axis_feedback_banner() - - def _new_model(self): - self.model = PatternSequence( - patterns=[], sequence=[], timings=[], durations=[], descriptions=[] - ) - self.ui.lineEdit_file_path.clear() - print("Loaded empty PatternSequence") - - def _read_table_ms(self): - timings, durations, sequence = [], [], [] - rows = self.ui.tableWidget.rowCount() - for r in range(rows): - t_item = self.ui.tableWidget.item(r, 0) - d_item = self.ui.tableWidget.item(r, 1) - s_item = self.ui.tableWidget.item(r, 2) - try: - if t_item and s_item: - t_text = (t_item.text() or "").strip() - s_text = (s_item.text() or "").strip() - if not t_text or not s_text: - continue - d_text = (d_item.text() if d_item else "") or "" - d_text = d_text.strip() - t = int(t_text) - d = int(d_text) if d_text else 0 - s = int(s_text) - timings.append(t) - durations.append(d) - sequence.append(s) - except Exception: - continue - return timings, durations, sequence - - def _write_table_ms(self, model: PatternSequence): - t_ms = model.timings_milliseconds - d_ms = model.durations_milliseconds - seq = model.sequence - self._updating_table = True - self.table_manager.ensure_desc_column() - self.ui.tableWidget.setRowCount(len(seq)) - for r, (t, d, s) in enumerate(zip(t_ms, d_ms, seq)): - self.ui.tableWidget.setItem(r, 0, QTableWidgetItem(str(int(t)))) - self.ui.tableWidget.setItem(r, 1, QTableWidgetItem(str(int(d)))) - self.ui.tableWidget.setItem(r, 2, QTableWidgetItem(str(int(s)))) - self.table_manager.set_sequence_row_description(r, int(s)) - self._updating_table = False - - def _load_patterns_file(self): - initial = self.ui.lineEdit_file_path.text().strip() - file_path, _ = QFileDialog.getOpenFileName( - self, - "Select pattern sequence", - initial if initial else "", - _HDF5_FILE_FILTER, - ) - if not file_path: - return - if self._calibration is None: - QMessageBox.warning( - self, - "Calibration required", - "Load or compute a DMD calibration before loading patterns.", - ) - return - try: - self.model = saving.load_pattern_sequence(file_path) - except RuntimeError as exc: - QMessageBox.warning(self, "Calibration required", str(exc)) - return - except Exception as exc: # noqa: BLE001 - QMessageBox.critical(self, "Load failed", str(exc)) - return - self.ui.lineEdit_file_path.setText(file_path) - print(f"Loaded PatternSequence from {file_path}") - - def _save_file(self) -> None: - model = self._collect_model() - if model is None: - return - current_path = self.ui.lineEdit_file_path.text().strip() - if not current_path: - target_path = self._prompt_save_path("Save pattern sequence", "") - if not target_path: - return - else: - target_path = self._ensure_h5_extension(current_path) - saved_path = self._write_pattern_sequence(target_path, model) - if saved_path: - self.ui.lineEdit_file_path.setText(saved_path) - - def _save_file_as(self) -> None: - model = self._collect_model() - if model is None: - return - current_path = self.ui.lineEdit_file_path.text().strip() - target_path = self._prompt_save_path("Save pattern sequence as", current_path) - if not target_path: - return - saved_path = self._write_pattern_sequence(target_path, model) - if saved_path: - self.ui.lineEdit_file_path.setText(saved_path) - - def _export_patterns_for_analysis(self) -> None: - model = self._collect_model() - if model is None: - return - current_path = self.ui.lineEdit_file_path.text().strip() - target_path = self._prompt_save_path("Export pattern sequence", current_path) - if not target_path: - return - saved_path = self._write_pattern_sequence(target_path, model, silent=True) - if not saved_path: - return - try: - self._write_analysis_metadata(saved_path, model) - except Exception as exc: # noqa: BLE001 - QMessageBox.warning( - self, - "Export incomplete", - "Pattern sequence saved, but analysis metadata could not be written.\n" - f"Reason: {exc}", - ) - return - print(f"Exported PatternSequence with analysis metadata to {saved_path}") - - def _collect_model(self) -> PatternSequence | None: - try: - return self.model - except RuntimeError as exc: - QMessageBox.warning(self, "Calibration required", str(exc)) - except Exception as exc: # noqa: BLE001 - QMessageBox.critical(self, "Pattern export failed", str(exc)) - return None - - def _prompt_save_path(self, title: str, initial_path: str) -> str: - initial = initial_path.strip() - file_path, _ = QFileDialog.getSaveFileName( - self, - title, - initial if initial else "", - _HDF5_FILE_FILTER, - ) - if not file_path: - return "" - return self._ensure_h5_extension(file_path) - - @staticmethod - def _ensure_h5_extension(path: str) -> str: - trimmed = path.strip() - if not trimmed: - return "" - candidate = Path(trimmed) - if candidate.suffix.lower() in {".h5", ".hdf5"}: - return str(candidate) - return str(candidate.with_suffix(".h5")) - - def _write_pattern_sequence( - self, - file_path: str, - model: PatternSequence, - *, - silent: bool = False, - ) -> str | None: - target = self._ensure_h5_extension(file_path) - if not target: - return None - try: - saving.save_pattern_sequence(target, model) - except Exception as exc: # noqa: BLE001 - QMessageBox.critical(self, "Save failed", str(exc)) - return None - if not silent: - print(f"Saved PatternSequence to {target}") - return target - - def _write_analysis_metadata( - self, - file_path: str, - model: PatternSequence, - ) -> None: - if self._calibration is None: - raise RuntimeError("A calibration must be available to export analysis metadata.") - calibration = self._calibration - with h5py.File(file_path, "a") as handle: - if "analysis" in handle: - del handle["analysis"] - analysis_grp = handle.create_group("analysis") - analysis_grp.attrs["version"] = 1 - analysis_grp.attrs["generator"] = "StimDMDWidget" - - axis_grp = analysis_grp.create_group("axis") - axis_grp.attrs["defined"] = bool(self._axis_defined) - axis_grp.attrs["coordinate_system"] = "camera_pixels" - axis_grp.attrs["camera_shape"] = np.asarray(calibration.camera_shape, dtype=np.int64) - if self._axis_defined: - axis_grp.create_dataset( - "origin_camera", - data=np.asarray(self._axis_origin_camera, dtype=np.float64), - ) - axis_grp.create_dataset( - "angle_rad", - data=np.array(self._axis_angle_rad, dtype=np.float64), - ) - try: - origin_um = self._axis_origin_micrometre() - axis_grp.create_dataset( - "origin_micrometre", - data=np.asarray(origin_um, dtype=np.float64), - ) - except Exception: # noqa: BLE001 - pass - - patterns_grp = analysis_grp.create_group("patterns_camera") - patterns_grp.attrs["coordinate_system"] = "camera_pixels" - - descriptions = model.descriptions or [] - shape_types = model.shape_types or [] - - axis_def = self._axis_definition() - - for pattern_index, pattern in enumerate(model.patterns): - pattern_grp = patterns_grp.create_group(f"pattern_{pattern_index}") - if pattern_index < len(descriptions) and descriptions[pattern_index]: - pattern_grp.attrs["description"] = descriptions[pattern_index] - for poly_index, polygon in enumerate(pattern): - points = np.asarray(polygon, dtype=np.float64) - if points.ndim != 2 or points.shape[1] != 2: - continue - if self._axis_defined: - global_um = axis_micrometre_to_global( - points, - axis_def, - calibration, - ) - else: - global_um = points - camera_points = calibration.micrometre_to_camera(global_um.T).T - dataset = pattern_grp.create_dataset( - f"polygon_{poly_index}", data=camera_points - ) - if ( - pattern_index < len(shape_types) - and poly_index < len(shape_types[pattern_index]) - ): - dataset.attrs["shape_type"] = str( - shape_types[pattern_index][poly_index] - ) - - def _add_row_table(self): - self.ui.tableWidget.insertRow(self.ui.tableWidget.rowCount()) - - def _remove_row_table(self): - rows = sorted( - {i.row() for i in self.ui.tableWidget.selectedIndexes()}, reverse=True - ) - for r in rows: - self.ui.tableWidget.removeRow(r) - - def _cycle_patterns(self) -> None: - pattern_count = self.ui.treeWidget.topLevelItemCount() - if pattern_count <= 0: - QMessageBox.information( - self, - "No patterns", - "Create at least one pattern before cycling them.", - ) - return - - table = self.ui.tableWidget - last_time = 0 - if table.rowCount() > 0: - last_item = table.item(table.rowCount() - 1, 0) - if last_item is not None: - try: - last_time = int(last_item.text()) - except Exception: - last_time = 0 - - default_repeat_gap = 100 - default_cycle_gap = 250 - default_duration = 100 - default_first_time = last_time + default_repeat_gap if table.rowCount() > 0 else 0 - - dialog = _CyclePatternsDialog( - self, - default_first_time=default_first_time, - default_cycle_count=1, - default_repeat_count=1, - default_repeat_gap=default_repeat_gap, - default_cycle_gap=default_cycle_gap, - default_duration=default_duration, - ) - if dialog.exec() != QDialog.DialogCode.Accepted: - return - - params = dialog.values() - cycle_count = max(1, int(params["cycle_count"])) - repeat_count = max(1, int(params["repeat_count"])) - first_time = int(params["first_time_ms"]) - repeat_gap = int(params["repeat_gap_ms"]) - cycle_gap = int(params["cycle_gap_ms"]) - duration = int(params["duration_ms"]) - - if cycle_count <= 0 or repeat_count <= 0: - return - - entries: list[tuple[int, int]] = [] - current_time = first_time - - for cycle_index in range(cycle_count): - for pattern_idx in range(pattern_count): - for repeat_index in range(repeat_count): - entries.append((pattern_idx, current_time)) - is_last_entry = ( - cycle_index == cycle_count - 1 - and pattern_idx == pattern_count - 1 - and repeat_index == repeat_count - 1 - ) - if is_last_entry: - continue - current_time += repeat_gap - end_of_cycle = ( - repeat_index == repeat_count - 1 - and pattern_idx == pattern_count - 1 - ) - if end_of_cycle: - current_time += cycle_gap - - if not entries: - return - - self.table_manager.ensure_desc_column() - signals_were_blocked = self.ui.tableWidget.blockSignals(True) - try: - for pattern_idx, start_time in entries: - row = table.rowCount() - table.insertRow(row) - table.setItem(row, 0, QTableWidgetItem(str(start_time))) - table.setItem(row, 1, QTableWidgetItem(str(duration))) - table.setItem(row, 2, QTableWidgetItem(str(pattern_idx))) - self.table_manager.set_sequence_row_description(row, pattern_idx) - finally: - self.ui.tableWidget.blockSignals(signals_were_blocked) - self.table_manager.refresh_sequence_descriptions() diff --git a/stim1p/ui/dmd_widget/__init__.py b/stim1p/ui/dmd_widget/__init__.py new file mode 100644 index 0000000..99ca44d --- /dev/null +++ b/stim1p/ui/dmd_widget/__init__.py @@ -0,0 +1,12 @@ +"""Support helpers for :mod:`stim1p.ui.dmd_stim_widget`.""" + +from .axis import AxisControlsMixin, AxisRedefinitionCache +from .calibration import CalibrationWorkflowMixin +from .pattern_io import PatternSequenceIOMixin + +__all__ = [ + "AxisControlsMixin", + "AxisRedefinitionCache", + "CalibrationWorkflowMixin", + "PatternSequenceIOMixin", +] diff --git a/stim1p/ui/dmd_widget/axis.py b/stim1p/ui/dmd_widget/axis.py new file mode 100644 index 0000000..4492fb3 --- /dev/null +++ b/stim1p/ui/dmd_widget/axis.py @@ -0,0 +1,503 @@ +"""Axis utilities shared by :class:`stim1p.ui.dmd_stim_widget.StimDMDWidget`. + +This module gathers all axis-related helpers that were previously embedded in the +main widget. The helpers are grouped as a mixin so the widget can opt-in to +axis behaviour without inheriting an unwieldy monolithic implementation. The +functions are intentionally verbose and document the underlying geometry +transformations so contributors can reason about the math without having to +reverse-engineer the matrix operations each time. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import numpy as np +from PySide6.QtCore import Qt, QTimer +from PySide6.QtWidgets import QFrame, QHBoxLayout, QLabel, QPushButton + +from ...logic.calibration import DMDCalibration +from ...logic.geometry import ( + AxisDefinition, + axis_micrometre_scale, + axis_micrometre_to_axis_pixels, + axis_pixels_to_axis_micrometre, +) +from ..capture_tools import AxisCapture + +if TYPE_CHECKING: # pragma: no cover - import for type checking only + from PySide6.QtWidgets import QTreeWidgetItem + + +@dataclass +class AxisRedefinitionCache: + """Cache state from an axis redefinition interaction. + + The cache stores the previous and pending axis definitions alongside the + projected shape coordinates. After the user chooses how to treat existing + ROIs the cache allows us to either restore the original points or reproject + them into the new axis frame without recalculating the capture. + """ + + previous_origin: np.ndarray + previous_angle: float + new_origin: np.ndarray + new_angle: float + shapes: dict["QTreeWidgetItem", tuple[np.ndarray, str]] + behaviour: str | None = None + + +class AxisControlsMixin: + """Mixin implementing axis definition and behaviour helpers. + + The mixin assumes the host widget exposes several attributes used + throughout the original implementation (for example ``self.ui`` or + ``self.roi_manager``). Centralising the logic here makes it possible to + document the coordinate conversions and UI state transitions in one place + instead of scattering them through the widget class. + """ + + _calibration: DMDCalibration | None + _axis_origin_camera: np.ndarray + _axis_angle_rad: float + _axis_defined: bool + _axis_redefine_cache: AxisRedefinitionCache | None + + def _axis_definition(self) -> AxisDefinition: + """Build an :class:`AxisDefinition` describing the current axis.""" + + origin = tuple(float(v) for v in self._axis_origin_camera.reshape(2)) + return AxisDefinition(origin_camera=origin, angle_rad=float(self._axis_angle_rad)) + + def _rotation_matrix(self, angle: float | None = None) -> np.ndarray: + """Return the 2D rotation matrix for ``angle`` radians. + + If no ``angle`` is supplied we reuse the current axis angle. The helper + is intentionally explicit so the accompanying unit tests and future + contributors can audit the trigonometric operations without digging into + NumPy internals. + """ + + angle = self._axis_angle_rad if angle is None else float(angle) + cos_a = float(np.cos(angle)) + sin_a = float(np.sin(angle)) + return np.array([[cos_a, -sin_a], [sin_a, cos_a]], dtype=float) + + def _camera_to_axis( + self, + points: np.ndarray, + *, + origin: np.ndarray | None = None, + angle: float | None = None, + ) -> np.ndarray: + """Convert camera pixel coordinates into the user-defined axis frame. + + ``points`` can be a single 2D coordinate or an array of coordinates. We + normalise the input to a two-dimensional array, subtract the chosen + origin, rotate into axis space and return the result with the original + dimensionality preserved. + """ + + arr = np.asarray(points, dtype=float) + was_1d = arr.ndim == 1 + pts = np.atleast_2d(arr) + origin_vec = ( + self._axis_origin_camera if origin is None else np.asarray(origin, dtype=float) + ) + R = self._rotation_matrix(angle) + relative = pts - origin_vec + result = (R.T @ relative.T).T + return result[0] if was_1d else result + + def _axis_to_camera( + self, + points: np.ndarray, + *, + origin: np.ndarray | None = None, + angle: float | None = None, + ) -> np.ndarray: + """Convert axis-aligned coordinates back to camera pixel indices. + + This performs the inverse transform to :meth:`_camera_to_axis` and is + documented separately so it is obvious that the rotation matrix is used + in the forward direction. Maintaining symmetry between the two helpers + is important for ROI reprojection. + """ + + arr = np.asarray(points, dtype=float) + was_1d = arr.ndim == 1 + pts = np.atleast_2d(arr) + origin_vec = ( + self._axis_origin_camera if origin is None else np.asarray(origin, dtype=float) + ) + R = self._rotation_matrix(angle) + result = (R @ pts.T).T + origin_vec + return result[0] if was_1d else result + + def _axis_origin_micrometre( + self, origin_camera: np.ndarray | None = None + ) -> np.ndarray: + """Return the axis origin in micrometres using the active calibration.""" + + if self._calibration is None: + raise RuntimeError("A calibration is required for micrometre conversion.") + origin_vec = ( + self._axis_origin_camera if origin_camera is None else np.asarray(origin_camera, dtype=float) + ) + mic = self._calibration.camera_to_micrometre(origin_vec.reshape(2, 1)).T[0] + return np.asarray(mic, dtype=float) + + def _axis_pixels_to_micrometres(self, points: np.ndarray) -> np.ndarray: + """Convert axis-space pixels to micrometres, validating calibration.""" + + if self._calibration is None: + raise RuntimeError("A calibration is required for micrometre conversion.") + return axis_pixels_to_axis_micrometre(points, self._axis_definition(), self._calibration) + + def _micrometres_to_axis_pixels(self, points_um: np.ndarray) -> np.ndarray: + """Convert micrometre coordinates into axis pixels using calibration.""" + + if self._calibration is None: + raise RuntimeError("A calibration is required for micrometre conversion.") + return axis_micrometre_to_axis_pixels(points_um, self._axis_definition(), self._calibration) + + def _axis_micrometre_scale(self) -> tuple[float, float] | None: + """Compute micrometre-per-pixel scale factors for both axis directions.""" + + if self._calibration is None: + return None + try: + scales = axis_micrometre_scale(self._axis_definition(), self._calibration) + except Exception: + return None + scale_x = float(scales[0]) + scale_y = float(scales[1]) + if ( + not np.isfinite(scale_x) + or not np.isfinite(scale_y) + or scale_x <= 0.0 + or scale_y <= 0.0 + ): + return None + return scale_x, scale_y + + def axis_unit_scale_for_orientation(self, orientation: str) -> float | None: + """Return the micrometre scale factor corresponding to an axis label.""" + + scales = self._axis_micrometre_scale() + if scales is None: + return None + orient = orientation.lower() + if orient in ("bottom", "top"): + return scales[0] + if orient in ("left", "right"): + return scales[1] + return None + + def _reproject_shapes_from_cache(self, cache: AxisRedefinitionCache) -> None: + """Reproject cached ROI points into the new axis definition.""" + + prev_origin = np.asarray(cache.previous_origin, dtype=float) + prev_angle = float(cache.previous_angle) + new_origin = np.asarray(cache.new_origin, dtype=float) + new_angle = float(cache.new_angle) + for item, (axis_points, shape_type) in cache.shapes.items(): + axis_pts = np.asarray(axis_points, dtype=float) + camera_pts = self._axis_to_camera(axis_pts, origin=prev_origin, angle=prev_angle) + axis_pts_new = self._camera_to_axis(camera_pts, origin=new_origin, angle=new_angle) + self.roi_manager.update_shape(item, shape_type, axis_pts_new) + + def _restore_shapes_from_cache(self, cache: AxisRedefinitionCache) -> None: + """Restore ROI geometry captured before redefining the axis.""" + + for item, (axis_points, shape_type) in cache.shapes.items(): + self.roi_manager.update_shape(item, shape_type, axis_points) + + def _setup_axis_behaviour_controls(self) -> None: + """Initialise the "axis behaviour" combo box and supporting mappings.""" + + combo = self.ui.comboBox_axis_behaviour + self._axis_behaviour_by_index = { + 0: self._AXIS_MODE_MOVE, + 1: self._AXIS_MODE_KEEP, + } + self._axis_behaviour_to_index = { + value: key for key, value in self._axis_behaviour_by_index.items() + } + for index, mode in self._axis_behaviour_by_index.items(): + combo.setItemText(index, self._AXIS_BEHAVIOUR_LABELS[mode]) + tooltip = ( + "Choose what happens to existing patterns when the axis is redefined.\n" + "A banner appears after redefining so you can switch behaviour for that change." + ) + combo.setToolTip(tooltip) + self.ui.label_axis_behaviour.setToolTip(tooltip) + stored_mode = self._preferences.axis_redefinition_mode() + index = self._axis_behaviour_to_index.get(stored_mode, 0) + combo.blockSignals(True) + combo.setCurrentIndex(index) + combo.blockSignals(False) + combo.currentIndexChanged.connect(self._on_axis_behaviour_combo_changed) + + def _setup_axis_feedback_banner(self) -> None: + """Create a transient banner that appears after redefining the axis.""" + + frame = QFrame(self.ui.verticalLayoutWidget) + frame.setObjectName("axisBehaviourBanner") + frame.setFrameShape(QFrame.Shape.StyledPanel) + frame.setVisible(False) + frame.setAttribute(Qt.WidgetAttribute.WA_Hover, True) + + layout = QHBoxLayout(frame) + layout.setContentsMargins(8, 4, 8, 4) + layout.setSpacing(8) + + label = QLabel(frame) + layout.addWidget(label, 1) + + layout.addStretch(1) + + move_btn = QPushButton(self._AXIS_BEHAVIOUR_LABELS[self._AXIS_MODE_MOVE], frame) + keep_btn = QPushButton(self._AXIS_BEHAVIOUR_LABELS[self._AXIS_MODE_KEEP], frame) + layout.addWidget(move_btn, 0) + layout.addWidget(keep_btn, 0) + + self.ui.verticalLayout_controls.insertWidget(1, frame) + + timer = QTimer(self) + timer.setSingleShot(True) + timer.timeout.connect(self._hide_axis_feedback_banner) + + move_btn.clicked.connect(lambda: self._handle_axis_banner_choice(self._AXIS_MODE_MOVE)) + keep_btn.clicked.connect(lambda: self._handle_axis_banner_choice(self._AXIS_MODE_KEEP)) + frame.installEventFilter(self) + + self._axis_feedback_frame = frame + self._axis_feedback_label = label + self._axis_feedback_move_button = move_btn + self._axis_feedback_keep_button = keep_btn + self._axis_feedback_timer = timer + + def _axis_behaviour_from_index(self, index: int) -> str: + """Translate a combo box index into a behaviour identifier.""" + + return self._axis_behaviour_by_index.get(index, self._AXIS_MODE_MOVE) + + def _axis_behaviour_label(self, behaviour: str) -> str: + """Return a human readable label for an axis redefinition mode.""" + + return self._AXIS_BEHAVIOUR_LABELS.get(behaviour, behaviour) + + def _default_axis_behaviour(self) -> str: + """Resolve the default behaviour taking stored preferences into account.""" + + return self._axis_behaviour_from_index(self.ui.comboBox_axis_behaviour.currentIndex()) + + def _update_axis_behaviour_combo(self, behaviour: str, *, update_preferences: bool) -> None: + """Synchronise the behaviour combo box and, optionally, saved preference.""" + + index = self._axis_behaviour_to_index.get(behaviour) + if index is None: + return + combo = self.ui.comboBox_axis_behaviour + if combo.currentIndex() != index: + combo.blockSignals(True) + combo.setCurrentIndex(index) + combo.blockSignals(False) + if update_preferences: + self._preferences.set_axis_redefinition_mode(behaviour) + + def _show_axis_feedback_banner(self, cache: AxisRedefinitionCache) -> None: + """Display the axis feedback banner summarising the applied behaviour.""" + + if not cache.shapes: + self._hide_axis_feedback_banner() + return + behaviour = cache.behaviour or self._default_axis_behaviour() + description = self._axis_behaviour_label(behaviour) + self._axis_feedback_label.setText( + f'Axis updated; patterns set to "{description}". Change?' + ) + self._refresh_axis_feedback_buttons(behaviour) + self._axis_feedback_frame.setVisible(True) + self._axis_feedback_timer.start(6000) + + def _hide_axis_feedback_banner(self) -> None: + """Hide and reset the feedback banner timer.""" + + self._axis_feedback_timer.stop() + self._axis_feedback_frame.setVisible(False) + + def _refresh_axis_feedback_buttons(self, behaviour: str) -> None: + """Enable/disable the axis feedback buttons based on ``behaviour``.""" + + move_active = behaviour == self._AXIS_MODE_MOVE + keep_active = behaviour == self._AXIS_MODE_KEEP + self._axis_feedback_move_button.setEnabled(not move_active) + self._axis_feedback_move_button.setDefault(move_active) + self._axis_feedback_keep_button.setEnabled(not keep_active) + self._axis_feedback_keep_button.setDefault(keep_active) + + def _handle_axis_banner_choice(self, behaviour: str) -> None: + """Respond to the user clicking a behaviour button in the banner.""" + + cache = self._axis_redefine_cache + if cache is None: + return + if cache.behaviour == behaviour: + self._hide_axis_feedback_banner() + return + self._apply_axis_definition(cache, behaviour, fit_view=False) + self._update_axis_behaviour_combo(behaviour, update_preferences=True) + self._show_axis_feedback_banner(cache) + + def _on_axis_behaviour_combo_changed(self, index: int) -> None: + """Persist a manual change in the behaviour combo box.""" + + behaviour = self._axis_behaviour_from_index(index) + self._preferences.set_axis_redefinition_mode(behaviour) + if self._axis_redefine_cache is not None: + self._refresh_axis_feedback_buttons( + self._axis_redefine_cache.behaviour or behaviour + ) + + def _update_axis_labels(self) -> None: + """Refresh the axis labels to reflect whether calibration is active.""" + + unit = "µm" if self._calibration is not None else "px" + axis_bottom = self._plot_item.getAxis("bottom") + axis_left = self._plot_item.getAxis("left") + axis_bottom.setLabel(f"X ({unit})") + axis_left.setLabel(f"Y ({unit})") + for axis in (axis_bottom, axis_left): + axis.picture = None + axis.update() + + def _image_axis_bounds(self) -> tuple[float, float, float, float]: + """Return the bounding box of the current image expressed in axis space.""" + + if self._current_image is None: + return (-50.0, 50.0, -50.0, 50.0) + height, width = self._current_image.shape[:2] + corners_camera = np.array( + [ + [0.0, 0.0], + [float(width), 0.0], + [float(width), float(height)], + [0.0, float(height)], + ], + dtype=float, + ) + corners_axis = self._camera_to_axis(corners_camera) + min_x = float(np.min(corners_axis[:, 0])) + max_x = float(np.max(corners_axis[:, 0])) + min_y = float(np.min(corners_axis[:, 1])) + max_y = float(np.max(corners_axis[:, 1])) + return min_x, max_x, min_y, max_y + + def _update_axis_visuals(self) -> None: + """Synchronise the origin/arrow items with the current axis definition.""" + + show = self._axis_defined + for item in (self._axis_line_item, self._axis_arrow_item, self._axis_origin_item): + item.setVisible(show) + if not show: + return + min_x, max_x, min_y, max_y = self._image_axis_bounds() + span = max(max_x - min_x, max_y - min_y, 1.0) + origin_x, origin_y = 0.0, 0.0 + end_x, end_y = span * 0.25, 0.0 + self._axis_line_item.setData([origin_x, end_x], [origin_y, end_y]) + self._axis_arrow_item.setPos(end_x, end_y) + self._axis_arrow_item.setStyle(angle=0.0) + self._axis_origin_item.setData([origin_x], [origin_y]) + + def _set_axis_state(self, origin_camera: np.ndarray, angle_rad: float, defined: bool) -> None: + """Update axis properties and refresh associated visuals/listeners.""" + + self._axis_origin_camera = np.asarray(origin_camera, dtype=float) + self._axis_angle_rad = float(angle_rad) + self._axis_defined = defined + self._update_image_transform() + self._update_axis_visuals() + self._update_listener_controls() + + def _apply_axis_definition( + self, + cache: AxisRedefinitionCache, + behaviour: str, + *, + fit_view: bool, + ) -> None: + """Apply an axis redefinition using the supplied behaviour. + + ``cache`` contains both the new axis parameters and a snapshot of the + ROI geometry prior to the change. Depending on ``behaviour`` the method + either reprojects the shapes into the new axis or keeps them untouched. + ``fit_view`` allows callers to control whether the image view is + realigned after the update (useful after drawing a brand-new axis). + """ + + self._axis_origin_camera = np.asarray(cache.new_origin, dtype=float) + self._axis_angle_rad = float(cache.new_angle) + self._axis_defined = True + self._update_image_transform() + + if cache.shapes: + if behaviour == self._AXIS_MODE_MOVE and cache.behaviour != self._AXIS_MODE_MOVE: + self._reproject_shapes_from_cache(cache) + elif ( + behaviour == self._AXIS_MODE_KEEP + and cache.behaviour not in (None, self._AXIS_MODE_KEEP) + ): + self._restore_shapes_from_cache(cache) + cache.behaviour = behaviour + + self._update_axis_visuals() + if fit_view: + self._fit_view_to_image() + self._update_listener_controls() + + def _define_axis(self) -> None: + """Enter the interactive axis capture tool and update state accordingly.""" + + button = self.ui.pushButton_define_axis + if not button.isEnabled(): + return + button.setChecked(True) + print( + "Axis tool: click to set origin, drag to direction, release to confirm. Right-click or Esc cancels." + ) + capture = AxisCapture(self._get_view_box(), self) + result = capture.exec() + button.setChecked(False) + if result is None: + return + origin_view, end_view = result + origin_axis = np.array([origin_view.x(), origin_view.y()], dtype=float) + end_axis = np.array([end_view.x(), end_view.y()], dtype=float) + vector_axis = end_axis - origin_axis + if np.linalg.norm(vector_axis) < 1e-6: + return + origin_camera = self._axis_to_camera(origin_axis) + direction_camera = self._rotation_matrix() @ vector_axis + angle_camera = float(np.arctan2(direction_camera[1], direction_camera[0])) + shapes_export = { + item: (np.asarray(points, dtype=float), shape_type) + for item, (points, shape_type) in self.roi_manager.export_shape_points().items() + } + cache = AxisRedefinitionCache( + previous_origin=self._axis_origin_camera.copy(), + previous_angle=self._axis_angle_rad, + new_origin=np.asarray(origin_camera, dtype=float), + new_angle=angle_camera, + shapes=shapes_export, + ) + self._axis_redefine_cache = cache + behaviour = self._default_axis_behaviour() + self._apply_axis_definition(cache, behaviour, fit_view=True) + if cache.shapes: + self._show_axis_feedback_banner(cache) + else: + self._hide_axis_feedback_banner() diff --git a/stim1p/ui/dmd_widget/calibration.py b/stim1p/ui/dmd_widget/calibration.py new file mode 100644 index 0000000..51a8cfa --- /dev/null +++ b/stim1p/ui/dmd_widget/calibration.py @@ -0,0 +1,465 @@ +"""Calibration workflow helpers used by :class:`StimDMDWidget`. + +The calibration routines manage several UI flows: prompting the user to capture +new calibration imagery, sending calibration frames to the hardware, and +persisting the resulting :class:`~stim1p.logic.calibration.DMDCalibration` +objects. The original widget bundled these responsibilities directly inside +event handlers, which made it hard to understand the overall flow. The mixin +documented below keeps the procedural steps in one place and adds commentary so +future maintainers can follow the user journey end-to-end. +""" + +from __future__ import annotations + +import os +from pathlib import Path +import numpy as np +from PIL import Image +from PySide6.QtWidgets import QDialog, QFileDialog, QMessageBox + +from ...logic import saving +from ...logic.calibration import DMDCalibration, compute_calibration_from_square +from ..capture_tools import AxisCapture +from ..dmd_dialogs import CalibrationDialog, CalibrationPreparationDialog + +class CalibrationWorkflowMixin: + """Mixin collecting the calibration related routines. + + The mixin expects the widget to expose ``self._stim``, ``self.ui`` and a + ``self._preferences`` object. Each helper is documented with the user + interaction it supports, making it easier to adjust the workflow without + accidentally omitting a state restoration step. + """ + + _last_calibration_file_path: str + _current_image: np.ndarray | None + _preferences: object + + def remember_calibration_file(self, path: str) -> None: + """Store the path to the most recently used calibration file.""" + + self._last_calibration_file_path = path + self._preferences.set_last_calibration_file_path(path) + + def last_calibration_file_path(self) -> str: + """Return the last calibration file recorded for this session.""" + + return self._last_calibration_file_path + + def _ensure_calibration_available(self) -> None: + """Ensure a calibration exists by loading a stored file or prompting. + + This helper centralises the "lazy loading" behaviour used by actions + that require calibration data. If the previously stored file cannot be + loaded we immediately fall back to guiding the user through the full + calibration wizard. + """ + + stored_path = self.last_calibration_file_path() + if stored_path: + success, _ = self._load_calibration_from_path(stored_path) + if success: + return + self._calibrate_dmd() + + def _calibrate_dmd(self) -> None: + """Handle the top-level calibration choice dialog.""" + + action = self._prompt_calibration_action() + if action is None: + return + if action == "load": + self._load_calibration_from_dialog() + elif action == "define": + self._define_new_calibration() + + def _prompt_calibration_action(self) -> str | None: + """Ask the user whether to load a calibration file or create a new one.""" + + prompt = QMessageBox(self) + prompt.setWindowTitle("Calibrate DMD") + prompt.setIcon(QMessageBox.Icon.Question) + prompt.setText("Choose how to obtain a DMD calibration.") + load_button = prompt.addButton( + "Load calibration file", QMessageBox.ButtonRole.ActionRole + ) + define_button = prompt.addButton( + "Define new calibration", QMessageBox.ButtonRole.ActionRole + ) + prompt.addButton(QMessageBox.StandardButton.Cancel) + if self._calibration is None: + prompt.setDefaultButton(define_button) + else: + prompt.setDefaultButton(load_button) + prompt.exec() + clicked = prompt.clickedButton() + if clicked is None: + return None + standard = prompt.standardButton(clicked) + if standard == QMessageBox.StandardButton.Cancel: + return None + if clicked is load_button: + return "load" + if clicked is define_button: + return "define" + return None + + def _load_calibration_from_dialog(self) -> None: + """Open a file dialog and load the selected calibration file.""" + + last_path = self.last_calibration_file_path() + initial = "" + if last_path: + candidate = Path(str(last_path)).expanduser() + if candidate.exists(): + initial = str(candidate) + else: + parent = candidate.parent + if parent.exists(): + initial = str(parent) + file_filter = "Calibration files (*.h5 *.hdf5);;All files (*)" + file_path, _ = QFileDialog.getOpenFileName( + self, + "Select calibration file", + initial, + file_filter, + ) + if not file_path: + return + success, error = self._load_calibration_from_path(file_path) + if not success: + QMessageBox.warning( + self, + "Calibration load error", + f"Unable to load calibration file:\n{error}", + ) + + def _prompt_calibration_preparation(self) -> tuple[str, int] | None: + """Prepare the user for capturing a calibration image. + + The dialog both determines the calibration square size and optionally + allows sending the calibration pattern to the DMD hardware before the + image is captured. + """ + + mirror_counts = self._preferences.mirror_counts() + default_mirror = int( + max(1, round(0.5 * (float(mirror_counts[0]) + float(mirror_counts[1])))) + ) + try: + dmd_shape = self._stim.dmd_shape() + except Exception: # noqa: BLE001 + dmd_shape = None + + dialog = CalibrationPreparationDialog( + self, + default_square_size=default_mirror, + can_send=self._stim.is_dmd_connected, + max_square_size=min(dmd_shape) if dmd_shape is not None else None, + ) + if dialog.exec() != QDialog.DialogCode.Accepted: + return None + action = dialog.chosen_action() + size = dialog.square_size() + if action is None: + return None + return action, size + + def _send_calibration_frame(self, square_size: int) -> bool: + """Display the calibration pattern on the DMD if connected.""" + + if not self._stim.is_dmd_connected: + QMessageBox.information( + self, + "DMD disconnected", + "Connect to the DMD before sending a calibration frame.", + ) + return False + try: + self._stim.display_calibration_frame(square_size) + except Exception as exc: # noqa: BLE001 + QMessageBox.warning( + self, + "Calibration frame error", + str(exc), + ) + return False + return True + + def _define_new_calibration(self) -> None: + """Guide the user through loading a calibration image and storing it.""" + + preparation = self._prompt_calibration_preparation() + if preparation is None: + return + action, square_size = preparation + if action == "send": + self._send_calibration_frame(square_size) + + initial_dir = self.ui.lineEdit_image_folder_path.text().strip() + stored_image_path = self._preferences.last_calibration_image_path() + if stored_image_path: + stored_dir = os.path.dirname(stored_image_path) + if stored_dir: + initial_dir = stored_dir + file_filter = ( + "Image files (*.png *.jpg *.jpeg *.tif *.tiff *.gif);;All files (*)" + ) + file_path, _ = QFileDialog.getOpenFileName( + self, + "Select calibration image", + initial_dir if initial_dir else "", + file_filter, + ) + if not file_path: + return + try: + with Image.open(file_path) as pil_image: + calibration_image = np.array(pil_image) + except Exception as exc: + QMessageBox.warning( + self, + "Calibration image error", + f"Unable to load calibration image:\n{exc}", + ) + return + self._preferences.set_last_calibration_image_path(file_path) + selected_dir = os.path.dirname(file_path) + if selected_dir: + self.ui.lineEdit_image_folder_path.setText(selected_dir) + + previous_image = self._current_image + previous_view = self._capture_view_state() + selected_items = self.ui.treeWidget.selectedItems() + selected_item = selected_items[0] if selected_items else None + self.roi_manager.clear_visible_only() + + self._set_image( + calibration_image, + fit_to_view=True, + apply_axis=False, + auto_contrast=True, + ) + + diagonal_points = self._prompt_calibration_diagonal() + if diagonal_points is None: + QMessageBox.information( + self, + "Calibration cancelled", + "No calibration diagonal was drawn. Calibration has been cancelled.", + ) + self._restore_after_calibration(previous_image, previous_view, selected_item) + return + + invert_defaults = self._preferences.axes_inverted() + default_invert_x = bool(invert_defaults[0]) + default_invert_y = bool(invert_defaults[1]) + default_mirrors = self._preferences.mirror_counts() + if action == "send": + default_mirrors = (int(square_size), int(square_size)) + dialog = CalibrationDialog( + self, + default_mirrors=default_mirrors, + default_pixel_size=self._preferences.pixel_size(), + default_invert_x=default_invert_x, + default_invert_y=default_invert_y, + ) + if dialog.exec() != QDialog.DialogCode.Accepted: + self._restore_after_calibration(previous_image, previous_view, selected_item) + return + + square_mirrors, pixel_size, invert_x, invert_y = dialog.values() + self._preferences.set_mirror_counts(square_mirrors, square_mirrors) + self._preferences.set_pixel_size(pixel_size) + self._preferences.set_axes_inverted(invert_x, invert_y) + camera_shape = ( + int(calibration_image.shape[1]), + int(calibration_image.shape[0]), + ) + if self.dmd is not None and hasattr(self.dmd, "shape"): + try: + dmd_shape = tuple(int(v) for v in self.dmd.shape) + except Exception: + dmd_shape = (1024, 768) + else: + dmd_shape = (1024, 768) + + try: + calibration = compute_calibration_from_square( + diagonal_points, + square_mirrors, + pixel_size, + camera_shape=camera_shape, + dmd_shape=dmd_shape, + invert_x=invert_x, + invert_y=invert_y, + ) + except ValueError as exc: + QMessageBox.warning(self, "Calibration failed", str(exc)) + self._restore_after_calibration(previous_image, previous_view, selected_item) + return + + self.calibration = calibration + print( + "Updated DMD calibration: pixels/mirror=(%.3f, %.3f), µm/mirror=(%.3f, %.3f), rotation=%.2f°" + % ( + calibration.camera_pixels_per_mirror[0], + calibration.camera_pixels_per_mirror[1], + calibration.micrometers_per_mirror[0], + calibration.micrometers_per_mirror[1], + np.degrees(calibration.camera_rotation_rad), + ) + ) + self._prompt_save_calibration(calibration) + self._restore_after_calibration(previous_image, previous_view, selected_item) + + def _prompt_save_calibration(self, calibration: DMDCalibration) -> None: + """Ask the user if they want to persist the computed calibration.""" + + response = QMessageBox.question( + self, + "Save calibration", + "Do you want to save this calibration to a file?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.Yes, + ) + if response != QMessageBox.StandardButton.Yes: + return + last_path = self.last_calibration_file_path() + initial = "" + if last_path: + candidate = Path(str(last_path)).expanduser() + if candidate.exists(): + initial = str(candidate) + else: + parent = candidate.parent + if parent.exists(): + initial = str(parent / candidate.name) + file_filter = "Calibration files (*.h5 *.hdf5);;All files (*)" + file_path, _ = QFileDialog.getSaveFileName( + self, + "Save calibration", + initial, + file_filter, + ) + if not file_path: + return + root, ext = os.path.splitext(file_path) + if not ext: + file_path = f"{file_path}.h5" + self._save_calibration_to_path(calibration, file_path) + + def _save_calibration_to_path(self, calibration: DMDCalibration, file_path: str) -> bool: + """Persist ``calibration`` to ``file_path`` and update the MRU list.""" + + path = Path(str(file_path)).expanduser() + try: + path.parent.mkdir(parents=True, exist_ok=True) + except Exception: + pass + try: + saving.save_calibration(str(path), calibration) + except Exception as exc: + QMessageBox.warning( + self, + "Save failed", + f"Unable to save calibration file:\n{exc}", + ) + return False + self.remember_calibration_file(str(path)) + print(f"Saved DMD calibration to {path}") + return True + + def _load_calibration_from_path(self, path_str: str) -> tuple[bool, str | None]: + """Load calibration from disk and activate it. + + Returns a ``(success, error_message)`` tuple mirroring the legacy + implementation so callers can distinguish between cancellations and + actual errors. + """ + + path = Path(str(path_str)).expanduser() + try: + calibration = saving.load_calibration(str(path)) + except Exception as exc: + message = f"{path}: {exc}" + print(f"Failed to load stored calibration from {message}") + return False, message + self.calibration = calibration + self.remember_calibration_file(str(path)) + try: + pixel_size = calibration.camera_pixel_size_um + except AttributeError: + pixel_size = None + if pixel_size is not None: + self._preferences.set_pixel_size(float(pixel_size)) + print(f"Active DMD calibration: {path}") + return True, None + + def _prompt_calibration_diagonal(self) -> np.ndarray | None: + """Capture the diagonal of the illuminated calibration square.""" + + prompt = QMessageBox(self) + prompt.setWindowTitle("Select calibration diagonal") + prompt.setIcon(QMessageBox.Icon.Information) + prompt.setText("Draw the diagonal of the illuminated calibration square.") + prompt.setInformativeText( + "Left-click and drag to draw the line. Right-click or press Esc to cancel." + ) + prompt.setStandardButtons( + QMessageBox.StandardButton.Cancel | QMessageBox.StandardButton.Ok + ) + if prompt.exec() != QMessageBox.StandardButton.Ok: + return None + + capture = AxisCapture(self._get_view_box(), self) + segment = capture.exec() + if segment is None: + return None + start, end = segment + start_xy = np.array([start.x(), start.y()], dtype=float) + end_xy = np.array([end.x(), end.y()], dtype=float) + if not np.all(np.isfinite(start_xy)) or not np.all(np.isfinite(end_xy)): + return None + if np.linalg.norm(end_xy - start_xy) < 1e-9: + return None + return np.vstack((start_xy, end_xy)) + + def _capture_view_state( + self, + ) -> tuple[tuple[float, float], tuple[float, float]] | None: + """Remember the current view box range so it can be restored later.""" + + view = self._get_view_box() + try: + x_range, y_range = view.viewRange() + except Exception: + return None + return (tuple(x_range), tuple(y_range)) + + def _restore_after_calibration( + self, + previous_image: np.ndarray | None, + previous_view_range: tuple[tuple[float, float], tuple[float, float]] | None, + selected_item, + ) -> None: + """Return the widget to the state it had before running calibration.""" + + if previous_image is not None: + self._set_image(previous_image, auto_contrast=True) + if previous_view_range is not None: + x_range, y_range = previous_view_range + self._get_view_box().setRange( + xRange=x_range, + yRange=y_range, + padding=0.0, + ) + else: + self._image_item.clear() + self._current_levels = None + self._current_image = None + + if selected_item is not None: + self.roi_manager.show_for_item(selected_item) + else: + self.roi_manager.clear_visible_only() diff --git a/stim1p/ui/dmd_widget/pattern_io.py b/stim1p/ui/dmd_widget/pattern_io.py new file mode 100644 index 0000000..9a1a385 --- /dev/null +++ b/stim1p/ui/dmd_widget/pattern_io.py @@ -0,0 +1,414 @@ +"""Pattern sequence import/export helpers for :class:`StimDMDWidget`. + +The pattern table in the widget doubles as both a UI editor and an interface to +persist experiments. This module documents how the pieces fit together: +reading/writing table contents, serialising pattern metadata and exporting +analysis information alongside the raw sequence. Adding commentary here keeps +the high-level responsibilities close to the implementation. +""" + +from __future__ import annotations + +from pathlib import Path + +import h5py +import numpy as np +from PySide6.QtWidgets import QDialog, QFileDialog, QMessageBox, QTableWidgetItem + +from ...logic import saving +from ...logic.calibration import DMDCalibration +from ...logic.geometry import axis_micrometre_to_global +from ...logic.sequence import PatternSequence + + +class PatternSequenceIOMixin: + """Mixin bundling pattern table handling and persistence helpers. + + The mixin is responsible for translating the mutable table widget state into + :class:`~stim1p.logic.sequence.PatternSequence` objects and vice versa. It + also encapsulates how files are located so UI code can focus on wiring + signals. + """ + + _calibration: DMDCalibration | None + + def _read_table_ms(self): + """Extract timing/duration/sequence columns from the pattern table.""" + + timings, durations, sequence = [], [], [] + rows = self.ui.tableWidget.rowCount() + for r in range(rows): + t_item = self.ui.tableWidget.item(r, 0) + d_item = self.ui.tableWidget.item(r, 1) + s_item = self.ui.tableWidget.item(r, 2) + try: + if t_item and s_item: + t_text = (t_item.text() or "").strip() + s_text = (s_item.text() or "").strip() + if not t_text or not s_text: + continue + d_text = (d_item.text() if d_item else "") or "" + d_text = d_text.strip() + t = int(t_text) + d = int(d_text) if d_text else 0 + s = int(s_text) + timings.append(t) + durations.append(d) + sequence.append(s) + except Exception: + continue + return timings, durations, sequence + + def _write_table_ms(self, model: PatternSequence): + """Populate the pattern table with values from ``model``.""" + + t_ms = model.timings_milliseconds + d_ms = model.durations_milliseconds + seq = model.sequence + self._updating_table = True + self.table_manager.ensure_desc_column() + self.ui.tableWidget.setRowCount(len(seq)) + for r, (t, d, s) in enumerate(zip(t_ms, d_ms, seq)): + self.ui.tableWidget.setItem(r, 0, QTableWidgetItem(str(int(t)))) + self.ui.tableWidget.setItem(r, 1, QTableWidgetItem(str(int(d)))) + self.ui.tableWidget.setItem(r, 2, QTableWidgetItem(str(int(s)))) + self.table_manager.set_sequence_row_description(r, int(s)) + self._updating_table = False + + def _collect_model(self) -> PatternSequence | None: + """Return the current :class:`PatternSequence` or surface validation errors.""" + + try: + return self.model + except RuntimeError as exc: + QMessageBox.warning(self, "Calibration required", str(exc)) + except Exception as exc: # noqa: BLE001 + QMessageBox.critical(self, "Pattern export failed", str(exc)) + return None + + def _prompt_save_path(self, title: str, initial_path: str) -> str: + """Prompt the user for a file path and normalise the extension.""" + + initial = initial_path.strip() + file_path, _ = QFileDialog.getSaveFileName( + self, + title, + initial if initial else "", + "HDF5 files (*.h5 *.hdf5);;All files (*)", + ) + if not file_path: + return "" + return self._ensure_h5_extension(file_path) + + @staticmethod + def _ensure_h5_extension(path: str) -> str: + """Append ``.h5`` to ``path`` if no HDF5-compatible suffix is present.""" + + trimmed = path.strip() + if not trimmed: + return "" + candidate = Path(trimmed) + if candidate.suffix.lower() in {".h5", ".hdf5"}: + return str(candidate) + return str(candidate.with_suffix(".h5")) + + def _write_pattern_sequence( + self, + file_path: str, + model: PatternSequence, + *, + silent: bool = False, + ) -> str | None: + """Serialise ``model`` to disk and optionally announce the destination.""" + + target = self._ensure_h5_extension(file_path) + if not target: + return None + try: + saving.save_pattern_sequence(target, model) + except Exception as exc: # noqa: BLE001 + QMessageBox.critical(self, "Save failed", str(exc)) + return None + if not silent: + print(f"Saved PatternSequence to {target}") + return target + + def _write_analysis_metadata( + self, + file_path: str, + model: PatternSequence, + ) -> None: + """Append coordinate metadata to the exported pattern sequence file.""" + + if self._calibration is None: + raise RuntimeError("A calibration must be available to export analysis metadata.") + calibration = self._calibration + with h5py.File(file_path, "a") as handle: + if "analysis" in handle: + del handle["analysis"] + analysis_grp = handle.create_group("analysis") + analysis_grp.attrs["version"] = 1 + analysis_grp.attrs["generator"] = "StimDMDWidget" + + axis_grp = analysis_grp.create_group("axis") + axis_grp.attrs["defined"] = bool(self._axis_defined) + axis_grp.attrs["coordinate_system"] = "camera_pixels" + axis_grp.attrs["camera_shape"] = np.asarray(calibration.camera_shape, dtype=np.int64) + if self._axis_defined: + axis_grp.create_dataset( + "origin_camera", + data=np.asarray(self._axis_origin_camera, dtype=np.float64), + ) + axis_grp.create_dataset( + "angle_rad", + data=np.array(self._axis_angle_rad, dtype=np.float64), + ) + try: + origin_um = self._axis_origin_micrometre() + axis_grp.create_dataset( + "origin_micrometre", + data=np.asarray(origin_um, dtype=np.float64), + ) + except Exception: # noqa: BLE001 + pass + + patterns_grp = analysis_grp.create_group("patterns_camera") + patterns_grp.attrs["coordinate_system"] = "camera_pixels" + + descriptions = model.descriptions or [] + shape_types = model.shape_types or [] + + axis_def = self._axis_definition() + + for pattern_index, pattern in enumerate(model.patterns): + pattern_grp = patterns_grp.create_group(f"pattern_{pattern_index}") + if pattern_index < len(descriptions) and descriptions[pattern_index]: + pattern_grp.attrs["description"] = descriptions[pattern_index] + for poly_index, polygon in enumerate(pattern): + points = np.asarray(polygon, dtype=np.float64) + if points.ndim != 2 or points.shape[1] != 2: + continue + if self._axis_defined: + global_um = axis_micrometre_to_global( + points, + axis_def, + calibration, + ) + else: + global_um = points + camera_points = calibration.micrometre_to_camera(global_um.T).T + dataset = pattern_grp.create_dataset( + f"polygon_{poly_index}", data=camera_points + ) + if ( + pattern_index < len(shape_types) + and poly_index < len(shape_types[pattern_index]) + ): + dataset.attrs["shape_type"] = str( + shape_types[pattern_index][poly_index] + ) + + def _load_patterns_file(self): + """Load a pattern sequence from disk and populate the editor.""" + + initial = self.ui.lineEdit_file_path.text().strip() + file_path, _ = QFileDialog.getOpenFileName( + self, + "Select pattern sequence", + initial if initial else "", + "HDF5 files (*.h5 *.hdf5);;All files (*)", + ) + if not file_path: + return + if self._calibration is None: + QMessageBox.warning( + self, + "Calibration required", + "Load or compute a DMD calibration before loading patterns.", + ) + return + try: + self.model = saving.load_pattern_sequence(file_path) + except RuntimeError as exc: + QMessageBox.warning(self, "Calibration required", str(exc)) + return + except Exception as exc: # noqa: BLE001 + QMessageBox.critical(self, "Load failed", str(exc)) + return + self.ui.lineEdit_file_path.setText(file_path) + print(f"Loaded PatternSequence from {file_path}") + + def _save_file(self) -> None: + """Save the current pattern sequence, prompting for a path if necessary.""" + + model = self._collect_model() + if model is None: + return + current_path = self.ui.lineEdit_file_path.text().strip() + if not current_path: + target_path = self._prompt_save_path("Save pattern sequence", "") + if not target_path: + return + else: + target_path = self._ensure_h5_extension(current_path) + saved_path = self._write_pattern_sequence(target_path, model) + if saved_path: + self.ui.lineEdit_file_path.setText(saved_path) + + def _save_file_as(self) -> None: + """Always prompt for a location before saving the current sequence.""" + + model = self._collect_model() + if model is None: + return + current_path = self.ui.lineEdit_file_path.text().strip() + target_path = self._prompt_save_path("Save pattern sequence as", current_path) + if not target_path: + return + saved_path = self._write_pattern_sequence(target_path, model) + if saved_path: + self.ui.lineEdit_file_path.setText(saved_path) + + def _export_patterns_for_analysis(self) -> None: + """Persist the sequence and supplement it with analysis metadata.""" + + model = self._collect_model() + if model is None: + return + current_path = self.ui.lineEdit_file_path.text().strip() + target_path = self._prompt_save_path("Export pattern sequence", current_path) + if not target_path: + return + saved_path = self._write_pattern_sequence(target_path, model, silent=True) + if not saved_path: + return + try: + self._write_analysis_metadata(saved_path, model) + except Exception as exc: # noqa: BLE001 + QMessageBox.warning( + self, + "Export incomplete", + "Pattern sequence saved, but analysis metadata could not be written.\n" + f"Reason: {exc}", + ) + return + print(f"Exported PatternSequence with analysis metadata to {saved_path}") + + def _new_model(self): + """Reset the editor to an empty :class:`PatternSequence`.""" + + self.model = PatternSequence( + patterns=[], sequence=[], timings=[], durations=[], descriptions=[] + ) + self.ui.lineEdit_file_path.clear() + print("Loaded empty PatternSequence") + + def _add_row_table(self): + """Insert a new empty row at the bottom of the pattern table.""" + + self.ui.tableWidget.insertRow(self.ui.tableWidget.rowCount()) + + def _remove_row_table(self): + """Delete any selected rows from the pattern table.""" + + rows = sorted( + {i.row() for i in self.ui.tableWidget.selectedIndexes()}, reverse=True + ) + for r in rows: + self.ui.tableWidget.removeRow(r) + + def _cycle_patterns(self) -> None: + """Expand the table so each pattern is repeated according to user input.""" + + pattern_count = self.ui.treeWidget.topLevelItemCount() + if pattern_count <= 0: + QMessageBox.information( + self, + "No patterns", + "Create at least one pattern before cycling them.", + ) + return + + table = self.ui.tableWidget + last_time = 0 + if table.rowCount() > 0: + last_item = table.item(table.rowCount() - 1, 0) + if last_item is not None: + try: + last_time = int(last_item.text()) + except Exception: + last_time = 0 + + default_repeat_gap = 100 + default_cycle_gap = 250 + default_duration = 100 + default_first_time = last_time + default_repeat_gap if table.rowCount() > 0 else 0 + + dialog = self._cycle_dialog_factory( + default_first_time=default_first_time, + default_cycle_count=1, + default_repeat_count=1, + default_repeat_gap=default_repeat_gap, + default_cycle_gap=default_cycle_gap, + default_duration=default_duration, + ) + if dialog.exec() != QDialog.DialogCode.Accepted: + return + + params = dialog.values() + cycle_count = max(1, int(params["cycle_count"])) + repeat_count = max(1, int(params["repeat_count"])) + first_time = int(params["first_time_ms"]) + repeat_gap = int(params["repeat_gap_ms"]) + cycle_gap = int(params["cycle_gap_ms"]) + duration = int(params["duration_ms"]) + + if cycle_count <= 0 or repeat_count <= 0: + return + + entries: list[tuple[int, int]] = [] + current_time = first_time + + for cycle_index in range(cycle_count): + for pattern_idx in range(pattern_count): + for repeat_index in range(repeat_count): + # Each ``entries`` tuple stores the pattern index and the + # start time for a single presentation. We build the full + # list first to keep the table mutation isolated below. + entries.append((pattern_idx, current_time)) + is_last_entry = ( + cycle_index == cycle_count - 1 + and pattern_idx == pattern_count - 1 + and repeat_index == repeat_count - 1 + ) + if is_last_entry: + continue + current_time += repeat_gap + end_of_cycle = ( + repeat_index == repeat_count - 1 + and pattern_idx == pattern_count - 1 + ) + if end_of_cycle: + current_time += cycle_gap + + if not entries: + return + + self.table_manager.ensure_desc_column() + signals_were_blocked = self.ui.tableWidget.blockSignals(True) + try: + for pattern_idx, start_time in entries: + row = table.rowCount() + table.insertRow(row) + table.setItem(row, 0, QTableWidgetItem(str(start_time))) + table.setItem(row, 1, QTableWidgetItem(str(duration))) + table.setItem(row, 2, QTableWidgetItem(str(pattern_idx))) + self.table_manager.set_sequence_row_description(row, pattern_idx) + finally: + self.ui.tableWidget.blockSignals(signals_were_blocked) + self.table_manager.refresh_sequence_descriptions() + + def _cycle_dialog_factory(self, **defaults): + from ..dmd_dialogs import CyclePatternsDialog + + return CyclePatternsDialog(self, **defaults)