diff --git a/mypy.ini b/mypy.ini index 522d42e..8bd2791 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,5 +1,5 @@ [mypy] -python_version = 3.8 +python_version = 3.9 warn_return_any = true warn_unused_configs = true warn_unused_ignores = true diff --git a/pymars/_forward.py b/pymars/_forward.py index 7ada1e0..74b3772 100644 --- a/pymars/_forward.py +++ b/pymars/_forward.py @@ -7,6 +7,7 @@ """ import logging +from typing import Optional import numpy as np @@ -53,7 +54,7 @@ def __init__(self, earth_model: Earth): def _calculate_rss_and_coeffs( self, B_matrix: np.ndarray, y: np.ndarray, *, drop_nan_rows: bool = True - ) -> tuple[float, np.ndarray | None, int]: + ) -> tuple[float, Optional[np.ndarray], int]: if B_matrix is None or B_matrix.shape[1] == 0: mean_y = np.mean(y) rss = np.sum((y - mean_y)**2) @@ -225,7 +226,7 @@ def run(self, X_fit_processed: np.ndarray, y_fit: np.ndarray, return self.current_basis_functions, self.current_coefficients - def _calculate_gcv_for_basis_set(self, basis_functions: list[BasisFunction]) -> tuple[float | None, np.ndarray | None]: + def _calculate_gcv_for_basis_set(self, basis_functions: list[BasisFunction]) -> tuple[Optional[float], Optional[np.ndarray]]: if not basis_functions: # This implies an intercept-only model for GCV calculation purposes rss_intercept_only = np.sum((self.y_train - np.mean(self.y_train))**2) @@ -348,8 +349,8 @@ def _get_allowable_knot_values(self, X_col_original_for_var: np.ndarray, parent_ minspan_countdown = max(0, minspan_abs - 1) return np.array(final_allowable_knots) - def _generate_candidates(self) -> list[tuple[BasisFunction, BasisFunction | None]]: - candidate_additions: list[tuple[BasisFunction, BasisFunction | None]] = [] + def _generate_candidates(self) -> list[tuple[BasisFunction, Optional[BasisFunction]]]: + candidate_additions: list[tuple[BasisFunction, Optional[BasisFunction]]] = [] for parent_bf in self.current_basis_functions: if parent_bf.degree() + 1 > self.model.max_degree: continue parent_involved_vars = parent_bf.get_involved_variables() diff --git a/pymars/_pruning.py b/pymars/_pruning.py index 07bb477..d086c62 100644 --- a/pymars/_pruning.py +++ b/pymars/_pruning.py @@ -8,6 +8,7 @@ """ import logging +from typing import Optional import numpy as np @@ -38,7 +39,7 @@ def __init__(self, earth_model: Earth): self.best_basis_functions_so_far: list[BasisFunction] = [] self.best_coeffs_so_far: np.ndarray = None - def _calculate_rss_and_coeffs(self, B_matrix: np.ndarray, y_data: np.ndarray) -> tuple[float, np.ndarray | None, int]: + def _calculate_rss_and_coeffs(self, B_matrix: np.ndarray, y_data: np.ndarray) -> tuple[float, Optional[np.ndarray], int]: """ Calculates RSS, coefficients, and num_valid_rows, considering NaNs in B_matrix. y_data is assumed finite. @@ -97,7 +98,7 @@ def _build_basis_matrix(self, X_data: np.ndarray, basis_functions: list[BasisFun def _compute_gcv_for_subset(self, X_fit_processed: np.ndarray, y_fit: np.ndarray, missing_mask: np.ndarray, X_fit_original: np.ndarray, - basis_subset: list[BasisFunction]) -> tuple[float | None, float | None, np.ndarray | None]: + basis_subset: list[BasisFunction]) -> tuple[Optional[float], Optional[float], Optional[np.ndarray]]: """ Computes GCV, RSS, and coefficients for a given subset of basis functions. Returns (gcv, rss, coeffs). diff --git a/pyproject.toml b/pyproject.toml index 07a6159..b96097b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,8 @@ exclude = [ [tool.ruff.lint.per-file-ignores] # Allow print statements in demos "pymars/demos/*" = ["E501", "S101"] +# Allow assert statements and long lines in tests (asserts are expected in tests) +"tests/*" = ["S101", "E501"] [tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity of 10. diff --git a/tox.ini b/tox.ini index de1aa2c..0285f38 100644 --- a/tox.ini +++ b/tox.ini @@ -2,11 +2,20 @@ envlist = py38, py39, py310, py311, py312 isolated_build = true +[gh-actions] +python = + 3.8: py38 + 3.9: py39 + 3.10: py310 + 3.11: py311 + 3.12: py312 + [testenv] deps = pytest pytest-cov coverage[toml] + hypothesis commands = pytest tests/ --cov=pymars --cov-report=term-missing --cov-report=html python -c "import coverage; cov = coverage.Coverage(); cov.load(); print(f'Coverage: {cov.report():.2f}%')"