Skip to content

Commit

Permalink
feat(callback): add progress_bar callback
Browse files Browse the repository at this point in the history
  • Loading branch information
34j committed May 5, 2023
1 parent a97c444 commit 0065605
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python-package/lightgbm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path

from .basic import Booster, Dataset, Sequence, register_logger
from .callback import early_stopping, log_evaluation, record_evaluation, reset_parameter
from .callback import early_stopping, log_evaluation, record_evaluation, reset_parameter, progress_bar
from .engine import CVBooster, cv, train

try:
Expand All @@ -32,5 +32,5 @@
'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'DaskLGBMRegressor', 'DaskLGBMClassifier', 'DaskLGBMRanker',
'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', 'progress_bar',
'plot_importance', 'plot_split_value_histogram', 'plot_metric', 'plot_tree', 'create_tree_digraph']
157 changes: 157 additions & 0 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
# coding: utf-8
"""Callbacks library."""
from __future__ import annotations

import collections
import importlib
import warnings
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Literal, Type
try:
import tqdm
except ImportError:
pass

from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning

Expand All @@ -11,6 +21,7 @@
'log_evaluation',
'record_evaluation',
'reset_parameter',
'progress_bar',
]

_EvalResultDict = Dict[str, Dict[str, List[Any]]]
Expand Down Expand Up @@ -413,3 +424,149 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
The callback that activates early stopping.
"""
return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta)


class _ProgressBarCallback:
"""Internal class to handle progress bar."""
tqdm_cls: "Type[tqdm.std.tqdm]"
pbar: "tqdm.std.tqdm" | None

def __init__(
self,
tqdm_cls: Literal[
"auto",
"autonotebook",
"std",
"notebook",
"asyncio",
"keras",
"dask",
"tk",
"gui",
"rich",
"contrib.slack",
"contrib.discord",
"contrib.telegram",
"contrib.bells",
]
| "Type[tqdm.std.tqdm]" = "auto",
early_stopping_callback: Any | None = None,
**tqdm_kwargs: Any,
) -> None:
"""Progress bar callback for LightGBM training.
Parameters
----------
tqdm_cls : Literal[ "auto", "autonotebook", "std", "notebook", "asyncio", "keras", "dask", "tk", "gui", "rich", "contrib.slack", "contrib.discord", "contrib.telegram", "contrib.bells", ] | Type[tqdm.std.tqdm], optional
The tqdm class or module name, by default "auto"
early_stopping_callback : _EarlyStoppingCallback | None, optional
The early stopping callback, by default None
.. rubric:: Example
.. code-block:: python
early_stopping_callback = early_stopping(stopping_rounds=50)
callbacks = [early_stopping_callback, ProgressBarCallback(early_stopping_callback=early_stopping_callback)]
estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=callbacks)
"""
if isinstance(tqdm_cls, str):
try:
tqdm_module = importlib.import_module(f"tqdm.{tqdm_cls}")
except ImportError as e:
raise ImportError(
f"tqdm needs to be installed to use tqdm.{tqdm_cls}") from e
self.tqdm_cls = getattr(tqdm_module, "tqdm")
else:
self.tqdm_cls = tqdm_cls
self.early_stopping_callback = early_stopping_callback
self.tqdm_kwargs = tqdm_kwargs
if "total" in tqdm_kwargs:
warnings.warn("'total' in tqdm_kwargs is ignored.", UserWarning)
self.pbar = None

def _init(self, env: CallbackEnv) -> None:
# create pbar on first call
tqdm_kwargs = self.tqdm_kwargs.copy()
tqdm_kwargs["total"] = env.end_iteration - env.begin_iteration
self.pbar = self.tqdm_cls(**tqdm_kwargs)

def __call__(self, env: CallbackEnv) -> None:
if env.iteration == env.begin_iteration:
self._init(env)
assert self.pbar is not None

# update postfix
if len(env.evaluation_result_list) > 0:
# If OrderedDict is not used, the order of display is disjointed and slightly difficult to see.
# https://github.com/microsoft/LightGBM/blob/a97c444b4cf9d2755bd888911ce65ace1fe13e4b/python-package/lightgbm/callback.py#L56-66
if self.early_stopping_callback is not None:
postfix = OrderedDict(
[
(
f"{entry[0]}'s {entry[1]}",
f"{entry[2]:g}{'=' if entry[2] == best_score else ('>' if cmp_op else '<')}{best_score:g}@{best_iter}it",
)
for entry, cmp_op, best_score, best_iter in zip(
env.evaluation_result_list,
self.early_stopping_callback.cmp_op,
self.early_stopping_callback.best_score,
self.early_stopping_callback.best_iter,
)
]
)
else:
postfix = OrderedDict(
[
(f"{entry[0]}'s {entry[1]}", f"{entry[2]:g}")
for entry in env.evaluation_result_list
]
)
self.pbar.set_postfix(ordered_dict=postfix, refresh=False)

# update pbar
self.pbar.update()
self.pbar.refresh()


def progress_bar(tqdm_cls: Literal[
"auto",
"autonotebook",
"std",
"notebook",
"asyncio",
"keras",
"dask",
"tk",
"gui",
"rich",
"contrib.slack",
"contrib.discord",
"contrib.telegram",
"contrib.bells",
]
| "Type[tqdm.std.tqdm]" = "auto",
early_stopping_callback: _EarlyStoppingCallback | None = None,
**tqdm_kwargs: Any,
) -> _ProgressBarCallback:
"""Progress bar callback for LightGBM training.
Parameters
----------
tqdm_cls : Literal[ &quot;auto&quot;, &quot;autonotebook&quot;, &quot;std&quot;, &quot;notebook&quot;, &quot;asyncio&quot;, &quot;keras&quot;, &quot;dask&quot;, &quot;tk&quot;, &quot;gui&quot;, &quot;rich&quot;, &quot;contrib.slack&quot;, &quot;contrib.discord&quot;, &quot;contrib.telegram&quot;, &quot;contrib.bells&quot;, ] | Type[tqdm.std.tqdm], optional
The tqdm class or module name, by default &quot;auto&quot;
early_stopping_callback : Any | None, optional
The early stopping callback, by default None
.. rubric:: Example
.. code-block:: python
early_stopping_callback = early_stopping(stopping_rounds=50)
callbacks = [early_stopping_callback, progress_bar(early_stopping_callback=early_stopping_callback)]
estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)], callbacks=callbacks)
Returns
-------
callback : _ProgressBarCallback
The callback that displays the progress bar.
"""
return _ProgressBarCallback(tqdm_cls=tqdm_cls, early_stopping_callback=early_stopping_callback, **tqdm_kwargs)
14 changes: 14 additions & 0 deletions tests/python_package_test/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,17 @@ def test_reset_parameter_callback_is_picklable(serializer):
assert callback_from_disk.before_iteration is True
assert callback.kwargs == callback_from_disk.kwargs
assert callback.kwargs == params

@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_progress_bar_callback_is_picklable(serializer):
rounds = 5
callback = lgb.progress_bar()
callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
assert callback_from_disk.order == 30
assert callback_from_disk.before_iteration is False
assert callback.stopping_rounds == callback_from_disk.stopping_rounds
assert callback.stopping_rounds == rounds

def test_progress_bar_warn_override(self) -> None:
with pytest.warns(UserWarning):
lgb.progress_bar(self.tqdm_cls, total=100, **self.tqdm_kwargs)

0 comments on commit 0065605

Please sign in to comment.