Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions src/aiperf/plot/core/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,16 +1288,18 @@ def _load_server_metrics_parquet(

# Build series key for grouping
df_filtered["labels_json"] = df_filtered.apply(
lambda row: orjson.dumps(
{
k: row[k]
for k in label_columns
if pd.notna(row[k]) and row[k] != ""
},
option=orjson.OPT_SORT_KEYS,
).decode()
if any(pd.notna(row[k]) and row[k] != "" for k in label_columns)
else "{}",
lambda row: (
orjson.dumps(
{
k: row[k]
for k in label_columns
if pd.notna(row[k]) and row[k] != ""
},
option=orjson.OPT_SORT_KEYS,
).decode()
if any(pd.notna(row[k]) and row[k] != "" for k in label_columns)
else "{}"
),
axis=1,
)

Expand Down Expand Up @@ -1607,6 +1609,33 @@ def _extract_experiment_group(self, run_path: Path, run_name: str) -> str:

return result

@staticmethod
def _extract_model_name(config: dict[str, Any]) -> str | None:
"""Resolve the model name from an ``input_config`` block.

YAML v2 stores it at ``models.items[].name``; legacy artifacts store
it at ``endpoint.model_names``. YAML v2 wins when both are present;
an empty/malformed ``models.items`` falls through to the legacy path.
"""
models_block = config.get("models")
if isinstance(models_block, dict):
items = models_block.get("items")
if (
isinstance(items, list)
and items
and isinstance(items[0], dict)
and items[0].get("name")
):
return items[0]["name"]

# Legacy: pre-YAML-v2 artifacts stored the model list on the endpoint block.
endpoint = config.get("endpoint")
if isinstance(endpoint, dict):
names = endpoint.get("model_names")
if names:
return names[0]
return None

def _extract_metadata(
self,
run_path: Path,
Expand Down Expand Up @@ -1636,10 +1665,7 @@ def _extract_metadata(
if aggregated and "input_config" in aggregated:
config = aggregated["input_config"]

if "endpoint" in config and "model_names" in config["endpoint"]:
models = config["endpoint"]["model_names"]
if models:
model = models[0]
model = self._extract_model_name(config)

if "loadgen" in config and "concurrency" in config["loadgen"]:
concurrency = config["loadgen"]["concurrency"]
Expand Down
63 changes: 63 additions & 0 deletions tests/unit/plot/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,69 @@ def test_extract_metadata_missing_aggregated_data(self, tmp_path: Path) -> None:
assert metadata.model is None
assert metadata.concurrency is None

def test_extract_metadata_model_from_yaml_v2_models_items(
self, tmp_path: Path
) -> None:
"""YAML v2 stores model name at input_config.models.items[].name."""
loader = DataLoader()
aggregated = {
"input_config": {
"models": {"items": [{"name": "Qwen/Qwen3-0.6B"}]},
"loadgen": {"concurrency": 5},
},
}

metadata = loader._extract_metadata(tmp_path / "run", None, aggregated)

assert metadata.model == "Qwen/Qwen3-0.6B"

def test_extract_metadata_model_yaml_v2_takes_precedence(
self, tmp_path: Path
) -> None:
"""When both YAML v2 and legacy shapes are present, YAML v2 wins."""
loader = DataLoader()
aggregated = {
"input_config": {
"models": {"items": [{"name": "yaml-v2-model"}]},
"endpoint": {"model_names": ["legacy-model"]},
},
}

metadata = loader._extract_metadata(tmp_path / "run", None, aggregated)

assert metadata.model == "yaml-v2-model"

def test_extract_metadata_model_from_legacy_endpoint_model_names(
self, tmp_path: Path
) -> None:
"""Legacy artifacts without models.items still resolve via endpoint.model_names."""
loader = DataLoader()
aggregated = {
"input_config": {
"endpoint": {"model_names": ["legacy-model"]},
},
}

metadata = loader._extract_metadata(tmp_path / "run", None, aggregated)

assert metadata.model == "legacy-model"

def test_extract_metadata_model_empty_yaml_v2_items_falls_back_to_legacy(
self, tmp_path: Path
) -> None:
"""Empty/malformed models.items must not shadow a valid legacy entry."""
loader = DataLoader()
aggregated = {
"input_config": {
"models": {"items": []},
"endpoint": {"model_names": ["legacy-model"]},
},
}

metadata = loader._extract_metadata(tmp_path / "run", None, aggregated)

assert metadata.model == "legacy-model"


class TestDataLoaderReloadWithDetails:
"""Tests for DataLoader.reload_with_details method."""
Expand Down
Loading