diff --git a/pyproject.toml b/pyproject.toml index c2e4f16e..9681a025 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,9 +31,9 @@ classifiers = [ ] dynamic = ["version"] dependencies = [ - "dvc>=3.33.3", + "dvc@git+https://github.com/iterative/dvc.git@refs/pull/10285/head", "dvc-render>=1.0.0,<2", - "dvc-studio-client>=0.17.1,<1", + "dvc-studio-client@git+https://github.com/iterative/dvc-studio-client.git@refs/pull/144/head", "funcy", "gto", "ruamel.yaml", diff --git a/src/dvclive/live.py b/src/dvclive/live.py index ce8d5357..6e35ba6b 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -96,6 +96,8 @@ def __init__( self._init_report() self._baseline_rev: Optional[str] = None + self._subdir: Optional[str] = None + self._exp_parent_data: Optional[Dict[str, Any]] = None self._exp_name: Optional[str] = exp_name self._exp_message: Optional[str] = exp_message self._experiment_rev: Optional[str] = None @@ -190,6 +192,10 @@ def _init_dvc(self): return self._baseline_rev = self._dvc_repo.scm.get_rev() + + self._subdir = self._dvc_repo.subrepo_relpath + self._exp_parent_data = self._dvc_repo.head_commit_info + if self._save_dvc_exp: self._exp_name = get_exp_name( self._exp_name, self._dvc_repo.scm, self._baseline_rev diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index f74db7ab..c34738e8 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -89,13 +89,18 @@ def get_dvc_studio_config(live): return get_studio_config(dvc_studio_config=config) -def post_to_studio(live, event): +def post_to_studio(live, event): # noqa: C901 if event in live._studio_events_to_skip: return kwargs = {} - if event == "start" and live._exp_message: - kwargs["message"] = live._exp_message + if event == "start": + if message := live._exp_message: + kwargs["message"] = message + if subdir := live._subdir: + kwargs["subdir"] = subdir + if dvc_experiment_parent_data := live._exp_parent_data: + kwargs["dvc_experiment_parent_data"] = dvc_experiment_parent_data elif event == "data": metrics, params, plots = get_studio_updates(live) kwargs["step"] = live.step diff --git a/tests/conftest.py b/tests/conftest.py index 1c856e45..0f689795 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,12 +19,20 @@ def mocked_dvc_repo(tmp_dir, mocker): _dvc_repo.scm.get_ref.return_value = None _dvc_repo.scm.no_commits = False _dvc_repo.experiments.save.return_value = "e" * 40 - _dvc_repo.root_dir = tmp_dir + _dvc_repo.root_dir = _dvc_repo.scm.root_dir = tmp_dir _dvc_repo.config = {} + _dvc_repo.subrepo_relpath = "" + _dvc_repo.head_commit_info = None mocker.patch("dvclive.live.get_dvc_repo", return_value=_dvc_repo) return _dvc_repo +@pytest.fixture() +def mocked_dvc_subrepo(tmp_dir, mocker, mocked_dvc_repo): + mocked_dvc_repo.subrepo_relpath = "subdir" + return mocked_dvc_repo + + @pytest.fixture() def dvc_repo(tmp_dir): from dvc.repo import Repo diff --git a/tests/test_post_to_studio.py b/tests/test_post_to_studio.py index a22dfa55..2a471cd2 100644 --- a/tests/test_post_to_studio.py +++ b/tests/test_post_to_studio.py @@ -32,7 +32,7 @@ def get_studio_call(event_type, exp_name, **kwargs): } -def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): +def test_post_to_studio(monkeypatch, tmp_dir, mocked_dvc_repo, mocked_studio_post): live = Live() live.log_param("fooparam", 1) @@ -41,7 +41,8 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): mocked_post, _ = mocked_studio_post mocked_post.assert_called_with( - "https://0.0.0.0/api/live", **get_studio_call("start", exp_name=live._exp_name) + "https://0.0.0.0/api/live", + **get_studio_call("start", exp_name=live._exp_name), ) live.log_metric("foo", 1) @@ -81,6 +82,18 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): ) +def test_post_to_studio_subrepo(tmp_dir, mocked_dvc_subrepo, mocked_studio_post): + live = Live() + live.log_param("fooparam", 1) + + mocked_post, _ = mocked_studio_post + + mocked_post.assert_called_with( + "https://0.0.0.0/api/live", + **get_studio_call("start", exp_name=live._exp_name, subdir="subdir"), + ) + + def test_post_to_studio_failed_data_request( tmp_dir, mocker, mocked_dvc_repo, mocked_studio_post ):