From 276ddc46729fe6b83edb2aa113ede7608b16549b Mon Sep 17 00:00:00 2001
From: Matt Seddon <mattseddon@hotmail.com>
Date: Wed, 7 Feb 2024 15:14:45 +1100
Subject: [PATCH] Post additional monorepo information to Studio

---
 pyproject.toml               |  4 ++--
 src/dvclive/live.py          |  6 ++++++
 src/dvclive/studio.py        | 11 ++++++++---
 tests/conftest.py            | 10 +++++++++-
 tests/test_post_to_studio.py | 17 +++++++++++++++--
 5 files changed, 40 insertions(+), 8 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml
index c2e4f16e..9681a025 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -31,9 +31,9 @@ classifiers = [
 ]
 dynamic = ["version"]
 dependencies = [
-    "dvc>=3.33.3",
+    "dvc@git+https://github.com/iterative/dvc.git@refs/pull/10285/head",
     "dvc-render>=1.0.0,<2",
-    "dvc-studio-client>=0.17.1,<1",
+    "dvc-studio-client@git+https://github.com/iterative/dvc-studio-client.git@refs/pull/144/head",
     "funcy",
     "gto",
     "ruamel.yaml",
diff --git a/src/dvclive/live.py b/src/dvclive/live.py
index ce8d5357..6e35ba6b 100644
--- a/src/dvclive/live.py
+++ b/src/dvclive/live.py
@@ -96,6 +96,8 @@ def __init__(
         self._init_report()
 
         self._baseline_rev: Optional[str] = None
+        self._subdir: Optional[str] = None
+        self._exp_parent_data: Optional[Dict[str, Any]] = None
         self._exp_name: Optional[str] = exp_name
         self._exp_message: Optional[str] = exp_message
         self._experiment_rev: Optional[str] = None
@@ -190,6 +192,10 @@ def _init_dvc(self):
             return
 
         self._baseline_rev = self._dvc_repo.scm.get_rev()
+
+        self._subdir = self._dvc_repo.subrepo_relpath
+        self._exp_parent_data = self._dvc_repo.head_commit_info
+
         if self._save_dvc_exp:
             self._exp_name = get_exp_name(
                 self._exp_name, self._dvc_repo.scm, self._baseline_rev
diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py
index f74db7ab..c34738e8 100644
--- a/src/dvclive/studio.py
+++ b/src/dvclive/studio.py
@@ -89,13 +89,18 @@ def get_dvc_studio_config(live):
     return get_studio_config(dvc_studio_config=config)
 
 
-def post_to_studio(live, event):
+def post_to_studio(live, event):  # noqa: C901
     if event in live._studio_events_to_skip:
         return
 
     kwargs = {}
-    if event == "start" and live._exp_message:
-        kwargs["message"] = live._exp_message
+    if event == "start":
+        if message := live._exp_message:
+            kwargs["message"] = message
+        if subdir := live._subdir:
+            kwargs["subdir"] = subdir
+        if dvc_experiment_parent_data := live._exp_parent_data:
+            kwargs["dvc_experiment_parent_data"] = dvc_experiment_parent_data
     elif event == "data":
         metrics, params, plots = get_studio_updates(live)
         kwargs["step"] = live.step
diff --git a/tests/conftest.py b/tests/conftest.py
index 1c856e45..0f689795 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -19,12 +19,20 @@ def mocked_dvc_repo(tmp_dir, mocker):
     _dvc_repo.scm.get_ref.return_value = None
     _dvc_repo.scm.no_commits = False
     _dvc_repo.experiments.save.return_value = "e" * 40
-    _dvc_repo.root_dir = tmp_dir
+    _dvc_repo.root_dir = _dvc_repo.scm.root_dir = tmp_dir
     _dvc_repo.config = {}
+    _dvc_repo.subrepo_relpath = ""
+    _dvc_repo.head_commit_info = None
     mocker.patch("dvclive.live.get_dvc_repo", return_value=_dvc_repo)
     return _dvc_repo
 
 
+@pytest.fixture()
+def mocked_dvc_subrepo(tmp_dir, mocker, mocked_dvc_repo):
+    mocked_dvc_repo.subrepo_relpath = "subdir"
+    return mocked_dvc_repo
+
+
 @pytest.fixture()
 def dvc_repo(tmp_dir):
     from dvc.repo import Repo
diff --git a/tests/test_post_to_studio.py b/tests/test_post_to_studio.py
index a22dfa55..2a471cd2 100644
--- a/tests/test_post_to_studio.py
+++ b/tests/test_post_to_studio.py
@@ -32,7 +32,7 @@ def get_studio_call(event_type, exp_name, **kwargs):
     }
 
 
-def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
+def test_post_to_studio(monkeypatch, tmp_dir, mocked_dvc_repo, mocked_studio_post):
     live = Live()
     live.log_param("fooparam", 1)
 
@@ -41,7 +41,8 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
     mocked_post, _ = mocked_studio_post
 
     mocked_post.assert_called_with(
-        "https://0.0.0.0/api/live", **get_studio_call("start", exp_name=live._exp_name)
+        "https://0.0.0.0/api/live",
+        **get_studio_call("start", exp_name=live._exp_name),
     )
 
     live.log_metric("foo", 1)
@@ -81,6 +82,18 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
     )
 
 
+def test_post_to_studio_subrepo(tmp_dir, mocked_dvc_subrepo, mocked_studio_post):
+    live = Live()
+    live.log_param("fooparam", 1)
+
+    mocked_post, _ = mocked_studio_post
+
+    mocked_post.assert_called_with(
+        "https://0.0.0.0/api/live",
+        **get_studio_call("start", exp_name=live._exp_name, subdir="subdir"),
+    )
+
+
 def test_post_to_studio_failed_data_request(
     tmp_dir, mocker, mocked_dvc_repo, mocked_studio_post
 ):