Skip to content

Commit bf073b9

Browse files
authored
fix(studio): package data to send in main thread (#860)
1 parent 40e4b4e commit bf073b9

File tree

3 files changed

+155
-45
lines changed

3 files changed

+155
-45
lines changed

src/dvclive/live.py

+27-4
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
inside_notebook,
5656
matplotlib_installed,
5757
open_file_in_browser,
58+
parse_metrics,
5859
)
5960
from .vscode import (
6061
cleanup_dvclive_step_completed,
@@ -135,7 +136,7 @@ def __init__(
135136
self._save_dvc_exp: bool = save_dvc_exp
136137
self._step: Optional[int] = None
137138
self._metrics: Dict[str, Any] = {}
138-
self._images: Dict[str, Any] = {}
139+
self._images: Dict[str, Image] = {}
139140
self._params: Dict[str, Any] = {}
140141
self._plots: Dict[str, Any] = {}
141142
self._artifacts: Dict[str, Dict] = {}
@@ -901,19 +902,41 @@ def make_dvcyaml(self):
901902
"""
902903
make_dvcyaml(self)
903904

905+
def _get_live_data(self) -> Optional[dict[str, Any]]:
906+
params = load_yaml(self.params_file) if os.path.isfile(self.params_file) else {}
907+
plots, metrics = parse_metrics(self)
908+
909+
# Plots can grow large, we don't want to keep in memory data
910+
# that we 100% sent already
911+
plots_to_send = {}
912+
plots_start_idx = {}
913+
for name, plot in plots.items():
914+
num_points_sent = self._num_points_sent_to_studio.get(name, 0)
915+
plots_to_send[name] = plot[num_points_sent:]
916+
plots_start_idx[name] = num_points_sent
917+
918+
return {
919+
"params": params,
920+
"plots": plots_to_send,
921+
"plots_start_idx": plots_start_idx,
922+
"metrics": metrics,
923+
"images": list(self._images.values()),
924+
"step": self.step,
925+
}
926+
904927
def post_data_to_studio(self):
905928
if not self._studio_queue:
906929
self._studio_queue = queue.Queue()
907930

908931
def worker():
909932
while True:
910-
item = self._studio_queue.get()
911-
post_to_studio(item, "data")
933+
item, data = self._studio_queue.get()
934+
post_to_studio(item, "data", data)
912935
self._studio_queue.task_done()
913936

914937
threading.Thread(target=worker, daemon=True).start()
915938

916-
self._studio_queue.put(self)
939+
self._studio_queue.put((self, self._get_live_data()))
917940

918941
def _wait_for_studio_updates_posted(self):
919942
if self._studio_queue:

src/dvclive/studio.py

+34-23
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import math
66
import os
77
from pathlib import PureWindowsPath
8-
from typing import TYPE_CHECKING, Literal, Mapping
8+
from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional
99

1010
from dvc.exceptions import DvcException
1111
from dvc_studio_client.config import get_studio_config
@@ -14,9 +14,9 @@
1414
from .utils import catch_and_warn
1515

1616
if TYPE_CHECKING:
17+
from dvclive.plots.image import Image
1718
from dvclive.live import Live
18-
from dvclive.serialize import load_yaml
19-
from dvclive.utils import parse_metrics, rel_path, StrPath
19+
from dvclive.utils import rel_path, StrPath
2020

2121
logger = logging.getLogger("dvclive")
2222

@@ -50,23 +50,24 @@ def _adapt_image(image_path: StrPath):
5050
return base64.b64encode(fobj.read()).decode("utf-8")
5151

5252

53-
def _adapt_images(live: Live):
53+
def _adapt_images(live: Live, images: list[Image]):
5454
return {
5555
_adapt_path(live, image.output_path): {"image": _adapt_image(image.output_path)}
56-
for image in live._images.values()
56+
for image in images
5757
if image.step > live._latest_studio_step
5858
}
5959

6060

61-
def get_studio_updates(live: Live):
62-
if os.path.isfile(live.params_file):
63-
params_file = live.params_file
64-
params_file = _adapt_path(live, params_file)
65-
params = {params_file: load_yaml(live.params_file)}
66-
else:
67-
params = {}
61+
def _get_studio_updates(live: Live, data: dict[str, Any]):
62+
params = data["params"]
63+
plots = data["plots"]
64+
plots_start_idx = data["plots_start_idx"]
65+
metrics = data["metrics"]
66+
images = data["images"]
6867

69-
plots, metrics = parse_metrics(live)
68+
params_file = live.params_file
69+
params_file = _adapt_path(live, params_file)
70+
params = {params_file: params}
7071

7172
metrics_file = live.metrics_file
7273
metrics_file = _adapt_path(live, metrics_file)
@@ -75,11 +76,12 @@ def get_studio_updates(live: Live):
7576
plots_to_send = {}
7677
for name, plot in plots.items():
7778
path = _adapt_path(live, name)
78-
num_points_sent = live._num_points_sent_to_studio.get(path, 0)
79-
plots_to_send[path] = _cast_to_numbers(plot[num_points_sent:])
79+
start_idx = plots_start_idx.get(name, 0)
80+
num_points_sent = live._num_points_sent_to_studio.get(name, 0)
81+
plots_to_send[path] = _cast_to_numbers(plot[num_points_sent - start_idx :])
8082

8183
plots_to_send = {k: {"data": v} for k, v in plots_to_send.items()}
82-
plots_to_send.update(_adapt_images(live))
84+
plots_to_send.update(_adapt_images(live, images))
8385

8486
return metrics, params, plots_to_send
8587

@@ -91,16 +93,22 @@ def get_dvc_studio_config(live: Live):
9193
return get_studio_config(dvc_studio_config=config)
9294

9395

94-
def increment_num_points_sent_to_studio(live, plots):
95-
for name, plot in plots.items():
96+
def increment_num_points_sent_to_studio(live, plots_sent, data):
97+
for name, _ in data["plots"].items():
98+
path = _adapt_path(live, name)
99+
plot = plots_sent.get(path, {})
96100
if "data" in plot:
97101
num_points_sent = live._num_points_sent_to_studio.get(name, 0)
98102
live._num_points_sent_to_studio[name] = num_points_sent + len(plot["data"])
99103
return live
100104

101105

102106
@catch_and_warn(DvcException, logger)
103-
def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa: C901
107+
def post_to_studio( # noqa: C901
108+
live: Live,
109+
event: Literal["start", "data", "done"],
110+
data: Optional[dict[str, Any]] = None,
111+
):
104112
if event in live._studio_events_to_skip:
105113
return
106114

@@ -111,8 +119,9 @@ def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa
111119
if subdir := live._subdir:
112120
kwargs["subdir"] = subdir
113121
elif event == "data":
114-
metrics, params, plots = get_studio_updates(live)
115-
kwargs["step"] = live.step # type: ignore
122+
assert data is not None # noqa: S101
123+
metrics, params, plots = _get_studio_updates(live, data)
124+
kwargs["step"] = data["step"] # type: ignore
116125
kwargs["metrics"] = metrics
117126
kwargs["params"] = params
118127
kwargs["plots"] = plots
@@ -128,15 +137,17 @@ def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa
128137
studio_repo_url=live._repo_url,
129138
**kwargs, # type: ignore
130139
)
140+
131141
if not response:
132142
logger.warning(f"`post_to_studio` `{event}` failed.")
133143
if event == "start":
134144
live._studio_events_to_skip.add("start")
135145
live._studio_events_to_skip.add("data")
136146
live._studio_events_to_skip.add("done")
137147
elif event == "data":
138-
live = increment_num_points_sent_to_studio(live, plots)
139-
live._latest_studio_step = live.step
148+
assert data is not None # noqa: S101
149+
live = increment_num_points_sent_to_studio(live, plots, data)
150+
live._latest_studio_step = data["step"]
140151

141152
if event == "done":
142153
live._studio_events_to_skip.add("done")

0 commit comments

Comments
 (0)