diff --git a/optics_framework/api/action_keyword.py b/optics_framework/api/action_keyword.py index ef2fd93b..f1614d75 100644 --- a/optics_framework/api/action_keyword.py +++ b/optics_framework/api/action_keyword.py @@ -1,5 +1,5 @@ from functools import wraps -from typing import Callable +from typing import Callable, Optional, Any from optics_framework.common.logging_config import logger, apply_logger_format_to_all from optics_framework.common.optics_builder import OpticsBuilder from optics_framework.common.strategies import StrategyManager @@ -7,9 +7,8 @@ from .verifier import Verifier import time -# Action Executor Decorator - +# Action Executor Decorator def with_self_healing(func: Callable) -> Callable: @wraps(func) def wrapper(self, element, *args, **kwargs): @@ -54,8 +53,8 @@ def __init__(self, builder: OpticsBuilder): # Click actions @with_self_healing def press_element( - self, element, repeat=1, offset_x=0, offset_y=0, event_name=None, *, located - ): + self, element: str, repeat: int = 1, offset_x: int = 0, offset_y: int = 0, event_name: Optional[str] = None, *, located: Any + ) -> None: """ Press a specified element. @@ -75,7 +74,7 @@ def press_element( logger.debug(f"Pressing element '{element}'") self.driver.press_element(located, repeat, event_name) - def press_by_percentage(self, percent_x, percent_y, repeat=1, event_name=None): + def press_by_percentage(self, percent_x: float, percent_y: float, repeat: int = 1, event_name: Optional[str] = None) -> None: """ Press an element by percentage coordinates. @@ -99,7 +98,7 @@ def press_by_percentage(self, percent_x, percent_y, repeat=1, event_name=None): y_coor = int(screen_height * percent_y) self.driver.press_coordinates(x_coor, y_coor, event_name) - def press_by_coordinates(self, coor_x, coor_y, repeat=1, event_name=None): + def press_by_coordinates(self, coor_x: int, coor_y: int, repeat: int = 1, event_name: Optional[str] = None) -> None: """ Press an element by absolute coordinates. @@ -111,7 +110,7 @@ def press_by_coordinates(self, coor_x, coor_y, repeat=1, event_name=None): utils.capture_screenshot("press_by_coordinates") self.driver.press_coordinates(coor_x, coor_y, event_name) - def press_element_with_index(self, element, index=0, event_name=None): + def press_element_with_index(self, element: str, index: int = 0, event_name: Optional[str] = None) -> None: """ Press a specified text at a given index. @@ -154,7 +153,7 @@ def press_element_with_index(self, element, index=0, event_name=None): 'XPath is not supported for index based location. Provide the attribute as text.') @with_self_healing - def detect_and_press(self, element, timeout, event_name=None, *, located): + def detect_and_press(self, element: str, timeout: int, event_name: Optional[str] = None, *, located: Any) -> None: """ Detect and press a specified element. @@ -174,11 +173,12 @@ def detect_and_press(self, element, timeout, event_name=None, *, located): x, y, event_name=event_name) else: logger.debug(f"Pressing detected element '{element}'") - self.driver.press_element(located, repeat=1, event_name=event_name) + self.driver.press_element( + located, repeat=1, event_name=event_name) @DeprecationWarning @with_self_healing - def press_checkbox(self, element, event_name=None, *, located): + def press_checkbox(self, element: str, event_name: Optional[str] = None, *, located: Any) -> None: """ Press a specified checkbox element. @@ -189,7 +189,7 @@ def press_checkbox(self, element, event_name=None, *, located): @DeprecationWarning @with_self_healing - def press_radio_button(self, element, event_name=None, *, located): + def press_radio_button(self, element: str, event_name: Optional[str] = None, *, located: Any) -> None: """ Press a specified radio button. @@ -198,7 +198,7 @@ def press_radio_button(self, element, event_name=None, *, located): """ self.press_element(element, event_name=event_name, located=located) - def select_dropdown_option(self, element, option, event_name=None): + def select_dropdown_option(self, element: str, option: str, event_name: Optional[str] = None) -> None: """ Select a specified dropdown option. @@ -209,7 +209,7 @@ def select_dropdown_option(self, element, option, event_name=None): pass # Swipe and Scroll actions - def swipe(self, coor_x, coor_y, direction='right', swipe_length=50, event_name=None): + def swipe(self, coor_x: int, coor_y: int, direction: str = 'right', swipe_length: int = 50, event_name: Optional[str] = None) -> None: """ Perform a swipe action in a specified direction. @@ -223,7 +223,7 @@ def swipe(self, coor_x, coor_y, direction='right', swipe_length=50, event_name=N self.driver.swipe(coor_x, coor_y, direction, swipe_length, event_name) @DeprecationWarning - def swipe_seekbar_to_right_android(self, element, event_name=None): + def swipe_seekbar_to_right_android(self, element: str, event_name: Optional[str] = None) -> None: """ Swipe a seekbar to the right. @@ -232,7 +232,7 @@ def swipe_seekbar_to_right_android(self, element, event_name=None): utils.capture_screenshot("swipe_seekbar_to_right_android") self.driver.swipe_element(element, 'right', 50, event_name) - def swipe_until_element_appears(self, element, direction, timeout, event_name=None): + def swipe_until_element_appears(self, element: str, direction: str, timeout: int, event_name: Optional[str] = None) -> None: """ Swipe in a specified direction until an element appears. @@ -252,18 +252,14 @@ def swipe_until_element_appears(self, element, direction, timeout, event_name=No time.sleep(3) @with_self_healing - def swipe_from_element(self, element, direction, swipe_length, event_name=None, *, located): + def swipe_from_element(self, element: str, direction: str, swipe_length: int, event_name: Optional[str] = None, *, located: Any) -> None: """ Perform a swipe action starting from a specified element. :param element: The element to swipe from (Image template, OCR template, or XPath). - :type element: str :param direction: The swipe direction (up, down, left, right). - :type direction: str :param swipe_length: The length of the swipe. - :type swipe_length: int or float :param event_name: The event triggering the swipe. - :type event_name: str """ if isinstance(located, tuple): x, y = located @@ -274,54 +270,45 @@ def swipe_from_element(self, element, direction, swipe_length, event_name=None, self.driver.swipe_element( located, direction, swipe_length, event_name) - def scroll(self, direction, event_name=None): + def scroll(self, direction: str, event_name: Optional[str] = None) -> None: """ Perform a scroll action in a specified direction. :param direction: The scroll direction (up, down, left, right). - :type direction: str :param event_name: The event triggering the scroll. - :type event_name: str """ utils.capture_screenshot("scroll") self.driver.scroll(direction, 1000, event_name) @with_self_healing - def scroll_until_element_appears(self, element, direction, timeout, event_name=None, *, located): + def scroll_until_element_appears(self, element: str, direction: str, timeout: int, event_name: Optional[str] = None, *, located: Any) -> None: """ Scroll in a specified direction until an element appears. :param element: The target element (Image template, OCR template, or XPath). - :type element: str :param direction: The scroll direction (up, down, left, right). - :type direction: str :param timeout: Timeout for the scroll operation. - :type timeout: int or float :param event_name: The event triggering the scroll. - :type event_name: str """ utils.capture_screenshot("scroll_until_element_appears") start_time = time.time() while time.time() - start_time < int(timeout): - result = self.verifier.assert_presence(element, timeout=3, rule="any") + result = self.verifier.assert_presence( + element, timeout=3, rule="any") if result: break - self.driver.scroll(direction,1000, event_name) + self.driver.scroll(direction, 1000, event_name) time.sleep(3) @with_self_healing - def scroll_from_element(self, element, direction, scroll_length, event_name, *, located): + def scroll_from_element(self, element: str, direction: str, scroll_length: int, event_name: Optional[str] = None, *, located: Any) -> None: """ Perform a scroll action starting from a specified element. :param element: The element to scroll from (Image template, OCR template, or XPath). - :type element: str :param direction: The scroll direction (up, down, left, right). - :type direction: str :param scroll_length: The length of the scroll. - :type scroll_length: int or float :param event_name: The event triggering the scroll. - :type event_name: str """ utils.capture_screenshot("scroll_from_element") self.swipe_from_element( @@ -329,16 +316,13 @@ def scroll_from_element(self, element, direction, scroll_length, event_name, *, # Text input actions @with_self_healing - def enter_text(self, element, text, event_name=None, *, located): + def enter_text(self, element: str, text: str, event_name: Optional[str] = None, *, located: Any) -> None: """ Enter text into a specified element. :param element: The target element (Image template, OCR template, or XPath). - :type element: str :param text: The text to be entered. - :type text: str :param event_name: The event triggering the input. - :type event_name: str """ if isinstance(located, tuple): x, y = located @@ -350,54 +334,45 @@ def enter_text(self, element, text, event_name=None, *, located): self.driver.enter_text_element(located, text, event_name) @DeprecationWarning - def enter_text_using_keyboard_android(self, text, event_name=None): + def enter_text_using_keyboard_android(self, text: str, event_name: Optional[str] = None) -> None: """ Enter text using the keyboard. :param text: The text to be entered. - :type text: str :param event_name: The event triggering the input. - :type event_name: str """ utils.capture_screenshot("enter_text_using_keyboard_android") self.driver.enter_text_using_keyboard(text, event_name) @with_self_healing - def enter_number(self, element, number, event_name=None, *, located): + def enter_number(self, element: str, number: float, event_name: Optional[str] = None, *, located: Any) -> None: """ Enter a specified number into an element. :param element: The target element (Image template, OCR template, or XPath). - :type element: str :param number: The number to be entered. - :type number: int or float :param event_name: The event triggering the input. - :type event_name: str """ utils.capture_screenshot("enter_number") self.enter_text(element, str(number), event_name, located=located) - def press_keycode(self, keycode, event_name): + def press_keycode(self, keycode: int, event_name: str) -> None: """ Press a specified keycode. :param keycode: The keycode to be pressed. - :type keycode: int :param event_name: The event triggering the press. - :type event_name: str """ utils.capture_screenshot("press_keycode") self.driver.press_keycode(keycode, event_name) @with_self_healing - def clear_element_text(self, element, event_name=None, *, located): + def clear_element_text(self, element: str, event_name: Optional[str] = None, *, located: Any) -> None: """ Clear text from a specified element. :param element: The target element (Image template, OCR template, or XPath). - :type element: str :param event_name: The event triggering the action. - :type event_name: str """ if isinstance(located, tuple): x, y = located @@ -409,32 +384,34 @@ def clear_element_text(self, element, event_name=None, *, located): logger.debug(f"Clearing text from element '{element}'") self.driver.clear_text_element(located, event_name) - def get_text(self, element): + def get_text(self, element: str) -> Optional[str]: """ Get the text from a specified element. :param element: The target element (Image template, OCR template, or XPath). - :type element: str - :return: The text from the element. - :rtype: str + :return: The text from the element or None if not supported. """ utils.capture_screenshot("get_text") - element_source_type = type(self.element_source.current_instance).__name__ + element_source_type = type( + self.element_source.current_instance).__name__ element_type = utils.determine_element_type(element) if element_type in ["Text", "XPath"]: if 'appium' in element_source_type.lower(): element = self.element_source.locate(element) return self.driver.get_text_element(element) else: - logger.error('Get Text is not supported for vision based search yet.') + logger.error( + 'Get Text is not supported for vision based search yet.') + return None else: - logger.error('Get Text is not supported for image based search yet.') + logger.error( + 'Get Text is not supported for image based search yet.') + return None - def sleep(self, duration): + def sleep(self, duration: int) -> None: """ Sleep for a specified duration. - :param duration: The duration of the sleep. - :type duration: int or float + :param duration: The duration of the sleep in seconds. """ time.sleep(int(duration)) diff --git a/optics_framework/api/app_management.py b/optics_framework/api/app_management.py index 6e4977f9..5953eefe 100644 --- a/optics_framework/api/app_management.py +++ b/optics_framework/api/app_management.py @@ -1,6 +1,8 @@ +from typing import Optional from optics_framework.common.logging_config import logger, apply_logger_format_to_all from optics_framework.common.optics_builder import OpticsBuilder + @apply_logger_format_to_all("internal") class AppManagement: """ @@ -12,74 +14,69 @@ class AppManagement: Attributes: driver (object): The driver instance for managing applications. """ + def __init__(self, builder: OpticsBuilder): self.driver = builder.get_driver() if self.driver is None: - logger.exception(f"Driver '{builder.driver_config}' could not be initialized.") + logger.error("Driver could not be initialized due to not being provided.") + # Optionally raise an exception if this should halt execution + # raise ValueError(f"Driver '{builder.driver_config}' could not be initialized.") - def initialise_setup(self): + def initialise_setup(self) -> None: """ - Set up the environment for the driver module. + Sets up the environment for the driver module. This method should be called before performing any application management operations. """ logger.debug("Initialising setup for AppManagement.") - def launch_app(self, event_name: str | None = None): + def launch_app(self, event_name: Optional[str] = None) -> None: """ - Launch the specified application. + Launches the specified application. - :param event_name: The event triggering the app launch. - :type event_name: str + :param event_name: The event triggering the app launch, if any. """ self.driver.launch_app(event_name) - def start_appium_session(self, event_name: str | None = None): + def start_appium_session(self, event_name: Optional[str] = None) -> None: """ - Start an Appium session. + Starts an Appium session. - :param event_name: The event triggering the session start. - :type event_name: str + :param event_name: The event triggering the session start, if any. """ self.driver.launch_app(event_name) - def start_other_app(self, package_name: str, event_name: str): + def start_other_app(self, package_name: str, event_name: Optional[str] = None) -> None: """ - Start another application. + Starts another application. :param package_name: The package name of the application. - :type package_name: str - :param event_name: The event triggering the app start. - :type event_name: str + :param event_name: The event triggering the app start, if any. """ pass - def close_and_terminate_app(self, package_name: str, event_name: str): + def close_and_terminate_app(self, package_name: str, event_name: Optional[str] = None) -> None: """ - Close and terminate a specified application. + Closes and terminates a specified application. :param package_name: The package name of the application. - :type package_name: str - :param event_name: The event triggering the app termination. - :type event_name: str + :param event_name: The event triggering the app termination, if any. """ pass - def force_terminate_app(self, event_name: str): + def force_terminate_app(self, event_name: Optional[str] = None) -> None: """ - Forcefully terminate the specified application. + Forcefully terminates the specified application. - :param event_name: The event triggering the forced termination. - :type event_name: str + :param event_name: The event triggering the forced termination, if any. """ pass - def get_app_version(self): + def get_app_version(self) -> Optional[str]: """ - Get the version of the application. + Gets the version of the application. - :return: The version of the application. - :rtype: str + :return: The version of the application, or None if not available. """ - self.driver.get_app_version() + return self.driver.get_app_version() diff --git a/optics_framework/api/flow_control.py b/optics_framework/api/flow_control.py index 024e9543..9afb583a 100644 --- a/optics_framework/api/flow_control.py +++ b/optics_framework/api/flow_control.py @@ -1,37 +1,44 @@ import re -from typing import Optional, Any, List, Union, Tuple -import os +from typing import Optional, Any, List, Union, Tuple, Callable, Dict +import os.path +import ast +from functools import wraps import json -import requests import csv +import requests from optics_framework.common.logging_config import logger from optics_framework.common.runner.test_runnner import TestRunner -import ast -from functools import wraps + def raw_params(*indices): """Decorator to mark parameter indices that should remain unresolved.""" - def decorator(func): + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: return func(*args, **kwargs) - wrapper._raw_param_indices = indices + wrapper._raw_param_indices = indices # pylint: disable=protected-access # type: ignore[attr-defined] return wrapper return decorator class FlowControl: - def __init__(self, runner: TestRunner | None): - self.runner: TestRunner | None = runner + """Manages control flow operations (loops, conditions, data) for a TestRunner.""" + + def __init__(self, runner: Optional[TestRunner] = None): + self.runner: Optional[TestRunner] = runner - def _ensure_runner(self): + def _ensure_runner(self) -> None: + """Ensures a TestRunner instance is set.""" if self.runner is None: - raise Exception( + raise ValueError( "FlowControl.runner is not set. Please assign a valid runner instance before using FlowControl.") def execute_module(self, module_name: str) -> List[Any]: + """Executes a module's keywords with resolved parameters.""" self._ensure_runner() + if self.runner is None: # Explicit check instead of assert + raise ValueError("Runner is None after ensure_runner call.") if module_name not in self.runner.modules: raise ValueError(f"Module '{module_name}' not found.") @@ -54,7 +61,10 @@ def execute_module(self, module_name: str) -> List[Any]: @raw_params(1, 3, 5, 7, 9, 11, 13, 15) def run_loop(self, target: str, *args: str) -> List[Any]: + """Runs a loop over a target module, either by count or with variables.""" self._ensure_runner() + if self.runner is None: # Explicit check instead of assert + raise ValueError("Runner is None after ensure_runner call.") if len(args) == 1: return self._loop_by_count(target, args[0]) return self._loop_with_variables(target, args) @@ -88,8 +98,11 @@ def _loop_with_variables(self, target: str, args: Tuple[str, ...]) -> List[Any]: var_names, parsed_iterables = self._parse_variable_iterable_pairs( variables, iterables) min_length = min(len(lst) for lst in parsed_iterables) + if self.runner is None: # Explicit check instead of assert + raise ValueError("Runner is None after ensure_runner call.") if not isinstance(self.runner.elements, dict): self.runner.elements = {} + runner_elements: Dict[str, Any] = self.runner.elements results = [] for i in range(min_length): @@ -97,7 +110,7 @@ def _loop_with_variables(self, target: str, args: Tuple[str, ...]) -> List[Any]: value = iterable_values[i] logger.debug( f"[RUN LOOP] Iteration {i+1}: Setting {var_name} = {value}") - self.runner.elements[var_name] = value + runner_elements[var_name] = value logger.debug( f"[RUN LOOP] Iteration {i+1}: Executing target '{target}'") result = self.execute_module(target) @@ -105,13 +118,13 @@ def _loop_with_variables(self, target: str, args: Tuple[str, ...]) -> List[Any]: return results def _parse_variable_iterable_pairs(self, variables: Tuple[str, ...], iterables: Tuple[str, ...]) -> Tuple[List[str], List[List[Any]]]: - """Parse variable names and their corresponding iterables.""" + """Parses variable names and their corresponding iterables.""" var_names = self._parse_variable_names(variables) parsed_iterables = self._parse_iterables(variables, iterables) return var_names, parsed_iterables def _parse_variable_names(self, variables: Tuple[str, ...]) -> List[str]: - """Extract and clean variable names from the input tuple.""" + """Extracts and cleans variable names from the input tuple.""" var_names = [] for variable in variables: var_name = variable.strip() @@ -124,7 +137,7 @@ def _parse_variable_names(self, variables: Tuple[str, ...]) -> List[str]: return var_names def _parse_iterables(self, variables: Tuple[str, ...], iterables: Tuple[str, ...]) -> List[List[Any]]: - """Parse iterables into lists, handling JSON strings and validating input.""" + """Parses iterables into lists, handling JSON strings and validating input.""" parsed_iterables = [] for i, iterable in enumerate(iterables): parsed = self._parse_single_iterable(iterable, variables[i]) @@ -135,7 +148,7 @@ def _parse_iterables(self, variables: Tuple[str, ...], iterables: Tuple[str, ... return parsed_iterables def _parse_single_iterable(self, iterable: Any, variable: str) -> List[Any]: - """Parse a single iterable, converting JSON strings or validating lists.""" + """Parses a single iterable, converting JSON strings or validating lists.""" if isinstance(iterable, str): try: values = json.loads(iterable) @@ -152,8 +165,11 @@ def _parse_single_iterable(self, iterable: Any, variable: str) -> List[Any]: raise ValueError( f"Expected a list or JSON string for iterable of variable '{variable}', got {type(iterable).__name__}.") - def condition(self, *args) -> Optional[List[Any]]: + def condition(self, *args: str) -> Optional[List[Any]]: + """Evaluates conditions and executes corresponding targets.""" self._ensure_runner() + if self.runner is None: # Explicit check instead of assert + raise ValueError("Runner is None after ensure_runner call.") if not args: raise ValueError("No condition-target pairs provided.") pairs, else_target = self._split_condition_args(args) @@ -195,11 +211,14 @@ def _is_condition_true(self, cond: str) -> bool: raise ValueError(f"Error evaluating condition '{cond}': {e}") def _resolve_condition(self, cond: str) -> str: + if self.runner is None: # Explicit check instead of assert + raise ValueError("Runner is None after ensure_runner call.") + runner_elements: Dict[str, Any] = self.runner.elements pattern = re.compile(r"\$\{([^}]+)\}") def replacer(match): var_name = match.group(1).strip() - value = self.runner.elements.get(var_name) + value = runner_elements.get(var_name) if value is None: raise ValueError( f"Variable '{var_name}' not found for condition resolution.") @@ -211,16 +230,21 @@ def replacer(match): return pattern.sub(replacer, cond) @raw_params(0) - def read_data(self, input_element: str, file_path: Union[str, List[Any]], index: Optional[int] = None): + def read_data(self, input_element: str, file_path: Union[str, List[Any]], index: Optional[int] = None) -> List[Any]: + """Reads data from a file, API, or list and stores it in runner.elements.""" self._ensure_runner() + if self.runner is None: # Explicit check instead of assert + raise ValueError("Runner is None after ensure_runner call.") elem_name = self._extract_element_name(input_element) data = self._load_data(file_path, index) if not isinstance(self.runner.elements, dict): self.runner.elements = {} - self.runner.elements[elem_name] = data + runner_elements: Dict[str, Any] = self.runner.elements + runner_elements[elem_name] = data return data def _extract_element_name(self, input_element: str) -> str: + """Extracts and cleans the element name from input.""" elem_name = input_element.strip() if elem_name.startswith("${") and elem_name.endswith("}"): return elem_name[2:-1].strip() @@ -229,6 +253,7 @@ def _extract_element_name(self, input_element: str) -> str: return elem_name def _load_data(self, file_path: Union[str, List[Any]], index: Optional[int]) -> List[Any]: + """Loads data from a list, API, or CSV file.""" # Direct list input if isinstance(file_path, list): return file_path @@ -292,14 +317,19 @@ def _load_data(self, file_path: Union[str, List[Any]], index: Optional[int]) -> "Unsupported file format. Use CSV or provide a list/URL.") @raw_params(0) - def evaluate(self, param1: str, param2: str): + def evaluate(self, param1: str, param2: str) -> Any: + """Evaluates an expression and stores the result in runner.elements.""" self._ensure_runner() + if self.runner is None: # Explicit check instead of assert + raise ValueError("Runner is None after ensure_runner call.") var_name = self._extract_variable_name(param1) result = self._compute_expression(param2) - self.runner.elements[var_name] = str(result) + runner_elements: Dict[str, Any] = self.runner.elements + runner_elements[var_name] = str(result) return result def _extract_variable_name(self, param1: str) -> str: + """Extracts and cleans the variable name from param1.""" var_name = param1.strip() if var_name.startswith("${") and var_name.endswith("}"): return var_name[2:-1].strip() @@ -308,16 +338,22 @@ def _extract_variable_name(self, param1: str) -> str: return var_name def _compute_expression(self, param2: str) -> Any: + """Computes an expression by resolving variables and evaluating it.""" + if self.runner is None: # Explicit check instead of assert + raise ValueError("Runner is None after ensure_runner call.") + runner_elements: Dict[str, Any] = self.runner.elements + def replace_var(match): var_name = match.group(1) - if var_name not in self.runner.elements: + if var_name not in runner_elements: raise ValueError( f"Variable '{var_name}' not found in elements.") - return str(self.runner.elements[var_name]) + return str(runner_elements[var_name]) param2_resolved = re.sub(r"\$\{([^}]+)\}", replace_var, param2) return self._safe_eval(param2_resolved) def _safe_eval(self, expression: str) -> Any: + """Safely evaluates an expression with restricted operations.""" try: node = ast.parse(expression, mode='eval') allowed_nodes = (ast.Expression, ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare, @@ -330,7 +366,7 @@ def _safe_eval(self, expression: str) -> Any: if not isinstance(node, allowed_nodes) and not isinstance(node, allowed_operators): raise ValueError( f"Unsafe expression detected: {expression}") - return eval(expression, {"__builtins__": None}, {}) # nosec + return eval(expression, {"__builtins__": None}, {}) # nosec except Exception as e: raise ValueError( f"Error evaluating expression '{expression}': {e}") diff --git a/optics_framework/api/verifier.py b/optics_framework/api/verifier.py index 929dca6a..16c39ac0 100644 --- a/optics_framework/api/verifier.py +++ b/optics_framework/api/verifier.py @@ -1,14 +1,17 @@ +from typing import Optional, Any, List from optics_framework.common.logging_config import logger, apply_logger_format_to_all from optics_framework.common import utils from optics_framework.common.optics_builder import OpticsBuilder import time + @apply_logger_format_to_all("internal") class Verifier: """ Provides methods to verify elements, screens, and data integrity. """ _instance = None + def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = super(Verifier, cls).__new__(cls) @@ -24,66 +27,65 @@ def validate_element( element: str, timeout: int = 10, rule: str = "all", - event_name: str | None = None, + event_name: Optional[str] = None, ) -> None: """ Verifies the specified element. :param element: The element to be verified (Image template, OCR template, or XPath). - :type element: str - :param timeout: The time to wait for verification. - :type timeout: int - :param rule: The rule used for verification. - :type rule: str - :param event_name: The name of the event associated with the verification. - :type event_name: str + :param timeout: The time to wait for verification in seconds. + :param rule: The rule used for verification ("all" or "any"). + :param event_name: The name of the event associated with the verification, if any. """ logger.debug(f"Validating element: {element}") logger.debug(f"Timeout: {timeout} and Rule: {rule}") self.assert_presence(element, timeout, rule, event_name) - def is_element( - self, element: str, element_state: str, timeout: int, event_name: str + self, + element: str, + element_state: str, + timeout: int, + event_name: Optional[str] = None, ) -> None: """ - Checks if the specified element is Enabled/Disabled/Visible/Invisible. + Checks if the specified element is in a given state (e.g., Enabled/Disabled/Visible/Invisible). :param element: The element to be checked (Image template, OCR template, or XPath). - :type element: str - :param element_state: The state of the element (visible, invisible, enabled, disabled). - :type element_state: str - :param timeout: The time to wait for the element. - :type timeout: int - :param event_name: The name of the event associated with the check. - :type event_name: str + :param element_state: The state to verify (visible, invisible, enabled, disabled). + :param timeout: The time to wait for the element in seconds. + :param event_name: The name of the event associated with the check, if any. """ pass - def assert_equality(self, output, expression) -> None: + def assert_equality(self, output: Any, expression: Any, event_name: Optional[str] = None) -> None: """ Compares two values for equality. :param output: The first value to be compared. - :type output: any :param expression: The second value to be compared. - :type expression: any - :param event_name: The name of the event associated with the comparison. - :type event_name: str + :param event_name: The name of the event associated with the comparison, if any. """ pass - def vision_search(self, elements: list[str], timeout: int, rule: str) -> bool: + def vision_search(self, elements: List[str], timeout: int, rule: str) -> bool: """ - Vision based search for elements + Performs a vision-based search for elements. + + :param elements: List of elements to search for (Image templates or OCR templates). + :param timeout: The time to wait for elements to appear in seconds. + :param rule: The rule for verification ("any" or "all"). + :return: True if the rule is satisfied, False otherwise. """ rule = rule.lower() timeout = int(timeout) found_text = False found_image = False # Group elements by type - texts = [el for el in elements if utils.determine_element_type(el) == 'Text'] - images = [el for el in elements if utils.determine_element_type(el) == 'Image'] + texts = [ + el for el in elements if utils.determine_element_type(el) == 'Text'] + images = [ + el for el in elements if utils.determine_element_type(el) == 'Image'] # Shared resources element_status = { @@ -93,9 +95,9 @@ def vision_search(self, elements: list[str], timeout: int, rule: str) -> bool: start_time = time.time() while (time.time() - start_time) < timeout: - # Capture a screenshot - timestamp = utils.get_current_time_for_events() # Get timestamp when the screenshot is taken + # Get timestamp when the screenshot is taken + timestamp = utils.get_current_time_for_events() frame = self.element_source.capture() if frame is None: @@ -106,11 +108,13 @@ def vision_search(self, elements: list[str], timeout: int, rule: str) -> bool: # Search for text elements if texts: - found_text = self.assert_texts_vision(frame, texts, element_status, rule) + found_text = self.assert_texts_vision( + frame, texts, element_status, rule) # Search for image elements if images: - found_image = self.assert_images_vision(frame, images, element_status, rule) + found_image = self.assert_images_vision( + frame, images, element_status, rule) # If rule is 'any' and either text or image is found, stop early if rule == 'any' and (found_text or found_image): @@ -121,24 +125,24 @@ def vision_search(self, elements: list[str], timeout: int, rule: str) -> bool: if rule == 'all' and all( item['found'] for status in element_status.values() for item in status.values() ): + utils.annotate_and_save(frame, element_status) return True time.sleep(0.5) - # Final annotation before returning - utils.annotate_and_save(frame, element_status) + # Final annotation before returning + utils.annotate_and_save(frame, element_status) return any(item['found'] for status in element_status.values() for item in status.values()) - def assert_texts_vision(self, frame, texts, element_status, rule): + def assert_texts_vision(self, frame: Any, texts: List[str], element_status: dict, rule: str) -> bool: """ Searches for the given texts in a single frame using OCR. - Args: - frame (numpy.ndarray): The image frame to search in. - texts (list): List of text elements to search for. - element_status (dict): Dictionary storing found element statuses. - Returns: - bool: True if an element is found (for 'any' rule), False otherwise. + :param frame: The image frame to search in (e.g., numpy.ndarray). + :param texts: List of text elements to search for. + :param element_status: Dictionary storing found element statuses. + :param rule: The rule for verification ("any" or "all"). + :return: True if the rule is satisfied, False otherwise. """ found_any = False @@ -149,7 +153,8 @@ def assert_texts_vision(self, frame, texts, element_status, rule): if found: if not element_status['texts'][text]['found']: - element_status['texts'][text] = {'found': True, 'bbox': bbox} + element_status['texts'][text] = { + 'found': True, 'bbox': bbox} logger.debug(f"Text '{text}' found at bbox: {bbox}.") found_any = True @@ -163,18 +168,15 @@ def assert_texts_vision(self, frame, texts, element_status, rule): item['found'] for item in element_status['texts'].values() ) - - def assert_images_vision(self, frame, images, element_status, rule): + def assert_images_vision(self, frame: Any, images: List[str], element_status: dict, rule: str) -> bool: """ Searches for the given images in a single frame using template matching. - Args: - frame (numpy.ndarray): The image frame to search in. - images (list): List of image templates to search for. - element_status (dict): Dictionary storing found element statuses. - rule (str): 'any' (stop when one is found) or 'all' (search all). - Returns: - bool: True if an element is found (for 'any' rule), False otherwise. + :param frame: The image frame to search in (e.g., numpy.ndarray). + :param images: List of image templates to search for. + :param element_status: Dictionary storing found element statuses. + :param rule: The rule for verification ("any" or "all"). + :return: True if the rule is satisfied, False otherwise. """ found_any = False @@ -199,60 +201,60 @@ def assert_images_vision(self, frame, images, element_status, rule): item['found'] for item in element_status['images'].values() ) - def assert_presence(self, elements, timeout=30, rule='any', event_name=None) -> bool: + def assert_presence(self, elements: str, timeout: int = 30, rule: str = 'any', event_name: Optional[str] = None) -> bool: """ Asserts the presence of elements. - :param elements: The elements to be checked (Image template, OCR template, or XPath). - :type elements: list - :param timeout: The time to wait for the elements. - :type timeout: int - :param rule: The rule used for verification. - :type rule: str - :param event_name: The name of the event associated with the assertion. - :type event_name: str + :param elements: Comma-separated string of elements to check (Image templates, OCR templates, or XPaths). + :param timeout: The time to wait for the elements in seconds. + :param rule: The rule for verification ("any" or "all"). + :param event_name: The name of the event associated with the assertion, if any. + :return: True if the rule is satisfied, False otherwise. """ - element_source_type = type(self.element_source.current_instance).__name__ + element_source_type = type( + self.element_source.current_instance).__name__ rule = rule.lower() timeout = int(timeout) - elements = elements.split(',') + elements_list = elements.split(',') # Group elements by type - texts = [el for el in elements if utils.determine_element_type(el) == 'Text'] - xpaths = [el for el in elements if utils.determine_element_type(el) == 'XPath'] - images = [el for el in elements if utils.determine_element_type(el) == 'Image'] + texts = [ + el for el in elements_list if utils.determine_element_type(el) == 'Text'] + xpaths = [ + el for el in elements_list if utils.determine_element_type(el) == 'XPath'] + images = [ + el for el in elements_list if utils.determine_element_type(el) == 'Image'] if 'appium' in element_source_type.lower(): - # calls assert presence from appium driver + # Calls assert presence from appium driver if images: - logger.error("Image search is not supported for Appium based search") + logger.error( + "Image search is not supported for Appium based search") return False texts_xpaths = texts + xpaths - result = self.element_source.assert_elements(texts_xpaths, timeout, rule) - + result = self.element_source.assert_elements( + texts_xpaths, timeout, rule) else: - # vision search + # Vision search if xpaths: - logger.error("XPath search is not supported for Vision based search") + logger.error( + "XPath search is not supported for Vision based search") return False texts_images = texts + images result = self.vision_search(texts_images, timeout, rule) + if event_name: - # Trigger event + # Trigger event (placeholder) pass return result - def validate_screen(self, elements, timeout=30, rule='any', event_name=None) -> None: + def validate_screen(self, elements: str, timeout: int = 30, rule: str = 'any', event_name: Optional[str] = None) -> None: """ - Verifies the specified screen. - - :param elements: The elements to be verified (Image template, OCR template, or XPath). - :type elements: list - :param timeout: The time to wait for verification. - :type timeout: int - :param rule: The rule used for verification. - :type rule: str - :param event_name: The name of the event associated with the verification. - :type event_name: str + Verifies the specified screen by checking element presence. + + :param elements: Comma-separated string of elements to verify (Image templates, OCR templates, or XPaths). + :param timeout: The time to wait for verification in seconds. + :param rule: The rule for verification ("any" or "all"). + :param event_name: The name of the event associated with the verification, if any. """ self.assert_presence(elements, timeout, rule, event_name) diff --git a/optics_framework/common/base_factory.py b/optics_framework/common/base_factory.py index 522d52a1..acdd18ed 100644 --- a/optics_framework/common/base_factory.py +++ b/optics_framework/common/base_factory.py @@ -3,9 +3,11 @@ import importlib import pkgutil import inspect +from pydantic import BaseModel, Field from optics_framework.common.logging_config import logger, apply_logger_format_to_all T = TypeVar("T") +S = TypeVar("S") # New TypeVar for FactoryState @apply_logger_format_to_all("internal") @@ -13,9 +15,17 @@ class GenericFactory(Generic[T]): """ A generic factory class for discovering and instantiating modules dynamically. """ - _MODULES: Dict[str, str] = {} - _INSTANCES: Dict[str, T] = {} - _DISCOVERED: bool = False + class FactoryState(BaseModel, Generic[S]): + """Pydantic model to manage factory state.""" + modules: Dict[str, str] = Field( + default_factory=dict) # {name: module_path} + instances: Dict[str, S] = Field( + default_factory=dict) # {name: instance} + + class Config: + arbitrary_types_allowed = True # Allow generic S + + _state: FactoryState[T] = FactoryState() # Bind to T from GenericFactory @classmethod def discover(cls, package: str) -> None: @@ -43,7 +53,7 @@ def _recursive_discover(cls, package_paths, base_package: str) -> None: """ for _, module_name, is_pkg in pkgutil.iter_modules(package_paths): full_module_name = f"{base_package}.{module_name}" - cls._MODULES[module_name] = full_module_name + cls._state.modules[module_name] = full_module_name logger.debug(f"Registered module: {full_module_name}") if is_pkg: cls._discover_subpackage(full_module_name) @@ -59,7 +69,7 @@ def _discover_subpackage(cls, full_module_name: str) -> None: f"Failed to import subpackage '{full_module_name}': {e}") @staticmethod - def _find_class(module, interface: Type[T]) -> Optional[Type[T]]: + def _find_class(module: ModuleType, interface: Type[T]) -> Optional[Type[T]]: """ Find a class in the module that implements the specified interface. """ @@ -76,12 +86,15 @@ def get(cls, name: Union[str, List[Union[str, dict]], None], interface: Type[T]) cls._ensure_discovery(interface) if isinstance(name, (list, dict)): return cls._get_fallback_instance(name, interface) + if name is None: + raise ValueError( + "Name cannot be None for single instance retrieval") return cls._get_single_instance(name, interface) @classmethod def _ensure_discovery(cls, interface: Type[T]) -> None: """Ensure modules have been discovered.""" - if not cls._MODULES: + if not cls._state.modules: raise RuntimeError( f"No modules discovered for {interface.__name__}. Call `discover` first.") @@ -115,11 +128,11 @@ def _get_fallback_instance(cls, name: Union[List[Union[str, dict]], dict], inter @classmethod def _get_single_instance(cls, name: str, interface: Type[T]) -> T: """Retrieve or create a single instance for a module name.""" - if name in cls._INSTANCES: + if name in cls._state.instances: logger.debug(f"Returning cached instance for: {name}") - return cls._INSTANCES[name] + return cls._state.instances[name] - module_path = cls._MODULES.get(name) + module_path = cls._state.modules.get(name) # pylint: disable=no-member if not module_path: raise ValueError(f"Unknown module requested: '{name}'") @@ -130,8 +143,7 @@ def _get_single_instance(cls, name: str, interface: Type[T]) -> T: f"No valid class found in '{module_path}' implementing {interface.__name__}") instance = cls_obj() - if name is not None: - cls._INSTANCES[name] = instance + cls._state.instances[name] = instance logger.debug( f"Successfully instantiated {cls_obj.__name__} from {module_path}") return instance @@ -139,12 +151,20 @@ def _get_single_instance(cls, name: str, interface: Type[T]) -> T: @classmethod def clear_cache(cls) -> None: """Clear cached instances.""" - cls._INSTANCES.clear() + cls._state.instances.clear() # pylint: disable=no-member logger.debug("Cleared instance cache.") -class FallbackProxy(Generic[T]): - def __init__(self, instances: List[T]): - self.instances = instances + +class FallbackProxy(BaseModel, Generic[T]): + """Pydantic-based proxy for fallback instances.""" + instances: List[T] = Field(default_factory=list) + current_instance: Optional[T] = None + + class Config: + arbitrary_types_allowed = True # Allow generic T + + def __init__(self, instances: List[T], **data): + super().__init__(instances=instances, **data) self.current_instance = instances[0] if instances else None def __getattr__(self, attr): diff --git a/optics_framework/common/config_handler.py b/optics_framework/common/config_handler.py index cd233d49..9f08644a 100644 --- a/optics_framework/common/config_handler.py +++ b/optics_framework/common/config_handler.py @@ -1,10 +1,11 @@ import os import yaml from typing import List, Dict, Any, Optional -from dataclasses import dataclass, asdict, field +from pydantic import BaseModel, Field from collections.abc import Mapping import logging + def deep_merge(d1: dict, d2: dict) -> dict: """ Recursively merge two dictionaries, giving priority to d2. @@ -18,30 +19,29 @@ def deep_merge(d1: dict, d2: dict) -> dict: return merged -@dataclass -class DependencyConfig: +class DependencyConfig(BaseModel): """Configuration for all dependency types.""" enabled: bool url: Optional[str] = None - capabilities: Optional[Dict[str, Any]] = None + capabilities: Dict[str, Any] = Field(default_factory=dict) - def __post_init__(self): - if self.capabilities is None: - self.capabilities = {} + class Config: + """Pydantic V2 configuration.""" + # Allow arbitrary types for capabilities (e.g., Any) + arbitrary_types_allowed = True -@dataclass -class Config: +class Config(BaseModel): """Default configuration structure.""" console: bool = True driver_sources: List[Dict[str, DependencyConfig] - ] = field(default_factory=list) + ] = Field(default_factory=list) elements_sources: List[Dict[str, DependencyConfig] - ] = field(default_factory=list) + ] = Field(default_factory=list) text_detection: List[Dict[str, DependencyConfig] - ] = field(default_factory=list) + ] = Field(default_factory=list) image_detection: List[Dict[str, DependencyConfig] - ] = field(default_factory=list) + ] = Field(default_factory=list) file_log: bool = False json_log: bool = False json_path: Optional[str] = None @@ -49,27 +49,22 @@ class Config: log_path: Optional[str] = None project_path: Optional[str] = None - def __post_init__(self): + def __init__(self, **data): + super().__init__(**data) + # Post-init logic from the original dataclass if not self.driver_sources: self.driver_sources = [ {"appium": DependencyConfig( enabled=False, url="http://127.0.0.1:4723", - capabilities={ - "deviceName": None, - "platformName": None, - "automationName": None, - "appPackage": None, - "appActivity": None - } - )}, + capabilities={})}, {"ble": DependencyConfig( enabled=False, url=None, capabilities={})} ] if not self.elements_sources: self.elements_sources = [ {"appium_find_element": DependencyConfig( - enabled=True, url=None, capabilities={})}, + enabled=False, url=None, capabilities={})}, {"appium_page_source": DependencyConfig( enabled=False, url=None, capabilities={})}, {"device_screenshot": DependencyConfig( @@ -92,6 +87,11 @@ def __post_init__(self): enabled=False, url=None, capabilities={})} ] + class Config: + """Pydantic V2 configuration.""" + # Allow arbitrary types in capabilities + arbitrary_types_allowed = True + class ConfigHandler: _instance = None @@ -133,13 +133,14 @@ def _ensure_global_config(self) -> None: os.makedirs(os.path.dirname( self.global_config_path), exist_ok=True) with open(self.global_config_path, "w", encoding="utf-8") as f: - yaml.dump(asdict(self.config), f, default_flow_style=False) + yaml.dump(self.config.model_dump(), + f, default_flow_style=False) def load(self) -> Config: global_config = self._load_yaml(self.global_config_path) or {} project_config = self._load_yaml( self.project_config_path) if self.project_config_path else {} - default_dict = asdict(self.config) + default_dict = self.config.model_dump() merged = deep_merge(default_dict, global_config) merged = deep_merge(merged, project_config) self.config = Config(**merged) @@ -164,15 +165,17 @@ def _precompute_enabled_configs(self) -> None: for name, details in item.items(): if isinstance(details, dict) and details.get("enabled", False): enabled_names.append(name) + elif isinstance(details, DependencyConfig) and details.enabled: + enabled_names.append(name) self._enabled_configs[key] = enabled_names # Store just the names def get_dependency_config(self, dependency_type: str, name: str) -> Optional[Dict[str, Any]]: # Still available if detailed config is needed elsewhere for item in getattr(self.config, dependency_type): - if name in item and item[name].get("enabled", False): + if name in item and item[name].enabled: return { - "url": item[name].get("url"), - "capabilities": item[name].get("capabilities", {}) + "url": item[name].url, + "capabilities": item[name].capabilities } return None @@ -184,7 +187,8 @@ def get(self, key: str, default: Any = None) -> Any: def save_config(self) -> None: try: with open(self.global_config_path, "w", encoding="utf-8") as f: - yaml.dump(asdict(self.config), f, default_flow_style=False) + yaml.dump(self.config.model_dump(), + f, default_flow_style=False) except Exception as e: raise e diff --git a/optics_framework/common/driver_interface.py b/optics_framework/common/driver_interface.py index 2bf4b194..0922087e 100644 --- a/optics_framework/common/driver_interface.py +++ b/optics_framework/common/driver_interface.py @@ -32,9 +32,8 @@ def get_app_version(self) -> str: """ pass - @abstractmethod - def press_coordinates(self,coor_x, coor_y, event_name) -> None: + def press_coordinates(self, coor_x, coor_y, event_name) -> None: """ Press an element by absolute coordinates. :param coor_x: X coordinate of the press. @@ -48,7 +47,7 @@ def press_coordinates(self,coor_x, coor_y, event_name) -> None: pass @abstractmethod - def press_element(self, element,repeat, event_name) -> None: + def press_element(self, element, repeat, event_name) -> None: """ Press an element using Appium. :param element: The element to be pressed. @@ -205,7 +204,6 @@ def scroll(self, direction, duration, event_name) -> None: """ pass - @abstractmethod def get_text_element(self, element) -> str: """ diff --git a/optics_framework/common/elementsource_interface.py b/optics_framework/common/elementsource_interface.py index 0abafb2d..f659a9e1 100644 --- a/optics_framework/common/elementsource_interface.py +++ b/optics_framework/common/elementsource_interface.py @@ -1,42 +1,54 @@ from abc import ABC, abstractmethod +from typing import Optional, Tuple, Any + class ElementSourceInterface(ABC): """ - Abstract base class for application drivers. + Abstract base class for element source drivers. + + This interface defines methods for capturing and interacting with screen elements + (e.g., images, UI components) within an application or environment, implementing + the :class:`ElementSourceInterface`. - This interface enforces the implementation of essential methods - for interacting with applications. + Implementers should handle specific element types (e.g., image bytes, templates) + as needed. """ @abstractmethod - def capture(self): + def capture(self) -> None: """ Capture the current screen state. :return: None :rtype: None """ - + pass @abstractmethod def locate(self, element, index=None, strategy=None) -> tuple: """ - Locate a template image within a larger image. + Locate an element within the current screen state. - :param image: The image to search. - :param template: The template to search for. - :return: A tuple containing the coordinates of the located template. - :rtype: tuple + :param element: The element to search for (e.g., template image, UI component). + :type element: Any + :return: A tuple (x, y) representing the center of the element, or None if not found. + :rtype: Optional[Tuple[int, int]] """ pass @abstractmethod - def assert_elements(self, elements, timeout=30, rule='any') -> None: + def assert_elements(self, elements: Any, timeout: int = 30, rule: str = 'any') -> None: """ Assert the presence of elements on the screen. - :param elements: The elements to be asserted. - :raises NotImplementedError: If the method is not implemented in a subclass. + + :param elements: The elements to check for presence (e.g., list of templates). + :type elements: Any + :param timeout: Time in seconds to wait for elements to appear (default: 30). + :type timeout: int + :param rule: Assertion rule ('any' for at least one, 'all' for all; default: 'any'). + :type rule: str :return: None :rtype: None + :raises AssertionError: If the assertion fails based on the rule. """ pass diff --git a/optics_framework/common/execution.py b/optics_framework/common/execution.py index 559dc13d..1cfe6b89 100644 --- a/optics_framework/common/execution.py +++ b/optics_framework/common/execution.py @@ -1,7 +1,7 @@ -# execution.py import asyncio from abc import ABC, abstractmethod -from typing import Optional, Dict +from typing import Optional, Dict, List +from pydantic import BaseModel, Field, ConfigDict from optics_framework.common.session_manager import SessionManager, Session from optics_framework.common.runner.test_runnner import TestRunner, TreeResultPrinter, TerminalWidthProvider, IResultPrinter, TestCaseResult from optics_framework.common.runner.keyword_register import KeywordRegistry @@ -35,6 +35,40 @@ def start_run(self, total_test_cases: int) -> None: pass +class TestCaseData(BaseModel): + """Structure for test cases loaded from CSV.""" + test_cases: Dict[str, List[str]] = Field( + default_factory=dict) # {test_case: [test_steps]} + + +class ModuleData(BaseModel): + """Structure for modules loaded from CSV.""" + modules: Dict[str, List[tuple[str, List[str]]]] = Field( + default_factory=dict) # {module_name: [(module_step, [params])]} + + +class ElementData(BaseModel): + """Structure for elements loaded from CSV.""" + elements: Dict[str, str] = Field( + default_factory=dict) # {element_name: element_id} + + +class ExecutionParams(BaseModel): + """Parameters for ExecutionEngine.execute.""" + model_config = ConfigDict( + arbitrary_types_allowed=True) # Allow asyncio.Queue + + session_id: str + mode: str + test_case: Optional[str] = None + keyword: Optional[str] = None + params: List[str] = Field(default_factory=list) + event_queue: Optional[asyncio.Queue] = None + test_cases: TestCaseData = Field(default_factory=TestCaseData) + modules: ModuleData = Field(default_factory=ModuleData) + elements: ElementData = Field(default_factory=ElementData) + + class Executor(ABC): """Abstract base class for execution strategies.""" @abstractmethod @@ -64,12 +98,12 @@ async def execute(self, session: Session, runner: TestRunner, event_queue: Optio await event_queue.put({"execution_id": session.session_id, "status": status, "message": message}) raise ValueError(message) result = runner.execute_test_case(self.test_case) - status = "PASS" if result["status"] == "PASS" else "FAIL" + status = "PASS" if result.status == "PASS" else "FAIL" message = f"Test case {self.test_case} completed with status {status}" else: runner.run_all() status = "PASS" if all( - tc["status"] == "PASS" for tc in runner.result_printer.test_state.values()) else "FAIL" + tc.status == "PASS" for tc in runner.result_printer.test_state.values()) else "FAIL" message = "All test cases completed" if event_queue: @@ -98,12 +132,12 @@ async def execute(self, session: Session, runner: TestRunner, event_queue: Optio await event_queue.put({"execution_id": session.session_id, "status": status, "message": message}) raise ValueError(message) result = runner.dry_run_test_case(self.test_case) - status = "PASS" if result["status"] == "PASS" else "FAIL" + status = "PASS" if result.status == "PASS" else "FAIL" message = f"Dry run for test case {self.test_case} completed with status {status}" else: runner.dry_run_all() status = "PASS" if all( - tc["status"] == "PASS" for tc in runner.result_printer.test_state.values()) else "FAIL" + tc.status == "PASS" for tc in runner.result_printer.test_state.values()) else "FAIL" message = "All test cases dry run completed" if event_queue: @@ -113,7 +147,7 @@ async def execute(self, session: Session, runner: TestRunner, event_queue: Optio class KeywordExecutor(Executor): """Executes a single keyword.""" - def __init__(self, keyword: str, params: list[str]): + def __init__(self, keyword: str, params: List[str]): self.keyword = keyword self.params = params @@ -145,9 +179,9 @@ class RunnerFactory: def create_runner( session: Session, use_printer: bool, - test_cases: Dict, - modules: Dict, - elements: Dict + test_cases: Dict[str, List[str]], + modules: Dict[str, List[tuple[str, List[str]]]], + elements: Dict[str, str] ) -> TestRunner: result_printer = TreeResultPrinter( TerminalWidthProvider()) if use_printer else NullResultPrinter() @@ -155,7 +189,7 @@ def create_runner( test_cases=test_cases, modules=modules, elements=elements, - keyword_map={}, + keyword_map={}, # Initialize with empty dict result_printer=result_printer ) registry = KeywordRegistry() @@ -167,7 +201,7 @@ def create_runner( registry.register(app_management) registry.register(flow_control) registry.register(verifier) - runner.keyword_map = registry.keyword_map + runner.keyword_map = registry.keyword_map # No type hint needed now return runner @@ -183,54 +217,70 @@ async def execute( mode: str, test_case: Optional[str] = None, keyword: Optional[str] = None, - params: list[str] = [], + params: Optional[List[str]] = None, event_queue: Optional[asyncio.Queue] = None, - test_cases: Optional[Dict] = None, - modules: Optional[Dict] = None, - elements: Optional[Dict] = None + test_cases: Optional[Dict[str, List[str]]] = None, + modules: Optional[Dict[str, List[tuple[str, List[str]]]]] = None, + elements: Optional[Dict[str, str]] = None ) -> None: - session = self.session_manager.get_session(session_id) + if params is None: + params = [] + params_model = ExecutionParams( + session_id=session_id, + mode=mode, + test_case=test_case, + keyword=keyword, + params=params, + event_queue=event_queue, + test_cases=TestCaseData(test_cases=test_cases or {}), + modules=ModuleData(modules=modules or {}), + elements=ElementData(elements=elements or {}) + ) + + session = self.session_manager.get_session(params_model.session_id) if not session: - if event_queue: - await event_queue.put({"status": "ERROR", "message": "Session not found"}) + if params_model.event_queue: + await params_model.event_queue.put({"status": "ERROR", "message": "Session not found"}) raise ValueError("Session not found") runner = RunnerFactory.create_runner( session, - use_printer=event_queue is None, - test_cases=test_cases or {}, - modules=modules or {}, - elements=elements or {} + use_printer=params_model.event_queue is None, + # Pylint suppression for false positives on Pydantic field access + test_cases=params_model.test_cases.test_cases, # pylint: disable=no-member + modules=params_model.modules.modules, # pylint: disable=no-member + elements=params_model.elements.elements # pylint: disable=no-member ) if runner.result_printer: runner.result_printer.start_live() - if mode == "batch": - executor = BatchExecutor(test_case) - elif mode == "dry_run": - executor = DryRunExecutor(test_case) - elif mode == "keyword": - if not keyword: - if event_queue: - await event_queue.put({"status": "ERROR", "message": "Keyword mode requires a keyword"}) + if params_model.mode == "batch": + executor = BatchExecutor(params_model.test_case) + elif params_model.mode == "dry_run": + executor = DryRunExecutor(params_model.test_case) + elif params_model.mode == "keyword": + if not params_model.keyword: + if params_model.event_queue: + await params_model.event_queue.put({"status": "ERROR", "message": "Keyword mode requires a keyword"}) raise ValueError("Keyword mode requires a keyword") - executor = KeywordExecutor(keyword, params) + executor = KeywordExecutor( + params_model.keyword, params_model.params) else: - if event_queue: - await event_queue.put({"status": "ERROR", "message": f"Unknown mode: {mode}"}) - raise ValueError(f"Unknown mode: {mode}") + if params_model.event_queue: + await params_model.event_queue.put({"status": "ERROR", "message": f"Unknown mode: {params_model.mode}"}) + raise ValueError(f"Unknown mode: {params_model.mode}") try: - if event_queue: - await event_queue.put({ - "execution_id": session_id, + if params_model.event_queue: + await params_model.event_queue.put({ + "execution_id": params_model.session_id, "status": "RUNNING", - "message": f"Starting {mode} execution" + "message": f"Starting {params_model.mode} execution" }) - await executor.execute(session, runner, event_queue) + await executor.execute(session, runner, params_model.event_queue) except Exception as e: - if event_queue: - await event_queue.put({"status": "FAIL", "message": f"Execution failed: {str(e)}"}) + if params_model.event_queue: + await params_model.event_queue.put({"status": "FAIL", "message": f"Execution failed: {str(e)}"}) raise finally: if runner.result_printer: diff --git a/optics_framework/common/expose_api.py b/optics_framework/common/expose_api.py index b64f30ed..3b554b9c 100644 --- a/optics_framework/common/expose_api.py +++ b/optics_framework/common/expose_api.py @@ -1,4 +1,3 @@ -# api.py import json import uuid import asyncio @@ -9,15 +8,15 @@ from optics_framework.common.session_manager import SessionManager, Session from optics_framework.common.execution import ExecutionEngine from optics_framework.common.logging_config import logger, apply_logger_format_to_all +from optics_framework.common.config_handler import Config, DependencyConfig app = FastAPI(title="Optics Framework API", version="1.0") session_manager = SessionManager() class SessionConfig(BaseModel): - """Schema for session creation.""" - driver_sources: list[str] - app_param: dict = {} + """Schema for session creation, aligned with Config expectations.""" + driver_sources: list[str] # Names of enabled drivers elements_sources: list[str] = [] text_detection: list[str] = [] image_detection: list[str] = [] @@ -32,28 +31,72 @@ class ExecuteRequest(BaseModel): params: list[str] = [] -@app.post("/v1/sessions") +class SessionResponse(BaseModel): + """Schema for session creation response.""" + session_id: str + status: str = "created" + + +class ExecutionResponse(BaseModel): + """Schema for execution response.""" + execution_id: str + status: str = "started" + + +class TerminationResponse(BaseModel): + """Schema for session termination response.""" + status: str = "terminated" + + +class ExecutionEvent(BaseModel): + """Schema for execution event payloads.""" + execution_id: str + status: str # e.g., "ERROR", "FAIL", "SUCCESS" + message: Optional[str] = None + + +@app.post("/v1/sessions", response_model=SessionResponse) async def create_session(config: SessionConfig): - """Creates a new session.""" + """ + Creates a new session with the specified configuration. + + :param config: Configuration for the session (enabled dependency names). + :return: Details of the created session. + """ try: - session_config = config.dict() + # Transform SessionConfig into Config format + session_config_dict = { + "driver_sources": [{"name": DependencyConfig(enabled=True)} for name in config.driver_sources], + "elements_sources": [{"name": DependencyConfig(enabled=True)} for name in config.elements_sources], + "text_detection": [{"name": DependencyConfig(enabled=True)} for name in config.text_detection], + "image_detection": [{"name": DependencyConfig(enabled=True)} for name in config.image_detection], + "project_path": config.project_path + } + session_config = Config(**session_config_dict) session_id = session_manager.create_session(session_config) logger.info( - f"Created session {session_id} with config: {session_config}") - return {"session_id": session_id, "status": "created"} + f"Created session {session_id} with config: {config.model_dump()}") + return SessionResponse(session_id=session_id) except Exception as e: logger.error(f"Failed to create session: {e}") raise HTTPException( status_code=500, detail=f"Session creation failed: {e}") -@app.post("/v1/sessions/{session_id}/execute") +@app.post("/v1/sessions/{session_id}/execute", response_model=ExecutionResponse) async def execute( session_id: str, request: ExecuteRequest, background_tasks: BackgroundTasks ): - """Triggers execution in a session.""" + """ + Triggers execution in a session as a background task. + + :param session_id: ID of the session to execute in. + :param request: Execution request details. + :param background_tasks: FastAPI background task handler. + :return: Execution start confirmation. + """ session = session_manager.get_session(session_id) if not session: logger.error(f"Session not found: {session_id}") @@ -61,7 +104,7 @@ async def execute( execution_id = str(uuid.uuid4()) logger.info( - f"Starting execution {execution_id} for session {session_id} with request: {request.dict()}") + f"Starting execution {execution_id} for session {session_id} with request: {request.model_dump()}") engine = ExecutionEngine(session_manager) background_tasks.add_task( @@ -75,12 +118,17 @@ async def execute( request.params, session.event_queue ) - return {"execution_id": execution_id, "status": "started"} + return ExecutionResponse(execution_id=execution_id) @app.get("/v1/sessions/{session_id}/events") async def stream_events(session_id: str): - """Streams execution events for a session.""" + """ + Streams execution events for a session via Server-Sent Events (SSE). + + :param session_id: ID of the session to stream events from. + :return: SSE stream of execution events. + """ session = session_manager.get_session(session_id) if not session: logger.error(f"Session not found for event streaming: {session_id}") @@ -90,13 +138,18 @@ async def stream_events(session_id: str): return EventSourceResponse(event_generator(session)) -@app.delete("/v1/sessions/{session_id}") +@app.delete("/v1/sessions/{session_id}", response_model=TerminationResponse) async def delete_session(session_id: str): - """Terminates a session.""" + """ + Terminates a session. + + :param session_id: ID of the session to terminate. + :return: Termination confirmation. + """ try: session_manager.terminate_session(session_id) logger.info(f"Terminated session: {session_id}") - return {"status": "terminated"} + return TerminationResponse() except Exception as e: logger.error(f"Failed to terminate session {session_id}: {e}") raise HTTPException( @@ -120,11 +173,13 @@ async def run_execution( logger.error( f"Session {session_id} not found or invalid during execution {execution_id}") if event_queue: - await event_queue.put({ - "execution_id": execution_id, - "status": "ERROR", - "message": "Session not found or invalid" - }) + await event_queue.put( + ExecutionEvent( + execution_id=execution_id, + status="ERROR", + message="Session not found or invalid" + ).model_dump() + ) return try: @@ -137,13 +192,22 @@ async def run_execution( event_queue=event_queue ) logger.info(f"Execution {execution_id} completed successfully") + await event_queue.put( + ExecutionEvent( + execution_id=execution_id, + status="SUCCESS", + message="Execution completed" + ).model_dump() + ) except Exception as e: logger.error(f"Execution {execution_id} failed: {e}") - await event_queue.put({ - "execution_id": execution_id, - "status": "FAIL", - "message": f"Execution failed: {str(e)}" - }) + await event_queue.put( + ExecutionEvent( + execution_id=execution_id, + status="FAIL", + message=f"Execution failed: {str(e)}" + ).model_dump() + ) @apply_logger_format_to_all("user") @@ -158,7 +222,13 @@ async def event_generator(session: Session): except Exception as e: logger.error( f"Error in event streaming for session {session.session_id}: {e}") - yield {"data": json.dumps({"status": "ERROR", "message": f"Event streaming failed: {e}"})} + yield {"data": json.dumps( + ExecutionEvent( + execution_id="unknown", + status="ERROR", + message=f"Event streaming failed: {e}" + ).model_dump() + )} break diff --git a/optics_framework/common/image_interface.py b/optics_framework/common/image_interface.py index 5f45e9bc..bf1b7cb6 100644 --- a/optics_framework/common/image_interface.py +++ b/optics_framework/common/image_interface.py @@ -1,17 +1,21 @@ from abc import ABC -from typing import Optional, Tuple +from typing import Optional, Tuple, Any class ImageInterface(ABC): """ - Abstract base class for vision-based detection models. + Abstract base class for image processing engines. - This interface enforces the implementation of a :meth:`detect` method, - which processes input data to identify specific objects, patterns, or text. + This interface defines methods for detecting and locating images or objects + within input data (e.g., images or video frames), implementing the + :class:`ImageInterface`. + + Implementers should handle specific input types (e.g., image bytes, file paths) + and reference data as needed. """ - def element_exist(self, input_data, reference_data) -> Optional[Tuple[int, int]]: + def element_exist(self, input_data: Any, reference_data: Any) -> Optional[Tuple[int, int]]: """ Find the location of a reference image within the input data. @@ -19,37 +23,38 @@ def element_exist(self, input_data, reference_data) -> Optional[Tuple[int, int]] :type input_data: Any :param reference_data: The reference data used for matching or comparison. :type reference_data: Any - :return: A tuple ``(x, y)`` representing the top-left corner of the reference image, - or ``None`` if the image is not found. + :return: A tuple (x, y) representing the top-left corner of the reference image, + or None if not found. :rtype: Optional[Tuple[int, int]] """ pass def locate(self, input_data, image, index=None) -> Optional[Tuple[int, int]]: """ - Find the location of text within the input data. + Find the location of a reference image within the input data. :param input_data: The input source (e.g., image, video frame) for detection. - :type - input_data: Any - :param text: The text to search for. - :type text: str - :return: A tuple ``(x, y)`` representing the centre of the text, - or ``None`` if the text is not found. + :type input_data: Any + :param image: The reference image to search for. + :type image: Any + :return: A tuple (x, y) representing the center of the image, or None if not found. :rtype: Optional[Tuple[int, int]] """ pass def find_element(self, input_data, image, index=None) -> Optional[Tuple[bool, Tuple[int, int], Tuple[Tuple[int, int], Tuple[int, int]]]]: """ - Find the location of an image within the input data. + Locate a specific image in the input data and return detailed detection info. :param input_data: The input source (e.g., image, video frame) for detection. :type input_data: Any - :param image: The image to search for. + :param image: The reference image to locate. :type image: Any - :return: A tuple containing a boolean indicating whether the image was found, - the center coordinates of the image, and the bounding box coordinates. + :return: A tuple (found, center, bounds) where: + - found: bool indicating if the image was found + - center: (x, y) coordinates of the image center + - bounds: ((x1, y1), (x2, y2)) bounding box (top-left, bottom-right) + Returns None if not found. :rtype: Optional[Tuple[bool, Tuple[int, int], Tuple[Tuple[int, int], Tuple[int, int]]]] """ pass diff --git a/optics_framework/common/logging_config.py b/optics_framework/common/logging_config.py index 4e64f493..5c3fe23b 100644 --- a/optics_framework/common/logging_config.py +++ b/optics_framework/common/logging_config.py @@ -1,78 +1,93 @@ import json -import logging.config +import logging +from typing import Optional, Dict, Any, Callable from rich.logging import RichHandler from contextlib import contextmanager import threading from pathlib import Path from functools import wraps -from optics_framework.common.config_handler import ConfigHandler +from optics_framework.common.config_handler import ConfigHandler, Config from pythonjsonlogger.json import JsonFormatter -# --- Thread-local variable for temporary configuration --- +# Thread-local variable for temporary configuration _log_context = threading.local() -_log_context.format_type = "internal" # default mode is "internal" +_log_context.format_type = "internal" # Default mode is "internal" # Get the singleton config instance config_handler = ConfigHandler.get_instance() -config = config_handler.config +config: Config = config_handler.config # Type as Pydantic Config logger = logging.getLogger("optics_framework") -# --- Dynamic Filter --- class DynamicFilter(logging.Filter): - def filter(self, record): + """Filter logs based on mode and configured level.""" + + def filter(self, record: logging.LogRecord) -> bool: + """Filters log records dynamically. + + :param record: The log record to filter. + :return: True if the record should be logged, False otherwise. + """ current = getattr(_log_context, "format_type", "internal") - config_level_name = str( - config_handler.get("log_level", "INFO")).upper() + config_level_name = str(config.log_level).upper() # Use Config field config_level = getattr(logging, config_level_name, logging.INFO) - # Apply level filtering only in "user" mode for console; "internal" mode sees all + # Apply level filtering only in "user" mode for console; "internal" sees all if current == "user": return record.levelno >= config_level return True # "internal" mode passes all logs to handlers -# --- Universal Formatter --- - class UniversalFormatter(logging.Formatter): - def __init__(self): - internal_fmt = ( - "[%(asctime)s] [%(levelname)-8s] %(message)-65s | " - "%(name)s:%(funcName)s:%(lineno)d" - ) + """Formatter switching between internal and user formats.""" + + def __init__(self) -> None: + internal_fmt = ("%(message)-65s") user_fmt = "%(message)s" - datefmt = "%Y-%m-%d %H:%M:%S" + datefmt = "%H:%M:%S" self.datefmt = datefmt self.internal_formatter = logging.Formatter( internal_fmt, datefmt=datefmt) self.user_formatter = logging.Formatter(user_fmt) - def format(self, record): + def format(self, record: logging.LogRecord) -> str: + """Formats the log record based on current mode. + + :param record: The log record to format. + :return: Formatted log string. + """ fmt = getattr(_log_context, "format_type", "internal") for attr in ["test_case", "test_module", "keyword"]: if not hasattr(record, attr): setattr(record, attr, "N/A") if fmt == "user": return self.user_formatter.format(record) - else: - return self.internal_formatter.format(record) + return self.internal_formatter.format(record) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if name == "_style": fmt = getattr(_log_context, "format_type", "internal") if fmt == "user": return self.user_formatter._style - else: - return self.internal_formatter._style + return self.internal_formatter._style raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{name}'") class EnhancedJsonFormatter(JsonFormatter): - def add_fields(self, log_record, record, message_dict): + """JSON formatter with enhanced fields.""" + + def add_fields(self, log_record: Dict[str, Any], record: logging.LogRecord, message_dict: Dict[str, Any]) -> None: + """Adds custom fields to the JSON log record. + + :param log_record: The dictionary to populate with log data. + :param record: The original log record. + :param message_dict: The message dictionary from the record. + """ super().add_fields(log_record, record, message_dict) log_record["timestamp"] = self.formatTime(record) log_record["level"] = record.levelname - log_record["message"] = record.msg + # Use getMessage() for consistency + log_record["message"] = record.getMessage() log_record["test_case"] = getattr(record, "test_case", "N/A") log_record["test_module"] = getattr(record, "test_module", "N/A") log_record["keyword"] = getattr(record, "keyword", "N/A") @@ -84,17 +99,23 @@ def add_fields(self, log_record, record, message_dict): class HierarchicalJsonHandler(logging.Handler): - """ - Custom logging handler that accumulates log records into a nested dictionary. - """ + """Custom handler that accumulates logs into a nested JSON structure.""" - def __init__(self, filename): + def __init__(self, filename: str | Path) -> None: + """Initializes the handler with a target file. + + :param filename: Path to the JSON log file. + """ super().__init__() - self.filename = filename - self.logs = {} + self.filename = str(filename) # Ensure string for consistency + self.logs: Dict[str, Dict[str, Dict[str, list[Dict[str, Any]]]]] = {} self.setFormatter(EnhancedJsonFormatter()) - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: + """Emits a log record to the nested dictionary. + + :param record: The log record to emit. + """ try: json_record = self.format(record) log_entry = json.loads(json_record) @@ -108,21 +129,21 @@ def emit(self, record): except Exception: self.handleError(record) - def flush(self): + def flush(self) -> None: + """Writes the accumulated logs to the file.""" with open(self.filename, 'w', encoding='utf-8') as f: json.dump(self.logs, f, indent=2) logging.root.handlers = [] - logger.setLevel(logging.DEBUG) # Root logger captures all levels -# --- Console (Rich) Handler --- -console_level_name = str(config_handler.get("log_level", "INFO")).upper() +# Console (Rich) Handler +console_level_name = str(config.log_level).upper() # Line 172 fix console_level = getattr(logging, console_level_name, logging.INFO) rich_handler = RichHandler( - rich_tracebacks=bool(config_handler.get("backtrace", True)), - tracebacks_show_locals=bool(config_handler.get("diagnose", True)), + rich_tracebacks=True, # Use Config fields + tracebacks_show_locals=True, show_time=True, show_level=True, ) @@ -131,7 +152,9 @@ def flush(self): rich_handler.setLevel(console_level) # Respect configured log level logger.addHandler(rich_handler) -def initialize_additional_handlers(): + +def initialize_additional_handlers() -> None: + """Initializes file and JSON handlers based on configuration.""" project_path = config_handler.get_project_path() if not project_path: logger.warning("Project path not set; defaulting to ~/.optics") @@ -139,12 +162,11 @@ def initialize_additional_handlers(): else: logger.debug(f"Using project path: {project_path}") - # --- File Handler --- - if config_handler.get("file_log", True): + # File Handler + if config.file_log: default_log_path = Path(project_path) / "execution_output" / "logs.log" - configured_log_path = config_handler.config.get("log_path") - log_path = Path(str( - configured_log_path if configured_log_path is not None else default_log_path)).expanduser() + log_path = Path( + config.log_path if config.log_path is not None else default_log_path).expanduser() logger.debug(f"Log file path: {log_path}") log_path.parent.mkdir(parents=True, exist_ok=True) file_handler = logging.FileHandler(log_path, mode='w') @@ -153,56 +175,55 @@ def initialize_additional_handlers(): file_handler.setLevel(logging.DEBUG) # File captures all logs logger.addHandler(file_handler) - # --- JSON Handler --- - if config_handler.get("json_log", True): + # JSON Handler + if config.json_log: default_json_path = Path(project_path) / \ "execution_output" / "logs.json" - configured_json_path = config_handler.config.get("json_path") - json_path = Path(str( - configured_json_path if configured_json_path is not None else default_json_path)).expanduser() + json_path = Path( + config.json_path if config.json_path is not None else default_json_path).expanduser() logger.debug(f"JSON log file path: {json_path}") json_path.parent.mkdir(parents=True, exist_ok=True) json_handler = HierarchicalJsonHandler(json_path) json_handler.setLevel(logging.DEBUG) # JSON captures all logs logger.addHandler(json_handler) -# --- Context Manager and Decorators --- - @contextmanager def set_logger_format(fmt: str): - """ - Temporarily set the logger mode for the duration of a context. + """Temporarily sets the logger mode for the duration of a context. + :param fmt: Logger mode ("internal" or "user"). """ old_format = getattr(_log_context, "format_type", "internal") _log_context.format_type = fmt try: - yield + yield # Explicit yield for contextmanager finally: _log_context.format_type = old_format -def use_logger_format(fmt: str | None = None): - """ - Decorator to automatically set the logger mode for a function call. - :param fmt: Logger mode ("internal" or "user"). +def use_logger_format(fmt: Optional[str] = None) -> Callable: + """Decorator to set the logger mode for a function call. + + :param fmt: Logger mode ("internal" or "user"), defaults to None (uses "internal"). + :return: Decorated function. """ - def decorator(func): + def decorator(func: Callable) -> Callable: @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: with set_logger_format(fmt or "internal"): return func(*args, **kwargs) return wrapper return decorator -def apply_logger_format_to_all(fmt: str | None = None): - """ - Class decorator that applies a specific logger mode to every callable method. - :param fmt: Logger mode for all methods in the class. +def apply_logger_format_to_all(fmt: Optional[str] = None) -> Callable: + """Class decorator to apply a logger mode to all callable methods. + + :param fmt: Logger mode for all methods in the class, defaults to None (uses "internal"). + :return: Decorated class. """ - def decorator(cls): + def decorator(cls: type) -> type: for attr_name in dir(cls): if not attr_name.startswith("__"): attribute = getattr(cls, attr_name) diff --git a/optics_framework/common/optics_builder.py b/optics_framework/common/optics_builder.py index 6b1b1a1b..496cb238 100644 --- a/optics_framework/common/optics_builder.py +++ b/optics_framework/common/optics_builder.py @@ -1,58 +1,63 @@ from typing import Union, List, Dict, Optional, Type, TypeVar from optics_framework.common.factories import DeviceFactory, ElementSourceFactory, ImageFactory, TextFactory +from pydantic import BaseModel T = TypeVar('T') # Generic type for the build method +class OpticsConfig(BaseModel): + """Configuration for OpticsBuilder.""" + driver_config: Optional[Union[str, List[Union[str, Dict]]]] = None + element_source_config: Optional[Union[str, List[Union[str, Dict]]]] = None + image_config: Optional[Union[str, List[Union[str, Dict]]]] = None + text_config: Optional[Union[str, List[Union[str, Dict]]]] = None + + class OpticsBuilder: """ A builder that sets configurations and instantiates drivers for Optics Framework API classes. """ def __init__(self): - self.driver_config: Optional[Union[str, List[Union[str, Dict]]]] = None - self.element_source_config: Optional[Union[str, - List[Union[str, Dict]]]] = None - self.image_config: Optional[Union[str, List[Union[str, Dict]]]] = None - self.text_config: Optional[Union[str, List[Union[str, Dict]]]] = None + self.config = OpticsConfig() # Fluent methods to set configurations def add_driver(self, config: Union[str, List[Union[str, Dict]]]) -> 'OpticsBuilder': - self.driver_config = config + self.config.driver_config = config return self def add_element_source(self, config: Union[str, List[Union[str, Dict]]]) -> 'OpticsBuilder': - self.element_source_config = config + self.config.element_source_config = config return self def add_image_detection(self, config: Union[str, List[Union[str, Dict]]]) -> 'OpticsBuilder': - self.image_config = config + self.config.image_config = config return self def add_text_detection(self, config: Union[str, List[Union[str, Dict]]]) -> 'OpticsBuilder': - self.text_config = config + self.config.text_config = config return self # Methods to instantiate drivers def get_driver(self): - if not self.driver_config: + if not self.config.driver_config: raise ValueError("Driver configuration must be set") - return DeviceFactory.get_driver(self.driver_config) + return DeviceFactory.get_driver(self.config.driver_config) def get_element_source(self): - if not self.element_source_config: + if not self.config.element_source_config: raise ValueError("Element source configuration must be set") - return ElementSourceFactory.get_driver(self.element_source_config) + return ElementSourceFactory.get_driver(self.config.element_source_config) def get_image_detection(self): - if not self.image_config: + if not self.config.image_config: return None - return ImageFactory.get_driver(self.image_config) + return ImageFactory.get_driver(self.config.image_config) def get_text_detection(self): - if not self.text_config: + if not self.config.text_config: return None - return TextFactory.get_driver(self.text_config) + return TextFactory.get_driver(self.config.text_config) def build(self, cls: Type[T]) -> T: """ @@ -62,5 +67,5 @@ def build(self, cls: Type[T]) -> T: :return: An instance of the specified class. :raises ValueError: If required configurations are missing for the specified class. """ - instance = cls(self) # type: ignore + instance = cls(self) # type: ignore return instance diff --git a/optics_framework/common/runner/csv_reader.py b/optics_framework/common/runner/csv_reader.py index cedafc1c..e83e605a 100644 --- a/optics_framework/common/runner/csv_reader.py +++ b/optics_framework/common/runner/csv_reader.py @@ -1,5 +1,6 @@ import csv from abc import ABC, abstractmethod +from typing import Optional from optics_framework.common.logging_config import logger, use_logger_format @@ -45,13 +46,13 @@ def read_modules(self, file_path: str) -> dict: pass @abstractmethod - def read_elements(self, file_path: str) -> dict: + def read_elements(self, file_path: Optional[str]) -> dict: """ Read a file containing element information and return a dictionary mapping element names to their corresponding element IDs. - :param file_path: Path to the file. - :type file_path: str + :param file_path: Path to the file, or None if elements are not provided. + :type file_path: Optional[str] :return: A dictionary where keys are element names and values are element IDs. :rtype: dict """ @@ -138,18 +139,20 @@ def read_modules(self, file_path: str) -> dict: modules[module_name].append((keyword, params)) return modules - def read_elements(self, file_path: str) -> dict: + def read_elements(self, file_path: Optional[str]) -> dict: """ Read a CSV file containing element information and return a dictionary mapping - element names to their corresponding element IDs. + element names to their corresponding element IDs. Returns an empty dict if file_path is None. The CSV file is expected to have columns 'Element_Name' and 'Element_ID'. - :param file_path: Path to the elements CSV file. - :type file_path: str + :param file_path: Path to the elements CSV file, or None if not provided. + :type file_path: Optional[str] :return: A dictionary where keys are element names and values are element IDs. :rtype: dict """ + if not file_path: + return {} rows = self.read_file(file_path) elements = {} for row in rows: diff --git a/optics_framework/common/runner/keyword_register.py b/optics_framework/common/runner/keyword_register.py index d5763039..3d06a755 100644 --- a/optics_framework/common/runner/keyword_register.py +++ b/optics_framework/common/runner/keyword_register.py @@ -1,3 +1,4 @@ +from typing import Callable, Dict, Optional from optics_framework.common.logging_config import logger, apply_logger_format_to_all @@ -17,9 +18,9 @@ def __init__(self): Sets up an empty dictionary to store the mapping between keyword function names and their methods. """ - self.keyword_map = {} + self.keyword_map: Dict[str, Callable[..., object]] = {} - def register(self, instance): + def register(self, instance: object) -> None: """ Register all public callable methods of an instance. @@ -28,7 +29,6 @@ def register(self, instance): name is encountered, a warning is logged. :param instance: The instance whose methods are to be registered. - :type instance: object """ for method_name in dir(instance): if not method_name.startswith("_"): @@ -40,7 +40,7 @@ def register(self, instance): ) self.keyword_map[method_name] = method - def get_method(self, func_name): + def get_method(self, func_name: str) -> Optional[Callable[..., object]]: """ Retrieve a method by its function name. @@ -48,8 +48,6 @@ def get_method(self, func_name): If no such method exists, None is returned. :param func_name: The name of the function to retrieve. - :type func_name: str :return: The callable method if found; otherwise, None. - :rtype: callable or None """ return self.keyword_map.get(func_name) diff --git a/optics_framework/common/runner/test.py b/optics_framework/common/runner/test.py deleted file mode 100644 index 6a7d3a29..00000000 --- a/optics_framework/common/runner/test.py +++ /dev/null @@ -1,18 +0,0 @@ -from time import sleep -from rich.live import Live -from rich.table import Table - -def generate_table(): - table = Table(title="Live Data Table") - table.add_column("Row", justify="right", style="cyan", no_wrap=True) - table.add_column("Description", style="magenta") - table.add_column("Value", justify="right", style="green") - - for row in range(1, 11): - table.add_row(str(row), f"Description {row}", f"{row * 10}") - return table - -with Live(generate_table(), refresh_per_second=2) as live: - for _ in range(20): # Update the table 20 times - sleep(0.5) - live.update(generate_table()) \ No newline at end of file diff --git a/optics_framework/common/runner/test_runnner.py b/optics_framework/common/runner/test_runnner.py index 717e3f7a..68635bdd 100644 --- a/optics_framework/common/runner/test_runnner.py +++ b/optics_framework/common/runner/test_runnner.py @@ -1,7 +1,8 @@ import time import shutil import abc -from typing import Callable, Dict, List, Optional, Tuple, TypedDict, Union +from typing import Callable, Dict, List, Optional, Tuple, Union, Any +from pydantic import BaseModel, Field from optics_framework.common.logging_config import logger, apply_logger_format_to_all, HierarchicalJsonHandler from rich.live import Live from rich.tree import Tree @@ -11,7 +12,8 @@ from rich.console import Group -class KeywordResult(TypedDict): +class KeywordResult(BaseModel): + """Result of a single keyword execution.""" name: str resolved_name: str elapsed: str @@ -19,18 +21,20 @@ class KeywordResult(TypedDict): reason: str -class ModuleResult(TypedDict): +class ModuleResult(BaseModel): + """Result of a module execution.""" name: str elapsed: str status: str - keywords: List[KeywordResult] + keywords: List[KeywordResult] = Field(default_factory=list) -class TestCaseResult(TypedDict): +class TestCaseResult(BaseModel): + """Result of a test case execution.""" name: str elapsed: str status: str - modules: List[ModuleResult] + modules: List[ModuleResult] = Field(default_factory=list) class IResultPrinter(abc.ABC): @@ -117,28 +121,28 @@ def _render_tree(self) -> Group: tree = Tree("Test Suite", style="bold white") for tc_result in self.test_state.values(): test_case_node = tree.add(self.create_label( - tc_result["name"], tc_result["elapsed"], tc_result["status"], 0)) - for module in tc_result["modules"]: + tc_result.name, tc_result.elapsed, tc_result.status, 0)) + for module in tc_result.modules: module_node = test_case_node.add(self.create_label( - module["name"], module["elapsed"], module["status"], 1)) - for keyword in module["keywords"]: + module.name, module.elapsed, module.status, 1)) + for keyword in module.keywords: module_node.add(self.create_label( - keyword["resolved_name"], keyword["elapsed"], keyword["status"], 2)) + keyword.resolved_name, keyword.elapsed, keyword.status, 2)) completed = sum(1 for tc in self.test_state.values() - if tc["status"] in ["PASS", "FAIL"]) + if tc.status in ["PASS", "FAIL"]) if self.task_id is not None: self.progress.update(self.task_id, completed=completed) total, passed, failed = len(self.test_state), sum(1 for tc in self.test_state.values( - ) if tc["status"] == "PASS"), sum(1 for tc in self.test_state.values() if tc["status"] == "FAIL") + ) if tc.status == "PASS"), sum(1 for tc in self.test_state.values() if tc.status == "FAIL") summary_text = f"Total Test Cases: {total} | Passed: {passed} | Failed: {failed}" summary_panel = Panel( summary_text, style="green" if failed == 0 else "red") return Group(self.progress, tree, summary_panel) def print_tree_log(self, test_case_result: TestCaseResult) -> None: - self.test_state[test_case_result["name"]] = test_case_result + self.test_state[test_case_result.name] = test_case_result if self._live: self._live.update(self._render_tree()) @@ -146,7 +150,7 @@ def start_live(self) -> None: if not self._live: self._live = Live(self._render_tree(), refresh_per_second=10) self._live.start() - self._live.console.log("testing started") + self._live.console.log("Testing started") def stop_live(self) -> None: if self._live: @@ -161,7 +165,7 @@ def __init__( test_cases: Dict[str, List[str]], modules: Dict[str, List[Tuple[str, List[str]]]], elements: Dict[str, str], - keyword_map: Dict[str, Callable[..., None]], + keyword_map: Dict[str, Callable[..., Any]], result_printer: IResultPrinter ) -> None: self.test_cases = test_cases @@ -184,19 +188,19 @@ def resolve_param(self, param: str) -> str: return resolved_value def _init_test_case(self, test_case_name: str) -> TestCaseResult: - return {"name": test_case_name, "elapsed": "0.00s", "status": "NOT RUN", "modules": []} + return TestCaseResult(name=test_case_name, elapsed="0.00s", status="NOT RUN") def _init_module(self, module_name: str) -> ModuleResult: - return {"name": module_name, "elapsed": "0.00s", "status": "NOT RUN", "keywords": []} + return ModuleResult(name=module_name, elapsed="0.00s", status="NOT RUN") def _init_keyword(self, keyword: str) -> KeywordResult: - return {"name": keyword, "resolved_name": keyword, "elapsed": "0.00s", "status": "NOT RUN", "reason": ""} + return KeywordResult(name=keyword, resolved_name=keyword, elapsed="0.00s", status="NOT RUN", reason="") def _update_status(self, result: Union[TestCaseResult, ModuleResult, KeywordResult], status: str, elapsed: Optional[float] = None) -> None: - result["status"] = status + result.status = status if elapsed is not None: - result["elapsed"] = f"{elapsed:.2f}s" - if "modules" in result: # TestCaseResult + result.elapsed = f"{elapsed:.2f}s" + if isinstance(result, TestCaseResult): self.result_printer.print_tree_log(result) def _execute_keyword(self, keyword: str, params: List[str], keyword_result: KeywordResult, module_result: ModuleResult, test_case_result: TestCaseResult, start_time: float, extra: Dict[str, str]) -> bool: @@ -208,8 +212,8 @@ def _execute_keyword(self, keyword: str, params: List[str], keyword_result: Keyw method = self.keyword_map.get(func_name) if not method: logger.error(f"Keyword not found: {keyword}", extra=extra) - keyword_result.update( - {"reason": "Keyword not found", "elapsed": f"{time.time() - start_time:.2f}s"}) + keyword_result.reason = "Keyword not found" + keyword_result.elapsed = f"{time.time() - start_time:.2f}s" self._update_status(keyword_result, "FAIL") self._update_status(module_result, "FAIL") self._update_status(test_case_result, "FAIL") @@ -220,8 +224,8 @@ def _execute_keyword(self, keyword: str, params: List[str], keyword_result: Keyw raw_indices = getattr(method, '_raw_param_indices', []) resolved_params = [param if i in raw_indices else self.resolve_param( param) for i, param in enumerate(params)] - keyword_result["resolved_name"] = f"{keyword} ({', '.join(str(p) for p in resolved_params)})" - method(*resolved_params) + keyword_result.resolved_name = f"{keyword} ({', '.join(str(p) for p in resolved_params)})" + method(*resolved_params) # Return value ignored logger.debug( f"Keyword '{keyword}' executed successfully", extra=extra) self._update_status(keyword_result, "PASS", @@ -231,8 +235,8 @@ def _execute_keyword(self, keyword: str, params: List[str], keyword_result: Keyw except Exception as e: logger.error( f"Error executing keyword '{keyword}': {e}", extra=extra) - keyword_result.update( - {"reason": str(e), "elapsed": f"{time.time() - start_time:.2f}s"}) + keyword_result.reason = str(e) + keyword_result.elapsed = f"{time.time() - start_time:.2f}s" self._update_status(keyword_result, "FAIL") self._update_status(module_result, "FAIL") self._update_status(test_case_result, "FAIL") @@ -242,7 +246,8 @@ def _execute_keyword(self, keyword: str, params: List[str], keyword_result: Keyw def _process_module(self, module_name: str, test_case_result: TestCaseResult, extra: Dict[str, str]) -> bool: logger.debug(f"Loading module: {module_name}", extra=extra) module_result = self._init_module(module_name) - test_case_result["modules"].append(module_result) + test_case_result.modules.append( + module_result) # pylint: disable=no-member self.result_printer.print_tree_log(test_case_result) if module_name not in self.modules: @@ -258,14 +263,15 @@ def _process_module(self, module_name: str, test_case_result: TestCaseResult, ex for keyword, params in self.modules[module_name]: keyword_result = self._init_keyword(keyword) - module_result["keywords"].append(keyword_result) + module_result.keywords.append( # pylint: disable=no-member + keyword_result) extra["keyword"] = keyword keyword_start = time.time() if not self._execute_keyword(keyword, params, keyword_result, module_result, test_case_result, keyword_start, extra): return False - module_result["elapsed"] = f"{time.time() - module_start:.2f}s" - module_result["status"] = "PASS" if all( - k["status"] == "PASS" for k in module_result["keywords"]) else "FAIL" + module_result.elapsed = f"{time.time() - module_start:.2f}s" + module_result.status = "PASS" if all( + k.status == "PASS" for k in module_result.keywords) else "FAIL" self.result_printer.print_tree_log(test_case_result) return True @@ -290,9 +296,9 @@ def execute_test_case(self, test_case_name: str) -> TestCaseResult: if not self._process_module(module_name, test_case_result, self._extra(test_case_name, module_name)): return test_case_result - test_case_result["elapsed"] = f"{time.time() - start_time:.2f}s" - test_case_result["status"] = "PASS" if all( - m["status"] == "PASS" for m in test_case_result["modules"]) else "FAIL" + test_case_result.elapsed = f"{time.time() - start_time:.2f}s" + test_case_result.status = "PASS" if all( + m.status == "PASS" for m in test_case_result.modules) else "FAIL" logger.debug("Completed test case execution", extra=extra) self.result_printer.print_tree_log(test_case_result) return test_case_result @@ -331,10 +337,10 @@ def _dry_run_keyword(self, keyword: str, params: List[str], keyword_result: Keyw try: resolved_params = [self.resolve_param(param) for param in params] if resolved_params: - keyword_result["resolved_name"] = f"{keyword} ({', '.join(resolved_params)})" + keyword_result.resolved_name = f"{keyword} ({', '.join(resolved_params)})" except ValueError as e: logger.error(f"Parameter resolution failed: {e}", extra=extra) - keyword_result["reason"] = str(e) + keyword_result.reason = str(e) self._update_status(keyword_result, "FAIL") self._update_status(module_result, "FAIL") self._update_status(test_case_result, "FAIL") @@ -344,7 +350,7 @@ def _dry_run_keyword(self, keyword: str, params: List[str], keyword_result: Keyw func_name = "_".join(keyword.split()).lower() if func_name not in self.keyword_map: logger.error(f"Keyword not found: {keyword}", extra=extra) - keyword_result["reason"] = "Keyword not found" + keyword_result.reason = "Keyword not found" self._update_status(keyword_result, "FAIL") self._update_status(module_result, "FAIL") self._update_status(test_case_result, "FAIL") @@ -358,7 +364,8 @@ def _dry_run_keyword(self, keyword: str, params: List[str], keyword_result: Keyw def _dry_run_module(self, module_name: str, test_case_result: TestCaseResult, extra: Dict[str, str]) -> bool: logger.debug(f"Loading module: {module_name}", extra=extra) module_result = self._init_module(module_name) - test_case_result["modules"].append(module_result) + test_case_result.modules.append( + module_result) # pylint: disable=no-member self.result_printer.print_tree_log(test_case_result) self._update_status(module_result, "RUNNING") @@ -366,13 +373,14 @@ def _dry_run_module(self, module_name: str, test_case_result: TestCaseResult, ex for keyword, params in self.modules.get(module_name, []): keyword_result = self._init_keyword(keyword) - module_result["keywords"].append(keyword_result) + module_result.keywords.append( # pylint: disable=no-member + keyword_result) extra["keyword"] = keyword if not self._dry_run_keyword(keyword, params, keyword_result, module_result, test_case_result, extra): return False - module_result["status"] = "PASS" if all( - k["status"] == "PASS" for k in module_result["keywords"]) else "FAIL" - module_result["elapsed"] = "0.00s" + module_result.status = "PASS" if all( + k.status == "PASS" for k in module_result.keywords) else "FAIL" + module_result.elapsed = "0.00s" self.result_printer.print_tree_log(test_case_result) return True @@ -398,9 +406,9 @@ def dry_run_test_case(self, test_case_name: str) -> TestCaseResult: if not self._dry_run_module(module_name, test_case_result, self._extra(test_case_name, module_name)): return test_case_result - test_case_result["elapsed"] = f"{time.time() - start_time:.2f}s" - test_case_result["status"] = "PASS" if all( - m["status"] == "PASS" for m in test_case_result["modules"]) else "FAIL" + test_case_result.elapsed = f"{time.time() - start_time:.2f}s" + test_case_result.status = "PASS" if all( + m.status == "PASS" for m in test_case_result.modules) else "FAIL" logger.debug("Completed dry run", extra=extra) self.result_printer.print_tree_log(test_case_result) return test_case_result diff --git a/optics_framework/common/session_manager.py b/optics_framework/common/session_manager.py index cd3ce2fe..9b5eff26 100644 --- a/optics_framework/common/session_manager.py +++ b/optics_framework/common/session_manager.py @@ -2,14 +2,14 @@ import asyncio from abc import ABC, abstractmethod from typing import Dict, Optional -from optics_framework.common.config_handler import ConfigHandler +from optics_framework.common.config_handler import Config, ConfigHandler from optics_framework.common.optics_builder import OpticsBuilder class SessionHandler(ABC): """Abstract interface for session management.""" @abstractmethod - def create_session(self, config: dict) -> str: + def create_session(self, config: Config) -> str: pass @abstractmethod @@ -24,12 +24,10 @@ def terminate_session(self, session_id: str) -> None: class Session: """Represents a single execution session with config and optics.""" - def __init__(self, session_id: str, config: dict, project_path: str): + def __init__(self, session_id: str, config: Config): self.session_id = session_id self.config_handler = ConfigHandler.get_instance() - self.config_handler.set_project(project_path) - self.config_handler.load() - self.config = self.config_handler.config + self.config = config # Fetch enabled dependency names driver_sources = self.config_handler.get("driver_sources", []) @@ -56,11 +54,10 @@ class SessionManager(SessionHandler): def __init__(self): self.sessions: Dict[str, Session] = {} - def create_session(self, config: dict) -> str: + def create_session(self, config: Config) -> str: """Creates a new session with a unique ID.""" session_id = str(uuid.uuid4()) - project_path = config.get("project_path", "") - self.sessions[session_id] = Session(session_id, config, project_path) + self.sessions[session_id] = Session(session_id, config) return session_id def get_session(self, session_id: str) -> Optional[Session]: diff --git a/optics_framework/common/strategies.py b/optics_framework/common/strategies.py index f5d97ed0..5447c29a 100644 --- a/optics_framework/common/strategies.py +++ b/optics_framework/common/strategies.py @@ -1,33 +1,49 @@ from abc import ABC, abstractmethod import inspect -from typing import Union, Tuple, Generator, Set +from typing import List, Union, Tuple, Generator, Set from optics_framework.common.base_factory import FallbackProxy from optics_framework.common.elementsource_interface import ElementSourceInterface from optics_framework.common import utils from optics_framework.common.logging_config import logger -# Locator Strategy Interface + class LocatorStrategy(ABC): + """Abstract base class for element location strategies.""" + @property @abstractmethod def element_source(self) -> ElementSourceInterface: - """The element source this strategy operates on.""" + """Returns the element source this strategy operates on.""" pass @abstractmethod - def locate(self, element) -> Union[object, Tuple[int, int]]: - """Locate an element and return either an element object or coordinates (x, y).""" + def locate(self, element: str) -> Union[object, Tuple[int, int]]: + """Locates an element and returns either an element object or coordinates (x, y). + + :param element: The element identifier (e.g., XPath, text, image path). + :return: Either an element object or a tuple of (x, y) coordinates. + """ pass @staticmethod @abstractmethod def supports(element_type: str, element_source: ElementSourceInterface) -> bool: - """Determine if this strategy supports the given element type and source.""" + """Determines if this strategy supports the given element type and source. + + :param element_type: The type of element (e.g., 'XPath', 'Text', 'Image'). + :param element_source: The source to check compatibility with. + :return: True if supported, False otherwise. + """ pass @staticmethod - def _is_method_implemented(element_source, method_name: str) -> bool: - """Check if the method is implemented and not a stub.""" + def _is_method_implemented(element_source: ElementSourceInterface, method_name: str) -> bool: + """Checks if the method is implemented and not a stub. + + :param element_source: The source to inspect. + :param method_name: The name of the method to check. + :return: True if implemented, False if abstract or a stub. + """ method = getattr(element_source, method_name) if inspect.isabstract(method): return False @@ -37,10 +53,10 @@ def _is_method_implemented(element_source, method_name: str) -> bool: except (OSError, TypeError): return True -# Concrete Strategies - class XPathStrategy(LocatorStrategy): + """Strategy for locating elements via XPath.""" + def __init__(self, element_source: ElementSourceInterface): self._element_source = element_source @@ -48,7 +64,7 @@ def __init__(self, element_source: ElementSourceInterface): def element_source(self) -> ElementSourceInterface: return self._element_source - def locate(self, element): + def locate(self, element: str) -> Union[object, Tuple[int, int]]: return self.element_source.locate(element) @staticmethod @@ -57,6 +73,8 @@ def supports(element_type: str, element_source: ElementSourceInterface) -> bool: class TextElementStrategy(LocatorStrategy): + """Strategy for locating text elements directly via the element source.""" + def __init__(self, element_source: ElementSourceInterface): self._element_source = element_source @@ -64,7 +82,7 @@ def __init__(self, element_source: ElementSourceInterface): def element_source(self) -> ElementSourceInterface: return self._element_source - def locate(self, element): + def locate(self, element: str) -> Union[object, Tuple[int, int]]: return self.element_source.locate(element) @staticmethod @@ -73,6 +91,8 @@ def supports(element_type: str, element_source: ElementSourceInterface) -> bool: class TextDetectionStrategy(LocatorStrategy): + """Strategy for locating text elements using text detection.""" + def __init__(self, element_source: ElementSourceInterface, text_detection): self._element_source = element_source self.text_detection = text_detection @@ -81,7 +101,7 @@ def __init__(self, element_source: ElementSourceInterface, text_detection): def element_source(self) -> ElementSourceInterface: return self._element_source - def locate(self, element): + def locate(self, element: str) -> Union[object, Tuple[int, int]]: screenshot = self.element_source.capture() return self.text_detection.locate(screenshot, element) @@ -91,6 +111,8 @@ def supports(element_type: str, element_source: ElementSourceInterface) -> bool: class ImageDetectionStrategy(LocatorStrategy): + """Strategy for locating image elements using image detection.""" + def __init__(self, element_source: ElementSourceInterface, image_detection): self._element_source = element_source self.image_detection = image_detection @@ -99,7 +121,7 @@ def __init__(self, element_source: ElementSourceInterface, image_detection): def element_source(self) -> ElementSourceInterface: return self._element_source - def locate(self, element): + def locate(self, element: str) -> Union[object, Tuple[int, int]]: screenshot = self.element_source.capture() return self.image_detection.locate(screenshot, element) @@ -107,10 +129,10 @@ def locate(self, element): def supports(element_type: str, element_source: ElementSourceInterface) -> bool: return element_type == "Image" and LocatorStrategy._is_method_implemented(element_source, "capture") -# Strategy Factory - class StrategyFactory: + """Factory for creating locator strategies.""" + def __init__(self, text_detection, image_detection): self.text_detection = text_detection self.image_detection = image_detection @@ -123,36 +145,45 @@ def __init__(self, text_detection, image_detection): "image_detection": self.image_detection}), ] - def create_strategies(self, element_source: ElementSourceInterface) -> list: - """Create strategies compatible with the given element source using the registry.""" + def create_strategies(self, element_source: ElementSourceInterface) -> List[LocatorStrategy]: + """Creates strategies compatible with the given element source. + + :param element_source: The source to build strategies for. + :return: List of compatible strategy instances. + """ strategies = [] for strategy_cls, element_type, extra_args in self._strategy_registry: if strategy_cls.supports(element_type, element_source): strategies.append(strategy_cls(element_source, **extra_args)) return strategies -# Result Wrapper - class LocateResult: + """Wrapper for location results from a strategy.""" + def __init__(self, value: Union[object, Tuple[int, int]], strategy: LocatorStrategy): self.value = value self.strategy = strategy self.is_coordinates = isinstance(value, tuple) -# Strategy Manager - class StrategyManager: + """Manages multiple locator strategies for element location.""" + def __init__(self, element_source: ElementSourceInterface, text_detection, image_detection): self.element_source = element_source self.factory = StrategyFactory(text_detection, image_detection) self.strategies = self._build_strategies() - logger.debug(f"Built strategies: {self.strategies}") + logger.debug( + f"Built strategies: {[s.__class__.__name__ for s in self.strategies]}") def _build_strategies(self) -> Set[LocatorStrategy]: - """Build a set of all strategies from the element source.""" - all_strategies = set() + """Builds a set of all strategies from the element source. + + :return: Set of strategy instances. + :raises ValueError: If no strategies are available. + """ + all_strategies: Set[LocatorStrategy] = set() if isinstance(self.element_source, FallbackProxy): for instance in self.element_source.instances: all_strategies.update(self.factory.create_strategies(instance)) @@ -165,11 +196,15 @@ def _build_strategies(self) -> Set[LocatorStrategy]: "No strategies available for the given element source") return all_strategies - def locate(self, element) -> Generator[LocateResult, None, None]: - """Yield applicable strategies' results in order of attempt.""" + def locate(self, element: str) -> Generator[LocateResult, None, None]: + """Yields applicable strategies' results in order of attempt. + + :param element: The element identifier to locate. + :yields: LocateResult objects with location data and strategy used. + """ element_type = utils.determine_element_type(element) - applicable_strategies = {s for s in self.strategies if s.supports( - element_type, s.element_source)} + applicable_strategies = { + s for s in self.strategies if s.supports(element_type, s.element_source)} for strategy in applicable_strategies: try: diff --git a/optics_framework/common/text_interface.py b/optics_framework/common/text_interface.py index 69256deb..fa5902ac 100644 --- a/optics_framework/common/text_interface.py +++ b/optics_framework/common/text_interface.py @@ -4,10 +4,14 @@ class TextInterface(ABC): """ - Abstract base class for vision-based detection models. + Abstract base class for text processing engines. - This interface enforces the implementation of a :meth:`detect` method, - which processes input data to identify specific objects, patterns, or text. + This interface defines methods for detecting and locating text (e.g., via OCR) + within input data, such as images or video frames, implementing the + :class:`TextInterface`. + + Implementers should handle specific input types (e.g., image bytes, file paths) + and reference data as needed. """ def element_exist(self, input_data, reference_data) -> Optional[Tuple[int, int]]: @@ -18,37 +22,38 @@ def element_exist(self, input_data, reference_data) -> Optional[Tuple[int, int]] :type input_data: Any :param reference_data: The reference data used for matching or comparison. :type reference_data: Any - :return: A tuple ``(x, y)`` representing the top-left corner of the reference image, - or ``None`` if the image is not found. + :return: A tuple (x, y) representing the top-left corner of the reference image, + or None if not found. :rtype: Optional[Tuple[int, int]] """ pass def locate(self, input_data, text, index=None) -> Optional[Tuple[int, int]]: """ - Find the location of text within the input data. + Find the location of specific text within the input data. :param input_data: The input source (e.g., image, video frame) for detection. - :type - input_data: Any + :type input_data: Any :param text: The text to search for. :type text: str - :return: A tuple ``(x, y)`` representing the centre of the text, - or ``None`` if the text is not found. + :return: A tuple (x, y) representing the center of the text, or None if not found. :rtype: Optional[Tuple[int, int]] """ pass def find_element(self, input_data, text, index=None) -> Optional[Tuple[bool, Tuple[int, int], Tuple[Tuple[int, int], Tuple[int, int]]]]: """ - Locate a specific text in the given input data using OCR and return the coordinates. + Locate specific text in the input data using OCR and return detailed detection info. :param input_data: The input source (e.g., image, video frame) for detection. :type input_data: Any :param text: The text to locate in the input data. :type text: str - :return: A tuple containing a boolean indicating if the text was found, - the coordinates of the text, and the bounding box of the text. + :return: A tuple (found, center, bounds) where: + - found: bool indicating if the text was found + - center: (x, y) coordinates of the text center + - bounds: ((x1, y1), (x2, y2)) bounding box (top-left, bottom-right) + Returns None if not found. :rtype: Optional[Tuple[bool, Tuple[int, int], Tuple[Tuple[int, int], Tuple[int, int]]]] """ pass diff --git a/optics_framework/helper/cli.py b/optics_framework/helper/cli.py index 373a7ef5..9c2387e4 100644 --- a/optics_framework/helper/cli.py +++ b/optics_framework/helper/cli.py @@ -1,5 +1,7 @@ import argparse import sys +from typing import Optional +from pydantic import BaseModel from optics_framework.common.logging_config import apply_logger_format_to_all from optics_framework.helper.list_keyword import main as list_main from optics_framework.helper.config_manager import main as config_main @@ -8,6 +10,7 @@ from optics_framework.helper.execute import execute_main, dryrun_main from optics_framework.helper.generate import generate_test_file as generate_framework_code + class Command: """ Abstract base class for CLI commands. @@ -26,66 +29,43 @@ def register(self, subparsers: argparse._SubParsersAction): :type subparsers: argparse._SubParsersAction :raises NotImplementedError: If the subclass does not implement this method. """ - raise NotImplementedError("Subclasses must implement the `register` method.") + raise NotImplementedError( + "Subclasses must implement the `register` method.") - def execute(self, args: argparse.Namespace): + def execute(self, args): """ Execute the command using the provided arguments. - :param args: The parsed command-line arguments. - :type args: argparse.Namespace - :raises NotImplementedError: If the subclass does not implement this method. + :param args: The parsed command-line arguments (Pydantic model or argparse.Namespace). """ - raise NotImplementedError("Subclasses must implement the `execute` method.") + raise NotImplementedError( + "Subclasses must implement the `execute` method.") @apply_logger_format_to_all("user") class ListCommand(Command): - """ - Command to list all available methods in the API. - - This command calls the :func:`list_main` function to display the available methods. - """ - def register(self, subparsers: argparse._SubParsersAction): - """ - Register the list command with the provided subparsers. - - :param subparsers: The argparse subparsers object. - :type subparsers: argparse._SubParsersAction - """ parser = subparsers.add_parser( "list", help="List all available methods in the API" ) parser.set_defaults(func=self.execute) - def execute(self, args: argparse.Namespace): - """ - Execute the list command. - - :param args: The parsed command-line arguments. - :type args: argparse.Namespace - """ + def execute(self, args): list_main() -@apply_logger_format_to_all("user") -class GenerateCommand(Command): - """ - Command to generate test framework code. - This command generates test framework code using the provided options. - """ +class GenerateArgs(BaseModel): + """Arguments for the generate command.""" + project_path: str + output_file: str = "generated_test.py" - def register(self, subparsers: argparse._SubParsersAction): - """ - Register the generate command with the provided subparsers. - :param subparsers: The argparse subparsers object. - :type subparsers: argparse._SubParsersAction - """ - parser = subparsers.add_parser("generate", help="Generate test framework code") - parser.add_argument("project_path", - help="Project name (required)") +@apply_logger_format_to_all("user") +class GenerateCommand(Command): + def register(self, subparsers: argparse._SubParsersAction): + parser = subparsers.add_parser( + "generate", help="Generate test framework code") + parser.add_argument("project_path", help="Project name (required)") parser.add_argument( "output_file", help="Path to the output file where the code will be generated", @@ -94,59 +74,37 @@ def register(self, subparsers: argparse._SubParsersAction): ) parser.set_defaults(func=self.execute) - def execute(self, args: argparse.Namespace): - """ - Execute the generate command. + def execute(self, args): + generate_args = GenerateArgs( + project_path=args.project_path, output_file=args.output_file) + generate_framework_code( + generate_args.project_path, generate_args.output_file) - :param args: The parsed command-line arguments. - :type args: argparse.Namespace - """ - generate_framework_code(args.project_path, args.output_file) class ConfigCommand(Command): - """ - Command to manage configuration. - - This command delegates to the :func:`config_main` function for configuration management. - """ - def register(self, subparsers: argparse._SubParsersAction): - """ - Register the config command with the provided subparsers. - - :param subparsers: The argparse subparsers object. - :type subparsers: argparse._SubParsersAction - """ parser = subparsers.add_parser("config", help="Manage configuration") parser.set_defaults(func=self.execute) - def execute(self, args: argparse.Namespace): - """ - Execute the config command. - - :param args: The parsed command-line arguments. - :type args: argparse.Namespace - """ + def execute(self, args): config_main() -@apply_logger_format_to_all("user") -class InitCommand(Command): - """ - Command to initialize a new project. +class InitArgs(BaseModel): + """Arguments for the init command.""" + name: str + path: Optional[str] = None + force: bool = False + template: Optional[str] = None + git_init: bool = False - This command creates a new project using the provided options. - """ +@apply_logger_format_to_all("user") +class InitCommand(Command): def register(self, subparsers: argparse._SubParsersAction): - """ - Register the init command with the provided subparsers. - - :param subparsers: The argparse subparsers object. - :type subparsers: argparse._SubParsersAction - """ parser = subparsers.add_parser("init", help="Initialize a new project") - parser.add_argument("--name", required=True, help="Project name (required)") + parser.add_argument("--name", required=True, + help="Project name (required)") parser.add_argument( "--path", help="Directory where the project will be created" ) @@ -163,31 +121,26 @@ def register(self, subparsers: argparse._SubParsersAction): ) parser.set_defaults(func=self.execute) - def execute(self, args: argparse.Namespace): - """ - Execute the init command. + def execute(self, args): + init_args = InitArgs( + name=args.name, + path=args.path, + force=args.force, + template=args.template, + git_init=args.git_init + ) + create_project(init_args) - :param args: The parsed command-line arguments. - :type args: argparse.Namespace - """ - create_project(args) + +class DryRunArgs(BaseModel): + """Arguments for the dry_run command.""" + folder_path: str + test_name: str = "" @apply_logger_format_to_all("user") class DryRunCommand(Command): - """ - Command to generate a dry run report. - - This command generates a dry run report using CSV files for test cases, modules, and elements. - """ - def register(self, subparsers: argparse._SubParsersAction): - """ - Register the dry run command with the provided subparsers. - - :param subparsers: The argparse subparsers object. - :type subparsers: argparse._SubParsersAction - """ parser = subparsers.add_parser( "dry_run", help="Execute test cases from CSV files" ) @@ -203,32 +156,21 @@ def register(self, subparsers: argparse._SubParsersAction): ) parser.set_defaults(func=self.execute) - def execute(self, args: argparse.Namespace): - """ - Execute the dry run command. + def execute(self, args): + dry_run_args = DryRunArgs( + folder_path=args.folder_path, test_name=args.test_name) + dryrun_main(dry_run_args.folder_path, dry_run_args.test_name) - :param args: The parsed command-line arguments. - :type args: argparse.Namespace - """ - dryrun_main(args.folder_path, args.test_name) +class ExecuteArgs(BaseModel): + """Arguments for the execute command.""" + folder_path: str + test_name: str = "" @apply_logger_format_to_all("user") class ExecuteCommand(Command): - """ - Command to execute test cases from CSV files. - - This command runs test cases located in a specified folder, optionally filtering by test name. - """ - def register(self, subparsers: argparse._SubParsersAction): - """ - Register the execute command with the provided subparsers. - - :param subparsers: The argparse subparsers object. - :type subparsers: argparse._SubParsersAction - """ parser = subparsers.add_parser( "execute", help="Execute test cases from CSV files" ) @@ -244,39 +186,20 @@ def register(self, subparsers: argparse._SubParsersAction): ) parser.set_defaults(func=self.execute) - def execute(self, args: argparse.Namespace): - """ - Execute the test cases. - - :param args: The parsed command-line arguments. - :type args: argparse.Namespace - """ - execute_main(args.folder_path, args.test_name) + def execute(self, args): + execute_args = ExecuteArgs( + folder_path=args.folder_path, test_name=args.test_name) + execute_main(execute_args.folder_path, execute_args.test_name) @apply_logger_format_to_all("user") class VersionCommand(Command): - """ - Command to display the current version of the Optics Framework. - """ - def register(self, subparsers: argparse._SubParsersAction): - """ - Register the version command with the provided subparsers. - - :param subparsers: The argparse subparsers object. - :type subparsers: argparse._SubParsersAction - """ - parser = subparsers.add_parser("version", help="Print the current version") + parser = subparsers.add_parser( + "version", help="Print the current version") parser.set_defaults(func=self.execute) - def execute(self, args: argparse.Namespace): - """ - Execute the version command. - - :param args: The parsed command-line arguments. - :type args: argparse.Namespace - """ + def execute(self, args): print(f"Optics Framework {VERSION}") @@ -287,7 +210,8 @@ def main(): This function sets up the argument parser, registers all commands, parses the command-line arguments, and dispatches the appropriate command function. """ - parser = argparse.ArgumentParser(prog="optics", description="Optics Framework CLI") + parser = argparse.ArgumentParser( + prog="optics", description="Optics Framework CLI") subparsers = parser.add_subparsers(dest="command", required=True) # Register all commands. diff --git a/optics_framework/helper/config_manager.py b/optics_framework/helper/config_manager.py index e29bad01..0c245328 100644 --- a/optics_framework/helper/config_manager.py +++ b/optics_framework/helper/config_manager.py @@ -1,181 +1,206 @@ -import curses -import curses.textpad +from textual.app import App, ComposeResult +from textual.widgets import Header, Footer, ListView, ListItem, Label, Input, Button +from textual.containers import Vertical, Horizontal, Container +from textual.screen import ModalScreen +from optics_framework.common.config_handler import ConfigHandler, DependencyConfig import ast -# Import updated ConfigHandler -from optics_framework.common.config_handler import ConfigHandler -class LoggerTUI: - """ - A text-based UI for editing logger configuration. - """ +class QuitConfirmScreen(ModalScreen[bool]): + """Modal screen to confirm quitting without saving.""" - def __init__(self, stdscr): - self.stdscr = stdscr - # Use the updated singleton instance of ConfigHandler - self.config_handler = ConfigHandler.get_instance() - self.options = list(self.config_handler.config.keys()) - self.current_index = 0 - self.init_curses() - self.run() - - def init_curses(self): - curses.curs_set(0) - curses.start_color() - curses.init_pair(1, curses.COLOR_CYAN, curses.COLOR_BLACK) - curses.init_pair(2, curses.COLOR_YELLOW, curses.COLOR_BLACK) - curses.init_pair(3, curses.COLOR_GREEN, curses.COLOR_BLACK) - - def confirm_quit(self): - height, width = self.stdscr.getmaxyx() - win = curses.newwin(5, 50, height // 2 - 2, width // 2 - 25) - win.box() - win.addstr(1, 2, "Quit without saving? (y/n)", curses.color_pair(2)) - win.refresh() - - while True: - key = win.getch() - if key in [ord("y"), ord("Y")]: - return True - elif key in [ord("n"), ord("N")]: - return False - - def run(self): - while True: - self.stdscr.clear() - self.display_menu() - key = self.stdscr.getch() - - if key == curses.KEY_UP: - self.current_index = ( - self.current_index - 1) % len(self.options) - elif key == curses.KEY_DOWN: - self.current_index = ( - self.current_index + 1) % len(self.options) - elif key == ord(" "): - self.modify_value() - elif key == ord("s"): - try: - self.config_handler.save_config() - except Exception as e: - self.show_error_message(f"Error saving config: {e}") - exit(0) - elif key == ord("q"): - if self.confirm_quit(): - exit(0) - - def display_menu(self): - height, width = self.stdscr.getmaxyx() - title = "Logger Configuration" - self.stdscr.addstr( - 1, - (width // 2 - len(title) // 2), - title, - curses.A_BOLD | curses.color_pair(2), + def compose(self) -> ComposeResult: + yield Vertical( + Label("Quit without saving? (y/n)", classes="modal-title"), + Horizontal( + Button("Yes", variant="error", id="yes"), + Button("No", variant="primary", id="no"), + classes="modal-buttons" + ), + classes="modal" ) - # Always fetch the latest config dynamically - config = self.config_handler.config + def on_button_pressed(self, event: Button.Pressed) -> None: + self.dismiss(event.button.id == "yes") - for idx, key in enumerate(self.options): - prefix = "> " if idx == self.current_index else " " - value = str(config[key]) # Fetch latest value - color = ( - curses.color_pair( - 1) if idx == self.current_index else curses.A_NORMAL - ) - self.stdscr.addstr(idx + 3, 2, f"{prefix}{key}: {value}", color) - - footer = "[SPACE] Edit [S] Save [Q] Quit" - self.stdscr.addstr( - height - 2, (width // 2 - len(footer) // - 2), footer, curses.color_pair(3) + +class ErrorScreen(ModalScreen[None]): + """Modal screen to display error messages.""" + + def __init__(self, message: str): + super().__init__() + self.message = message + + def compose(self) -> ComposeResult: + yield Vertical( + Label(self.message, classes="error-message"), + Button("OK", variant="primary", id="ok"), + classes="modal" ) - def modify_value(self): - key = self.options[self.current_index] - config = self.config_handler.config # Fetch latest config + def on_button_pressed(self, event: Button.Pressed) -> None: + if event.button.id == "ok": + self.dismiss(None) + + +class LoggerTUI(App): + """A Textual-based UI for editing logger configuration.""" + CSS = """ + Screen { + align: center middle; + background: $background; + } + Header { + background: $primary; + } + Footer { + background: $secondary; + } + ListView { + height: 80%; + width: 80%; + border: solid $accent; + padding: 1; + } + ListItem { + padding: 0 1; + } + ListItem.--highlight { + background: $primary-darken-1; + } + .option-label { + color: $text; + } + .editing { + height: 3; + margin: 1 0; + } + .modal { + width: 40; + height: 10; + background: $panel; + border: solid $accent; + padding: 1; + } + .modal-title { + color: $warning; + text-align: center; + } + .modal-buttons { + margin-top: 1; + align: center middle; + } + .error-message { + color: $error; + text-align: center; + } + """ + + BINDINGS = [ + ("up", "move_up", "Move up"), + ("down", "move_down", "Move down"), + ("space", "edit", "Edit value"), + ("s", "save", "Save config"), + ("q", "quit", "Quit"), + ] + + def __init__(self): + super().__init__() + self.config_handler = ConfigHandler.get_instance() + self.options = list(self.config_handler.config.model_fields.keys()) + self.selected_index = 0 # Changed to plain int + + def compose(self) -> ComposeResult: + yield Header() + yield ListView(*[ListItem(Label(f"{key}: {self.get_value(key)}", classes="option-label")) + for key in self.options], id="config-list") + yield Footer() + + def on_mount(self) -> None: + self.query_one("#config-list").focus() + + def get_value(self, key: str) -> str: + """Fetch and format the current config value.""" + value = getattr(self.config_handler.config, key) + if key in self.config_handler.DEPENDENCY_KEYS: + return str(self.config_handler.get(key)) + return str(value) + + def action_move_up(self) -> None: + self.selected_index = max(0, self.selected_index - 1) + self.refresh_list() + + def action_move_down(self) -> None: + self.selected_index = min( + len(self.options) - 1, self.selected_index + 1) + self.refresh_list() + + def refresh_list(self) -> None: + list_view = self.query_one("#config-list", ListView) + for idx, key in enumerate(self.options): + list_view.children[idx].query_one(Label).update( + f"{key}: {self.get_value(key)}") + list_view.index = self.selected_index + + async def action_edit(self) -> None: + key = self.options[self.selected_index] + current_value = getattr(self.config_handler.config, key) - if isinstance(config[key], bool): - config[key] = not config[key] + if isinstance(current_value, bool): + setattr(self.config_handler.config, key, not current_value) + self.refresh_list() else: - new_value = self.get_validated_input(config[key]) - if new_value is not None: - config[key] = new_value - - def get_text_input(self, current_value): - height, width = self.stdscr.getmaxyx() - win = curses.newwin(5, 50, height // 2 - 2, width // 2 - 25) - win.box() - win.addstr(1, 2, "Enter new value:", curses.color_pair(2)) - win.addstr(2, 2, f"[Current: {current_value}]", curses.color_pair(1)) - win.refresh() - - curses.echo() - input_value = "" - cursor_x = 2 - - while True: - win.addstr(3, cursor_x, input_value) - win.refresh() - key = win.getch() - - if key in [10, 13]: # Enter key - break - elif key in [127, curses.KEY_BACKSPACE]: - if input_value: - input_value = input_value[:-1] - win.addstr(3, cursor_x, " " * 48) - elif key == ord("q"): - curses.noecho() - if self.confirm_quit(): - exit(0) - else: - win.clear() - win.box() - win.addstr(1, 2, "Enter new value:", curses.color_pair(2)) - win.addstr( - 2, 2, f"[Current: {current_value}]", curses.color_pair( - 1) - ) - win.refresh() + input_widget = Input(placeholder=str( + current_value), id="edit-input") + confirm_button = Button( + "Confirm", variant="success", id="confirm-edit") + self.mount( + Container(input_widget, confirm_button, classes="editing")) + + def on_input_submitted(self, event: Input.Submitted) -> None: + self.handle_edit_confirm(event.value) + + def on_button_pressed(self, event: Button.Pressed) -> None: + if event.button.id == "confirm-edit": + input_value = self.query_one("#edit-input", Input).value + self.handle_edit_confirm(input_value) + + def handle_edit_confirm(self, new_value: str) -> None: + key = self.options[self.selected_index] + current_value = getattr(self.config_handler.config, key) + + try: + if isinstance(current_value, list) and key in self.config_handler.DEPENDENCY_KEYS: + parsed = ast.literal_eval(new_value) + if not isinstance(parsed, list) or not all(isinstance(x, str) for x in parsed): + raise ValueError("Must be a list of strings") + setattr(self.config_handler.config, key, + [{"name": DependencyConfig(enabled=True)} for name in parsed]) else: - input_value += chr(key) - - curses.noecho() - return input_value.strip() if input_value else current_value - - def get_validated_input(self, current_value): - while True: - new_value = self.get_text_input(current_value) - try: - # For list types, use ast.literal_eval to parse user input - if isinstance(current_value, list): - parsed = ast.literal_eval(new_value) - if isinstance(parsed, list): - return parsed - else: - raise ValueError - else: - # Attempt to cast new_value to the type of the current value - return type(current_value)(new_value) - except Exception: - self.show_error_message("Invalid input!") - - def show_error_message(self, message): - height, width = self.stdscr.getmaxyx() - win = curses.newwin( - 3, len(message) + 6, height // 2, width // 2 - (len(message) // 2) - ) - win.box() - win.addstr(1, 2, message, curses.color_pair(2)) - win.refresh() - curses.napms(1500) + parsed = type(current_value)(new_value) + setattr(self.config_handler.config, key, parsed) + self.refresh_list() + except Exception as e: + self.push_screen(ErrorScreen( + f"Invalid input: {e}"), lambda _: None) + finally: + self.query_one(".editing").remove() + + def action_save(self) -> None: + try: + self.config_handler.save_config() + self.exit(0) + except Exception as e: + self.push_screen(ErrorScreen( + f"Error saving config: {e}"), lambda _: None) + + async def action_quit(self) -> None: + self.push_screen(QuitConfirmScreen(), self.handle_quit) + + def handle_quit(self, confirmed: bool | None) -> None: + if confirmed is True: + self.exit(0) def main(): - curses.wrapper(LoggerTUI) - - -if __name__ == "__main__": - main() + LoggerTUI().run() diff --git a/optics_framework/helper/execute.py b/optics_framework/helper/execute.py index 4ea8fcf4..e545cd15 100644 --- a/optics_framework/helper/execute.py +++ b/optics_framework/helper/execute.py @@ -1,21 +1,21 @@ import os import asyncio from typing import Optional, Tuple -from dataclasses import asdict +from pydantic import BaseModel, field_validator from optics_framework.common.logging_config import logger, apply_logger_format_to_all -from optics_framework.common.config_handler import ConfigHandler # Updated import +from optics_framework.common.config_handler import ConfigHandler from optics_framework.common.runner.csv_reader import CSVDataReader from optics_framework.common.session_manager import SessionManager from optics_framework.common.execution import ExecutionEngine @apply_logger_format_to_all("user") -def find_csv_files(folder_path: str) -> Tuple[Optional[str], Optional[str], Optional[str]]: +def find_csv_files(folder_path: str) -> Tuple[str, str, Optional[str]]: """ Search for CSV files in a folder and categorize them by reading their headers. :param folder_path: Path to the project folder. - :return: Tuple of paths to test_cases, modules, and elements CSV files. + :return: Tuple of paths to test_cases (required), modules (required), and elements (optional) CSV files. """ test_cases, modules, elements = None, None, None for file in os.listdir(folder_path): @@ -38,55 +38,71 @@ def find_csv_files(folder_path: str) -> Tuple[Optional[str], Optional[str], Opti elif "element_name" in headers and "element_id" in headers: elements = file_path logger.debug(f"Found elements file: {file_path}") + + if not test_cases or not modules: + missing = [f for f, p in [ + ("test_cases", test_cases), ("modules", modules)] if not p] + logger.error( + f"Missing required CSV files in {folder_path}: {', '.join(missing)}") + raise ValueError(f"Required CSV files missing: {', '.join(missing)}") return test_cases, modules, elements +class RunnerArgs(BaseModel): + """Arguments for BaseRunner initialization.""" + folder_path: str + test_name: str = "" + + @field_validator('folder_path') + @classmethod + def folder_path_must_exist(cls, v: str) -> str: + """Ensure folder_path is an existing directory.""" + abs_path = os.path.abspath(v) + if not os.path.isdir(abs_path): + raise ValueError(f"Invalid project folder: {abs_path}") + return abs_path + + @field_validator('test_name') + @classmethod + def strip_test_name(cls, v: str) -> str: + """Strip whitespace from test_name.""" + return v.strip() + + @apply_logger_format_to_all("user") class BaseRunner: """Base class for running test cases from CSV files using ExecutionEngine.""" - def __init__(self, folder_path: str, test_name: str = ""): - self.folder_path = os.path.abspath(folder_path) - self.test_name = test_name.strip() + def __init__(self, args: RunnerArgs): + self.folder_path = args.folder_path + self.test_name = args.test_name - # Validate folder and CSV files - if not os.path.isdir(self.folder_path): - logger.error(f"Project folder does not exist: {self.folder_path}") - raise ValueError(f"Invalid project folder: {self.folder_path}") + # Validate CSV files (test_cases and modules required, elements optional) test_cases_file, modules_file, elements_file = find_csv_files( self.folder_path) - if not all([test_cases_file, modules_file, elements_file]): - missing = [f for f, p in [ - ("test_cases", test_cases_file), - ("modules", modules_file), - ("elements", elements_file) - ] if not p] - logger.error( - f"Missing required CSV files in {self.folder_path}: {', '.join(missing)}") - raise ValueError( - f"Incomplete CSV file set: missing {', '.join(missing)}") # Load CSV data csv_reader = CSVDataReader() self.test_cases_data = csv_reader.read_test_cases(test_cases_file) self.modules_data = csv_reader.read_modules(modules_file) - self.elements_data = csv_reader.read_elements(elements_file) + self.elements_data = csv_reader.read_elements( + elements_file) if elements_file else {} + if not self.test_cases_data: logger.debug(f"No test cases found in {test_cases_file}") - # Load and validate configuration using the new ConfigHandler + # Load and validate configuration using ConfigHandler self.config_handler = ConfigHandler.get_instance() self.config_handler.set_project(self.folder_path) self.config_handler.load() - self.config = self.config_handler.config # Now a Config dataclass instance + self.config = self.config_handler.config # Ensure project_path is set in the Config object self.config.project_path = self.folder_path logger.debug(f"Loaded configuration: {self.config}") - # Check required configs using the new get() method - required_configs = [ - "driver_sources", "elements_sources"] + # Check required configs using the get() method + required_configs = ["driver_sources", "elements_sources"] missing_configs = [ key for key in required_configs if not self.config_handler.get(key)] if missing_configs: @@ -97,8 +113,7 @@ def __init__(self, folder_path: str, test_name: str = ""): # Setup session self.manager = SessionManager() - # Pass the Config object as a dict for session creation - self.session_id = self.manager.create_session(asdict(self.config)) + self.session_id = self.manager.create_session(self.config) self.engine = ExecutionEngine(self.manager) async def run(self, mode: str): @@ -143,13 +158,15 @@ async def execute(self): def execute_main(folder_path: str, test_name: str = ""): """Entry point for execute command.""" - runner = ExecuteRunner(folder_path, test_name) + args = RunnerArgs(folder_path=folder_path, test_name=test_name) + runner = ExecuteRunner(args) asyncio.run(runner.execute()) def dryrun_main(folder_path: str, test_name: str = ""): """Entry point for dry run command.""" - runner = DryRunRunner(folder_path, test_name) + args = RunnerArgs(folder_path=folder_path, test_name=test_name) + runner = DryRunRunner(args) asyncio.run(runner.execute()) diff --git a/pyproject.toml b/pyproject.toml index 3dc1dcda..7b9c77bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ google-cloud-vision = "^3.10.0" sse-starlette = "^2.2.1" fastapi = "^0.115.12" requests = "^2.32.3" +textual = "^3.0.0" [tool.poetry.group.dev.dependencies] tox = "~4.24.1"