Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add typehints to Live's public methods #770

Merged
merged 7 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import os
import shutil
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING
from pathlib import Path, PurePath
from typing import Any, Dict, List, Optional, Set, Tuple, Union, TYPE_CHECKING, Literal

if TYPE_CHECKING:
import numpy as np
import pandas as pd
import matplotlib
import PIL

from dvc.exceptions import DvcException
from funcy import set_in
Expand All @@ -36,7 +38,7 @@
from .plots import PLOT_TYPES, SKLEARN_PLOTS, CustomPlot, Image, Metric, NumpyEncoder
from .report import BLANK_NOTEBOOK_REPORT, make_report
from .serialize import dump_json, dump_yaml, load_yaml
from .studio import get_dvc_studio_config, post_to_studio
from .studio import StudioEventKind, get_dvc_studio_config, post_to_studio
from .utils import (
StrPath,
catch_and_warn,
Expand All @@ -62,16 +64,17 @@
logger.addHandler(handler)

ParamLike = Union[int, float, str, bool, List["ParamLike"], Dict[str, "ParamLike"]]
SkleanPlotKind = [*SKLEARN_PLOTS.keys()]


class Live:
def __init__(
self,
dir: str = "dvclive", # noqa: A002
resume: bool = False,
report: Optional[str] = None,
report: Literal["md", "notebook", "html", None] = None,
save_dvc_exp: bool = True,
dvcyaml: Union[str, bool] = "dvc.yaml",
dvcyaml: Optional[str] = "dvc.yaml",
cache_images: bool = False,
exp_name: Optional[str] = None,
exp_message: Optional[str] = None,
Expand Down Expand Up @@ -379,11 +382,15 @@ def log_metric(
self.summary = set_in(self.summary, metric.summary_keys, val)
logger.debug(f"Logged {name}: {val}")

def log_image(self, name: str, val):
def log_image(
self,
name: str,
val: Union[np.ndarray, matplotlib.figure.Figure, PIL.Image, StrPath],
):
if not Image.could_log(val):
raise InvalidDataTypeError(name, type(val))

if isinstance(val, (str, Path)):
if isinstance(val, (str, PurePath)):
from PIL import Image as ImagePIL

val = ImagePIL.open(val)
Expand All @@ -401,10 +408,10 @@ def log_image(self, name: str, val):
def log_plot(
self,
name: str,
datapoints: pd.DataFrame | np.ndarray | List[Dict],
datapoints: Union[pd.DataFrame, np.ndarray, List[Dict]],
x: str,
y: str,
template: Optional[str] = None,
template: Optional[str] = "linear",
title: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
Expand Down Expand Up @@ -434,7 +441,14 @@ def log_plot(
plot.dump(datapoints)
logger.debug(f"Logged {name}")

def log_sklearn_plot(self, kind, labels, predictions, name=None, **kwargs):
def log_sklearn_plot(
self,
kind: SkleanPlotKind,
labels: Union[List, np.ndarray],
predictions: Union[List, Tuple, np.ndarray],
name: Optional[str] = None,
**kwargs,
):
val = (labels, predictions)

plot_config = {
Expand Down Expand Up @@ -491,7 +505,7 @@ def log_artifact(
cache: bool = True,
):
"""Tracks a local file or directory with DVC"""
if not isinstance(path, (str, Path)):
if not isinstance(path, (str, PurePath)):
raise InvalidDataTypeError(path, builtins.type(path))

if self._dvc_repo is not None:
Expand Down Expand Up @@ -527,7 +541,7 @@ def log_artifact(
)

@catch_and_warn(DvcException, logger)
def cache(self, path):
def cache(self, path: str):
if self._inside_dvc_pipeline:
existing_stage = find_overlapping_stage(self._dvc_repo, path)

Expand Down Expand Up @@ -574,7 +588,7 @@ def make_dvcyaml(self):
make_dvcyaml(self)

@catch_and_warn(DvcException, logger)
def post_to_studio(self, event):
def post_to_studio(self, event: StudioEventKind):
post_to_studio(self, event)

def end(self):
Expand Down
3 changes: 3 additions & 0 deletions src/dvclive/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@

from dvc_studio_client.config import get_studio_config
from dvc_studio_client.post_live_metrics import post_live_metrics
from dvc_studio_client.schema import BASE_SCHEMA

from dvclive.serialize import load_yaml
from dvclive.utils import parse_metrics, rel_path

logger = logging.getLogger("dvclive")

StudioEventKind = [*BASE_SCHEMA.schema["type"].validators]


def _get_unsent_datapoints(plot, latest_step):
return [x for x in plot if int(x["step"]) > latest_step]
Expand Down
4 changes: 2 additions & 2 deletions src/dvclive/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import re
import shutil
from pathlib import Path
from pathlib import Path, PurePath
from platform import uname
from typing import Union, List, Dict, TYPE_CHECKING
import webbrowser
Expand All @@ -26,7 +26,7 @@
np = None


StrPath = Union[str, Path]
StrPath = Union[str, PurePath]


def run_once(f):
Expand Down
Loading