diff --git a/python-package/lightgbm/__init__.py b/python-package/lightgbm/__init__.py index 5815bc602bde..d311ba934073 100644 --- a/python-package/lightgbm/__init__.py +++ b/python-package/lightgbm/__init__.py @@ -6,7 +6,7 @@ from pathlib import Path from .basic import Booster, Dataset, Sequence, register_logger -from .callback import early_stopping, log_evaluation, record_evaluation, reset_parameter +from .callback import early_stopping, log_evaluation, record_evaluation, reset_parameter, progress_bar from .engine import CVBooster, cv, train try: @@ -32,5 +32,5 @@ 'train', 'cv', 'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker', 'DaskLGBMRegressor', 'DaskLGBMClassifier', 'DaskLGBMRanker', - 'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', + 'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', 'progress_bar', 'plot_importance', 'plot_split_value_histogram', 'plot_metric', 'plot_tree', 'create_tree_digraph'] diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 868a6fc15534..30b08a0b149c 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -1,8 +1,18 @@ # coding: utf-8 """Callbacks library.""" +from __future__ import annotations + import collections +import importlib +import warnings +from collections import OrderedDict from functools import partial from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Literal, Type +try: + import tqdm +except ImportError: + pass from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning @@ -11,6 +21,7 @@ 'log_evaluation', 'record_evaluation', 'reset_parameter', + 'progress_bar', ] _EvalResultDict = Dict[str, Dict[str, List[Any]]] @@ -413,3 +424,149 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos The callback that activates early stopping. """ return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta) + + +class _ProgressBarCallback: + """Internal class to handle progress bar.""" + tqdm_cls: "Type[tqdm.std.tqdm]" + pbar: "tqdm.std.tqdm" | None + + def __init__( + self, + tqdm_cls: Literal[ + "auto", + "autonotebook", + "std", + "notebook", + "asyncio", + "keras", + "dask", + "tk", + "gui", + "rich", + "contrib.slack", + "contrib.discord", + "contrib.telegram", + "contrib.bells", + ] + | "Type[tqdm.std.tqdm]" = "auto", + early_stopping_callback: Any | None = None, + **tqdm_kwargs: Any, + ) -> None: + """Progress bar callback for LightGBM training. + + Parameters + ---------- + tqdm_cls : Literal[ "auto", "autonotebook", "std", "notebook", "asyncio", "keras", "dask", "tk", "gui", "rich", "contrib.slack", "contrib.discord", "contrib.telegram", "contrib.bells", ] | Type[tqdm.std.tqdm], optional + The tqdm class or module name, by default "auto" + early_stopping_callback : _EarlyStoppingCallback | None, optional + The early stopping callback, by default None + + .. rubric:: Example + + .. code-block:: python + early_stopping_callback = early_stopping(stopping_rounds=50) + callbacks = [early_stopping_callback, ProgressBarCallback(early_stopping_callback=early_stopping_callback)] + estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=callbacks) + """ + if isinstance(tqdm_cls, str): + try: + tqdm_module = importlib.import_module(f"tqdm.{tqdm_cls}") + except ImportError as e: + raise ImportError( + f"tqdm needs to be installed to use tqdm.{tqdm_cls}") from e + self.tqdm_cls = getattr(tqdm_module, "tqdm") + else: + self.tqdm_cls = tqdm_cls + self.early_stopping_callback = early_stopping_callback + self.tqdm_kwargs = tqdm_kwargs + if "total" in tqdm_kwargs: + warnings.warn("'total' in tqdm_kwargs is ignored.", UserWarning) + self.pbar = None + + def _init(self, env: CallbackEnv) -> None: + # create pbar on first call + tqdm_kwargs = self.tqdm_kwargs.copy() + tqdm_kwargs["total"] = env.end_iteration - env.begin_iteration + self.pbar = self.tqdm_cls(**tqdm_kwargs) + + def __call__(self, env: CallbackEnv) -> None: + if env.iteration == env.begin_iteration: + self._init(env) + assert self.pbar is not None + + # update postfix + if len(env.evaluation_result_list) > 0: + # If OrderedDict is not used, the order of display is disjointed and slightly difficult to see. + # https://github.com/microsoft/LightGBM/blob/a97c444b4cf9d2755bd888911ce65ace1fe13e4b/python-package/lightgbm/callback.py#L56-66 + if self.early_stopping_callback is not None: + postfix = OrderedDict( + [ + ( + f"{entry[0]}'s {entry[1]}", + f"{entry[2]:g}{'=' if entry[2] == best_score else ('>' if cmp_op else '<')}{best_score:g}@{best_iter}it", + ) + for entry, cmp_op, best_score, best_iter in zip( + env.evaluation_result_list, + self.early_stopping_callback.cmp_op, + self.early_stopping_callback.best_score, + self.early_stopping_callback.best_iter, + ) + ] + ) + else: + postfix = OrderedDict( + [ + (f"{entry[0]}'s {entry[1]}", f"{entry[2]:g}") + for entry in env.evaluation_result_list + ] + ) + self.pbar.set_postfix(ordered_dict=postfix, refresh=False) + + # update pbar + self.pbar.update() + self.pbar.refresh() + + +def progress_bar(tqdm_cls: Literal[ + "auto", + "autonotebook", + "std", + "notebook", + "asyncio", + "keras", + "dask", + "tk", + "gui", + "rich", + "contrib.slack", + "contrib.discord", + "contrib.telegram", + "contrib.bells", +] + | "Type[tqdm.std.tqdm]" = "auto", + early_stopping_callback: _EarlyStoppingCallback | None = None, + **tqdm_kwargs: Any, +) -> _ProgressBarCallback: + """Progress bar callback for LightGBM training. + + Parameters + ---------- + tqdm_cls : Literal[ "auto", "autonotebook", "std", "notebook", "asyncio", "keras", "dask", "tk", "gui", "rich", "contrib.slack", "contrib.discord", "contrib.telegram", "contrib.bells", ] | Type[tqdm.std.tqdm], optional + The tqdm class or module name, by default "auto" + early_stopping_callback : Any | None, optional + The early stopping callback, by default None + + .. rubric:: Example + + .. code-block:: python + early_stopping_callback = early_stopping(stopping_rounds=50) + callbacks = [early_stopping_callback, progress_bar(early_stopping_callback=early_stopping_callback)] + estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=callbacks) + + Returns + ------- + callback : _ProgressBarCallback + The callback that displays the progress bar. + """ + return _ProgressBarCallback(tqdm_cls=tqdm_cls, early_stopping_callback=early_stopping_callback, **tqdm_kwargs) diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index cb5dc707bf43..5cfb0a7cbbd9 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -55,3 +55,17 @@ def test_reset_parameter_callback_is_picklable(serializer): assert callback_from_disk.before_iteration is True assert callback.kwargs == callback_from_disk.kwargs assert callback.kwargs == params + +@pytest.mark.parametrize('serializer', SERIALIZERS) +def test_progress_bar_callback_is_picklable(serializer): + rounds = 5 + callback = lgb.progress_bar() + callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer) + assert callback_from_disk.order == 30 + assert callback_from_disk.before_iteration is False + assert callback.stopping_rounds == callback_from_disk.stopping_rounds + assert callback.stopping_rounds == rounds + +def test_progress_bar_warn_override(self) -> None: + with pytest.warns(UserWarning): + lgb.progress_bar(self.tqdm_cls, total=100, **self.tqdm_kwargs)