8
8
import shutil
9
9
import tempfile
10
10
from pathlib import Path
11
- from typing import Any , Dict , List , Optional , Set , Union , TYPE_CHECKING
11
+ from typing import Any , Dict , List , Optional , Set , Tuple , Union , TYPE_CHECKING , Literal
12
12
13
13
if TYPE_CHECKING :
14
14
import numpy as np
15
15
import pandas as pd
16
+ import matplotlib
17
+ import PIL
16
18
17
19
from dvc .exceptions import DvcException
18
20
from funcy import set_in
62
64
logger .addHandler (handler )
63
65
64
66
ParamLike = Union [int , float , str , bool , List ["ParamLike" ], Dict [str , "ParamLike" ]]
67
+ TemplatePlotKind = Literal [
68
+ "linear" ,
69
+ "simple" ,
70
+ "scatter" ,
71
+ "smooth" ,
72
+ "confusion" ,
73
+ "confusion_normalized" ,
74
+ "bar_horizontal" ,
75
+ "bar_horizontal_sorted" ,
76
+ ]
65
77
66
78
67
79
class Live :
@@ -71,7 +83,7 @@ def __init__(
71
83
resume : bool = False ,
72
84
report : Optional [str ] = None ,
73
85
save_dvc_exp : bool = True ,
74
- dvcyaml : Union [str , bool ] = "dvc.yaml" ,
86
+ dvcyaml : Optional [str ] = "dvc.yaml" ,
75
87
cache_images : bool = False ,
76
88
exp_name : Optional [str ] = None ,
77
89
exp_message : Optional [str ] = None ,
@@ -379,7 +391,11 @@ def log_metric(
379
391
self .summary = set_in (self .summary , metric .summary_keys , val )
380
392
logger .debug (f"Logged { name } : { val } " )
381
393
382
- def log_image (self , name : str , val ):
394
+ def log_image (
395
+ self ,
396
+ name : str ,
397
+ val : Union [np .ndarray , matplotlib .figure .Figure , PIL .Image , StrPath ],
398
+ ):
383
399
if not Image .could_log (val ):
384
400
raise InvalidDataTypeError (name , type (val ))
385
401
@@ -401,10 +417,10 @@ def log_image(self, name: str, val):
401
417
def log_plot (
402
418
self ,
403
419
name : str ,
404
- datapoints : pd .DataFrame | np .ndarray | List [Dict ],
420
+ datapoints : Union [ pd .DataFrame , np .ndarray , List [Dict ] ],
405
421
x : str ,
406
422
y : str ,
407
- template : Optional [ str ] = None ,
423
+ template : TemplatePlotKind = "linear" ,
408
424
title : Optional [str ] = None ,
409
425
x_label : Optional [str ] = None ,
410
426
y_label : Optional [str ] = None ,
@@ -434,7 +450,14 @@ def log_plot(
434
450
plot .dump (datapoints )
435
451
logger .debug (f"Logged { name } " )
436
452
437
- def log_sklearn_plot (self , kind , labels , predictions , name = None , ** kwargs ):
453
+ def log_sklearn_plot (
454
+ self ,
455
+ kind : Literal ["calibration" , "confusion_matrix" , "precision_recall" , "roc" ],
456
+ labels : Union [List , np .ndarray ],
457
+ predictions : Union [List , Tuple , np .ndarray ],
458
+ name : Optional [str ] = None ,
459
+ ** kwargs ,
460
+ ):
438
461
val = (labels , predictions )
439
462
440
463
plot_config = {
@@ -527,7 +550,7 @@ def log_artifact(
527
550
)
528
551
529
552
@catch_and_warn (DvcException , logger )
530
- def cache (self , path ):
553
+ def cache (self , path : StrPath ):
531
554
if self ._inside_dvc_pipeline :
532
555
existing_stage = find_overlapping_stage (self ._dvc_repo , path )
533
556
@@ -574,7 +597,7 @@ def make_dvcyaml(self):
574
597
make_dvcyaml (self )
575
598
576
599
@catch_and_warn (DvcException , logger )
577
- def post_to_studio (self , event ):
600
+ def post_to_studio (self , event : Literal [ "start" , "data" , "done" ] ):
578
601
post_to_studio (self , event )
579
602
580
603
def end (self ):
0 commit comments