Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a public basis validator #314

Open
BalzaniEdoardo opened this issue Feb 20, 2025 · 1 comment
Open

Add a public basis validator #314

BalzaniEdoardo opened this issue Feb 20, 2025 · 1 comment

Comments

@BalzaniEdoardo
Copy link
Collaborator

A public function that performs some sanity checks on a user object to check if it is a valid basis. In practice, passing this checks with:

  • check_io=True should guarantee that the basis can be composed.
  • require_sklearn=True should check the availability of the estimator machinery get_params and set_params.

Below a draft of the function, in this PR we should make sure that if the checks are true, then the basis can be used with sklearn.

def is_valid_basis(object: Any, check_io: bool | NDArray | Sequence[NDArray]= True, require_sklearn: bool = False):
    """
    Check if the object is valid basis.

    This method checks if an object is a valid basis.

    - If `require_sklearn` and `check_io` are False the methods checks for the availability of the `compute_features`
    method.

    - If `check_io` is set to True, it will check that it can call `compute_features` on
    a linspace between 0 and 1 as a minimal example and checks for the output structure.

    - If `check_io` is one or more numpy array, it calls `compute features` on them and check the output structure.

    - If `require_sklearn`, it checks that the `get_params` and `set_params` methods exist.

    """
    if require_sklearn:
        sklearn_interface = all(hasattr(object, val) for val in ("get_params", "set_params")):
        if not sklearn_interface:
            print("Not compatible with scikit-learn.")
            return False
        
        try:
             new = copy.deepcopy(object)
             new.set_params(**get_params())
        except Exception as e:
             print(f"``get_params`` and ``set_param`` did not work as expected. Error was {e}")
             return False
        try:
            TransformerBasis(object)
        except Exception as e:
            print(f"Cannot convert to a transformer with error:\n{e}")
            return False

    if check_io is False:
        if hasattr(object, "compute_features"):
            return True
        else:
            print(f"Does not implement ``compute_features``.")
            return False

    if check_io is True:
       inp = (np.linspace(0, 1, 5), )
    elif hasattr(check_io, "shape"):
        inp = (check_io,)
    else:
        inp = check_io

    for x in inp:
        try:
            out = object.compute_features(x)
        except Exception as e:
            print(f"Cannot call `compute_features` over {x} with error:\n{e}"
                  f"If your basis require a specific input structure (N-dimensional or more than one input, "
                  f"please pass the input directly, i.e. call ``is_valid_basis(my_basis, check_io=my_input)``.")
            return False

        if not all(hasattr(out, val) for val in ("shape", "ndim")):
            print("The output of compute features is not an array.")
            return False

        elif out.shape[0] != x.shape[0]:
            print(f"The output of ``compute_features`` do not preserve the first axis. The shape"
                  "of the input first array was {x.shape[0}, that of the output was {out.shape[0]} instead!")
            return False
        elif out.ndim != 2:
            print(f"The output of ``compute_features`` is not 2-dimensional. The output dimensionality is {out.ndim}")
            return False
@BalzaniEdoardo
Copy link
Collaborator Author

Also, check if it has a str label. That should always be checked

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant