From a0b1c15dfe9bbe003f5fb4f423d4cc3dfc99fb6b Mon Sep 17 00:00:00 2001 From: Elias Bermudez <6505145+debermudez@users.noreply.github.com> Date: Tue, 26 May 2026 17:46:09 -0700 Subject: [PATCH] fix: model name for plot (#998) (cherry picked from commit f070f5a0b3686b655338b6b7aa8d76d224506454) --- src/aiperf/plot/core/data_loader.py | 54 ++++++++++++++++++------- tests/unit/plot/test_data_loader.py | 63 +++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 14 deletions(-) diff --git a/src/aiperf/plot/core/data_loader.py b/src/aiperf/plot/core/data_loader.py index c81e82f49..ff2c32eab 100644 --- a/src/aiperf/plot/core/data_loader.py +++ b/src/aiperf/plot/core/data_loader.py @@ -1288,16 +1288,18 @@ def _load_server_metrics_parquet( # Build series key for grouping df_filtered["labels_json"] = df_filtered.apply( - lambda row: orjson.dumps( - { - k: row[k] - for k in label_columns - if pd.notna(row[k]) and row[k] != "" - }, - option=orjson.OPT_SORT_KEYS, - ).decode() - if any(pd.notna(row[k]) and row[k] != "" for k in label_columns) - else "{}", + lambda row: ( + orjson.dumps( + { + k: row[k] + for k in label_columns + if pd.notna(row[k]) and row[k] != "" + }, + option=orjson.OPT_SORT_KEYS, + ).decode() + if any(pd.notna(row[k]) and row[k] != "" for k in label_columns) + else "{}" + ), axis=1, ) @@ -1607,6 +1609,33 @@ def _extract_experiment_group(self, run_path: Path, run_name: str) -> str: return result + @staticmethod + def _extract_model_name(config: dict[str, Any]) -> str | None: + """Resolve the model name from an ``input_config`` block. + + YAML v2 stores it at ``models.items[].name``; legacy artifacts store + it at ``endpoint.model_names``. YAML v2 wins when both are present; + an empty/malformed ``models.items`` falls through to the legacy path. + """ + models_block = config.get("models") + if isinstance(models_block, dict): + items = models_block.get("items") + if ( + isinstance(items, list) + and items + and isinstance(items[0], dict) + and items[0].get("name") + ): + return items[0]["name"] + + # Legacy: pre-YAML-v2 artifacts stored the model list on the endpoint block. + endpoint = config.get("endpoint") + if isinstance(endpoint, dict): + names = endpoint.get("model_names") + if names: + return names[0] + return None + def _extract_metadata( self, run_path: Path, @@ -1636,10 +1665,7 @@ def _extract_metadata( if aggregated and "input_config" in aggregated: config = aggregated["input_config"] - if "endpoint" in config and "model_names" in config["endpoint"]: - models = config["endpoint"]["model_names"] - if models: - model = models[0] + model = self._extract_model_name(config) if "loadgen" in config and "concurrency" in config["loadgen"]: concurrency = config["loadgen"]["concurrency"] diff --git a/tests/unit/plot/test_data_loader.py b/tests/unit/plot/test_data_loader.py index e160b920b..2f5f754bf 100644 --- a/tests/unit/plot/test_data_loader.py +++ b/tests/unit/plot/test_data_loader.py @@ -329,6 +329,69 @@ def test_extract_metadata_missing_aggregated_data(self, tmp_path: Path) -> None: assert metadata.model is None assert metadata.concurrency is None + def test_extract_metadata_model_from_yaml_v2_models_items( + self, tmp_path: Path + ) -> None: + """YAML v2 stores model name at input_config.models.items[].name.""" + loader = DataLoader() + aggregated = { + "input_config": { + "models": {"items": [{"name": "Qwen/Qwen3-0.6B"}]}, + "loadgen": {"concurrency": 5}, + }, + } + + metadata = loader._extract_metadata(tmp_path / "run", None, aggregated) + + assert metadata.model == "Qwen/Qwen3-0.6B" + + def test_extract_metadata_model_yaml_v2_takes_precedence( + self, tmp_path: Path + ) -> None: + """When both YAML v2 and legacy shapes are present, YAML v2 wins.""" + loader = DataLoader() + aggregated = { + "input_config": { + "models": {"items": [{"name": "yaml-v2-model"}]}, + "endpoint": {"model_names": ["legacy-model"]}, + }, + } + + metadata = loader._extract_metadata(tmp_path / "run", None, aggregated) + + assert metadata.model == "yaml-v2-model" + + def test_extract_metadata_model_from_legacy_endpoint_model_names( + self, tmp_path: Path + ) -> None: + """Legacy artifacts without models.items still resolve via endpoint.model_names.""" + loader = DataLoader() + aggregated = { + "input_config": { + "endpoint": {"model_names": ["legacy-model"]}, + }, + } + + metadata = loader._extract_metadata(tmp_path / "run", None, aggregated) + + assert metadata.model == "legacy-model" + + def test_extract_metadata_model_empty_yaml_v2_items_falls_back_to_legacy( + self, tmp_path: Path + ) -> None: + """Empty/malformed models.items must not shadow a valid legacy entry.""" + loader = DataLoader() + aggregated = { + "input_config": { + "models": {"items": []}, + "endpoint": {"model_names": ["legacy-model"]}, + }, + } + + metadata = loader._extract_metadata(tmp_path / "run", None, aggregated) + + assert metadata.model == "legacy-model" + class TestDataLoaderReloadWithDetails: """Tests for DataLoader.reload_with_details method."""