Skip to content

Commit 35ce2f3

Browse files
Dave Berenbaumdaavoo
Dave Berenbaum
andauthored
Allow for logging to Studio when not inside a repo (#646)
* post to studio even without git/dvc repo * tests for no-git scenario * studio: make no-repo paths relative to cwd * make ruff happy * don't require exp name * don't require baseline rev * refactor studio path formatting * live: Set new defaults `report=None` and `save_dvc_exp=True`. * frameworks: Drop model_file. * update examples * Write to root dvc.yaml (#687) * add dvcyaml to root * clean up dvcyaml implementation * fix existing tests * add new tests * add unit tests for updating dvcyaml * use posix paths * don't resolve symlinks * drop entire dvclive dir on cleanup * fix studio tests * revert cleanup changes * unify rel_path util func * cleanup test * refactor tests * add test for multiple dvclive instances * put dvc_file logic into _init_dvc_file --------- Co-authored-by: daavoo <[email protected]> * report: Drop "auto" logic. Fallback to `None` when conditions are not met for other types. * studio: Extract `post_to_studio` and decoulple from `make_report` (#705) * refactor(tests): Split `test_main` into separate files. Rename test_frameworks to frameworks. * fix matplotlib warning * fix studio tests * fix windows studio paths * fix windows studio paths for plots * skip fabric tests if not installed * drop dvc repo * drop dvcignore * drop unrelated test_fabric.py file * fix windows paths * fix windows paths * adapt plot paths even if no dvc repo * default baseline rev to all zeros * consolidate repro tests * set null sha as variable * add type hints to studio * limit windows path handling to studio * fix typing errors in studio module * fix mypy in live module * drop checking for dvc_file --------- Co-authored-by: daavoo <[email protected]>
1 parent b8526e9 commit 35ce2f3

File tree

5 files changed

+131
-85
lines changed

5 files changed

+131
-85
lines changed

src/dvclive/dvc.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import Path
66
from typing import TYPE_CHECKING, Any, List, Optional
77

8+
from dvclive import env
89
from dvclive.plots import Image, Metric
910
from dvclive.serialize import dump_yaml
1011
from dvclive.utils import StrPath, rel_path
@@ -131,9 +132,14 @@ def _update_entries(old, new, key):
131132
def get_exp_name(name, scm, baseline_rev) -> str:
132133
from dvc.exceptions import InvalidArgumentError
133134
from dvc.repo.experiments.refs import ExpRefInfo
134-
from dvc.repo.experiments.utils import check_ref_format, get_random_exp_name
135+
from dvc.repo.experiments.utils import (
136+
check_ref_format,
137+
gen_random_name,
138+
get_random_exp_name,
139+
)
135140

136-
if name:
141+
name = name or os.getenv(env.DVC_EXP_NAME)
142+
if name and scm and baseline_rev:
137143
ref = ExpRefInfo(baseline_sha=baseline_rev, name=name)
138144
if scm.get_ref(str(ref)):
139145
logger.warning(f"Experiment conflicts with existing experiment '{name}'.")
@@ -144,7 +150,11 @@ def get_exp_name(name, scm, baseline_rev) -> str:
144150
logger.warning(e)
145151
else:
146152
return name
147-
return get_random_exp_name(scm, baseline_rev)
153+
if scm and baseline_rev:
154+
return get_random_exp_name(scm, baseline_rev)
155+
if name:
156+
return name
157+
return gen_random_name()
148158

149159

150160
def find_overlapping_stage(dvc_repo: "Repo", path: StrPath) -> Optional["Stage"]:

src/dvclive/live.py

+17-30
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@
6666

6767
ParamLike = Union[int, float, str, bool, List["ParamLike"], Dict[str, "ParamLike"]]
6868

69+
NULL_SHA: str = "0" * 40
70+
6971

7072
class Live:
7173
def __init__(
@@ -136,8 +138,8 @@ def __init__(
136138
self._report_notebook = None
137139
self._init_report()
138140

139-
self._baseline_rev: Optional[str] = None
140-
self._exp_name: Optional[str] = exp_name
141+
self._baseline_rev: str = os.getenv(env.DVC_EXP_BASELINE_REV, NULL_SHA)
142+
self._exp_name: Optional[str] = exp_name or os.getenv(env.DVC_EXP_NAME)
141143
self._exp_message: Optional[str] = exp_message
142144
self._experiment_rev: Optional[str] = None
143145
self._inside_dvc_exp: bool = False
@@ -156,7 +158,7 @@ def __init__(
156158
else:
157159
self._init_cleanup()
158160

159-
self._latest_studio_step = self.step if resume else -1
161+
self._latest_studio_step: int = self.step if resume else -1
160162
self._studio_events_to_skip: Set[str] = set()
161163
self._dvc_studio_config: Dict[str, Any] = {}
162164
self._init_studio()
@@ -189,28 +191,36 @@ def _init_cleanup(self):
189191
os.remove(dvc_file)
190192

191193
@catch_and_warn(DvcException, logger)
192-
def _init_dvc(self):
194+
def _init_dvc(self): # noqa: C901
193195
from dvc.scm import NoSCM
194196

195197
if os.getenv(env.DVC_ROOT, None):
196198
self._inside_dvc_pipeline = True
197199
self._init_dvc_pipeline()
198200
self._dvc_repo = get_dvc_repo()
199201

202+
scm = self._dvc_repo.scm if self._dvc_repo else None
203+
if isinstance(scm, NoSCM):
204+
scm = None
205+
if scm:
206+
self._baseline_rev = scm.get_rev()
207+
self._exp_name = get_exp_name(self._exp_name, scm, self._baseline_rev)
208+
logger.info(f"Logging to experiment '{self._exp_name}'")
209+
200210
dvc_logger = logging.getLogger("dvc")
201211
dvc_logger.setLevel(os.getenv(env.DVCLIVE_LOGLEVEL, "WARNING").upper())
202212

203213
self._dvc_file = self._init_dvc_file()
204214

205-
if (self._dvc_repo is None) or isinstance(self._dvc_repo.scm, NoSCM):
215+
if not scm:
206216
if self._save_dvc_exp:
207217
logger.warning(
208218
"Can't save experiment without a Git Repo."
209219
"\nCreate a Git repo (`git init`) and commit (`git commit`)."
210220
)
211221
self._save_dvc_exp = False
212222
return
213-
if self._dvc_repo.scm.no_commits:
223+
if scm.no_commits:
214224
if self._save_dvc_exp:
215225
logger.warning(
216226
"Can't save experiment to an empty Git Repo."
@@ -230,12 +240,7 @@ def _init_dvc(self):
230240
if self._inside_dvc_pipeline:
231241
return
232242

233-
self._baseline_rev = self._dvc_repo.scm.get_rev()
234243
if self._save_dvc_exp:
235-
self._exp_name = get_exp_name(
236-
self._exp_name, self._dvc_repo.scm, self._baseline_rev
237-
)
238-
logger.info(f"Logging to experiment '{self._exp_name}'")
239244
mark_dvclive_only_started(self._exp_name)
240245
self._include_untracked.append(self.dir)
241246

@@ -249,8 +254,6 @@ def _init_dvc_file(self) -> str:
249254
def _init_dvc_pipeline(self):
250255
if os.getenv(env.DVC_EXP_BASELINE_REV, None):
251256
# `dvc exp` execution
252-
self._baseline_rev = os.getenv(env.DVC_EXP_BASELINE_REV, "")
253-
self._exp_name = os.getenv(env.DVC_EXP_NAME, "")
254257
self._inside_dvc_exp = True
255258
if self._save_dvc_exp:
256259
logger.info("Ignoring `save_dvc_exp` because `dvc exp run` is running")
@@ -275,22 +278,6 @@ def _init_studio(self):
275278
logger.debug("Skipping `studio` report `start` and `done` events.")
276279
self._studio_events_to_skip.add("start")
277280
self._studio_events_to_skip.add("done")
278-
elif self._dvc_repo is None:
279-
logger.warning(
280-
"Can't connect to Studio without a DVC Repo."
281-
"\nYou can create a DVC Repo by calling `dvc init`."
282-
)
283-
self._studio_events_to_skip.add("start")
284-
self._studio_events_to_skip.add("data")
285-
self._studio_events_to_skip.add("done")
286-
elif not self._save_dvc_exp:
287-
logger.warning(
288-
"Can't connect to Studio without creating a DVC experiment."
289-
"\nIf you have a DVC Pipeline, run it with `dvc exp run`."
290-
)
291-
self._studio_events_to_skip.add("start")
292-
self._studio_events_to_skip.add("data")
293-
self._studio_events_to_skip.add("done")
294281
else:
295282
self.post_to_studio("start")
296283

@@ -840,7 +827,7 @@ def make_dvcyaml(self):
840827
make_dvcyaml(self)
841828

842829
@catch_and_warn(DvcException, logger)
843-
def post_to_studio(self, event: str):
830+
def post_to_studio(self, event: Literal["start", "data", "done"]):
844831
post_to_studio(self, event)
845832

846833
def end(self):

src/dvclive/studio.py

+20-13
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,28 @@
11
# ruff: noqa: SLF001
2+
from __future__ import annotations
23
import base64
34
import logging
45
import math
56
import os
7+
from pathlib import PureWindowsPath
8+
from typing import TYPE_CHECKING, Literal, Mapping
69

710
from dvc_studio_client.config import get_studio_config
811
from dvc_studio_client.post_live_metrics import post_live_metrics
912

13+
if TYPE_CHECKING:
14+
from dvclive.live import Live
1015
from dvclive.serialize import load_yaml
11-
from dvclive.utils import parse_metrics, rel_path
16+
from dvclive.utils import parse_metrics, rel_path, StrPath
1217

1318
logger = logging.getLogger("dvclive")
1419

1520

16-
def _get_unsent_datapoints(plot, latest_step):
21+
def _get_unsent_datapoints(plot: Mapping, latest_step: int):
1722
return [x for x in plot if int(x["step"]) > latest_step]
1823

1924

20-
def _cast_to_numbers(datapoints):
25+
def _cast_to_numbers(datapoints: Mapping):
2126
for datapoint in datapoints:
2227
for k, v in datapoint.items():
2328
if k == "step":
@@ -33,31 +38,33 @@ def _cast_to_numbers(datapoints):
3338
return datapoints
3439

3540

36-
def _adapt_path(live, name):
41+
def _adapt_path(live: Live, name: StrPath):
3742
if live._dvc_repo is not None:
3843
name = rel_path(name, live._dvc_repo.root_dir)
44+
if os.name == "nt":
45+
name = str(PureWindowsPath(name).as_posix())
3946
return name
4047

4148

42-
def _adapt_plot_datapoints(live, plot):
49+
def _adapt_plot_datapoints(live: Live, plot: Mapping):
4350
datapoints = _get_unsent_datapoints(plot, live._latest_studio_step)
4451
return _cast_to_numbers(datapoints)
4552

4653

47-
def _adapt_image(image_path):
54+
def _adapt_image(image_path: StrPath):
4855
with open(image_path, "rb") as fobj:
4956
return base64.b64encode(fobj.read()).decode("utf-8")
5057

5158

52-
def _adapt_images(live):
59+
def _adapt_images(live: Live):
5360
return {
5461
_adapt_path(live, image.output_path): {"image": _adapt_image(image.output_path)}
5562
for image in live._images.values()
5663
if image.step > live._latest_studio_step
5764
}
5865

5966

60-
def get_studio_updates(live):
67+
def get_studio_updates(live: Live):
6168
if os.path.isfile(live.params_file):
6269
params_file = live.params_file
6370
params_file = _adapt_path(live, params_file)
@@ -82,14 +89,14 @@ def get_studio_updates(live):
8289
return metrics, params, plots
8390

8491

85-
def get_dvc_studio_config(live):
92+
def get_dvc_studio_config(live: Live):
8693
config = {}
8794
if live._dvc_repo:
8895
config = live._dvc_repo.config.get("studio")
8996
return get_studio_config(dvc_studio_config=config)
9097

9198

92-
def post_to_studio(live, event):
99+
def post_to_studio(live: Live, event: Literal["start", "data", "done"]):
93100
if event in live._studio_events_to_skip:
94101
return
95102

@@ -98,7 +105,7 @@ def post_to_studio(live, event):
98105
kwargs["message"] = live._exp_message
99106
elif event == "data":
100107
metrics, params, plots = get_studio_updates(live)
101-
kwargs["step"] = live.step
108+
kwargs["step"] = live.step # type: ignore
102109
kwargs["metrics"] = metrics
103110
kwargs["params"] = params
104111
kwargs["plots"] = plots
@@ -108,10 +115,10 @@ def post_to_studio(live, event):
108115
response = post_live_metrics(
109116
event,
110117
live._baseline_rev,
111-
live._exp_name,
118+
live._exp_name, # type: ignore
112119
"dvclive",
113120
dvc_studio_config=live._dvc_studio_config,
114-
**kwargs,
121+
**kwargs, # type: ignore
115122
)
116123
if not response:
117124
logger.warning(f"`post_to_studio` `{event}` failed.")

tests/test_dvc.py

+8-34
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,24 @@ def test_get_dvc_repo_subdir(tmp_dir):
2929
def test_exp_save_on_end(tmp_dir, save, mocked_dvc_repo):
3030
live = Live(save_dvc_exp=save)
3131
live.end()
32+
assert live._baseline_rev is not None
33+
assert live._exp_name is not None
3234
if save:
33-
assert live._baseline_rev is not None
34-
assert live._exp_name is not None
3535
mocked_dvc_repo.experiments.save.assert_called_with(
3636
name=live._exp_name,
3737
include_untracked=[live.dir, "dvc.yaml"],
3838
force=True,
3939
message=None,
4040
)
4141
else:
42-
assert live._baseline_rev is not None
43-
assert live._exp_name is None
4442
mocked_dvc_repo.experiments.save.assert_not_called()
4543

4644

47-
def test_exp_save_skip_on_env_vars(tmp_dir, monkeypatch, mocker):
45+
def test_exp_save_skip_on_env_vars(tmp_dir, monkeypatch):
4846
monkeypatch.setenv(DVC_EXP_BASELINE_REV, "foo")
4947
monkeypatch.setenv(DVC_EXP_NAME, "bar")
5048
monkeypatch.setenv(DVC_ROOT, tmp_dir)
5149

52-
mocker.patch("dvclive.live.get_dvc_repo", return_value=None)
5350
live = Live()
5451
live.end()
5552

@@ -60,31 +57,6 @@ def test_exp_save_skip_on_env_vars(tmp_dir, monkeypatch, mocker):
6057
assert live._inside_dvc_pipeline
6158

6259

63-
def test_exp_save_run_on_dvc_repro(tmp_dir, mocker):
64-
dvc_repo = mocker.MagicMock()
65-
dvc_stage = mocker.MagicMock()
66-
dvc_file = mocker.MagicMock()
67-
dvc_repo.index.stages = [dvc_stage, dvc_file]
68-
dvc_repo.scm.get_rev.return_value = "current_rev"
69-
dvc_repo.scm.get_ref.return_value = None
70-
dvc_repo.scm.no_commits = False
71-
dvc_repo.config = {}
72-
dvc_repo.root_dir = tmp_dir
73-
mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo)
74-
live = Live()
75-
assert live._save_dvc_exp
76-
assert live._baseline_rev is not None
77-
assert live._exp_name is not None
78-
live.end()
79-
80-
dvc_repo.experiments.save.assert_called_with(
81-
name=live._exp_name,
82-
include_untracked=[live.dir, "dvc.yaml"],
83-
force=True,
84-
message=None,
85-
)
86-
87-
8860
def test_exp_save_with_dvc_files(tmp_dir, mocker):
8961
dvc_repo = mocker.MagicMock()
9062
dvc_file = mocker.MagicMock()
@@ -166,7 +138,7 @@ def test_errors_on_git_add_are_catched(tmp_dir, mocked_dvc_repo, monkeypatch):
166138
mocked_dvc_repo.scm.untracked_files.return_value = ["dvclive/metrics.json"]
167139
mocked_dvc_repo.scm.add.side_effect = DvcException("foo")
168140

169-
with Live(dvcyaml=False) as live:
141+
with Live() as live:
170142
live.summary["foo"] = 1
171143

172144

@@ -204,10 +176,12 @@ def test_no_scm_repo(tmp_dir, mocker):
204176
assert live._save_dvc_exp is False
205177

206178

207-
def test_dvc_repro(tmp_dir, monkeypatch, mocker):
179+
def test_dvc_repro(tmp_dir, monkeypatch, mocked_dvc_repo, mocked_studio_post):
208180
monkeypatch.setenv(DVC_ROOT, "root")
209-
mocker.patch("dvclive.live.get_dvc_repo", return_value=None)
210181
live = Live(save_dvc_exp=True)
182+
assert live._baseline_rev is not None
183+
assert live._exp_name is not None
184+
assert not live._studio_events_to_skip
211185
assert not live._save_dvc_exp
212186

213187

0 commit comments

Comments
 (0)