Skip to content

Commit 28e370b

Browse files
add typehints to Live public methods
1 parent fb4a2a5 commit 28e370b

File tree

1 file changed

+31
-8
lines changed

1 file changed

+31
-8
lines changed

src/dvclive/live.py

+31-8
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
import shutil
99
import tempfile
1010
from pathlib import Path
11-
from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING
11+
from typing import Any, Dict, List, Optional, Set, Tuple, Union, TYPE_CHECKING, Literal
1212

1313
if TYPE_CHECKING:
1414
import numpy as np
1515
import pandas as pd
16+
import matplotlib
17+
import PIL
1618

1719
from dvc.exceptions import DvcException
1820
from funcy import set_in
@@ -62,6 +64,16 @@
6264
logger.addHandler(handler)
6365

6466
ParamLike = Union[int, float, str, bool, List["ParamLike"], Dict[str, "ParamLike"]]
67+
TemplatePlotKind = Literal[
68+
"linear",
69+
"simple",
70+
"scatter",
71+
"smooth",
72+
"confusion",
73+
"confusion_normalized",
74+
"bar_horizontal",
75+
"bar_horizontal_sorted",
76+
]
6577

6678

6779
class Live:
@@ -71,7 +83,7 @@ def __init__(
7183
resume: bool = False,
7284
report: Optional[str] = None,
7385
save_dvc_exp: bool = True,
74-
dvcyaml: Union[str, bool] = "dvc.yaml",
86+
dvcyaml: Optional[str] = "dvc.yaml",
7587
cache_images: bool = False,
7688
exp_name: Optional[str] = None,
7789
exp_message: Optional[str] = None,
@@ -379,7 +391,11 @@ def log_metric(
379391
self.summary = set_in(self.summary, metric.summary_keys, val)
380392
logger.debug(f"Logged {name}: {val}")
381393

382-
def log_image(self, name: str, val):
394+
def log_image(
395+
self,
396+
name: str,
397+
val: Union[np.ndarray, matplotlib.figure.Figure, PIL.Image, StrPath],
398+
):
383399
if not Image.could_log(val):
384400
raise InvalidDataTypeError(name, type(val))
385401

@@ -401,10 +417,10 @@ def log_image(self, name: str, val):
401417
def log_plot(
402418
self,
403419
name: str,
404-
datapoints: pd.DataFrame | np.ndarray | List[Dict],
420+
datapoints: Union[pd.DataFrame, np.ndarray, List[Dict]],
405421
x: str,
406422
y: str,
407-
template: Optional[str] = None,
423+
template: TemplatePlotKind = "linear",
408424
title: Optional[str] = None,
409425
x_label: Optional[str] = None,
410426
y_label: Optional[str] = None,
@@ -434,7 +450,14 @@ def log_plot(
434450
plot.dump(datapoints)
435451
logger.debug(f"Logged {name}")
436452

437-
def log_sklearn_plot(self, kind, labels, predictions, name=None, **kwargs):
453+
def log_sklearn_plot(
454+
self,
455+
kind: Literal["calibration", "confusion_matrix", "precision_recall", "roc"],
456+
labels: Union[List, np.ndarray],
457+
predictions: Union[List, Tuple, np.ndarray],
458+
name: Optional[str] = None,
459+
**kwargs,
460+
):
438461
val = (labels, predictions)
439462

440463
plot_config = {
@@ -527,7 +550,7 @@ def log_artifact(
527550
)
528551

529552
@catch_and_warn(DvcException, logger)
530-
def cache(self, path):
553+
def cache(self, path: StrPath):
531554
if self._inside_dvc_pipeline:
532555
existing_stage = find_overlapping_stage(self._dvc_repo, path)
533556

@@ -574,7 +597,7 @@ def make_dvcyaml(self):
574597
make_dvcyaml(self)
575598

576599
@catch_and_warn(DvcException, logger)
577-
def post_to_studio(self, event):
600+
def post_to_studio(self, event: Literal["start", "data", "done"]):
578601
post_to_studio(self, event)
579602

580603
def end(self):

0 commit comments

Comments
 (0)