Skip to content

Commit

Permalink
chore: testing and coverage (#15)
Browse files Browse the repository at this point in the history
* docstrings

* parsers test

* rendering test

* cli test

* app test

* mypy and ruff on tests
  • Loading branch information
FBruzzesi authored Jun 7, 2024
1 parent 300aace commit 3e83b9e
Show file tree
Hide file tree
Showing 19 changed files with 850 additions and 159 deletions.
33 changes: 0 additions & 33 deletions .devcontainer/devcontainer.json

This file was deleted.

6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@ lint:
ruff check $(sources) --fix
ruff clean

# Requires pytest-xdist (pip install pytest-xdist)
test:
pytest tests -n auto

# Requires pytest-cov (pip install pytest-cov)
test-cov:
pytest tests --cov=sksmithy -n auto

# Requires coverage (pip install coverage)
coverage:
rm -rf .coverage
(rm docs/img/coverage.svg) || (echo "No coverage.svg file found")
Expand Down
19 changes: 11 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ smith --help
```

```terminal
Usage: smith [OPTIONS] COMMAND [ARGS]...
Awesome CLI to generate scikit-learn estimator boilerplate code
Usage: smith [OPTIONS] COMMAND [ARGS]...
CLI to generate scikit-learn estimator boilerplate code
...
╭─ Commands ─────────────────────────────────────────────────────────────────────────────╮
│ forge Generate a new shiny scikit-learn compatible estimator ✨ │
│ version Display library version. │
Expand All @@ -78,15 +80,16 @@ Generate a new shiny scikit-learn compatible estimator ✨
Depending on the estimator type the following additional information could be required:
* if the estimator is linear (classifier or regression)
* if the estimator has a `predict_proba` method (classifier or outlier detector)
* is the estimator has a `decision_function` method (classifier only)
* if the estimator implements `.predict_proba()` method (classifier or outlier detector)
* if the estimator implements `.decision_function()` method (classifier only)
Finally, the following two questions will be prompt:
* if the estimator should have tags (To know more about tags, check the dedicated scikit-learn documentation
at https://scikit-learn.org/dev/developers/develop.html#estimator-tags
* in which file the class should be saved (default is `f'{name.lower()}.py'`)
at https://scikit-learn.org/dev/developers/develop.html#estimator-tags)
* in which file the class should be saved (default is `f'{name.lower()}.py'`)
╭─ Options ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ * --name TEXT Name of the estimator [default: None] [required] │
│ * --estimator-type [classifier|outlier|regressor|transformer|cluster] Estimator type [default: None] [required] │
Expand Down
11 changes: 8 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ dependencies = [
"jinja2>=3.0.0",
"result>=0.16.0",
"ruff>=0.4.0",
"typing-extensions>=4.4.0; python_version < '3.11'",
]

classifiers = [
Expand All @@ -34,7 +33,7 @@ classifiers = [
streamlit = ["streamlit>=1.34.0"]

[project.scripts]
smith = "sksmithy.__main__:app"
smith = "sksmithy.__main__:cli"

[tool.hatch.build.targets.sdist]
only-include = ["sksmithy"]
Expand Down Expand Up @@ -62,7 +61,7 @@ ignore = [
]

[tool.ruff.lint.per-file-ignores]
"tests/*" = ["ANN", "D", "N802", "N806", "PD901", "PT006", "PT007", "PLR0913", "S101"]
"tests/*" = ["D103","S101"]

[tool.ruff.lint.pydocstyle]
convention = "numpy"
Expand All @@ -79,3 +78,9 @@ python_version = "3.10"

[tool.coverage.run]
source = ["sksmithy/"]
omit = [
"sksmithy/__main__.py",
"sksmithy/_arguments.py",
"sksmithy/_logger.py",
"sksmithy/_prompts.py",
]
85 changes: 2 additions & 83 deletions sksmithy/__main__.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,4 @@
from pathlib import Path

import typer

from sksmithy._arguments import (
decision_function_arg,
estimator_type_arg,
linear_arg,
name_arg,
optional_params_arg,
output_file_arg,
predict_proba_arg,
required_params_arg,
sample_weight_arg,
tags_arg,
)
from sksmithy._logger import console
from sksmithy._utils import render_template

app = typer.Typer(
help="Awesome CLI to generate scikit-learn estimator boilerplate code",
rich_markup_mode="rich",
rich_help_panel="Customization and Utils",
)


@app.command()
def version() -> None:
"""Display library version."""
from importlib import metadata

__version__ = metadata.version("sklearn-smithy")
console.print(f"sklearn-smithy {__version__}", style="good")


@app.command()
def forge(
name: name_arg,
estimator_type: estimator_type_arg,
required_params: required_params_arg = "",
optional_params: optional_params_arg = "",
sample_weight: sample_weight_arg = False,
linear: linear_arg = False,
predict_proba: predict_proba_arg = False,
decision_function: decision_function_arg = False,
tags: tags_arg = "",
output_file: output_file_arg = "",
) -> None:
"""Generate a new shiny scikit-learn compatible estimator ✨
Depending on the estimator type the following additional information could be required:
* if the estimator is linear (classifier or regression)
* if the estimator has a `predict_proba` method (classifier or outlier detector)
* is the estimator has a `decision_function` method (classifier only)
Finally, the following two questions will be prompt:
* if the estimator should have tags (To know more about tags, check the dedicated scikit-learn documentation
at https://scikit-learn.org/dev/developers/develop.html#estimator-tags
* in which file the class should be saved (default is `f'{name.lower()}.py'`)
"""
forged_template = render_template(
name=name,
estimator_type=estimator_type,
required=required_params, # type: ignore[arg-type] # Callback transforms it into `list[str]`
optional=optional_params, # type: ignore[arg-type] # Callback transforms it into `list[str]`
linear=linear,
sample_weight=sample_weight,
predict_proba=predict_proba,
decision_function=decision_function,
tags=tags, # type: ignore[arg-type] # Callback transforms it into `list[str]`
)

destination_file = Path(output_file)
destination_file.parent.mkdir(parents=True, exist_ok=True)

with destination_file.open(mode="w") as destination:
destination.write(forged_template)

console.print(f"Template forged at {destination_file}", style="good")

from sksmithy.cli import cli

if __name__ == "__main__":
app()
cli()
7 changes: 2 additions & 5 deletions sksmithy/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typer import Option

from sksmithy._callbacks import estimator_callback, name_callback, params_callback, tags_callback
from sksmithy._callbacks import estimator_callback, linear_callback, name_callback, params_callback, tags_callback
from sksmithy._models import EstimatorType
from sksmithy._prompts import (
PROMPT_DECISION_FUNCTION,
Expand All @@ -21,7 +21,6 @@
str,
Option(
prompt=PROMPT_NAME,
prompt_required=False,
help="[bold green]Name[/bold green] of the estimator",
callback=name_callback,
),
Expand All @@ -31,7 +30,6 @@
EstimatorType,
Option(
prompt=PROMPT_ESTIMATOR,
prompt_required=False,
help="[bold green]Estimator type[/bold green]",
callback=estimator_callback,
),
Expand Down Expand Up @@ -60,7 +58,6 @@
Option(
is_flag=True,
prompt=PROMPT_SAMPLE_WEIGHT,
prompt_required=False,
help="Whether or not `.fit()` supports [bold green]`sample_weight`[/bold green]",
),
]
Expand All @@ -71,6 +68,7 @@
is_flag=True,
prompt=PROMPT_LINEAR,
help="Whether or not the estimator is [bold green]linear[/bold green]",
callback=linear_callback,
),
]

Expand All @@ -96,7 +94,6 @@
str,
Option(
prompt=PROMPT_TAGS,
prompt_required=False,
help="List of optional extra scikit-learn [bold green]tags[/bold green]",
callback=tags_callback,
),
Expand Down
72 changes: 67 additions & 5 deletions sksmithy/_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,39 @@ def _parse_wrapper(
*args: PS.args,
**kwargs: PS.kwargs,
) -> tuple[Context, CallbackParam, R]:
"""Wrap a parser to handle 'caching' logic."""
"""Wrap a parser to handle 'caching' logic.
`parser` should return a Result[R, str]
Parameters
----------
ctx
Typer context.
param
Callback parameter information.
value
Input for the parser callable.
parser
Parser function, it should return Result[R, str]
*args
Extra args for `parser`.
**kwargs
Extra kwargs for `parser`.
Returns
-------
ctx : Context
Typer context updated with extra information.
param : CallbackParam
Unchanged callback parameters.
result_value : R
Parsed value.
Raises
------
BadParameter
If parser returns Err(msg)
"""
if not ctx.obj:
ctx.obj = {}

Expand Down Expand Up @@ -52,7 +84,7 @@ def name_callback(ctx: Context, param: CallbackParam, value: str) -> str:


def params_callback(ctx: Context, param: CallbackParam, value: str) -> list[str]:
"""`required-params` and `optional-params` arguments callback."""
"""`required_params` and `optional_params` arguments callback."""
ctx, param, parsed_params = _parse_wrapper(ctx, param, value, params_parser)

if param.name == "optional_params" and (
Expand All @@ -76,9 +108,13 @@ def tags_callback(ctx: Context, param: CallbackParam, value: str) -> list[str]:
def estimator_callback(ctx: Context, param: CallbackParam, estimator: EstimatorType) -> str:
"""`estimator_type` argument callback.
It dynamically modifies the behaviour of the rest of the prompts based on its value.
It dynamically modifies the behaviour of the rest of the prompts based on its value:
- If not classifier or regressor, turns off linear prompt.
- If not classifier or outlier, turns off predict_proba prompt.
- If not classifier, turns off decision_function prompt.
"""
if not ctx.obj:
if not ctx.obj: # pragma: no cover
ctx.obj = {}

if param.name in ctx.obj:
Expand All @@ -87,7 +123,7 @@ def estimator_callback(ctx: Context, param: CallbackParam, estimator: EstimatorT
# !Warning: This unpacking relies on the order of the arguments in the forge command to be in the same order.
# Is there a better/more robust way of dealing with it?
linear, predict_proba, decision_function = (
op for op in ctx.command.params if op.name in {"linear", "predict_proba", "decision_function"}
opt for opt in ctx.command.params if opt.name in {"linear", "predict_proba", "decision_function"}
)

match estimator:
Expand All @@ -114,3 +150,29 @@ def estimator_callback(ctx: Context, param: CallbackParam, estimator: EstimatorT
ctx.obj[param.name] = estimator.value

return estimator.value


def linear_callback(ctx: Context, param: CallbackParam, linear: bool) -> bool:
"""`linear` argument callback.
It dynamically modifies the behaviour of the rest of the prompts based on its value: if the estimator is linear,
then `decision_function` method is already implemented for a classifier.
"""
if not ctx.obj: # pragma: no cover
ctx.obj = {}

if param.name in ctx.obj: # pragma: no cover
return ctx.obj[param.name]

decision_function = next(opt for opt in ctx.command.params if opt.name == "decision_function")

match linear:
case True:
decision_function.prompt = False # type: ignore[attr-defined]
decision_function.prompt_required = False # type: ignore[attr-defined]
case False:
pass

ctx.obj[param.name] = linear

return linear
9 changes: 7 additions & 2 deletions sksmithy/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@


class EstimatorType(str, Enum):
"""List of possible estimator types."""
"""List of possible estimator types.
The reason of naming the enum with the mixin class is to simplify and have a convenient way of using the enum to
render the jinja template with the class to import.
"""

ClassifierMixin = "classifier"
OutlierMixin = "outlier"
Expand All @@ -14,7 +18,8 @@ class EstimatorType(str, Enum):
class TagType(str, Enum):
"""List of extra tags.
Description of each tag is available at https://scikit-learn.org/dev/developers/develop.html#estimator-tags.
Description of each tag is available in the dedicated section of the scikit-learn documentation:
[estimator tags](https://scikit-learn.org/dev/developers/develop.html#estimator-tags).
"""

allow_nan = "allow_nan"
Expand Down
Loading

0 comments on commit 3e83b9e

Please sign in to comment.