5
5
import math
6
6
import os
7
7
from pathlib import PureWindowsPath
8
- from typing import TYPE_CHECKING , Literal , Mapping
8
+ from typing import TYPE_CHECKING , Any , Literal , Mapping , Optional
9
9
10
10
from dvc .exceptions import DvcException
11
11
from dvc_studio_client .config import get_studio_config
14
14
from .utils import catch_and_warn
15
15
16
16
if TYPE_CHECKING :
17
+ from dvclive .plots .image import Image
17
18
from dvclive .live import Live
18
- from dvclive .serialize import load_yaml
19
- from dvclive .utils import parse_metrics , rel_path , StrPath
19
+ from dvclive .utils import rel_path , StrPath
20
20
21
21
logger = logging .getLogger ("dvclive" )
22
22
@@ -50,23 +50,24 @@ def _adapt_image(image_path: StrPath):
50
50
return base64 .b64encode (fobj .read ()).decode ("utf-8" )
51
51
52
52
53
- def _adapt_images (live : Live ):
53
+ def _adapt_images (live : Live , images : list [ Image ] ):
54
54
return {
55
55
_adapt_path (live , image .output_path ): {"image" : _adapt_image (image .output_path )}
56
- for image in live . _images . values ()
56
+ for image in images
57
57
if image .step > live ._latest_studio_step
58
58
}
59
59
60
60
61
- def get_studio_updates (live : Live ):
62
- if os .path .isfile (live .params_file ):
63
- params_file = live .params_file
64
- params_file = _adapt_path (live , params_file )
65
- params = {params_file : load_yaml (live .params_file )}
66
- else :
67
- params = {}
61
+ def _get_studio_updates (live : Live , data : dict [str , Any ]):
62
+ params = data ["params" ]
63
+ plots = data ["plots" ]
64
+ plots_start_idx = data ["plots_start_idx" ]
65
+ metrics = data ["metrics" ]
66
+ images = data ["images" ]
68
67
69
- plots , metrics = parse_metrics (live )
68
+ params_file = live .params_file
69
+ params_file = _adapt_path (live , params_file )
70
+ params = {params_file : params }
70
71
71
72
metrics_file = live .metrics_file
72
73
metrics_file = _adapt_path (live , metrics_file )
@@ -75,11 +76,12 @@ def get_studio_updates(live: Live):
75
76
plots_to_send = {}
76
77
for name , plot in plots .items ():
77
78
path = _adapt_path (live , name )
78
- num_points_sent = live ._num_points_sent_to_studio .get (path , 0 )
79
- plots_to_send [path ] = _cast_to_numbers (plot [num_points_sent :])
79
+ start_idx = plots_start_idx .get (name , 0 )
80
+ num_points_sent = live ._num_points_sent_to_studio .get (name , 0 )
81
+ plots_to_send [path ] = _cast_to_numbers (plot [num_points_sent - start_idx :])
80
82
81
83
plots_to_send = {k : {"data" : v } for k , v in plots_to_send .items ()}
82
- plots_to_send .update (_adapt_images (live ))
84
+ plots_to_send .update (_adapt_images (live , images ))
83
85
84
86
return metrics , params , plots_to_send
85
87
@@ -91,16 +93,22 @@ def get_dvc_studio_config(live: Live):
91
93
return get_studio_config (dvc_studio_config = config )
92
94
93
95
94
- def increment_num_points_sent_to_studio (live , plots ):
95
- for name , plot in plots .items ():
96
+ def increment_num_points_sent_to_studio (live , plots_sent , data ):
97
+ for name , _ in data ["plots" ].items ():
98
+ path = _adapt_path (live , name )
99
+ plot = plots_sent .get (path , {})
96
100
if "data" in plot :
97
101
num_points_sent = live ._num_points_sent_to_studio .get (name , 0 )
98
102
live ._num_points_sent_to_studio [name ] = num_points_sent + len (plot ["data" ])
99
103
return live
100
104
101
105
102
106
@catch_and_warn (DvcException , logger )
103
- def post_to_studio (live : Live , event : Literal ["start" , "data" , "done" ]): # noqa: C901
107
+ def post_to_studio ( # noqa: C901
108
+ live : Live ,
109
+ event : Literal ["start" , "data" , "done" ],
110
+ data : Optional [dict [str , Any ]] = None ,
111
+ ):
104
112
if event in live ._studio_events_to_skip :
105
113
return
106
114
@@ -111,8 +119,9 @@ def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa
111
119
if subdir := live ._subdir :
112
120
kwargs ["subdir" ] = subdir
113
121
elif event == "data" :
114
- metrics , params , plots = get_studio_updates (live )
115
- kwargs ["step" ] = live .step # type: ignore
122
+ assert data is not None # noqa: S101
123
+ metrics , params , plots = _get_studio_updates (live , data )
124
+ kwargs ["step" ] = data ["step" ] # type: ignore
116
125
kwargs ["metrics" ] = metrics
117
126
kwargs ["params" ] = params
118
127
kwargs ["plots" ] = plots
@@ -128,15 +137,17 @@ def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa
128
137
studio_repo_url = live ._repo_url ,
129
138
** kwargs , # type: ignore
130
139
)
140
+
131
141
if not response :
132
142
logger .warning (f"`post_to_studio` `{ event } ` failed." )
133
143
if event == "start" :
134
144
live ._studio_events_to_skip .add ("start" )
135
145
live ._studio_events_to_skip .add ("data" )
136
146
live ._studio_events_to_skip .add ("done" )
137
147
elif event == "data" :
138
- live = increment_num_points_sent_to_studio (live , plots )
139
- live ._latest_studio_step = live .step
148
+ assert data is not None # noqa: S101
149
+ live = increment_num_points_sent_to_studio (live , plots , data )
150
+ live ._latest_studio_step = data ["step" ]
140
151
141
152
if event == "done" :
142
153
live ._studio_events_to_skip .add ("done" )
0 commit comments