Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,31 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, 3.9, '3.10', '3.11', '3.12']
python-version: [3.8, 3.9, '3.10']

steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Install importlib-metadata backport for Python < 3.10
if [[ "${{ matrix.python-version }}" == "3.8" ]] || [[ "${{ matrix.python-version }}" == "3.9" ]]; then
pip install importlib-metadata
fi
pip install tox tox-gh-actions
- name: Test with tox
run: tox -e py

lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install dependencies
Expand All @@ -44,9 +48,9 @@ jobs:
type-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install dependencies
Expand All @@ -59,9 +63,9 @@ jobs:
coverage:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install dependencies
Expand All @@ -80,9 +84,9 @@ jobs:
format:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install dependencies
Expand Down
6 changes: 6 additions & 0 deletions pymars/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
try:
import importlib.metadata as importlib_metadata # type: ignore
except Exception:
import importlib_metadata # type: ignore

__all__ = ("importlib_metadata",)
9 changes: 5 additions & 4 deletions pymars/_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import logging
from typing import Union

import numpy as np

Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(self, earth_model: Earth):

def _calculate_rss_and_coeffs(
self, B_matrix: np.ndarray, y: np.ndarray, *, drop_nan_rows: bool = True
) -> tuple[float, np.ndarray | None, int]:
) -> tuple[float, Union[np.ndarray, None], int]:
if B_matrix is None or B_matrix.shape[1] == 0:
mean_y = np.mean(y)
rss = np.sum((y - mean_y)**2)
Expand Down Expand Up @@ -225,7 +226,7 @@ def run(self, X_fit_processed: np.ndarray, y_fit: np.ndarray,

return self.current_basis_functions, self.current_coefficients

def _calculate_gcv_for_basis_set(self, basis_functions: list[BasisFunction]) -> tuple[float | None, np.ndarray | None]:
def _calculate_gcv_for_basis_set(self, basis_functions: list[BasisFunction]) -> tuple[Union[float, None], Union[np.ndarray, None]]:
if not basis_functions:
# This implies an intercept-only model for GCV calculation purposes
rss_intercept_only = np.sum((self.y_train - np.mean(self.y_train))**2)
Expand Down Expand Up @@ -348,8 +349,8 @@ def _get_allowable_knot_values(self, X_col_original_for_var: np.ndarray, parent_
minspan_countdown = max(0, minspan_abs - 1)
return np.array(final_allowable_knots)

def _generate_candidates(self) -> list[tuple[BasisFunction, BasisFunction | None]]:
candidate_additions: list[tuple[BasisFunction, BasisFunction | None]] = []
def _generate_candidates(self) -> list[tuple[BasisFunction, Union[BasisFunction, None]]]:
candidate_additions: list[tuple[BasisFunction, Union[BasisFunction, None]]] = []
for parent_bf in self.current_basis_functions:
if parent_bf.degree() + 1 > self.model.max_degree: continue
parent_involved_vars = parent_bf.get_involved_variables()
Expand Down
5 changes: 3 additions & 2 deletions pymars/_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import logging
from typing import Union

import numpy as np

Expand Down Expand Up @@ -38,7 +39,7 @@ def __init__(self, earth_model: Earth):
self.best_basis_functions_so_far: list[BasisFunction] = []
self.best_coeffs_so_far: np.ndarray = None

def _calculate_rss_and_coeffs(self, B_matrix: np.ndarray, y_data: np.ndarray) -> tuple[float, np.ndarray | None, int]:
def _calculate_rss_and_coeffs(self, B_matrix: np.ndarray, y_data: np.ndarray) -> tuple[float, Union[np.ndarray, None], int]:
"""
Calculates RSS, coefficients, and num_valid_rows, considering NaNs in B_matrix.
y_data is assumed finite.
Expand Down Expand Up @@ -97,7 +98,7 @@ def _build_basis_matrix(self, X_data: np.ndarray, basis_functions: list[BasisFun

def _compute_gcv_for_subset(self, X_fit_processed: np.ndarray, y_fit: np.ndarray,
missing_mask: np.ndarray, X_fit_original: np.ndarray,
basis_subset: list[BasisFunction]) -> tuple[float | None, float | None, np.ndarray | None]:
basis_subset: list[BasisFunction]) -> tuple[Union[float, None], Union[float, None], Union[np.ndarray, None]]:
"""
Computes GCV, RSS, and coefficients for a given subset of basis functions.
Returns (gcv, rss, coeffs).
Expand Down