Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[mypy]
python_version = 3.8
python_version = 3.9
warn_return_any = true
warn_unused_configs = true
warn_unused_ignores = true
Expand Down
9 changes: 5 additions & 4 deletions pymars/_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import logging
from typing import Optional

import numpy as np

Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(self, earth_model: Earth):

def _calculate_rss_and_coeffs(
self, B_matrix: np.ndarray, y: np.ndarray, *, drop_nan_rows: bool = True
) -> tuple[float, np.ndarray | None, int]:
) -> tuple[float, Optional[np.ndarray], int]:
if B_matrix is None or B_matrix.shape[1] == 0:
mean_y = np.mean(y)
rss = np.sum((y - mean_y)**2)
Expand Down Expand Up @@ -225,7 +226,7 @@ def run(self, X_fit_processed: np.ndarray, y_fit: np.ndarray,

return self.current_basis_functions, self.current_coefficients

def _calculate_gcv_for_basis_set(self, basis_functions: list[BasisFunction]) -> tuple[float | None, np.ndarray | None]:
def _calculate_gcv_for_basis_set(self, basis_functions: list[BasisFunction]) -> tuple[Optional[float], Optional[np.ndarray]]:
if not basis_functions:
# This implies an intercept-only model for GCV calculation purposes
rss_intercept_only = np.sum((self.y_train - np.mean(self.y_train))**2)
Expand Down Expand Up @@ -348,8 +349,8 @@ def _get_allowable_knot_values(self, X_col_original_for_var: np.ndarray, parent_
minspan_countdown = max(0, minspan_abs - 1)
return np.array(final_allowable_knots)

def _generate_candidates(self) -> list[tuple[BasisFunction, BasisFunction | None]]:
candidate_additions: list[tuple[BasisFunction, BasisFunction | None]] = []
def _generate_candidates(self) -> list[tuple[BasisFunction, Optional[BasisFunction]]]:
candidate_additions: list[tuple[BasisFunction, Optional[BasisFunction]]] = []
for parent_bf in self.current_basis_functions:
if parent_bf.degree() + 1 > self.model.max_degree: continue
parent_involved_vars = parent_bf.get_involved_variables()
Expand Down
5 changes: 3 additions & 2 deletions pymars/_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import logging
from typing import Optional

import numpy as np

Expand Down Expand Up @@ -38,7 +39,7 @@ def __init__(self, earth_model: Earth):
self.best_basis_functions_so_far: list[BasisFunction] = []
self.best_coeffs_so_far: np.ndarray = None

def _calculate_rss_and_coeffs(self, B_matrix: np.ndarray, y_data: np.ndarray) -> tuple[float, np.ndarray | None, int]:
def _calculate_rss_and_coeffs(self, B_matrix: np.ndarray, y_data: np.ndarray) -> tuple[float, Optional[np.ndarray], int]:
"""
Calculates RSS, coefficients, and num_valid_rows, considering NaNs in B_matrix.
y_data is assumed finite.
Expand Down Expand Up @@ -97,7 +98,7 @@ def _build_basis_matrix(self, X_data: np.ndarray, basis_functions: list[BasisFun

def _compute_gcv_for_subset(self, X_fit_processed: np.ndarray, y_fit: np.ndarray,
missing_mask: np.ndarray, X_fit_original: np.ndarray,
basis_subset: list[BasisFunction]) -> tuple[float | None, float | None, np.ndarray | None]:
basis_subset: list[BasisFunction]) -> tuple[Optional[float], Optional[float], Optional[np.ndarray]]:
"""
Computes GCV, RSS, and coefficients for a given subset of basis functions.
Returns (gcv, rss, coeffs).
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ exclude = [
[tool.ruff.lint.per-file-ignores]
# Allow print statements in demos
"pymars/demos/*" = ["E501", "S101"]
# Allow assert statements and long lines in tests (asserts are expected in tests)
"tests/*" = ["S101", "E501"]

[tool.ruff.lint.mccabe]
# Unlike Flake8, default to a complexity of 10.
Expand Down
9 changes: 9 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,20 @@
envlist = py38, py39, py310, py311, py312
isolated_build = true

[gh-actions]
python =
3.8: py38
3.9: py39
3.10: py310
3.11: py311
3.12: py312

[testenv]
deps =
pytest
pytest-cov
coverage[toml]
hypothesis
commands =
pytest tests/ --cov=pymars --cov-report=term-missing --cov-report=html
python -c "import coverage; cov = coverage.Coverage(); cov.load(); print(f'Coverage: {cov.report():.2f}%')"
Expand Down