Skip to content

Commit ad74985

Browse files
AlexandreKempfDave Berenbaumpre-commit-ci[bot]
authored
Add docstrings to public functions (#767)
* add docstrings to public functions in dvclive --------- Co-authored-by: Dave Berenbaum <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b1e6bf4 commit ad74985

File tree

1 file changed

+260
-6
lines changed

1 file changed

+260
-6
lines changed

src/dvclive/live.py

+260-6
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,44 @@ def __init__(
7979
exp_name: Optional[str] = None,
8080
exp_message: Optional[str] = None,
8181
):
82+
"""
83+
Initializes a DVCLive logger. A `Live()` instance is required in order to log
84+
machine learning parameters, metrics and other metadata.
85+
Warning: `Live()` will remove all existing DVCLive related files under dir
86+
unless `resume=True`.
87+
88+
Args:
89+
dir (str | Path): where to save DVCLive's outputs. Defaults to `"dvclive"`.
90+
resume (bool): if `True`, DVCLive will try to read the previous step from
91+
the metrics_file and start from that point. Defaults to `False`.
92+
report ("html", "md", "notebook", None): any of `"html"`, `"notebook"`,
93+
`"md"` or `None`. See `Live.make_report()`. Defaults to None.
94+
save_dvc_exp (bool): if `True`, DVCLive will create a new DVC experiment as
95+
part of `Live.end()`. Defaults to `True`. If you are using DVCLive
96+
inside a DVC Pipeline and running with `dvc exp run`, the option will be
97+
ignored.
98+
dvcyaml (str | None): where to write dvc.yaml file, which adds DVC
99+
configuration for metrics, plots, and parameters as part of
100+
`Live.next_step()` and `Live.end()`. If `None`, no dvc.yaml file is
101+
written. Defaults to `"dvc.yaml"`. See `Live.make_dvcyaml()`.
102+
If a string like `"subdir/dvc.yaml"`, DVCLive will write the
103+
configuration to that path (file must be named "dvc.yaml").
104+
If `False`, DVCLive will not write to "dvc.yaml" (useful if you are
105+
tracking DVCLive metrics, plots, and parameters independently and
106+
want to avoid duplication).
107+
cache_images (bool): if `True`, DVCLive will cache any images logged with
108+
`Live.log_image()` as part of `Live.end()`. Defaults to `False`.
109+
If running a DVC pipeline, `cache_images` will be ignored, and you
110+
should instead cache images as pipeline outputs.
111+
exp_name (str | None): if not `None`, and `save_dvc_exp` is `True`, the
112+
provided string will be passed to `dvc exp save --name`.
113+
If DVCLive is used inside `dvc exp run`, the option will be ignored, use
114+
`dvc exp run --name` instead.
115+
exp_message (str | None): if not `None`, and `save_dvc_exp` is `True`, the
116+
provided string will be passed to `dvc exp save --message`.
117+
If DVCLive is used inside `dvc exp run`, the option will be ignored, use
118+
`dvc exp run --message` instead.
119+
"""
82120
self.summary: Dict[str, Any] = {}
83121

84122
self._dir: str = dir
@@ -283,10 +321,11 @@ def _init_report(self):
283321

284322
def _init_test(self):
285323
"""
286-
Enables test mode that writes to temp paths and doesn't depend on repo.
324+
Enables a test mode that writes to temporary paths and doesn't depend on the
325+
repository.
287326
288-
Needed to run integration tests in external libraries like huggingface
289-
accelerate.
327+
This is needed to run integration tests in external libraries, such as
328+
HuggingFace Accelerate.
290329
"""
291330
with tempfile.TemporaryDirectory() as dirpath:
292331
self._dir = os.path.join(dirpath, self._dir)
@@ -300,6 +339,7 @@ def _init_test(self):
300339

301340
@property
302341
def dir(self) -> str: # noqa: A003
342+
"""Location of the directory to store outputs."""
303343
return self._dir
304344

305345
@property
@@ -312,6 +352,7 @@ def metrics_file(self) -> str:
312352

313353
@property
314354
def dvc_file(self) -> str:
355+
"""Path for dvc.yaml file."""
315356
return self._dvc_file
316357

317358
@property
@@ -349,6 +390,14 @@ def sync(self):
349390
self.post_to_studio("data")
350391

351392
def next_step(self):
393+
"""
394+
Signals that the current iteration has ended and increases step value by one.
395+
DVCLive uses `step` to track the history of the metrics logged with
396+
`Live.log_metric()`.
397+
You can use `Live.next_step()` to increase the step by one. In addition to
398+
increasing the `step` number, it will call `Live.make_report()`,
399+
`Live.make_dvcyaml()`, and `Live.make_summary()` by default.
400+
"""
352401
if self._step is None:
353402
self._step = 0
354403

@@ -363,6 +412,28 @@ def log_metric(
363412
timestamp: bool = False,
364413
plot: bool = True,
365414
):
415+
"""
416+
On each `Live.log_metric(name, val)` call `DVCLive` will create a metrics
417+
history file in `{Live.plots_dir}/metrics/{name}.tsv`. Each subsequent call to
418+
`Live.log_metric(name, val)` will add a new row to
419+
`{Live.plots_dir}/metrics/{name}.tsv`. In addition, `DVCLive` will store the
420+
latest value logged in `Live.summary`, so it can be serialized with calls to
421+
`live.make_summary()`, `live.next_step()` or when exiting the `Live` context
422+
block.
423+
424+
Args:
425+
name (str): name of the metric being logged.
426+
val (int | float | str): the value to be logged.
427+
timestamp (bool): whether to automatically log timestamp in the metrics
428+
history file.
429+
plot (bool): whether to add the metric value to the metrics history file for
430+
plotting. If `False`, the metric will only be saved to the metrics
431+
summary.
432+
433+
Raises:
434+
`InvalidDataTypeError`: thrown if the provided `val` does not have a
435+
supported type.
436+
"""
366437
if not Metric.could_log(val):
367438
raise InvalidDataTypeError(name, type(val))
368439

@@ -387,6 +458,36 @@ def log_image(
387458
name: str,
388459
val: Union[np.ndarray, matplotlib.figure.Figure, PIL.Image, StrPath],
389460
):
461+
"""
462+
Saves the given image `val` to the output file `name`.
463+
464+
Supported values for val are:
465+
- A valid NumPy array (convertible to an image via `PIL.Image.fromarray`)
466+
- A `matplotlib.figure.Figure` instance
467+
- A `PIL.Image` instance
468+
- A path to an image file (`str` or `Path`). It should be in a format that is
469+
readable by `PIL.Image.open()`
470+
471+
The images will be saved in `{Live.plots_dir}/images/{name}`. When using
472+
`Live(cache_images=True)`, the images directory will also be cached as part of
473+
`Live.end()`. In that case, a `.dvc` file will be saved to track it, and the
474+
directory will be added to a `.gitignore` file to prevent Git tracking.
475+
476+
By default the images will be overwritten on each step. However, you can log
477+
images using the following pattern
478+
`live.log_image(f"folder/{live.step}.png", img)`.
479+
In `DVC Studio` and the `DVC Extension for VSCode`, folders following this
480+
pattern will be rendered using an image slider.
481+
482+
Args:
483+
name (str): name of the image file that this command will output
484+
val (np.ndarray | matplotlib.figure.Figure | PIL.Image | StrPath):
485+
image to be saved. See the list of supported values in the description.
486+
487+
Raises:
488+
`InvalidDataTypeError`: thrown if the provided `val` does not have a
489+
supported type.
490+
"""
390491
if not Image.could_log(val):
391492
raise InvalidDataTypeError(name, type(val))
392493

@@ -425,6 +526,29 @@ def log_plot(
425526
x_label: Optional[str] = None,
426527
y_label: Optional[str] = None,
427528
):
529+
"""
530+
The method will dump the provided datapoints to
531+
`{Live.dir}/plots/custom/{name}.json`and store the provided properties to be
532+
included in the plots section written by `Live.make_dvcyaml()`. The plot can be
533+
rendered with `DVC CLI`, `VSCode Extension` or `DVC Studio`.
534+
535+
Args:
536+
name (StrPath): name of the output file.
537+
datapoints (pd.DataFrame | np.ndarray | List[Dict]): Pandas DataFrame, Numpy
538+
Array or List of dictionaries containing the data for the plot.
539+
x (str): name of the key (present in the dictionaries) to use as the x axis.
540+
y (str): name of the key (present in the dictionaries) to use the y axis.
541+
template (str): name of the `DVC plots template` to use. Defaults to
542+
`"linear"`.
543+
title (str): title to be displayed. Defaults to
544+
`"{Live.dir}/plots/custom/{name}.json"`.
545+
x_label (str): label for the x axis. Defaults to the name passed as `x`.
546+
y_label (str): label for the y axis. Defaults to the name passed as `y`.
547+
548+
Raises:
549+
`InvalidDataTypeError`: thrown if the provided `datapoints` does not have a
550+
supported type.
551+
"""
428552
# Convert the given datapoints to List[Dict]
429553
datapoints = convert_datapoints_to_list_of_dicts(datapoints=datapoints)
430554

@@ -458,6 +582,30 @@ def log_sklearn_plot(
458582
name: Optional[str] = None,
459583
**kwargs,
460584
):
585+
"""
586+
Generates a scikit learn plot and saves the data in
587+
`{Live.dir}/plots/sklearn/{name}.json`. The method will compute and dump the
588+
`kind` plot to `{Live.dir}/plots/sklearn/{name}` in a format compatible with
589+
dvc plots. It will also store the provided properties to be included in the
590+
plots section written by `Live.make_dvcyaml()`.
591+
592+
Args:
593+
kind ("calibration" | "confusion_matrix" | "det" | "precision_recall" |
594+
"roc"): a supported plot type.
595+
labels (List | np.ndarray): array of ground truth labels.
596+
predictions (List | np.ndarray): array of predicted labels (for
597+
`"confusion_matrix"`) or predicted probabilities (for other plots).
598+
name (str): optional name of the output file. If not provided, `kind` will
599+
be used as name.
600+
kwargs: additional arguments to tune the result. Arguments are passed to the
601+
scikit-learn function (e.g. `drop_intermediate=True` for the `"roc"`
602+
type). Plus extra arguments supported by the type of a plot are:
603+
- `normalized`: default to `False`. `confusion_matrix` with values
604+
normalized to `<0, 1>` range.
605+
Raises:
606+
InvalidPlotTypeError: thrown if the provided `kind` does not correspond to
607+
any of the supported plots.
608+
"""
461609
val = (labels, predictions)
462610

463611
plot_config = {
@@ -493,13 +641,39 @@ def _dump_params(self):
493641
raise InvalidParameterTypeError(exc.args[0]) from exc
494642

495643
def log_params(self, params: Dict[str, ParamLike]):
496-
"""Saves the given set of parameters (dict) to yaml"""
644+
"""
645+
On each `Live.log_params(params)` call, DVCLive will write keys/values pairs in
646+
the params dict to `{Live.dir}/params.yaml`:
647+
648+
Also see `Live.log_param()`.
649+
650+
Args:
651+
params (Dict[str, ParamLike]): dictionary with name/value pairs of
652+
parameters to be logged.
653+
654+
Raises:
655+
`InvalidParameterTypeError`: thrown if the parameter value is not among
656+
supported types.
657+
"""
497658
self._params.update(params)
498659
self._dump_params()
499660
logger.debug(f"Logged {params} parameters to {self.params_file}")
500661

501662
def log_param(self, name: str, val: ParamLike):
502-
"""Saves the given parameter value to yaml"""
663+
"""
664+
On each `Live.log_param(name, val)` call, DVCLive will write the name parameter
665+
to `{Live.dir}/params.yaml` with the corresponding `val`.
666+
667+
Also see `Live.log_params()`.
668+
669+
Args:
670+
name (str): name of the parameter being logged.
671+
val (ParamLike): the value to be logged.
672+
673+
Raises:
674+
`InvalidParameterTypeError`: thrown if the parameter value is not among
675+
supported types.
676+
"""
503677
self.log_params({name: val})
504678

505679
def log_artifact(
@@ -513,7 +687,46 @@ def log_artifact(
513687
copy: bool = False,
514688
cache: bool = True,
515689
):
516-
"""Tracks a local file or directory with DVC"""
690+
"""
691+
Tracks an existing directory or file with DVC.
692+
693+
Log path, saving its contents to DVC storage. Also annotate with any included
694+
metadata fields (for example, to be consumed in the model registry or automation
695+
scenarios).
696+
If `cache=True` (which is the default), uses `dvc add` to track path with DVC,
697+
saving it to the DVC cache and generating a `{path}.dvc` file that acts as a
698+
pointer to the cached data.
699+
If you include any of the optional metadata fields (type, name, desc, labels,
700+
meta), it will add an artifact and all the metadata passed as arguments to the
701+
corresponding `dvc.yaml` (unless `dvcyaml=None`). Passing `type="model"` will
702+
include it in the model registry.
703+
704+
Args:
705+
path (StrPath): an existing directory or file.
706+
type (Optional[str]): an optional type of the artifact. Common types are
707+
`"model"` or `"dataset"`.
708+
name (Optional[str]): an optional custom name of an artifact.
709+
If not provided the `path` stem (last part of the path without the
710+
file extension) will be used as the artifact name.
711+
desc (Optional[str]): an optional description of an artifact.
712+
labels (Optional[List[str]]): optional labels describing the artifact.
713+
meta (Optional[Dict[str, Any]]): optional metainformation in `key: value`
714+
format.
715+
copy (bool): copy a directory or file at path into the `dvclive/artifacts`
716+
location (default) before tracking it. The new path is used instead of
717+
the original one to track the artifact. Useful if you don't want to
718+
track the original path in your repo (for example, it is outside the
719+
repo or in a Git-ignored directory).
720+
cache (bool): cache the files with DVC to track them outside of Git.
721+
Defaults to `True`, but set to `False` if you want to annotate metadata
722+
about the artifact without storing a copy in the DVC cache.
723+
If running a DVC pipeline, `cache` will be ignored, and you should
724+
instead cache artifacts as pipeline outputs.
725+
726+
Raises:
727+
`InvalidDataTypeError`: thrown if the provided `path` does not have a
728+
supported type.
729+
"""
517730
if not isinstance(path, (str, PurePath)):
518731
raise InvalidDataTypeError(path, builtins.type(path))
519732

@@ -582,25 +795,66 @@ def cache(self, path):
582795
self._include_untracked.append(str(Path(dvc_file).parent / ".gitignore"))
583796

584797
def make_summary(self):
798+
"""
799+
Serializes a summary of the logged metrics (`Live.summary`) to
800+
`Live.metrics_file`.
801+
802+
The `Live.summary` object will contain the latest value of each metric logged
803+
with `Live.log_metric()`. It can be also modified manually.
804+
805+
`Live.next_step()` and `Live.end()` will call `Live.make_summary()` internally,
806+
so you don't need to call both.
807+
808+
The summary is usable by `dvc metrics`.
809+
"""
585810
if self._step is not None:
586811
self.summary["step"] = self.step
587812
dump_json(self.summary, self.metrics_file, cls=NumpyEncoder)
588813

589814
def make_report(self):
815+
"""
816+
Generates a report from the logged data.
817+
818+
`Live.next_step()` and `Live.end()` will call `Live.make_report()` internally,
819+
so you don't need to call both.
820+
821+
On each call, DVCLive will collect all the data logged in `{Live.dir}`, generate
822+
a report and save it in `{Live.dir}/report.{format}`. The format can be HTML
823+
or Markdown depending on the value of the `report` argument passed to `Live()`.
824+
"""
590825
if self._report_mode is not None:
591826
make_report(self)
592827
if self._report_mode == "html" and env2bool(env.DVCLIVE_OPEN):
593828
open_file_in_browser(self.report_file)
594829

595830
@catch_and_warn(DvcException, logger)
596831
def make_dvcyaml(self):
832+
"""
833+
Writes DVC configuration for metrics, plots, and parameters to `Live.dvc_file`.
834+
835+
Creates `dvc.yaml`, which describes and configures metrics, plots, and
836+
parameters. DVC tools use this file to show reports and experiments tables.
837+
`Live.next_step()` and `Live.end()` will call `Live.make_dvcyaml()` internally,
838+
so you don't need to call both (unless `dvcyaml=None`).
839+
"""
597840
make_dvcyaml(self)
598841

599842
@catch_and_warn(DvcException, logger)
600843
def post_to_studio(self, event: str):
601844
post_to_studio(self, event)
602845

603846
def end(self):
847+
"""
848+
Signals that the current experiment has ended.
849+
`Live.end()` gets automatically called when exiting the context manager. It is
850+
also called when the training ends for each of the supported ML Frameworks
851+
852+
By default, `Live.end()` will call `Live.make_summary()`, `Live.make_dvcyaml()`,
853+
and `Live.make_report()`.
854+
855+
If `save_dvc_exp=True`, it will save a new DVC experiment and write a `dvc.yaml`
856+
file configuring what DVC will show for logged plots, metrics, and parameters.
857+
"""
604858
if self._inside_with:
605859
# Prevent `live.end` calls inside context manager
606860
return

0 commit comments

Comments
 (0)