diff --git a/docs/architecture.md b/docs/architecture.md index 7ff049680..45b12c7d9 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -165,6 +165,18 @@ The Timing Manager uses a **credit-based flow control system** to control when r - Scales to large numbers of workers without bottlenecks - Efficient message routing minimizes overhead +### Phase baseline handshake + +Before each phase issues its first credit, TimingManager invokes a synchronous gate +on SystemController. SystemController's `BaselineCoordinator` broadcasts a +`PhaseBaselineRequestMessage` to all services that advertised +`ServiceCapability.BASELINE_COLLECTOR` and waits for acks (with a configurable +timeout). The gate releases when all collectors ack or the timeout fires. A +symmetric END gate fires after credits drain. This guarantees telemetry and +server-metrics scrapes capture clean pre/post-phase reference points without +requiring TimingManager to know about any specific collector. Failed acks count +as acks (logged), so a dead collector cannot block the phase. + ### Data Flow & Messaging This section describes the end-to-end message flow during a benchmark run, showing how data moves between components through the ZMQ message bus. diff --git a/docs/dev/patterns.md b/docs/dev/patterns.md index fa9938fed..bb0c73a0f 100644 --- a/docs/dev/patterns.md +++ b/docs/dev/patterns.md @@ -141,6 +141,34 @@ service: **Config types:** - `CLIConfig`: unified CLI input DTO carrying both benchmark params (endpoints, loadgen) and service-runtime knobs (ZMQ ports, logging level) +## Baseline Collector Pattern + +To take a clean reading at phase boundaries, mix in `BaselineCollectorMixin`: + +```python +from aiperf.common.base_component_service import BaseComponentService +from aiperf.common.enums import BaselineKind +from aiperf.common.mixins.baseline_collector_mixin import BaselineCollectorMixin + + +class MyMonitor(BaselineCollectorMixin, BaseComponentService): + """Captures a system snapshot before profiling and after credits drain.""" + + async def collect_baseline( + self, kind: BaselineKind, phase_id: str, phase_name: str + ) -> None: + if phase_name != "profiling": + return + snapshot = await self._take_snapshot() + self._snapshots.append((kind, phase_id, snapshot)) +``` + +The mixin auto-advertises `ServiceCapability.BASELINE_COLLECTOR` via its +`extra_capabilities` ClassVar, which SystemController's `BaselineCoordinator` +keys off to fan out gate requests. Per-collector exceptions are caught and +surfaced via the ack message's `success=False` field rather than blocking the +gate. + ## Model Pattern Use `AIPerfBaseModel` for data, `BaseConfig` for configuration: diff --git a/docs/environment-variables.md b/docs/environment-variables.md index 70a28ca3e..5462a7272 100644 --- a/docs/environment-variables.md +++ b/docs/environment-variables.md @@ -41,6 +41,15 @@ API server settings. Controls the host and port of the API server. | `AIPERF_API_SERVER_CORS_ORIGINS` | `[]` | — | List of CORS origins to allow (empty = no CORS, ['*'] = all origins) | | `AIPERF_API_SERVER_SHUTDOWN_TIMEOUT` | `5.0` | ≥ 1.0, ≤ 300.0 | Timeout in seconds for graceful API server shutdown before force-cancelling | +## BASELINE + +Phase baseline handshake settings. Controls the gate that blocks TimingManager between phases until registered baseline collectors finish their point-in-time scrape. + +| Environment Variable | Default | Constraints | Description | +|----------------------|---------|-------------|-------------| +| `AIPERF_BASELINE_GATE_TIMEOUT_S` | `5.0` | > 0.0 | Per-gate timeout (seconds). If registered baseline collectors do not all ack within this window, the gate releases with a warning and the phase proceeds without waiting for stragglers. | +| `AIPERF_BASELINE_GATE_ENABLED` | `True` | — | Master switch for the phase baseline handshake. When False, PhaseGateClient short-circuits to no-op and PhaseRunner does not wait between phases. Useful for replay/debug runs. | + ## COMPRESSION Compression settings for streaming file transfers. Controls chunk size and compression levels for zstd and gzip encodings used in dataset and results file transfers. @@ -161,6 +170,7 @@ Record processing and export configuration. Controls batch sizes, processor scal | `AIPERF_RECORD_PROCESSOR_SCALE_FACTOR` | `4` | ≥ 1, ≤ 100 | Scale factor for number of record processors to spawn based on worker count. Formula: 1 record processor for every X workers | | `AIPERF_RECORD_PROGRESS_REPORT_INTERVAL` | `2.0` | ≥ 0.1, ≤ 600.0 | Interval in seconds between records progress report messages | | `AIPERF_RECORD_PROCESS_RECORDS_TIMEOUT` | `300.0` | ≥ 1.0, ≤ 100000.0 | Timeout in seconds for processing record results | +| `AIPERF_RECORD_CREDITS_COMPLETE_FALLBACK_TIMEOUT` | `10.0` | ≥ 0.0, ≤ 300.0 | Maximum seconds RecordsManager waits for CreditsComplete after all profiling records are ready before finalizing defensively | ## SEARCHPLANNER @@ -180,7 +190,6 @@ Server metrics collection configuration. Controls server metrics collection freq | Environment Variable | Default | Constraints | Description | |----------------------|---------|-------------|-------------| -| `AIPERF_SERVER_METRICS_COLLECTION_FLUSH_PERIOD` | `2.0` | ≥ 0.0, ≤ 30.0 | Time in seconds to continue collecting metrics after profiling completes, allowing server-side metrics to flush/finalize before shutting down (default: 2.0s) | | `AIPERF_SERVER_METRICS_COLLECTION_INTERVAL` | `0.333` | ≥ 0.001, ≤ 300.0 | Server metrics collection interval in seconds (default: 333ms, ~3Hz) | | `AIPERF_SERVER_METRICS_EXPORT_BATCH_SIZE` | `100` | ≥ 1, ≤ 1000000 | Batch size for server metrics jsonl writer export results processor | | `AIPERF_SERVER_METRICS_REACHABILITY_TIMEOUT` | `10` | ≥ 1, ≤ 300 | Timeout in seconds for checking server metrics endpoint reachability during init | diff --git a/docs/index.yml b/docs/index.yml index 9236b6d62..06091757f 100644 --- a/docs/index.yml +++ b/docs/index.yml @@ -201,6 +201,8 @@ navigation: path: reference/tokenizer-auto-detection.md - page: Conversation Context Mode path: reference/conversation-context-mode.md + - page: Phase Baseline Handshake + path: reference/phase-baseline-handshake.md - page: List-Metric Aggregation path: reference/list-metric-aggregation.md - page: Vendor Usage Field Reference diff --git a/docs/reference/phase-baseline-handshake.md b/docs/reference/phase-baseline-handshake.md new file mode 100644 index 000000000..bada9509a --- /dev/null +++ b/docs/reference/phase-baseline-handshake.md @@ -0,0 +1,125 @@ +# Phase Baseline Handshake + +The phase baseline handshake captures point-in-time baseline readings at phase boundaries without coupling `TimingManager` to specific collectors. `PhaseRunner` pauses at each boundary through `PhaseGateClient`; `SystemController` fans the request out through `BaselineCoordinator`; registered baseline collectors scrape once and ACK; then the gate releases and the benchmark continues. + +## Component map + +```mermaid +flowchart LR + subgraph TimingManager + PR[PhaseRunner] + PGC[PhaseGateClient] + PR -->|before_phase / after_phase| PGC + end + + subgraph Controller + SC[SystemController] + BC[BaselineCoordinator] + SC --> BC + end + + subgraph Collectors[Baseline collector services] + GTM[GPUTelemetryManager] + SMM[ServerMetricsManager] + BCM[BaselineCollectorMixin] + GTM --> BCM + SMM --> BCM + end + + PGC -->|PhaseStartGateCommand / PhaseEndGateCommand| SC + BC -->|PhaseBaselineRequestMessage| BCM + BCM -->|PhaseBaselineAckMessage| SC + SC -->|PhaseGateGrantedResponse| PGC +``` + +## Per-phase message sequence + +```mermaid +sequenceDiagram + autonumber + participant Runner as PhaseRunner + participant Gate as PhaseGateClient + participant SC as SystemController + participant Coord as BaselineCoordinator + participant Collector as Baseline collectors + + Runner->>Gate: before_phase(phase_id, phase_name) + Gate->>SC: PhaseStartGateCommand + SC->>Coord: gate_phase(kind=START) + Coord->>Collector: PhaseBaselineRequestMessage(kind=START) + Collector->>Collector: collect_baseline(START, phase_id, phase_name) + Collector-->>SC: PhaseBaselineAckMessage(success=True) + SC->>Coord: handle_ack(ack) + Coord-->>SC: all registered collectors acked + SC-->>Gate: PhaseGateGrantedResponse + Gate-->>Runner: start gate released + + Runner->>Runner: issue credits and wait for returns + + Runner->>Gate: after_phase(phase_id, phase_name) + Gate->>SC: PhaseEndGateCommand + SC->>Coord: gate_phase(kind=END) + Coord->>Collector: PhaseBaselineRequestMessage(kind=END) + Collector->>Collector: collect_baseline(END, phase_id, phase_name) + Collector-->>SC: PhaseBaselineAckMessage(success=True) + SC->>Coord: handle_ack(ack) + Coord-->>SC: all registered collectors acked + SC-->>Gate: PhaseGateGrantedResponse + Gate-->>Runner: end gate released +``` + +## Credit ordering at phase boundaries + +```mermaid +sequenceDiagram + autonumber + participant Runner as PhaseRunner + participant Gate as PhaseGateClient + participant Controller as SystemController + participant Coord as BaselineCoordinator + participant Collectors as Baseline collectors + participant Issuer as CreditIssuer + participant Workers as Workers + + Runner->>Gate: before_phase(phase_id, phase_name) + Gate->>Controller: PhaseStartGateCommand + Controller->>Coord: gate_phase(kind=START) + Coord->>Collectors: PhaseBaselineRequestMessage(kind=START) + Collectors-->>Controller: PhaseBaselineAckMessage(success=True) + Controller-->>Gate: PhaseGateGrantedResponse + Gate-->>Runner: START gate released + Runner->>Issuer: start strategy.execute_phase() + Issuer->>Workers: publish credits for this phase + Workers-->>Runner: return credit results + Runner->>Runner: wait for sends complete, then returns drain + Runner->>Gate: after_phase(phase_id, phase_name) + Gate->>Controller: PhaseEndGateCommand + Controller->>Coord: gate_phase(kind=END) + Coord->>Collectors: PhaseBaselineRequestMessage(kind=END) + Collectors-->>Controller: PhaseBaselineAckMessage(success=True) + Controller-->>Gate: PhaseGateGrantedResponse + Gate-->>Runner: END gate released + Runner->>Runner: phase transition may complete +``` + +## TimingManager phase flow + +```mermaid +flowchart TD + A[PhaseRunner starts phase] --> B[Generate phase_id and phase_name] + B --> C[PhaseGateClient.before_phase] + C --> D[SystemController handles PHASE_START_GATE] + D --> E[BaselineCoordinator broadcasts START request] + E --> F[Collectors scrape START baseline] + F --> G[Collectors publish START ACKs] + G --> H[SystemController returns PhaseGateGrantedResponse] + H --> I[PhaseRunner issues credits] + I --> J[PhaseRunner waits for phase completion and returns] + J --> K[PhaseGateClient.after_phase] + K --> L[SystemController handles PHASE_END_GATE] + L --> M[BaselineCoordinator broadcasts END request] + M --> N[Collectors scrape END baseline] + N --> O[Collectors publish END ACKs] + O --> P[SystemController returns PhaseGateGrantedResponse] + P --> Q[PhaseRunner completes phase transition] +``` diff --git a/docs/server-metrics/server-metrics.md b/docs/server-metrics/server-metrics.md index e086ea07d..44b5c0d51 100644 --- a/docs/server-metrics/server-metrics.md +++ b/docs/server-metrics/server-metrics.md @@ -197,7 +197,6 @@ WARNING Disabling server metrics collection for http://127.0.0.1:60000/metrics: | Environment Variable | Default | Description | |---------------------|---------|-------------| | `AIPERF_SERVER_METRICS_COLLECTION_INTERVAL` | 0.333s | Collection frequency (333ms, ~3Hz) | -| `AIPERF_SERVER_METRICS_COLLECTION_FLUSH_PERIOD` | 2.0s | Wait time for final metrics after benchmark | | `AIPERF_SERVER_METRICS_REACHABILITY_TIMEOUT` | 10s | Timeout for endpoint reachability tests | | `AIPERF_SERVER_METRICS_EXPORT_BATCH_SIZE` | 100 | Batch size for JSONL writer | | `AIPERF_SERVER_METRICS_SHUTDOWN_DELAY` | 5.0s | Shutdown delay for command response transmission | diff --git a/src/aiperf/common/base_component_service.py b/src/aiperf/common/base_component_service.py index 5aaa3af21..b19bd9980 100644 --- a/src/aiperf/common/base_component_service.py +++ b/src/aiperf/common/base_component_service.py @@ -3,7 +3,7 @@ import asyncio import uuid -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from aiperf.common.base_service import BaseService from aiperf.common.enums import CommandType, LifecycleState @@ -40,6 +40,9 @@ class BaseComponentService(BaseService): publishing the current state of the service to the system controller. """ + extra_capabilities: ClassVar[tuple[str, ...]] = () + """Capabilities advertised by this service in RegisterServiceCommand. Mixins extend this.""" + def __init__( self, run: "BenchmarkRun", @@ -77,6 +80,7 @@ async def _register_service_on_start(self) -> None: # Target the system controller directly to avoid broadcasting to all services. target_service_type=ServiceType.SYSTEM_CONTROLLER, state=self.state, + capabilities=tuple(self.extra_capabilities), ) max_attempts = Environment.SERVICE.REGISTRATION_MAX_ATTEMPTS registration_interval = Environment.SERVICE.REGISTRATION_INTERVAL diff --git a/src/aiperf/common/enums/__init__.py b/src/aiperf/common/enums/__init__.py index 0b0eab9ad..420a3359d 100644 --- a/src/aiperf/common/enums/__init__.py +++ b/src/aiperf/common/enums/__init__.py @@ -6,6 +6,12 @@ BasePydanticEnumInfo, CaseInsensitiveStrEnum, ) +from aiperf.common.enums.baseline_enums import ( + BaselineKind, + ServiceCapability, + make_result_producer_capability, + parse_result_producer_capability, +) from aiperf.common.enums.enums import ( AIPerfLogLevel, AudioFormat, @@ -86,6 +92,7 @@ "BaseMetricUnitInfo", "BasePydanticBackedStrEnum", "BasePydanticEnumInfo", + "BaselineKind", "CaseInsensitiveStrEnum", "CommAddress", "CommandResponseStatus", @@ -141,6 +148,7 @@ "SSEFieldType", "ServerMetricsDiscoveryMode", "ServerMetricsFormat", + "ServiceCapability", "ServiceRegistrationStatus", "SweepMode", "SystemState", @@ -151,4 +159,6 @@ "VideoJobStatus", "VideoSynthType", "WorkerStatus", + "make_result_producer_capability", + "parse_result_producer_capability", ] diff --git a/src/aiperf/common/enums/baseline_enums.py b/src/aiperf/common/enums/baseline_enums.py new file mode 100644 index 000000000..9698e3eb3 --- /dev/null +++ b/src/aiperf/common/enums/baseline_enums.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Enums for the phase baseline handshake. + +BaselineKind tags whether a baseline reading is taken before a phase starts +issuing credits (START) or after credits have drained (END). + +ServiceCapability is a generic capability tag advertised by services in their +RegisterServiceCommand; SystemController dispatches based on membership +(e.g. BASELINE_COLLECTOR services join the BaselineCoordinator's registered set). +""" + +from aiperf.common.enums.base_enums import CaseInsensitiveStrEnum + + +class BaselineKind(CaseInsensitiveStrEnum): + """Direction of a baseline reading relative to a phase.""" + + START = "start" + END = "end" + + +class ServiceCapability(CaseInsensitiveStrEnum): + """Capability tags a service may advertise at registration time.""" + + BASELINE_COLLECTOR = "baseline_collector" + RESULT_PRODUCER = "result_producer" + + +_RESULT_PRODUCER_PREFIX = f"{ServiceCapability.RESULT_PRODUCER}:" + + +def make_result_producer_capability(domain: str) -> str: + """Build a result-producer capability tag for a result domain.""" + + return f"{_RESULT_PRODUCER_PREFIX}{domain}" + + +def parse_result_producer_capability(capability: str) -> str | None: + """Return the result domain if capability is a result-producer tag.""" + + if not capability.startswith(_RESULT_PRODUCER_PREFIX): + return None + domain = capability.removeprefix(_RESULT_PRODUCER_PREFIX) + return domain or None diff --git a/src/aiperf/common/enums/enums.py b/src/aiperf/common/enums/enums.py index 1a083d863..a6235c403 100644 --- a/src/aiperf/common/enums/enums.py +++ b/src/aiperf/common/enums/enums.py @@ -92,6 +92,8 @@ class CommandType(CaseInsensitiveStrEnum): PROFILE_COMPLETE = "profile_complete" PROFILE_CONFIGURE = "profile_configure" PROFILE_START = "profile_start" + PHASE_END_GATE = "phase_end_gate" + PHASE_START_GATE = "phase_start_gate" REGISTER_SERVICE = "register_service" SHUTDOWN = "shutdown" SHUTDOWN_WORKERS = "shutdown_workers" @@ -348,6 +350,8 @@ class MessageType(CaseInsensitiveStrEnum): INFERENCE_RESULTS = "inference_results" METRIC_RECORDS = "metric_records" PARSED_INFERENCE_RESULTS = "parsed_inference_results" + PHASE_BASELINE_ACK = "phase_baseline_ack" + PHASE_BASELINE_REQUEST = "phase_baseline_request" PROCESSING_STATS = "processing_stats" PROCESS_RECORDS_RESULT = "process_records_result" PROCESS_TELEMETRY_RESULT = "process_telemetry_result" diff --git a/src/aiperf/common/environment.py b/src/aiperf/common/environment.py index 946470a3c..08402d37a 100644 --- a/src/aiperf/common/environment.py +++ b/src/aiperf/common/environment.py @@ -8,6 +8,7 @@ Structure: Environment.API_SERVER.* - API server settings + Environment.BASELINE.* - Phase baseline handshake Environment.COMPRESSION.* - Compression settings for streaming file transfers Environment.DATASET.* - Dataset management Environment.DEV.* - Development and debugging settings @@ -87,6 +88,34 @@ class _APIServerSettings(BaseSettings): ) +class _BaselineSettings(BaseSettings): + """Phase baseline handshake settings. + + Controls the gate that blocks TimingManager between phases until + registered baseline collectors finish their point-in-time scrape. + """ + + model_config = SettingsConfigDict(env_prefix="AIPERF_BASELINE_") + + GATE_TIMEOUT_S: float = Field( + default=5.0, + gt=0.0, + description=( + "Per-gate timeout (seconds). If registered baseline collectors do not " + "all ack within this window, the gate releases with a warning and the " + "phase proceeds without waiting for stragglers." + ), + ) + GATE_ENABLED: bool = Field( + default=True, + description=( + "Master switch for the phase baseline handshake. When False, " + "PhaseGateClient short-circuits to no-op and PhaseRunner does not " + "wait between phases. Useful for replay/debug runs." + ), + ) + + class _CompressionSettings(BaseSettings): """Compression settings for streaming file transfers. @@ -592,6 +621,12 @@ class _RecordSettings(BaseSettings): default=300.0, description="Timeout in seconds for processing record results", ) + CREDITS_COMPLETE_FALLBACK_TIMEOUT: float = Field( + ge=0.0, + le=300.0, + default=10.0, + description="Maximum seconds RecordsManager waits for CreditsComplete after all profiling records are ready before finalizing defensively", + ) class _SearchPlannerSettings(BaseSettings): @@ -676,13 +711,6 @@ class _ServerMetricsSettings(BaseSettings): env_parse_enums=True, ) - COLLECTION_FLUSH_PERIOD: float = Field( - ge=0.0, - le=30.0, - default=2.0, - description="Time in seconds to continue collecting metrics after profiling completes, " - "allowing server-side metrics to flush/finalize before shutting down (default: 2.0s)", - ) COLLECTION_INTERVAL: float = Field( ge=0.001, le=300.0, @@ -1192,6 +1220,10 @@ class _Environment(BaseSettings): default_factory=_APIServerSettings, description="API server settings", ) + BASELINE: _BaselineSettings = Field( + default_factory=_BaselineSettings, + description="Phase baseline handshake settings", + ) COMPRESSION: _CompressionSettings = Field( default_factory=_CompressionSettings, description="Compression settings for streaming file transfers", diff --git a/src/aiperf/common/messages/__init__.py b/src/aiperf/common/messages/__init__.py index 4c3dad95f..84368cc55 100644 --- a/src/aiperf/common/messages/__init__.py +++ b/src/aiperf/common/messages/__init__.py @@ -6,6 +6,10 @@ Message, RequiresRequestNSMixin, ) +from aiperf.common.messages.baseline_messages import ( + PhaseBaselineAckMessage, + PhaseBaselineRequestMessage, +) from aiperf.common.messages.command_messages import ( CommandAcknowledgedResponse, CommandErrorResponse, @@ -14,6 +18,9 @@ CommandSuccessResponse, CommandUnhandledResponse, ConnectionProbeMessage, + PhaseEndGateCommand, + PhaseGateGrantedResponse, + PhaseStartGateCommand, ProcessRecordsCommand, ProcessRecordsResponse, ProfileCancelCommand, @@ -96,6 +103,11 @@ "Message", "MetricRecordsData", "MetricRecordsMessage", + "PhaseBaselineAckMessage", + "PhaseBaselineRequestMessage", + "PhaseEndGateCommand", + "PhaseGateGrantedResponse", + "PhaseStartGateCommand", "ProcessRecordsCommand", "ProcessRecordsResponse", "ProcessRecordsResultMessage", diff --git a/src/aiperf/common/messages/base_messages.py b/src/aiperf/common/messages/base_messages.py index 4ee62e8cb..6131608e0 100644 --- a/src/aiperf/common/messages/base_messages.py +++ b/src/aiperf/common/messages/base_messages.py @@ -32,6 +32,7 @@ class Message(AIPerfBaseModel): request_ns: int | None = Field( default=None, + ge=0, description="Timestamp of the request", ) @@ -65,6 +66,7 @@ class RequiresRequestNSMixin(Message): request_ns: int = Field( # type: ignore[assignment] default_factory=time.time_ns, + ge=0, description="Timestamp of the request in nanoseconds", ) diff --git a/src/aiperf/common/messages/baseline_messages.py b/src/aiperf/common/messages/baseline_messages.py new file mode 100644 index 000000000..bf46b669d --- /dev/null +++ b/src/aiperf/common/messages/baseline_messages.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Messages for the phase baseline handshake. + +PhaseBaselineRequestMessage is broadcast by SystemController; baseline-collector +services respond with PhaseBaselineAckMessage carrying success/error status. +""" + +from pydantic import Field + +from aiperf.common.enums import BaselineKind, MessageType +from aiperf.common.messages.base_messages import Message +from aiperf.common.messages.service_messages import BaseServiceMessage +from aiperf.common.types import MessageTypeT + + +class PhaseBaselineRequestMessage(Message): + """Broadcast by SystemController to ask all baseline collectors to scrape.""" + + message_type: MessageTypeT = MessageType.PHASE_BASELINE_REQUEST + + phase_id: str = Field( + ..., description="UUID of the phase being gated; pairs request to ack." + ) + phase_name: str = Field( + ..., description="Human-readable phase name (warmup, profiling, ...)." + ) + kind: BaselineKind = Field( + ..., description="START before credits are issued; END after returns drain." + ) + + +class PhaseBaselineAckMessage(BaseServiceMessage): + """Sent by a baseline collector after attempting collect_baseline().""" + + message_type: MessageTypeT = MessageType.PHASE_BASELINE_ACK + + phase_id: str = Field(..., description="Phase ID this ack is for.") + kind: BaselineKind = Field(..., description="START or END.") + success: bool = Field( + ..., + description="False if collect_baseline() raised; coordinator still counts as ack.", + ) + error: str | None = Field( + default=None, description="Error string when success=False." + ) diff --git a/src/aiperf/common/messages/command_messages.py b/src/aiperf/common/messages/command_messages.py index 6a0554e01..6d76e0d5a 100644 --- a/src/aiperf/common/messages/command_messages.py +++ b/src/aiperf/common/messages/command_messages.py @@ -261,10 +261,7 @@ class ProfileStartCommand(CommandMessage): class ProfileCompleteCommand(CommandMessage): - """Command message sent when all records are received and profiling is complete. - - Triggers final scrape of server metrics to capture end state. - """ + """Command message sent when all records are received and profiling is complete.""" command: CommandTypeT = CommandType.PROFILE_COMPLETE @@ -291,6 +288,14 @@ class RegisterServiceCommand(CommandMessage): ..., description="The type of the service to register" ) state: LifecycleState = Field(..., description="The current state of the service") + capabilities: tuple[str, ...] = Field( + default=(), + description=( + "Capability tags advertised by this service. SystemController dispatches " + "based on membership; e.g., 'baseline_collector' joins the " + "BaselineCoordinator's registered set. Unknown tags are ignored." + ), + ) class ProcessRecordsResponse(CommandSuccessResponse): @@ -308,3 +313,27 @@ class ConnectionProbeMessage(TargetedServiceMessage): """Message containing a connection probe from a service. This is used to probe the connection to the service.""" message_type: MessageTypeT = MessageType.CONNECTION_PROBE + + +class PhaseStartGateCommand(CommandMessage): + """PhaseRunner -> SystemController: hold before issuing first credit of a phase.""" + + command: CommandTypeT = CommandType.PHASE_START_GATE + + phase_id: str = Field(..., description="UUID of the phase being gated.") + phase_name: str = Field(..., description="Phase name for diagnostics.") + + +class PhaseEndGateCommand(CommandMessage): + """PhaseRunner -> SystemController: hold after credits drain, before next phase.""" + + command: CommandTypeT = CommandType.PHASE_END_GATE + + phase_id: str = Field(..., description="UUID of the phase being gated.") + phase_name: str = Field(..., description="Phase name for diagnostics.") + + +class PhaseGateGrantedResponse(CommandSuccessResponse): + """Response from BaselineCoordinator releasing a phase gate.""" + + phase_id: str = Field(..., description="Phase ID matching the gate command.") diff --git a/src/aiperf/common/mixins/__init__.py b/src/aiperf/common/mixins/__init__.py index 2fb6a794c..56b2c244f 100644 --- a/src/aiperf/common/mixins/__init__.py +++ b/src/aiperf/common/mixins/__init__.py @@ -12,6 +12,7 @@ TRecordCallback, ) from aiperf.common.mixins.base_mixin import BaseMixin +from aiperf.common.mixins.baseline_collector_mixin import BaselineCollectorMixin from aiperf.common.mixins.buffered_jsonl_writer_mixin import BufferedJSONLWriterMixin from aiperf.common.mixins.command_handler_mixin import CommandHandlerMixin from aiperf.common.mixins.communication_mixin import CommunicationMixin @@ -42,6 +43,7 @@ "AIPerfLoggerMixin", "BaseMetricsCollectorMixin", "BaseMixin", + "BaselineCollectorMixin", "BufferedJSONLWriterMixin", "CombinedPhaseStats", "CommandHandlerMixin", diff --git a/src/aiperf/common/mixins/baseline_collector_mixin.py b/src/aiperf/common/mixins/baseline_collector_mixin.py new file mode 100644 index 000000000..c64b60222 --- /dev/null +++ b/src/aiperf/common/mixins/baseline_collector_mixin.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Mixin for services that take pre/post-phase baseline readings. + +Subclasses implement ``collect_baseline(kind, phase_id, phase_name)`` and gain: +- automatic ServiceCapability.BASELINE_COLLECTOR registration via extra_capabilities +- a PhaseBaselineRequestMessage handler that calls collect_baseline and publishes + PhaseBaselineAckMessage with success/error. + +The mixin assumes the host class provides ``self.publish(msg)`` and ``self.service_id`` +(both satisfied by BaseComponentService). +""" + +from __future__ import annotations + +from abc import abstractmethod +from typing import ClassVar + +from aiperf.common.enums import BaselineKind, MessageType, ServiceCapability +from aiperf.common.hooks import on_message +from aiperf.common.messages import ( + PhaseBaselineAckMessage, + PhaseBaselineRequestMessage, +) + + +class BaselineCollectorMixin: + """Mix into a BaseComponentService to participate in the phase baseline handshake.""" + + extra_capabilities: ClassVar[tuple[str, ...]] = ( + ServiceCapability.BASELINE_COLLECTOR, + ) + + @abstractmethod + async def collect_baseline( + self, kind: BaselineKind, phase_id: str, phase_name: str + ) -> None: + """Take a single point-in-time baseline reading. + + Implementations MUST be idempotent under retries (rare) and MUST NOT + block longer than AIPERF_BASELINE_GATE_TIMEOUT_S; the coordinator + will release the gate without their ack on timeout. + """ + + @on_message(MessageType.PHASE_BASELINE_REQUEST) + async def _on_phase_baseline_request( + self, message: PhaseBaselineRequestMessage + ) -> None: + """Drive collect_baseline and publish an ack with success/error status.""" + success = True + error: str | None = None + try: + await self.collect_baseline( + message.kind, message.phase_id, message.phase_name + ) + except Exception as exc: # noqa: BLE001 - per-collector fault tolerance + success = False + error = f"{type(exc).__name__}: {exc}" + + await self.publish( + PhaseBaselineAckMessage( + service_id=self.service_id, + phase_id=message.phase_id, + kind=message.kind, + success=success, + error=error, + ) + ) diff --git a/src/aiperf/controller/baseline_coordinator.py b/src/aiperf/controller/baseline_coordinator.py new file mode 100644 index 000000000..de6d37764 --- /dev/null +++ b/src/aiperf/controller/baseline_coordinator.py @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""SystemController-side coordinator for the phase baseline handshake. + +Owns the set of services that registered with capability BASELINE_COLLECTOR, +fans out PhaseBaselineRequestMessage, and gathers PhaseBaselineAckMessage +responses with a per-gate timeout. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable + +from aiperf.common.aiperf_logger import AIPerfLogger +from aiperf.common.enums import BaselineKind +from aiperf.common.messages import ( + PhaseBaselineAckMessage, + PhaseBaselineRequestMessage, +) + +_logger = AIPerfLogger(__name__) + + +class BaselineCoordinator: + """Coordinates pre/post-phase baseline scrapes across registered collectors.""" + + def __init__( + self, + publish: Callable[[PhaseBaselineRequestMessage], Awaitable[None]], + gate_timeout_s: float, + ) -> None: + self._publish = publish + self._gate_timeout_s = gate_timeout_s + self._registered: set[str] = set() + self._inflight: dict[ + tuple[str, BaselineKind], + dict[str, asyncio.Future[PhaseBaselineAckMessage]], + ] = {} + + @property + def registered_count(self) -> int: + return len(self._registered) + + def register(self, service_id: str) -> None: + """Add a service to the registered baseline-collector set. Idempotent.""" + self._registered.add(service_id) + + def unregister(self, service_id: str) -> None: + """Remove a service (e.g., on heartbeat-loss eviction). No-op if absent.""" + self._registered.discard(service_id) + + def handle_ack(self, ack: PhaseBaselineAckMessage) -> None: + """Resolve the pending future for (phase_id, kind, service_id), if any.""" + pending = self._inflight.get((ack.phase_id, ack.kind)) + if pending is None: + return + fut = pending.get(ack.service_id) + if fut is None or fut.done(): + return + fut.set_result(ack) + + async def gate_phase( + self, phase_id: str, phase_name: str, kind: BaselineKind + ) -> None: + """Block until all currently-registered collectors ack, or timeout fires.""" + registered = tuple(self._registered) + if not registered: + return + + pending: dict[str, asyncio.Future[PhaseBaselineAckMessage]] = { + sid: asyncio.get_running_loop().create_future() for sid in registered + } + self._inflight[(phase_id, kind)] = pending + + await self._publish( + PhaseBaselineRequestMessage( + phase_id=phase_id, phase_name=phase_name, kind=kind + ) + ) + + try: + results = await asyncio.wait_for( + asyncio.gather(*pending.values(), return_exceptions=True), + timeout=self._gate_timeout_s, + ) + for ack in results: + if isinstance(ack, PhaseBaselineAckMessage) and not ack.success: + _logger.warning( + f"Baseline {kind} for phase '{phase_name}' " + f"(id={phase_id[:8]}) collector {ack.service_id!r} " + f"reported failure: {ack.error}" + ) + except asyncio.TimeoutError: + unacked: list[str] = [] + for sid, f in pending.items(): + # A future is "acked" only if it completed via set_result (not + # via cancellation from the wait_for timeout). + if f.cancelled() or not f.done(): + unacked.append(sid) + unacked.sort() + _logger.warning( + f"Baseline {kind} gate for phase '{phase_name}' " + f"(id={phase_id[:8]}) timed out after {self._gate_timeout_s}s; " + f"proceeding without acks from {unacked}. " + f"Increase AIPERF_BASELINE_GATE_TIMEOUT_S or set " + f"AIPERF_BASELINE_GATE_ENABLED=0 to disable." + ) + finally: + self._inflight.pop((phase_id, kind), None) diff --git a/src/aiperf/controller/result_join_coordinator.py b/src/aiperf/controller/result_join_coordinator.py new file mode 100644 index 000000000..614a35023 --- /dev/null +++ b/src/aiperf/controller/result_join_coordinator.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tracks post-profile result producers until every registered result is joined.""" + +from __future__ import annotations + + +class ResultJoinCoordinator: + """Coordinates readiness across registered result-producing domains.""" + + def __init__(self) -> None: + self._required: dict[str, set[str]] = {} + self._completed: dict[str, set[str]] = {} + self._last_reported_pending: tuple[str, ...] = () + + @property + def ready(self) -> bool: + return self.pending_domains == () + + @property + def pending_domains(self) -> tuple[str, ...]: + return tuple( + sorted( + domain + for domain, required in self._required.items() + if required - self._completed.get(domain, set()) + ) + ) + + def register(self, domain: str, service_id: str) -> None: + self._required.setdefault(domain, set()).add(service_id) + + def unregister(self, domain: str, service_id: str) -> None: + required = self._required.get(domain) + if required is None: + return + + required.discard(service_id) + completed = self._completed.get(domain) + if completed is not None: + completed.discard(service_id) + + if not required: + self._required.pop(domain, None) + self._completed.pop(domain, None) + + def unregister_service(self, service_id: str) -> None: + for domain in tuple(self._required): + self.unregister(domain, service_id) + + def complete(self, domain: str, service_id: str) -> None: + if service_id not in self._required.get(domain, set()): + return + self._completed.setdefault(domain, set()).add(service_id) + + def complete_domain(self, domain: str) -> None: + required = self._required.get(domain) + if not required: + return + self._completed[domain] = set(required) + + def pending_domains_changed(self) -> tuple[str, ...] | None: + pending = self.pending_domains + if pending == self._last_reported_pending: + return None + self._last_reported_pending = pending + return pending diff --git a/src/aiperf/controller/system_controller.py b/src/aiperf/controller/system_controller.py index 9830fb7e2..bcf6558c1 100644 --- a/src/aiperf/controller/system_controller.py +++ b/src/aiperf/controller/system_controller.py @@ -16,10 +16,13 @@ ) from aiperf.common.base_service import BaseService from aiperf.common.enums import ( + BaselineKind, CommandResponseStatus, CommandType, MessageType, + ServiceCapability, ServiceRegistrationStatus, + parse_result_producer_capability, ) from aiperf.common.environment import Environment from aiperf.common.exceptions import LifecycleOperationError @@ -31,6 +34,10 @@ CommandResponse, CommandSuccessResponse, HeartbeatMessage, + PhaseBaselineAckMessage, + PhaseEndGateCommand, + PhaseGateGrantedResponse, + PhaseStartGateCommand, ProcessRecordsResultMessage, ProcessServerMetricsResultMessage, ProcessTelemetryResultMessage, @@ -56,9 +63,11 @@ from aiperf.common.models.server_metrics_models import ServerMetricsResults from aiperf.common.types import ServiceTypeT from aiperf.config.artifacts import OutputDefaults +from aiperf.controller.baseline_coordinator import BaselineCoordinator from aiperf.controller.controller_utils import print_exit_errors from aiperf.controller.protocols import ServiceManagerProtocol from aiperf.controller.proxy_manager import ProxyManager +from aiperf.controller.result_join_coordinator import ResultJoinCoordinator from aiperf.controller.system_mixins import SignalHandlerMixin from aiperf.credit.messages import CreditsCompleteMessage from aiperf.exporters.exporter_manager import ExporterManager @@ -157,9 +166,6 @@ def __init__( self._exit_errors: list[ExitErrorInfo] = [] self._telemetry_results: TelemetryExportData | None = None self._server_metrics_results: ServerMetricsResults | None = None - self._profile_results_received = False - self._should_wait_for_telemetry = False - self._should_wait_for_server_metrics = False self._shutdown_triggered = False self._shutdown_lock = asyncio.Lock() @@ -167,6 +173,11 @@ def __init__( self._telemetry_endpoints_reachable: list[str] = [] self._server_metrics_endpoints_configured: list[str] = [] self._server_metrics_endpoints_reachable: list[str] = [] + self._baseline_coordinator = BaselineCoordinator( + publish=self.publish, + gate_timeout_s=Environment.BASELINE.GATE_TIMEOUT_S, + ) + self._result_join_coordinator = ResultJoinCoordinator() self.debug("System Controller created") def _should_warn_osl_without_ignore_eos(self) -> bool: @@ -249,14 +260,12 @@ async def _start_services(self) -> None: await self.service_manager.run_service(ServiceType.GPU_TELEMETRY_MANAGER) else: self.info("GPU telemetry disabled via --no-gpu-telemetry") - self._should_wait_for_telemetry = False if self.run.cfg.server_metrics.enabled: self.debug("Starting optional ServerMetricsManager service") await self.service_manager.run_service(ServiceType.SERVER_METRICS_MANAGER) else: self.info("Server metrics disabled via --no-server-metrics") - self._should_wait_for_server_metrics = False # Start AIPerf API if enabled api_port = self.run.cfg.runtime.api_port or Environment.API_SERVER.PORT @@ -377,12 +386,52 @@ async def _handle_register_service_command( self.service_manager.service_map[message.service_type] = [] self.service_manager.service_map[message.service_type].append(service_info) + if ServiceCapability.BASELINE_COLLECTOR in message.capabilities: + self._baseline_coordinator.register(message.service_id) + + for capability in message.capabilities: + domain = parse_result_producer_capability(capability) + if domain is not None: + self._result_join_coordinator.register(domain, message.service_id) + try: type_name = ServiceType(message.service_type).name.title().replace("_", " ") except (TypeError, ValueError): type_name = message.service_type self.info(lambda: f"Registered {type_name} (id: '{message.service_id}')") + @on_command(CommandType.PHASE_START_GATE) + async def _on_phase_start_gate( + self, message: PhaseStartGateCommand + ) -> PhaseGateGrantedResponse: + await self._baseline_coordinator.gate_phase( + message.phase_id, message.phase_name, BaselineKind.START + ) + return PhaseGateGrantedResponse( + command_id=message.command_id, + service_id=self.service_id, + command=message.command, + phase_id=message.phase_id, + ) + + @on_command(CommandType.PHASE_END_GATE) + async def _on_phase_end_gate( + self, message: PhaseEndGateCommand + ) -> PhaseGateGrantedResponse: + await self._baseline_coordinator.gate_phase( + message.phase_id, message.phase_name, BaselineKind.END + ) + return PhaseGateGrantedResponse( + command_id=message.command_id, + service_id=self.service_id, + command=message.command, + phase_id=message.phase_id, + ) + + @on_message(MessageType.PHASE_BASELINE_ACK) + async def _on_phase_baseline_ack(self, message: PhaseBaselineAckMessage) -> None: + self._baseline_coordinator.handle_ack(message) + @on_message(MessageType.HEARTBEAT) async def _process_heartbeat_message(self, message: HeartbeatMessage) -> None: """Process a heartbeat message from a service. It will @@ -442,6 +491,8 @@ async def _process_service_error_message( service_id=message.service_id, ) ) + self._result_join_coordinator.unregister_service(message.service_id) + await self._check_and_trigger_shutdown() @on_message(MessageType.STATUS) async def _process_status_message(self, message: StatusMessage) -> None: @@ -491,9 +542,9 @@ async def _on_telemetry_status_message( self._telemetry_endpoints_configured = message.endpoints_configured self._telemetry_endpoints_reachable = message.endpoints_reachable - self._should_wait_for_telemetry = message.enabled if not message.enabled: + self._result_join_coordinator.unregister("telemetry", message.service_id) reason_msg = f": {message.reason}" if message.reason else "" self.info(f"DCGM telemetry skipped{reason_msg}") else: @@ -515,9 +566,11 @@ async def _on_server_metrics_status_message( self._server_metrics_endpoints_configured = message.endpoints_configured self._server_metrics_endpoints_reachable = message.endpoints_reachable - self._should_wait_for_server_metrics = message.enabled if not message.enabled: + self._result_join_coordinator.unregister( + "server_metrics", message.service_id + ) reason_msg = f" - {message.reason}" if message.reason else "" self.info(f"Server metrics disabled{reason_msg}") else: @@ -608,8 +661,7 @@ async def _on_process_records_result_message( f"Received process records result message with no records: {message.results.results}" ) - self._profile_results_received = True - # Coordinate with telemetry results before shutdown + self._result_join_coordinator.complete_domain("profile") await self._check_and_trigger_shutdown() @on_message(MessageType.PROCESS_TELEMETRY_RESULT) @@ -643,7 +695,7 @@ async def _on_process_telemetry_result_message( except Exception as e: self.exception(f"Error processing telemetry results message: {e!r}") finally: - self._should_wait_for_telemetry = False + self._result_join_coordinator.complete_domain("telemetry") await self._check_and_trigger_shutdown() @on_message(MessageType.PROCESS_SERVER_METRICS_RESULT) @@ -683,30 +735,23 @@ async def _on_process_server_metrics_result_message( except Exception as e: self.exception(f"Error processing server metrics results message: {e!r}") finally: - self._should_wait_for_server_metrics = False + self._result_join_coordinator.complete_domain("server_metrics") await self._check_and_trigger_shutdown() async def _check_and_trigger_shutdown(self) -> None: - """Check if all required results are received and trigger unified export + shutdown. - - Coordination logic: - 1. Always wait for profile results (ProcessRecordsResultMessage) - 2. If telemetry disabled OR telemetry results received → proceed - 3. If server metrics disabled OR server metrics results received → proceed - 4. Otherwise → wait (results arrive nearly simultaneously and will call this method again) - - Thread safety: - Uses self._shutdown_lock to prevent race conditions when ProcessRecordsResultMessage, - ProcessTelemetryResultMessage, and ProcessServerMetricsResultMessage arrive concurrently. - The lock ensures atomic check-and-set of _shutdown_triggered, preventing double-triggering of stop(). + """Check if all registered result producers are complete and trigger shutdown. + + Uses self._shutdown_lock to prevent races when result messages arrive + concurrently. The lock ensures atomic check-and-set of _shutdown_triggered, + preventing double-triggering of stop(). """ self.debug( - f"_check_and_trigger_shutdown: profile_received={self._profile_results_received}, " - f"wait_telemetry={self._should_wait_for_telemetry}, telemetry_results={self._telemetry_results is not None}, " - f"wait_server_metrics={self._should_wait_for_server_metrics}, server_metrics_results={self._server_metrics_results is not None}, " - f"shutdown_triggered={self._shutdown_triggered}" + lambda: ( + "_check_and_trigger_shutdown: " + f"pending_domains={self._result_join_coordinator.pending_domains}, " + f"shutdown_triggered={self._shutdown_triggered}" + ) ) - # Check if we should trigger shutdown (with lock protection) should_shutdown = False async with self._shutdown_lock: if self._shutdown_triggered: @@ -715,31 +760,15 @@ async def _check_and_trigger_shutdown(self) -> None: ) return - if not self._profile_results_received: - self.debug( - "_check_and_trigger_shutdown: profile results not received yet" - ) - return - - telemetry_ready_for_shutdown = ( - not self._should_wait_for_telemetry - or self._telemetry_results is not None - ) - - server_metrics_ready_for_shutdown = ( - not self._should_wait_for_server_metrics - or self._server_metrics_results is not None - ) - - if telemetry_ready_for_shutdown and server_metrics_ready_for_shutdown: + if self._result_join_coordinator.ready: self._shutdown_triggered = True should_shutdown = True self.info("All results received, initiating shutdown") - else: - if not telemetry_ready_for_shutdown: - self.info("Waiting for telemetry results...") - if not server_metrics_ready_for_shutdown: - self.info("Waiting for server metrics results...") + elif ( + pending_domains + := self._result_join_coordinator.pending_domains_changed() + ) is not None: + self.info(f"Waiting for result domains: {', '.join(pending_domains)}") # Call stop() OUTSIDE the lock to prevent deadlock if should_shutdown: @@ -885,7 +914,6 @@ async def _cancel_profiling(self) -> None: ) ) self._profile_results = response.data - self._profile_results_received = True break except Exception as e: # Catch ANY exception during cancellation - we must always proceed to stop(). diff --git a/src/aiperf/gpu_telemetry/manager.py b/src/aiperf/gpu_telemetry/manager.py index 21cdaebfc..2d500ab2f 100644 --- a/src/aiperf/gpu_telemetry/manager.py +++ b/src/aiperf/gpu_telemetry/manager.py @@ -5,10 +5,16 @@ import asyncio from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar from aiperf.common.base_component_service import BaseComponentService -from aiperf.common.enums import CommAddress, CommandType +from aiperf.common.enums import ( + BaselineKind, + CommAddress, + CommandType, + ServiceCapability, + make_result_producer_capability, +) from aiperf.common.environment import Environment from aiperf.common.hooks import on_command, on_init, on_stop from aiperf.common.messages import ( @@ -18,6 +24,7 @@ TelemetryRecordsMessage, TelemetryStatusMessage, ) +from aiperf.common.mixins.baseline_collector_mixin import BaselineCollectorMixin from aiperf.common.models import ErrorDetails, TelemetryRecord from aiperf.common.protocols import PushClientProtocol from aiperf.gpu_telemetry.protocols import GPUTelemetryCollectorProtocol @@ -37,7 +44,7 @@ class _CollectorCandidate: kwargs: dict[str, Any] -class GPUTelemetryManager(BaseComponentService): +class GPUTelemetryManager(BaselineCollectorMixin, BaseComponentService): """Coordinates multiple TelemetryDataCollector instances for GPU telemetry collection. The GPUTelemetryManager coordinates multiple TelemetryDataCollector instances @@ -56,6 +63,11 @@ class GPUTelemetryManager(BaseComponentService): service_id: Optional unique identifier for this service instance """ + extra_capabilities: ClassVar[tuple[str, ...]] = ( + ServiceCapability.BASELINE_COLLECTOR, + make_result_producer_capability("telemetry"), + ) + def __init__( self, run: BenchmarkRun, @@ -246,13 +258,6 @@ async def _configure_reachable_collectors( self._collectors[source_identifier] = collector self._collector_id_to_url[candidate.collector_id] = source_identifier self.debug(f"GPU Telemetry: {source_identifier} is reachable") - baseline_failure_reason = await self._capture_collector_baseline( - collector, - candidate.collector_id, - source_identifier, - ) - if baseline_failure_reason is not None: - failure_reason = baseline_failure_reason except RuntimeError as e: failure_reason = str(e) self.error(f"GPU Telemetry: {e}") @@ -263,34 +268,6 @@ async def _configure_reachable_collectors( ) return configured_sources, failure_reason - async def _capture_collector_baseline( - self, - collector: GPUTelemetryCollectorProtocol, - collector_id: str, - source_identifier: str, - ) -> str | None: - self.info(f"GPU Telemetry: Capturing baseline metrics from {source_identifier}") - try: - await collector.initialize() - except (Exception, asyncio.CancelledError) as e: # noqa: BLE001 - self.warning( - f"GPU Telemetry: Failed to initialize {source_identifier} during " - f"baseline capture, disabling collector: {e!r}" - ) - self._collectors.pop(source_identifier, None) - self._collector_id_to_url.pop(collector_id, None) - return f"{source_identifier} initialization failed: {e}" - - try: - await collector.collect_and_process_metrics() - self.debug(f"GPU Telemetry: Captured baseline from {source_identifier}") - except Exception as e: # noqa: BLE001 - baseline scrape best-effort - self.warning( - f"GPU Telemetry: Failed to capture baseline from {source_identifier} " - f"(collector remains enabled): {e}" - ) - return None - async def _send_configure_status( self, configured_sources: list[str], failure_reason: str | None ) -> None: @@ -379,34 +356,53 @@ async def _handle_profile_cancel_command( """ await self._stop_all_collectors() - @on_command(CommandType.PROFILE_COMPLETE) - async def _handle_profile_complete_command( - self, message: ProfileCompleteCommand + async def collect_baseline( + self, kind: BaselineKind, phase_id: str, phase_name: str ) -> None: - """Trigger final scrape when profiling completes. - - Ensures GPU telemetry captures final state for accurate counter deltas. - This final scrape provides the end-point values needed for metrics like - energy_consumption which are computed as (final - baseline). - - Args: - message: Profile complete command from SystemController - """ + if phase_name != "profiling": + return if not self._collectors: - self.debug("GPU Telemetry: Already stopped, skipping final scrape") return - self.info("GPU Telemetry: Profiling complete, capturing final metrics...") + self.info( + f"GPU Telemetry: Capturing {kind} baseline for phase '{phase_name}'..." + ) + await self._collect_once(label=str(kind)) + + async def _collect_once(self, label: str) -> None: + failures: list[str] = [] - for dcgm_url, collector in list(self._collectors.items()): + async def collect( + source_url: str, collector: GPUTelemetryCollectorProtocol + ) -> None: try: await collector.collect_and_process_metrics() - self.debug(f"GPU Telemetry: Captured final state from {dcgm_url}") - except Exception as e: + self.debug( + lambda url=source_url: f"GPU Telemetry: Captured {label} state from {url}" + ) + except Exception as e: # noqa: BLE001 - keep attempting other endpoints + failures.append(f"{source_url}: {e}") self.warning( - f"GPU Telemetry: Failed to capture final state from {dcgm_url}: {e}" + f"GPU Telemetry: Failed to capture {label} state from {source_url}: {e}" ) + await asyncio.gather( + *( + collect(source_url, collector) + for source_url, collector in list(self._collectors.items()) + ) + ) + if failures: + raise RuntimeError("; ".join(failures)) + + @on_command(CommandType.PROFILE_COMPLETE) + async def _handle_profile_complete_command( + self, message: ProfileCompleteCommand + ) -> None: + if not self._collectors: + self.debug("GPU Telemetry: Already stopped, skipping completion") + return + await self._stop_all_collectors() @on_stop diff --git a/src/aiperf/gpu_telemetry/pynvml_collector.py b/src/aiperf/gpu_telemetry/pynvml_collector.py index 8c281221e..977b69ed1 100644 --- a/src/aiperf/gpu_telemetry/pynvml_collector.py +++ b/src/aiperf/gpu_telemetry/pynvml_collector.py @@ -343,9 +343,7 @@ async def _collect_metrics_loop(self) -> None: async def collect_and_process_metrics(self) -> None: """Public alias for one-shot scrape. - ``GPUTelemetryManager`` calls this name during baseline and final-state - capture (``manager.py`` :func:`_capture_collector_baseline` and - :func:`_handle_profile_complete_command`). + ``GPUTelemetryManager`` calls this name during phase baseline capture. """ await self._collect_and_process_metrics() diff --git a/src/aiperf/records/records_manager.py b/src/aiperf/records/records_manager.py index cf293e1c6..e5f0c5e97 100644 --- a/src/aiperf/records/records_manager.py +++ b/src/aiperf/records/records_manager.py @@ -6,15 +6,15 @@ import time from collections import defaultdict from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from aiperf.common.base_component_service import BaseComponentService -from aiperf.common.constants import NANOS_PER_SECOND from aiperf.common.enums import ( CommAddress, CommandType, CreditPhase, MessageType, + make_result_producer_capability, ) from aiperf.common.environment import Environment from aiperf.common.exceptions import PostProcessorDisabled @@ -107,6 +107,10 @@ class RecordsManager(PullClientMixin, BaseComponentService): many records before finalizing results. """ + extra_capabilities: ClassVar[tuple[str, ...]] = ( + make_result_producer_capability("profile"), + ) + def __init__( self, run: BenchmarkRun, @@ -151,6 +155,8 @@ def __init__( # warmup vs profiling separation. self._phase_branch_stats: dict[CreditPhase, BranchStats] = {} self._complete_credit_phases: set[CreditPhase] = set() + self._credits_complete_received = False + self._credits_complete_fallback_task: asyncio.Task[None] | None = None self._telemetry_state = ErrorTrackingState() self._server_metrics_state = ErrorTrackingState() @@ -237,13 +243,7 @@ async def _on_metric_records(self, message: MetricRecordsMessage) -> None: ) phase = record_data.metadata.benchmark_phase - if ( - phase in self._complete_credit_phases - and self._records_tracker.check_and_set_all_records_received_for_phase( - phase - ) - ): - await self._handle_all_records_received(phase) + await self._maybe_handle_all_records_received(phase) @on_pull_message(MessageType.TELEMETRY_RECORDS) async def _on_telemetry_records(self, message: TelemetryRecordsMessage) -> None: @@ -285,6 +285,50 @@ async def _on_server_metrics_records( if message.error: self._server_metrics_state.error_counts[message.error] += 1 + async def _maybe_handle_all_records_received(self, phase: CreditPhase) -> None: + if phase not in self._complete_credit_phases: + return + if phase == CreditPhase.PROFILING and not self._credits_complete_received: + self._maybe_schedule_credits_complete_fallback(phase) + return + if self._records_tracker.check_and_set_all_records_received_for_phase(phase): + self._cancel_credits_complete_fallback() + await self._handle_all_records_received(phase) + + def _maybe_schedule_credits_complete_fallback(self, phase: CreditPhase) -> None: + if self._credits_complete_fallback_task is not None: + return + phase_stats = self._records_tracker.create_stats_for_phase(phase) + if phase_stats.final_requests_completed is None: + return + if phase_stats.total_records < phase_stats.final_requests_completed: + return + timeout = Environment.RECORD.CREDITS_COMPLETE_FALLBACK_TIMEOUT + self.warning( + f"All profiling records arrived before CreditsComplete; waiting up to {timeout:.1f}s before finalizing defensively" + ) + self._credits_complete_fallback_task = self.execute_async( + self._finalize_without_credits_complete_after_timeout(phase) + ) + + def _cancel_credits_complete_fallback(self) -> None: + if self._credits_complete_fallback_task is None: + return + self._credits_complete_fallback_task.cancel() + self._credits_complete_fallback_task = None + + async def _finalize_without_credits_complete_after_timeout( + self, phase: CreditPhase + ) -> None: + await asyncio.sleep(Environment.RECORD.CREDITS_COMPLETE_FALLBACK_TIMEOUT) + if self._credits_complete_received: + return + self.warning( + "CreditsComplete was not received after all profiling records arrived; finalizing records defensively" + ) + if self._records_tracker.check_and_set_all_records_received_for_phase(phase): + await self._handle_all_records_received(phase) + async def _handle_all_records_received(self, phase: CreditPhase) -> None: """Handle the case where all records have been received.""" if phase != CreditPhase.PROFILING: @@ -325,35 +369,15 @@ async def _finalize_and_process_results( ) ) - # Trigger final server metrics scrape and wait for completion - # This ensures final metrics are pushed before we export results response = await self.send_command_and_wait_for_response( ProfileCompleteCommand(service_id=self.service_id), timeout=10.0 ) if isinstance(response, ErrorDetails): - self.warning(f"Server metrics final scrape timed out or failed: {response}") + self.warning(f"Server metrics completion timed out or failed: {response}") else: - self.debug("Server metrics final scrape completed") - - self.debug("Waiting for server metrics flush period...") - # Wait for server metrics flush period to allow final metrics to be collected - # This ensures metrics that are still being processed by the server are captured - flush_period = Environment.SERVER_METRICS.COLLECTION_FLUSH_PERIOD - phase_stats = self._records_tracker.create_stats_for_phase( - CreditPhase.PROFILING - ) - flush_end_ns = (phase_stats.requests_end_ns or time.time_ns()) + ( - (flush_period or 0) * NANOS_PER_SECOND - ) - sleep_dur_sec = (flush_end_ns - time.time_ns()) / NANOS_PER_SECOND - if sleep_dur_sec > 0: - self.info( - f"Waiting {sleep_dur_sec:.1f}s for server metrics flush period..." - ) - await asyncio.sleep(sleep_dur_sec) + self.debug("Server metrics completion acknowledged") - self.debug("Server metrics flush period complete, processing now...") await self._process_results(phase=phase, cancelled=cancelled) self.info("_finalize_and_process_results completed") @@ -566,12 +590,7 @@ async def _on_credit_phase_complete( f"(currently {phase_stats.total_records:,} of {phase_stats.final_requests_completed:,} records processed)..." ) - # This check is to prevent a race condition where the records manager processes - # all records before the timing manager has sent the final completed count. - if self._records_tracker.check_and_set_all_records_received_for_phase( - message.stats.phase - ): - await self._handle_all_records_received(message.stats.phase) + await self._maybe_handle_all_records_received(message.stats.phase) def _snapshot_branch_stats(self, phase: CreditPhase) -> BranchStats | None: """Return the orchestrator-published BranchStats for ``phase``. @@ -588,13 +607,9 @@ async def _on_credits_complete(self, message: CreditsCompleteMessage) -> None: self.info( "All credits complete, please wait for the results to be processed..." ) - if ( - CreditPhase.PROFILING in self._complete_credit_phases - and self._records_tracker.check_and_set_all_records_received_for_phase( - CreditPhase.PROFILING - ) - ): - await self._handle_all_records_received(CreditPhase.PROFILING) + self._credits_complete_received = True + self._cancel_credits_complete_fallback() + await self._maybe_handle_all_records_received(CreditPhase.PROFILING) @background_task( interval=Environment.RECORD.PROGRESS_REPORT_INTERVAL, immediate=False diff --git a/src/aiperf/server_metrics/manager.py b/src/aiperf/server_metrics/manager.py index 15bd76208..4bcdbddeb 100644 --- a/src/aiperf/server_metrics/manager.py +++ b/src/aiperf/server_metrics/manager.py @@ -4,10 +4,16 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from aiperf.common.base_component_service import BaseComponentService -from aiperf.common.enums import CommAddress, CommandType +from aiperf.common.enums import ( + BaselineKind, + CommAddress, + CommandType, + ServiceCapability, + make_result_producer_capability, +) from aiperf.common.environment import Environment from aiperf.common.hooks import on_command, on_stop from aiperf.common.messages import ( @@ -19,6 +25,7 @@ ServerMetricsStatusMessage, ) from aiperf.common.metric_utils import normalize_metrics_endpoint_url +from aiperf.common.mixins.baseline_collector_mixin import BaselineCollectorMixin from aiperf.common.models import ErrorDetails, ServerMetricsRecord from aiperf.common.protocols import PushClientProtocol from aiperf.server_metrics.data_collector import ServerMetricsDataCollector @@ -27,7 +34,7 @@ from aiperf.config.resolution.plan import BenchmarkRun -class ServerMetricsManager(BaseComponentService): +class ServerMetricsManager(BaselineCollectorMixin, BaseComponentService): """Coordinates multiple ServerMetricsDataCollector instances for server metrics collection. The ServerMetricsManager coordinates multiple ServerMetricsDataCollector instances @@ -46,6 +53,11 @@ class ServerMetricsManager(BaseComponentService): service_id: Optional unique identifier for this service instance """ + extra_capabilities: ClassVar[tuple[str, ...]] = ( + ServiceCapability.BASELINE_COLLECTOR, + make_result_producer_capability("server_metrics"), + ) + def __init__( self, run: BenchmarkRun, @@ -129,6 +141,7 @@ async def _profile_configure_command( try: is_reachable = await collector.is_url_reachable() if is_reachable: + await collector.initialize() self._collectors[endpoint_url] = collector self.debug( lambda url=endpoint_url: f"Server Metrics: Prometheus endpoint {url} is reachable" @@ -152,20 +165,6 @@ async def _profile_configure_command( ) return - # Capture baseline metrics before profiling starts - self.info("Server Metrics: Capturing baseline metrics...") - for endpoint_url, collector in self._collectors.items(): - try: - await collector.initialize() - await collector.collect_and_process_metrics() - self.debug( - lambda url=endpoint_url: f"Server Metrics: Captured baseline from {url}" - ) - except Exception as e: - self.warning( - f"Server Metrics: Failed to capture baseline from {endpoint_url}: {e}" - ) - await self._send_server_metrics_status( enabled=True, reason=None, @@ -221,47 +220,53 @@ async def _on_start_profiling(self, message: ProfileStartCommand) -> None: f"Server Metrics: Started {started_count} collector(s) successfully" ) - @on_command(CommandType.PROFILE_COMPLETE) - async def _handle_profile_complete_command( - self, message: ProfileCompleteCommand + async def collect_baseline( + self, kind: BaselineKind, phase_id: str, phase_name: str ) -> None: - """Trigger final scrape when profiling completes. - - Performs one final metrics collection from all endpoints to capture - the end state immediately after profiling finishes. This ensures we - have metrics that cover the entire profiling period, including any - counter/histogram changes that occurred during the final seconds. - - Critical for accurate delta calculations on counters and histograms, - where missing the final state would undercount the actual activity. - - Idempotent: Can be called multiple times safely (e.g., if multiple - RecordsManager instances send the command). Subsequent calls are no-ops. - - Args: - message: Profile complete command from RecordsManager signaling that - all client request records have been processed - """ - # Idempotent check - skip if already stopped or no collectors + if phase_name != "profiling": + return if not self._collectors: - self.debug("Server Metrics: Already stopped, skipping final scrape") return - self.info("Server Metrics: Profiling complete, capturing final metrics...") + self.info( + f"Server Metrics: Capturing {kind} baseline for phase '{phase_name}'..." + ) + await self._collect_once(label=str(kind)) - # Trigger final scrape from all collectors - for endpoint_url, collector in list(self._collectors.items()): + async def _collect_once(self, label: str) -> None: + failures: list[str] = [] + + async def collect( + endpoint_url: str, collector: ServerMetricsDataCollector + ) -> None: try: await collector.collect_and_process_metrics() self.debug( - lambda url=endpoint_url: f"Server Metrics: Captured final state from {url}" + lambda url=endpoint_url: f"Server Metrics: Captured {label} state from {url}" ) - except Exception as e: + except Exception as e: # noqa: BLE001 - keep attempting other endpoints + failures.append(f"{endpoint_url}: {e}") self.warning( - f"Server Metrics: Failed to capture final state from {endpoint_url}: {e}" + f"Server Metrics: Failed to capture {label} state from {endpoint_url}: {e}" ) - # Stop all collectors after final scrape + await asyncio.gather( + *( + collect(endpoint_url, collector) + for endpoint_url, collector in list(self._collectors.items()) + ) + ) + if failures: + raise RuntimeError("; ".join(failures)) + + @on_command(CommandType.PROFILE_COMPLETE) + async def _handle_profile_complete_command( + self, message: ProfileCompleteCommand + ) -> None: + if not self._collectors: + self.debug("Server Metrics: Already stopped, skipping completion") + return + await self._stop_all_collectors() @on_command(CommandType.PROFILE_CANCEL) @@ -270,9 +275,6 @@ async def _handle_profile_cancel_command( ) -> None: """Stop all server metrics collectors when profiling is cancelled. - Called when user cancels profiling or an error occurs during profiling. - Waits for flush period to allow metrics to finalize, then stops collectors. - Args: message: Profile cancel command from SystemController """ diff --git a/src/aiperf/timing/manager.py b/src/aiperf/timing/manager.py index bfc1e667d..b4b6f3c11 100644 --- a/src/aiperf/timing/manager.py +++ b/src/aiperf/timing/manager.py @@ -27,6 +27,7 @@ from aiperf.common.models import DatasetMetadata from aiperf.credit.sticky_router import StickyCreditRouter from aiperf.timing.config import TimingConfig +from aiperf.timing.phase.phase_gate import PhaseGateClient from aiperf.timing.phase.publisher import PhasePublisher from aiperf.timing.phase_orchestrator import PhaseOrchestrator @@ -133,11 +134,18 @@ async def _profile_configure_command( self.debug(f"Configuring phase orchestrator for {self.service_id}") # Create orchestrator that executes phases + gate = PhaseGateClient( + sender=self, + service_id=self.service_id, + enabled=Environment.BASELINE.GATE_ENABLED, + timeout_s=Environment.BASELINE.GATE_TIMEOUT_S, + ) self._phase_orchestrator = PhaseOrchestrator( config=self.config, phase_publisher=self.phase_publisher, credit_router=self.sticky_router, dataset_metadata=self._dataset_metadata, + phase_gate=gate, ) await self._phase_orchestrator.initialize() diff --git a/src/aiperf/timing/phase/phase_gate.py b/src/aiperf/timing/phase/phase_gate.py new file mode 100644 index 000000000..7e210ffd3 --- /dev/null +++ b/src/aiperf/timing/phase/phase_gate.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""TimingManager-side client for the phase baseline handshake. + +A thin wrapper around send_command_and_wait_for_response that hides +PhaseStartGateCommand / PhaseEndGateCommand from PhaseRunner. Owns no +knowledge of telemetry, server-metrics, or any specific collector. +""" + +from __future__ import annotations + +import uuid +from typing import Protocol, TypeVar + +from aiperf.common.messages import ( + CommandMessage, + CommandResponse, + PhaseEndGateCommand, + PhaseStartGateCommand, +) +from aiperf.common.models.error_models import ErrorDetails +from aiperf.plugin.enums import ServiceType + +_GateCommandT = TypeVar("_GateCommandT", PhaseStartGateCommand, PhaseEndGateCommand) + + +class _CommandSender(Protocol): + async def send_command_and_wait_for_response( + self, message: CommandMessage, timeout: float = ... + ) -> CommandResponse | ErrorDetails: ... + + +class PhaseGateClient: + """Sends PhaseStartGate/PhaseEndGate commands and waits for the response. + + The gate semantic is "released" if any response comes back at all; the + payload is intentionally not inspected. On timeout the underlying sender + returns an `ErrorDetails` (it does not raise), and the client treats that + the same as a release so the benchmark keeps moving. + """ + + def __init__( + self, + sender: _CommandSender, + service_id: str, + enabled: bool, + timeout_s: float, + ) -> None: + self._sender = sender + self._service_id = service_id + self._enabled = enabled + self._timeout_s = timeout_s + + async def before_phase(self, phase_id: str, phase_name: str) -> None: + """Block until SystemController releases the START gate (or timeout).""" + await self._send_gate(PhaseStartGateCommand, phase_id, phase_name) + + async def after_phase(self, phase_id: str, phase_name: str) -> None: + """Block until SystemController releases the END gate (or timeout).""" + await self._send_gate(PhaseEndGateCommand, phase_id, phase_name) + + async def _send_gate( + self, + command_type: type[_GateCommandT], + phase_id: str, + phase_name: str, + ) -> None: + if not self._enabled: + return + await self._sender.send_command_and_wait_for_response( + command_type( + command_id=str(uuid.uuid4()), + service_id=self._service_id, + target_service_type=ServiceType.SYSTEM_CONTROLLER, + phase_id=phase_id, + phase_name=phase_name, + ), + timeout=self._timeout_s + 1.0, + ) diff --git a/src/aiperf/timing/phase/runner.py b/src/aiperf/timing/phase/runner.py index 998bd3292..ae973999e 100644 --- a/src/aiperf/timing/phase/runner.py +++ b/src/aiperf/timing/phase/runner.py @@ -9,6 +9,7 @@ from __future__ import annotations import asyncio +import uuid from collections.abc import Callable from typing import TYPE_CHECKING @@ -21,6 +22,7 @@ from aiperf.plugin.enums import PluginType, TimingMode from aiperf.timing.branch_orchestrator import BranchOrchestrator from aiperf.timing.phase.lifecycle import PhaseLifecycle +from aiperf.timing.phase.phase_gate import PhaseGateClient from aiperf.timing.phase.progress_tracker import PhaseProgressTracker from aiperf.timing.phase.stop_conditions import StopConditionChecker from aiperf.timing.ramping import Ramper, RamperConfig, RampType @@ -81,6 +83,7 @@ def __init__( callback_handler: CreditCallbackHandler, url_selection_strategy: URLSelectionStrategyProtocol | None = None, branch_orchestrator: BranchOrchestrator | None = None, + phase_gate: PhaseGateClient | None = None, **kwargs, ) -> None: """Initialize phase runner. @@ -99,6 +102,8 @@ def __init__( ``_is_phase_complete`` consults ``has_pending_branch_work`` so completion blocks while DAG children are still in flight, even after ``--request-count`` is reached. + phase_gate: Optional PhaseGateClient that issues START/END handshake + commands to the SystemController. None disables gating entirely. """ super().__init__(**kwargs) self._config = config @@ -122,6 +127,8 @@ def __init__( self._cancellation_policy = cancellation_policy self._callback_handler = callback_handler self._on_phase_complete: Callable[[], None] | None = None + self._phase_gate = phase_gate + self._pending_after_phase: tuple[str, str] | None = None # Per-phase components - order matters self._scheduler = LoopScheduler() @@ -268,15 +275,71 @@ def cancel(self) -> None: def _on_return_wait_complete(self, task: asyncio.Task) -> None: """Handle completion of background return wait task (seamless mode). - Called when _return_wait_task finishes. Cancels progress reporting and - notifies the orchestrator via on_phase_complete callback. + Called when _return_wait_task finishes. Cancels progress reporting, + schedules the END gate, and notifies the orchestrator via + on_phase_complete callback. """ if self._progress_task: self._progress_task.cancel() + if self._pending_after_phase is not None: + phase_id, phase_name = self._pending_after_phase + self._pending_after_phase = None + self.execute_async(self._gate_after_phase(phase_id, phase_name)) + if self._on_phase_complete: self._on_phase_complete() + async def _gate_after_phase(self, phase_id: str, phase_name: str) -> None: + """Run the END gate, swallowing exceptions so they cannot crash the phase.""" + if self._phase_gate is None: + return + try: + await self._phase_gate.after_phase(phase_id, phase_name) + except Exception as e: # noqa: BLE001 - END gate failures must not crash the run + self.warning(f"after_phase gate raised for '{phase_name}': {e!r}") + + async def _gate_before_phase(self, phase_id: str, phase_name: str) -> None: + """Run the START gate. Exceptions propagate (a failed START aborts the phase).""" + if self._phase_gate is None: + return + await self._phase_gate.before_phase(phase_id, phase_name) + + def _finalize_cancelled_phase( + self, phase_id: str, phase_name: str + ) -> CreditPhaseStats: + """Mark complete + return stats for the cancelled-early-return path. + + Note: callers must still `await self._gate_after_phase(...)` separately, + because this is sync (returns the stats; doesn't await the END gate). + """ + if not self._lifecycle.is_complete: + self._lifecycle.mark_complete(grace_period_triggered=True) + self._progress.freeze_completed_counts() + self._progress.all_credits_returned_event.set() + return self._progress.create_stats(self._lifecycle) + + async def _dispatch_phase_completion( + self, phase_id: str, phase_name: str, is_final_phase: bool + ) -> None: + """Dispatch end-of-phase return-wait: seamless background vs synchronous. + + Seamless non-final phases spawn a background return-wait task and defer the + END gate to `_on_return_wait_complete`. All other phases wait synchronously, + cancel the progress task, and fire the END gate inline. + """ + if self._config.seamless and not is_final_phase: + self._return_wait_task = self.execute_async( + self._wait_for_returning_complete() + ) + self._return_wait_task.add_done_callback(self._on_return_wait_complete) + # END gate is fired by _on_return_wait_complete via execute_async + self._pending_after_phase = (phase_id, phase_name) + else: + await self._wait_for_returning_complete() + self._progress_task.cancel() + await self._gate_after_phase(phase_id, phase_name) + async def run( self, is_final_phase: bool, @@ -354,6 +417,9 @@ async def _run_strategy( returning-complete pipeline. The exception path (publishing partial lifecycle state) lives in the caller's ``except``. """ + phase_id = uuid.uuid4().hex + phase_name = self._config.phase + self._concurrency_manager.configure_for_phase( self._config.phase, self._config.concurrency, @@ -364,6 +430,8 @@ async def _run_strategy( self._create_rampers(strategy) + await self._gate_before_phase(phase_id, phase_name) + self._lifecycle.start() stats = self._progress.create_stats(self._lifecycle) self.notice(self._format_phase_started(stats)) @@ -388,22 +456,11 @@ async def _run_strategy( await self._wait_for_sending_complete() if self._was_cancelled: - if not self._lifecycle.is_complete: - self._lifecycle.mark_complete(grace_period_triggered=True) - self._progress.freeze_completed_counts() - self._progress.all_credits_returned_event.set() - return self._progress.create_stats(self._lifecycle) + stats = self._finalize_cancelled_phase(phase_id, phase_name) + await self._gate_after_phase(phase_id, phase_name) + return stats - # Seamless mode: phase flows into next without waiting for returns. - # Progress task continues in background until phase complete. - if self._config.seamless and not is_final_phase: - self._return_wait_task = self.execute_async( - self._wait_for_returning_complete() - ) - self._return_wait_task.add_done_callback(self._on_return_wait_complete) - else: - await self._wait_for_returning_complete() - self._progress_task.cancel() + await self._dispatch_phase_completion(phase_id, phase_name, is_final_phase) for ramper in self._rampers: ramper.stop() diff --git a/src/aiperf/timing/phase_orchestrator.py b/src/aiperf/timing/phase_orchestrator.py index d829747c5..cacc9fd3c 100644 --- a/src/aiperf/timing/phase_orchestrator.py +++ b/src/aiperf/timing/phase_orchestrator.py @@ -24,6 +24,7 @@ from aiperf.plugin.enums import PluginType from aiperf.timing.concurrency import ConcurrencyManager from aiperf.timing.conversation_source import ConversationSource +from aiperf.timing.phase.phase_gate import PhaseGateClient from aiperf.timing.phase.runner import PhaseRunner from aiperf.timing.request_cancellation import RequestCancellationSimulator from aiperf.timing.url_samplers import URLSelectionStrategyProtocol @@ -86,6 +87,7 @@ def __init__( phase_publisher: PhasePublisher, credit_router: CreditRouterProtocol, dataset_metadata: DatasetMetadata, + phase_gate: PhaseGateClient | None = None, **kwargs, ) -> None: """Initialize timing strategy and orchestration components. @@ -95,12 +97,15 @@ def __init__( phase_publisher: Publishes phase events to message bus credit_router: Routes credits to workers dataset_metadata: Dataset for conversation sampling + phase_gate: Optional PhaseGateClient forwarded to each PhaseRunner + so START/END handshake commands can be issued. None disables gating. """ super().__init__(**kwargs) self._config = config self._phase_publisher = phase_publisher self._credit_router = credit_router self._dataset_metadata = dataset_metadata + self._phase_gate = phase_gate # Create dataset sampler SamplerClass = plugins.get_class( @@ -208,6 +213,7 @@ async def _execute_phases(self) -> None: cancellation_policy=self._cancellation_policy, callback_handler=self._callback_handler, url_selection_strategy=self._url_sampler, + phase_gate=self._phase_gate, ) # For seamless non-final phases, set callback to remove from active runners diff --git a/tests/component_integration/conftest.py b/tests/component_integration/conftest.py index e7b89aa8b..feb8374ab 100644 --- a/tests/component_integration/conftest.py +++ b/tests/component_integration/conftest.py @@ -39,7 +39,6 @@ from aiperf.cli import app from aiperf.common import random_generator as rng -from aiperf.common.environment import Environment from aiperf.plugin.enums import CommClientType # Import fakes for test harness @@ -155,15 +154,6 @@ def safe_os_kill(pid, sig): yield -@pytest.fixture(autouse=True, scope="package") -def no_server_metrics_flush_period(): - """Fixture to disable server metrics flush period.""" - original_flush_period = Environment.SERVER_METRICS.COLLECTION_FLUSH_PERIOD - Environment.SERVER_METRICS.COLLECTION_FLUSH_PERIOD = 0 - yield - Environment.SERVER_METRICS.COLLECTION_FLUSH_PERIOD = original_flush_period - - @pytest.fixture(autouse=True, scope="package") def hf_offline_mode(): """Disable HuggingFace Hub network calls for the duration of this package. diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 07d6826d1..196d8bfbe 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -181,8 +181,6 @@ async def create_server(**kwargs: Any) -> AsyncIterator[AIPerfMockServer]: host = "127.0.0.1" url = f"http://{host}:{port}" - os.environ["AIPERF_SERVER_METRICS_COLLECTION_FLUSH_PERIOD"] = "0" - process: SpawnProcess = mp_ctx.Process( target=aiperf_mock_server_serve, kwargs={ diff --git a/tests/unit/common/enums/test_baseline_enums.py b/tests/unit/common/enums/test_baseline_enums.py new file mode 100644 index 000000000..1849b4f41 --- /dev/null +++ b/tests/unit/common/enums/test_baseline_enums.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +from aiperf.common.enums import ( + BaselineKind, + ServiceCapability, + make_result_producer_capability, + parse_result_producer_capability, +) + + +def test_baseline_kind_values_lowercase(): + assert BaselineKind.START == "start" + assert BaselineKind.END == "end" + + +def test_baseline_kind_case_insensitive(): + assert BaselineKind("START") == BaselineKind.START + assert BaselineKind("End") == BaselineKind.END + + +def test_service_capability_baseline_collector_value(): + assert ServiceCapability.BASELINE_COLLECTOR == "baseline_collector" + + +def test_service_capability_result_producer_value(): + assert ServiceCapability.RESULT_PRODUCER == "result_producer" + + +def test_make_result_producer_capability_includes_domain(): + assert make_result_producer_capability("profile") == "result_producer:profile" + + +def test_parse_result_producer_capability_returns_domain(): + assert parse_result_producer_capability("result_producer:telemetry") == "telemetry" + + +def test_parse_result_producer_capability_ignores_other_capabilities(): + assert parse_result_producer_capability("baseline_collector") is None + assert parse_result_producer_capability("result_producer") is None + assert parse_result_producer_capability("result_producer:") is None diff --git a/tests/unit/common/messages/test_baseline_messages.py b/tests/unit/common/messages/test_baseline_messages.py new file mode 100644 index 000000000..cc9d11f6c --- /dev/null +++ b/tests/unit/common/messages/test_baseline_messages.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +from aiperf.common.enums import BaselineKind, CommandType, MessageType +from aiperf.common.messages import ( + PhaseBaselineAckMessage, + PhaseBaselineRequestMessage, + PhaseEndGateCommand, + PhaseGateGrantedResponse, + PhaseStartGateCommand, +) + + +def test_request_round_trip() -> None: + msg = PhaseBaselineRequestMessage( + phase_id="abc-123", + phase_name="profiling", + kind=BaselineKind.START, + ) + assert msg.message_type == MessageType.PHASE_BASELINE_REQUEST + parsed = PhaseBaselineRequestMessage.model_validate_json(msg.model_dump_json()) + assert parsed.phase_id == "abc-123" + assert parsed.kind == BaselineKind.START + + +def test_ack_success_and_error() -> None: + ok = PhaseBaselineAckMessage( + service_id="svc-1", + phase_id="abc", + kind=BaselineKind.END, + success=True, + ) + assert ok.error is None + bad = PhaseBaselineAckMessage( + service_id="svc-1", + phase_id="abc", + kind=BaselineKind.END, + success=False, + error="DCGM connection refused", + ) + assert bad.success is False + assert "DCGM" in bad.error + + +def test_gate_commands_carry_phase_metadata() -> None: + start = PhaseStartGateCommand( + service_id="svc-1", command_id="c1", phase_id="abc", phase_name="warmup" + ) + end = PhaseEndGateCommand( + service_id="svc-1", command_id="c2", phase_id="abc", phase_name="warmup" + ) + assert start.command == CommandType.PHASE_START_GATE + assert end.command == CommandType.PHASE_END_GATE + + +def test_gate_granted_response() -> None: + resp = PhaseGateGrantedResponse( + service_id="svc-1", + command=CommandType.PHASE_START_GATE, + command_id="c1", + phase_id="abc", + ) + assert resp.phase_id == "abc" diff --git a/tests/unit/common/messages/test_register_service_capabilities.py b/tests/unit/common/messages/test_register_service_capabilities.py new file mode 100644 index 000000000..e95e21171 --- /dev/null +++ b/tests/unit/common/messages/test_register_service_capabilities.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +from aiperf.common.enums import LifecycleState, ServiceCapability +from aiperf.common.messages import RegisterServiceCommand +from aiperf.plugin.enums import ServiceType + + +def _make(**overrides) -> RegisterServiceCommand: + base = dict( + command_id="cmd-1", + service_id="svc-1", + service_type=ServiceType.SYSTEM_CONTROLLER, + state=LifecycleState.RUNNING, + ) + base.update(overrides) + return RegisterServiceCommand(**base) + + +def test_capabilities_default_empty_tuple() -> None: + cmd = _make() + assert cmd.capabilities == () + + +def test_capabilities_round_trip_with_baseline() -> None: + cmd = _make(capabilities=(ServiceCapability.BASELINE_COLLECTOR,)) + assert cmd.capabilities == (ServiceCapability.BASELINE_COLLECTOR,) + payload = cmd.model_dump_json() + parsed = RegisterServiceCommand.model_validate_json(payload) + assert parsed.capabilities == (ServiceCapability.BASELINE_COLLECTOR,) + + +def test_capabilities_accepts_tuple_of_strings() -> None: + cmd = _make(capabilities=("baseline_collector",)) + assert cmd.capabilities == ("baseline_collector",) diff --git a/tests/unit/common/mixins/test_baseline_collector_mixin.py b/tests/unit/common/mixins/test_baseline_collector_mixin.py new file mode 100644 index 000000000..b34fbe42c --- /dev/null +++ b/tests/unit/common/mixins/test_baseline_collector_mixin.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +import pytest + +from aiperf.common.enums import BaselineKind, ServiceCapability +from aiperf.common.messages import ( + PhaseBaselineAckMessage, + PhaseBaselineRequestMessage, +) +from aiperf.common.mixins.baseline_collector_mixin import BaselineCollectorMixin + + +class _StubBus: + def __init__(self) -> None: + self.published: list[PhaseBaselineAckMessage] = [] + + async def publish(self, msg: PhaseBaselineAckMessage) -> None: + self.published.append(msg) + + +class _StubCollector(BaselineCollectorMixin): + """Concrete subclass exercised in isolation (no BaseComponentService needed).""" + + def __init__(self, bus: _StubBus, *, fail: bool = False) -> None: + self._bus = bus + self.service_id = "svc-stub" + self.calls: list[tuple[BaselineKind, str]] = [] + self._fail = fail + + async def publish(self, msg: PhaseBaselineAckMessage) -> None: + await self._bus.publish(msg) + + async def collect_baseline( + self, kind: BaselineKind, phase_id: str, phase_name: str + ) -> None: + if self._fail: + raise RuntimeError("simulated DCGM failure") + self.calls.append((kind, phase_name)) + + +def test_extra_capabilities_includes_baseline_collector() -> None: + assert ServiceCapability.BASELINE_COLLECTOR in _StubCollector.extra_capabilities + + +@pytest.mark.asyncio +async def test_handler_calls_collect_and_acks_success() -> None: + bus = _StubBus() + svc = _StubCollector(bus) + await svc._on_phase_baseline_request( + PhaseBaselineRequestMessage( + phase_id="p1", phase_name="profiling", kind=BaselineKind.START + ) + ) + assert svc.calls == [(BaselineKind.START, "profiling")] + assert len(bus.published) == 1 + ack = bus.published[0] + assert ack.success is True + assert ack.service_id == "svc-stub" + assert ack.phase_id == "p1" + assert ack.kind == BaselineKind.START + assert ack.error is None + + +@pytest.mark.asyncio +async def test_handler_acks_failure_when_collect_raises() -> None: + bus = _StubBus() + svc = _StubCollector(bus, fail=True) + await svc._on_phase_baseline_request( + PhaseBaselineRequestMessage( + phase_id="p1", phase_name="profiling", kind=BaselineKind.END + ) + ) + ack = bus.published[0] + assert ack.success is False + assert "simulated DCGM failure" in ack.error diff --git a/tests/unit/common/test_environment_baseline.py b/tests/unit/common/test_environment_baseline.py new file mode 100644 index 000000000..352230ebe --- /dev/null +++ b/tests/unit/common/test_environment_baseline.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +from aiperf.common.environment import Environment + + +def test_baseline_defaults() -> None: + assert Environment.BASELINE.GATE_TIMEOUT_S == 5.0 + assert Environment.BASELINE.GATE_ENABLED is True + + +def test_baseline_env_override(monkeypatch) -> None: + from aiperf.common.environment import _Environment + + monkeypatch.setenv("AIPERF_BASELINE_GATE_TIMEOUT_S", "1.25") + monkeypatch.setenv("AIPERF_BASELINE_GATE_ENABLED", "0") + env = _Environment() + assert env.BASELINE.GATE_TIMEOUT_S == 1.25 + assert env.BASELINE.GATE_ENABLED is False diff --git a/tests/unit/controller/test_baseline_coordinator.py b/tests/unit/controller/test_baseline_coordinator.py new file mode 100644 index 000000000..88fac9717 --- /dev/null +++ b/tests/unit/controller/test_baseline_coordinator.py @@ -0,0 +1,126 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +import asyncio + +import pytest + +from aiperf.common.enums import BaselineKind +from aiperf.common.messages import ( + PhaseBaselineAckMessage, + PhaseBaselineRequestMessage, +) +from aiperf.controller.baseline_coordinator import BaselineCoordinator + + +class _Bus: + def __init__(self) -> None: + self.published: list[PhaseBaselineRequestMessage] = [] + + async def publish(self, msg: PhaseBaselineRequestMessage) -> None: + self.published.append(msg) + + +@pytest.fixture +def bus() -> _Bus: + return _Bus() + + +@pytest.fixture +def coord(bus: _Bus) -> BaselineCoordinator: + return BaselineCoordinator(publish=bus.publish, gate_timeout_s=0.05) + + +@pytest.mark.asyncio +async def test_empty_registered_returns_immediately( + coord: BaselineCoordinator, bus: _Bus +) -> None: + await coord.gate_phase("p1", "warmup", BaselineKind.START) + assert bus.published == [] + + +@pytest.mark.asyncio +async def test_happy_path_acks_release_gate( + coord: BaselineCoordinator, bus: _Bus +) -> None: + coord.register("svc-a") + coord.register("svc-b") + + async def _drive() -> None: + await asyncio.sleep(0) # let gate publish first + coord.handle_ack( + PhaseBaselineAckMessage( + service_id="svc-a", phase_id="p1", kind=BaselineKind.START, success=True + ) + ) + coord.handle_ack( + PhaseBaselineAckMessage( + service_id="svc-b", phase_id="p1", kind=BaselineKind.START, success=True + ) + ) + + await asyncio.gather(coord.gate_phase("p1", "warmup", BaselineKind.START), _drive()) + assert len(bus.published) == 1 + assert bus.published[0].kind == BaselineKind.START + + +@pytest.mark.asyncio +async def test_timeout_with_unacked_logs_and_releases( + coord: BaselineCoordinator, bus: _Bus, caplog: pytest.LogCaptureFixture +) -> None: + coord.register("slow-svc") + await coord.gate_phase("p1", "profiling", BaselineKind.START) + assert len(bus.published) == 1 + assert any( + "slow-svc" in rec.getMessage() and "timed out" in rec.getMessage() + for rec in caplog.records + ) + assert any( + "AIPERF_BASELINE_GATE_TIMEOUT_S" in rec.getMessage() for rec in caplog.records + ) + + +@pytest.mark.asyncio +async def test_error_ack_counts_as_ack( + coord: BaselineCoordinator, bus: _Bus, caplog: pytest.LogCaptureFixture +) -> None: + coord.register("svc-a") + + async def _drive() -> None: + await asyncio.sleep(0) + coord.handle_ack( + PhaseBaselineAckMessage( + service_id="svc-a", + phase_id="p1", + kind=BaselineKind.START, + success=False, + error="DCGM down", + ) + ) + + await asyncio.gather(coord.gate_phase("p1", "x", BaselineKind.START), _drive()) + assert any("DCGM down" in rec.getMessage() for rec in caplog.records) + + +@pytest.mark.asyncio +async def test_late_ack_after_timeout_dropped_silently( + coord: BaselineCoordinator, +) -> None: + coord.register("svc-a") + await coord.gate_phase("p1", "x", BaselineKind.START) # times out, no acks + coord.handle_ack( + PhaseBaselineAckMessage( + service_id="svc-a", phase_id="p1", kind=BaselineKind.START, success=True + ) + ) + + +def test_re_registration_idempotent(coord: BaselineCoordinator) -> None: + coord.register("svc-a") + coord.register("svc-a") + assert coord.registered_count == 1 + + +def test_unregister_removes(coord: BaselineCoordinator) -> None: + coord.register("svc-a") + coord.unregister("svc-a") + assert coord.registered_count == 0 diff --git a/tests/unit/controller/test_result_join_coordinator.py b/tests/unit/controller/test_result_join_coordinator.py new file mode 100644 index 000000000..f0098fcab --- /dev/null +++ b/tests/unit/controller/test_result_join_coordinator.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from aiperf.controller.result_join_coordinator import ResultJoinCoordinator + + +def test_ready_when_no_result_producers_registered() -> None: + coord = ResultJoinCoordinator() + + assert coord.ready + assert coord.pending_domains == () + + +def test_register_marks_domain_pending_until_service_completes() -> None: + coord = ResultJoinCoordinator() + + coord.register("telemetry", "t1") + + assert not coord.ready + assert coord.pending_domains == ("telemetry",) + + coord.complete("telemetry", "t1") + + assert coord.ready + assert coord.pending_domains == () + + +def test_multiple_services_in_same_domain_all_must_complete() -> None: + coord = ResultJoinCoordinator() + + coord.register("telemetry", "t1") + coord.register("telemetry", "t2") + coord.complete("telemetry", "t1") + + assert not coord.ready + assert coord.pending_domains == ("telemetry",) + + coord.complete("telemetry", "t2") + + assert coord.ready + assert coord.pending_domains == () + + +def test_complete_domain_marks_all_participants_complete() -> None: + coord = ResultJoinCoordinator() + + coord.register("telemetry", "t1") + coord.register("telemetry", "t2") + coord.complete_domain("telemetry") + + assert coord.ready + assert coord.pending_domains == () + + +def test_unregister_removes_pending_participant() -> None: + coord = ResultJoinCoordinator() + + coord.register("telemetry", "t1") + coord.unregister("telemetry", "t1") + + assert coord.ready + assert coord.pending_domains == () + + +def test_unregister_service_removes_service_from_all_domains() -> None: + coord = ResultJoinCoordinator() + + coord.register("telemetry", "t1") + coord.register("server_metrics", "t1") + coord.register("profile", "records") + + coord.unregister_service("t1") + + assert coord.pending_domains == ("profile",) + + +def test_complete_unknown_participant_does_not_create_domain() -> None: + coord = ResultJoinCoordinator() + + coord.complete("telemetry", "t1") + + assert coord.ready + assert coord.pending_domains == () + + +def test_complete_unknown_domain_does_not_create_domain() -> None: + coord = ResultJoinCoordinator() + + coord.complete_domain("telemetry") + + assert coord.ready + assert coord.pending_domains == () + + +def test_completed_participant_reregistration_stays_complete() -> None: + coord = ResultJoinCoordinator() + + coord.register("telemetry", "t1") + coord.complete("telemetry", "t1") + coord.register("telemetry", "t1") + + assert coord.ready + assert coord.pending_domains == () + + +def test_pending_domains_changed_only_reports_changes() -> None: + coord = ResultJoinCoordinator() + + assert coord.pending_domains_changed() is None + + coord.register("server_metrics", "s1") + assert coord.pending_domains_changed() == ("server_metrics",) + assert coord.pending_domains_changed() is None + + coord.register("telemetry", "t1") + assert coord.pending_domains_changed() == ("server_metrics", "telemetry") + + coord.complete("telemetry", "t1") + assert coord.pending_domains_changed() == ("server_metrics",) + + coord.complete("server_metrics", "s1") + assert coord.pending_domains_changed() == () + assert coord.pending_domains_changed() is None diff --git a/tests/unit/controller/test_system_controller_baseline.py b/tests/unit/controller/test_system_controller_baseline.py new file mode 100644 index 000000000..af6673ee6 --- /dev/null +++ b/tests/unit/controller/test_system_controller_baseline.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""SystemController-level baseline registration wiring tests.""" + +import asyncio +from datetime import datetime, timezone +from unittest.mock import AsyncMock + +import pytest + +from aiperf.common.enums import ( + LifecycleState, + ServiceCapability, + make_result_producer_capability, +) +from aiperf.common.messages import ( + BaseServiceErrorMessage, + ProcessRecordsResultMessage, + ProcessServerMetricsResultMessage, + ProcessTelemetryResultMessage, + RegisterServiceCommand, + ServerMetricsStatusMessage, + TelemetryStatusMessage, +) +from aiperf.common.models import ( + ProcessRecordsResult, + ProcessServerMetricsResult, + ProcessTelemetryResult, + ProfileResults, + ServerMetricsResults, + TelemetryExportData, + TelemetrySummary, +) +from aiperf.common.models.error_models import ErrorDetails +from aiperf.controller.baseline_coordinator import BaselineCoordinator +from aiperf.controller.result_join_coordinator import ResultJoinCoordinator +from aiperf.controller.system_controller import SystemController +from aiperf.plugin.enums import ServiceType + + +async def _no_op_publish(msg) -> None: + return None + + +class _ServiceManagerStub: + def __init__(self) -> None: + self.service_id_map = {} + self.service_map = {} + + +class _BaselineCoordinatorStub: + def __init__(self) -> None: + self.registered: list[str] = [] + + def register(self, service_id: str) -> None: + self.registered.append(service_id) + + +def _controller_for_registration() -> SystemController: + controller = SystemController.__new__(SystemController) + controller.service_manager = _ServiceManagerStub() + controller._baseline_coordinator = _BaselineCoordinatorStub() + controller._result_join_coordinator = ResultJoinCoordinator() + controller.debug = lambda _message: None + controller.info = lambda _message: None + return controller + + +def _controller_for_shutdown() -> SystemController: + controller = SystemController.__new__(SystemController) + controller._result_join_coordinator = ResultJoinCoordinator() + controller._profile_results = None + controller._telemetry_results = None + controller._server_metrics_results = None + controller._shutdown_triggered = False + controller._shutdown_lock = asyncio.Lock() + controller._exit_errors = [] + controller._telemetry_endpoints_configured = [] + controller._telemetry_endpoints_reachable = [] + controller._server_metrics_endpoints_configured = [] + controller._server_metrics_endpoints_reachable = [] + controller.info_messages = [] + controller.info = controller.info_messages.append + controller.debug = lambda _message: None + controller.error = lambda _message: None + controller.exception = lambda _message: None + controller.trace_or_debug = lambda _trace_message, _debug_message: None + controller.stop = AsyncMock() + return controller + + +def _profile_result_message( + service_id: str = "records-1", +) -> ProcessRecordsResultMessage: + return ProcessRecordsResultMessage( + service_id=service_id, + results=ProcessRecordsResult( + results=ProfileResults( + records=[], + completed=0, + start_ns=1, + end_ns=2, + ) + ), + ) + + +def _telemetry_result_message( + service_id: str = "records-manager", +) -> ProcessTelemetryResultMessage: + return ProcessTelemetryResultMessage( + service_id=service_id, + telemetry_result=ProcessTelemetryResult( + results=TelemetryExportData( + summary=TelemetrySummary( + start_time=datetime.fromtimestamp(1, tz=timezone.utc), + end_time=datetime.fromtimestamp(2, tz=timezone.utc), + ), + endpoints={}, + ) + ), + ) + + +def _server_metrics_result_message( + service_id: str = "records-manager", +) -> ProcessServerMetricsResultMessage: + return ProcessServerMetricsResultMessage( + service_id=service_id, + server_metrics_result=ProcessServerMetricsResult( + results=ServerMetricsResults(start_ns=1, end_ns=2) + ), + ) + + +def test_coordinator_registers_baseline_collector() -> None: + coord = BaselineCoordinator(publish=_no_op_publish, gate_timeout_s=0.05) + cmd = RegisterServiceCommand( + command_id="c1", + service_id="svc-a", + service_type=ServiceType.GPU_TELEMETRY_MANAGER, + state=LifecycleState.RUNNING, + capabilities=(ServiceCapability.BASELINE_COLLECTOR,), + ) + if ServiceCapability.BASELINE_COLLECTOR in cmd.capabilities: + coord.register(cmd.service_id) + assert coord.registered_count == 1 + + +def test_coordinator_skips_service_without_capability() -> None: + coord = BaselineCoordinator(publish=_no_op_publish, gate_timeout_s=0.05) + cmd = RegisterServiceCommand( + command_id="c1", + service_id="svc-a", + service_type=ServiceType.WORKER, + state=LifecycleState.RUNNING, + ) + if ServiceCapability.BASELINE_COLLECTOR in cmd.capabilities: + coord.register(cmd.service_id) + assert coord.registered_count == 0 + + +@pytest.mark.asyncio +async def test_register_service_registers_result_producer_domain() -> None: + controller = _controller_for_registration() + cmd = RegisterServiceCommand( + command_id="c1", + service_id="records-1", + service_type=ServiceType.RECORDS_MANAGER, + state=LifecycleState.RUNNING, + capabilities=(make_result_producer_capability("profile"),), + ) + + await controller._handle_register_service_command(cmd) + + assert controller._result_join_coordinator.pending_domains == ("profile",) + + +@pytest.mark.asyncio +async def test_register_service_ignores_unknown_result_capabilities() -> None: + controller = _controller_for_registration() + cmd = RegisterServiceCommand( + command_id="c1", + service_id="worker-1", + service_type=ServiceType.WORKER, + state=LifecycleState.RUNNING, + capabilities=("result_producer", "result_producer:", "unknown:domain"), + ) + + await controller._handle_register_service_command(cmd) + + assert controller._result_join_coordinator.ready + assert controller._baseline_coordinator.registered == [] + + +@pytest.mark.asyncio +async def test_register_service_preserves_baseline_collector_registration() -> None: + controller = _controller_for_registration() + cmd = RegisterServiceCommand( + command_id="c1", + service_id="telemetry-1", + service_type=ServiceType.GPU_TELEMETRY_MANAGER, + state=LifecycleState.RUNNING, + capabilities=( + ServiceCapability.BASELINE_COLLECTOR, + make_result_producer_capability("telemetry"), + ), + ) + + await controller._handle_register_service_command(cmd) + + assert controller._baseline_coordinator.registered == ["telemetry-1"] + assert controller._result_join_coordinator.pending_domains == ("telemetry",) + + +@pytest.mark.asyncio +async def test_shutdown_waits_for_registered_result_domains_with_deduped_logs() -> None: + controller = _controller_for_shutdown() + controller._result_join_coordinator.register("profile", "records-manager") + controller._result_join_coordinator.register("telemetry", "telemetry-manager") + controller._result_join_coordinator.register( + "server_metrics", "server-metrics-manager" + ) + + await controller._on_process_records_result_message(_profile_result_message()) + + assert controller.stop.await_count == 0 + assert controller.info_messages == [ + "Waiting for result domains: server_metrics, telemetry" + ] + + await controller._on_process_telemetry_result_message(_telemetry_result_message()) + + assert controller.stop.await_count == 0 + assert controller.info_messages == [ + "Waiting for result domains: server_metrics, telemetry", + "Waiting for result domains: server_metrics", + ] + + await controller._on_process_server_metrics_result_message( + _server_metrics_result_message() + ) + + assert controller.info_messages == [ + "Waiting for result domains: server_metrics, telemetry", + "Waiting for result domains: server_metrics", + "All results received, initiating shutdown", + ] + controller.stop.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_shutdown_pending_check_does_not_repeat_unchanged_wait_log() -> None: + controller = _controller_for_shutdown() + controller._result_join_coordinator.register("profile", "records-manager") + controller._result_join_coordinator.register("telemetry", "telemetry-manager") + + await controller._on_process_records_result_message(_profile_result_message()) + await controller._check_and_trigger_shutdown() + + assert controller.info_messages == ["Waiting for result domains: telemetry"] + assert controller.stop.await_count == 0 + + +@pytest.mark.asyncio +async def test_service_error_unregisters_failed_result_producer_and_allows_shutdown() -> ( + None +): + controller = _controller_for_shutdown() + controller._result_join_coordinator.register("profile", "records-manager") + controller._result_join_coordinator.register("telemetry", "telemetry-manager") + + await controller._on_process_records_result_message( + _profile_result_message("records-manager") + ) + await controller._process_service_error_message( + BaseServiceErrorMessage( + service_id="telemetry-manager", + error=ErrorDetails(type="RuntimeError", message="telemetry failed"), + ) + ) + + assert controller._exit_errors[0].service_id == "telemetry-manager" + assert controller._exit_errors[0].error_details.message == "telemetry failed" + assert controller._result_join_coordinator.pending_domains == () + assert controller.info_messages == [ + "Waiting for result domains: telemetry", + "All results received, initiating shutdown", + ] + controller.stop.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_disabled_telemetry_status_unregisters_result_domain_and_allows_shutdown() -> ( + None +): + controller = _controller_for_shutdown() + controller._result_join_coordinator.register("profile", "records-manager") + controller._result_join_coordinator.register("telemetry", "telemetry-manager") + + await controller._on_process_records_result_message( + _profile_result_message("records-manager") + ) + await controller._on_telemetry_status_message( + TelemetryStatusMessage( + service_id="telemetry-manager", + enabled=False, + reason="no DCGM endpoints reachable", + endpoints_configured=[], + endpoints_reachable=[], + ) + ) + + assert controller._result_join_coordinator.pending_domains == () + assert controller.info_messages == [ + "Waiting for result domains: telemetry", + "DCGM telemetry skipped: no DCGM endpoints reachable", + "All results received, initiating shutdown", + ] + controller.stop.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_disabled_server_metrics_status_unregisters_result_domain_and_allows_shutdown() -> ( + None +): + controller = _controller_for_shutdown() + controller._result_join_coordinator.register("profile", "records-manager") + controller._result_join_coordinator.register( + "server_metrics", "server-metrics-manager" + ) + + await controller._on_process_records_result_message( + _profile_result_message("records-manager") + ) + await controller._on_server_metrics_status_message( + ServerMetricsStatusMessage( + service_id="server-metrics-manager", + enabled=False, + reason="no Prometheus endpoints reachable", + endpoints_configured=[], + endpoints_reachable=[], + ) + ) + + assert controller._result_join_coordinator.pending_domains == () + assert controller.info_messages == [ + "Waiting for result domains: server_metrics", + "Server metrics disabled - no Prometheus endpoints reachable", + "All results received, initiating shutdown", + ] + controller.stop.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_shutdown_triggers_when_no_result_producers_registered() -> None: + controller = _controller_for_shutdown() + + await controller._check_and_trigger_shutdown() + + assert controller.info_messages == ["All results received, initiating shutdown"] + controller.stop.assert_awaited_once() diff --git a/tests/unit/gpu_telemetry/test_baseline_integration.py b/tests/unit/gpu_telemetry/test_baseline_integration.py new file mode 100644 index 000000000..a1974c29c --- /dev/null +++ b/tests/unit/gpu_telemetry/test_baseline_integration.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""GPUTelemetryManager adopts BaselineCollectorMixin for the phase baseline handshake.""" + +from aiperf.common.enums import ServiceCapability, make_result_producer_capability +from aiperf.common.mixins.baseline_collector_mixin import BaselineCollectorMixin +from aiperf.gpu_telemetry.manager import GPUTelemetryManager + + +def test_gpu_telemetry_uses_mixin() -> None: + assert issubclass(GPUTelemetryManager, BaselineCollectorMixin) + + +def test_gpu_telemetry_advertises_baseline_capability() -> None: + assert ( + ServiceCapability.BASELINE_COLLECTOR in GPUTelemetryManager.extra_capabilities + ) + + +def test_gpu_telemetry_advertises_telemetry_result_producer() -> None: + assert ( + make_result_producer_capability("telemetry") + in GPUTelemetryManager.extra_capabilities + ) + + +def test_gpu_telemetry_implements_collect_baseline() -> None: + """The class must override the abstract method (be instantiable in principle).""" + assert "collect_baseline" in GPUTelemetryManager.__dict__ diff --git a/tests/unit/gpu_telemetry/test_telemetry_manager.py b/tests/unit/gpu_telemetry/test_telemetry_manager.py index e3ef14f69..4256c3232 100644 --- a/tests/unit/gpu_telemetry/test_telemetry_manager.py +++ b/tests/unit/gpu_telemetry/test_telemetry_manager.py @@ -1,13 +1,14 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import asyncio from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest +from aiperf.common.enums import BaselineKind, CommandType from aiperf.common.environment import Environment from aiperf.common.messages import ( + ProfileCompleteCommand, ProfileConfigureCommand, ProfileStartCommand, TelemetryRecordsMessage, @@ -961,13 +962,12 @@ async def test_configure_pynvml_collector_success(self): assert PYNVML_SOURCE_IDENTIFIER in call_args.endpoints_configured assert PYNVML_SOURCE_IDENTIFIER in call_args.endpoints_reachable - # Should have collector registered and baseline-scraped before profiling. assert PYNVML_SOURCE_IDENTIFIER in manager._collectors assert ( manager._collector_id_to_url["pynvml_collector"] == PYNVML_SOURCE_IDENTIFIER ) - mock_collector.initialize.assert_awaited_once() - mock_collector.collect_and_process_metrics.assert_awaited_once() + mock_collector.initialize.assert_not_called() + mock_collector.collect_and_process_metrics.assert_not_called() @pytest.mark.asyncio async def test_configure_pynvml_collector_no_gpus_found(self): @@ -1086,7 +1086,7 @@ def _create_test_manager( return manager @pytest.mark.asyncio - async def test_configure_runtime_local_collector_captures_baseline( + async def test_configure_runtime_local_collector_does_not_capture_baseline( self, ) -> None: fake_name = "fake_baseline_gpu" @@ -1131,8 +1131,8 @@ class FakeBaselineCollector: ) ) - mock_collector.initialize.assert_awaited_once() - mock_collector.collect_and_process_metrics.assert_awaited_once() + mock_collector.initialize.assert_not_called() + mock_collector.collect_and_process_metrics.assert_not_called() finally: GPUTelemetryCollectorType.deregister(fake_enum_member) @@ -1243,91 +1243,37 @@ async def test_configure_amdsmi_collector_success(self): manager._collector_id_to_url["amdsmi_collector"] == AMDSMI_SOURCE_IDENTIFIER ) - # Baseline scrape: configure must call initialize() + one - # collect_and_process_metrics() so counter deltas - # (amd_energy_consumption, amd_ecc_uncorrectable) are computed - # against a pre-profile reference, not the first in-window sample. - mock_collector.initialize.assert_awaited_once() - mock_collector.collect_and_process_metrics.assert_awaited_once() + mock_collector.initialize.assert_not_called() + mock_collector.collect_and_process_metrics.assert_not_called() @pytest.mark.asyncio - async def test_configure_amdsmi_collector_continues_when_baseline_scrape_fails( - self, - ): - # If only the baseline scrape raises (transient sensor read error - # after a successful init), the collector is still usable — keep - # it enabled and just lose the reference sample. The periodic - # collection loop still runs; counter deltas degrade to the - # first-in-window-sample fallback for the first interval. + async def test_collect_baseline_reports_scrape_failures(self): manager = self._create_test_manager() - manager.publish = AsyncMock() mock_collector = AsyncMock() - mock_collector.endpoint_url = AMDSMI_SOURCE_IDENTIFIER - mock_collector.is_url_reachable = AsyncMock(return_value=True) - mock_collector.initialize = AsyncMock() # init succeeds mock_collector.collect_and_process_metrics = AsyncMock( side_effect=RuntimeError("transient sensor read error") ) + manager._collectors = {AMDSMI_SOURCE_IDENTIFIER: mock_collector} - MockCollectorClass = MagicMock(return_value=mock_collector) - with patch( - "aiperf.plugin.plugins.get_class", - return_value=MockCollectorClass, - ): - configure_msg = ProfileConfigureCommand( - command_id="test", service_id="system_controller", config={} - ) - await manager._profile_configure_command(configure_msg) - - manager.publish.assert_called_once() - call_args = manager.publish.call_args[0][0] - assert isinstance(call_args, TelemetryStatusMessage) - assert call_args.enabled is True - assert AMDSMI_SOURCE_IDENTIFIER in manager._collectors - manager.warning.assert_called() + with pytest.raises(RuntimeError, match="transient sensor read error"): + await manager.collect_baseline(BaselineKind.START, "phase-1", "profiling") @pytest.mark.asyncio - async def test_configure_amdsmi_collector_disables_when_init_fails(self): - # AIPerfLifecycleMixin re-raises hook failures as - # ``asyncio.CancelledError`` (see test_amdsmi_collector.py - # ``test_init_failure_propagates_via_lifecycle``). The baseline path - # must catch that — letting it propagate would cancel the entire - # PROFILE_CONFIGURE flow rather than gracefully disabling telemetry. - # On init failure the collector is unusable, so it must be removed - # from ``_collectors`` and disabled status reported. + async def test_profile_complete_stops_without_final_scrape(self): manager = self._create_test_manager() - manager.publish = AsyncMock() mock_collector = AsyncMock() - mock_collector.endpoint_url = AMDSMI_SOURCE_IDENTIFIER - mock_collector.is_url_reachable = AsyncMock(return_value=True) - mock_collector.initialize = AsyncMock( - side_effect=asyncio.CancelledError( - "Failed to initialize amdsmi: driver gone" - ) - ) + manager._collectors = {AMDSMI_SOURCE_IDENTIFIER: mock_collector} - MockCollectorClass = MagicMock(return_value=mock_collector) - with patch( - "aiperf.plugin.plugins.get_class", - return_value=MockCollectorClass, - ): - configure_msg = ProfileConfigureCommand( - command_id="test", service_id="system_controller", config={} + await manager._handle_profile_complete_command( + ProfileCompleteCommand( + service_id=manager.service_id, command=CommandType.PROFILE_COMPLETE ) - # Must NOT propagate CancelledError out of configure. - await manager._profile_configure_command(configure_msg) + ) - manager.publish.assert_called_once() - call_args = manager.publish.call_args[0][0] - assert isinstance(call_args, TelemetryStatusMessage) - assert call_args.enabled is False - assert "amdsmi://localhost initialization failed" in call_args.reason - assert AMDSMI_SOURCE_IDENTIFIER not in manager._collectors - assert "amdsmi_collector" not in manager._collector_id_to_url - # collect_and_process_metrics must NOT be invoked when init failed. mock_collector.collect_and_process_metrics.assert_not_called() + mock_collector.stop.assert_called_once() @pytest.mark.asyncio async def test_configure_amdsmi_collector_no_gpus_found(self): diff --git a/tests/unit/records/test_records_manager.py b/tests/unit/records/test_records_manager.py index 9afad6b63..f44d1d760 100644 --- a/tests/unit/records/test_records_manager.py +++ b/tests/unit/records/test_records_manager.py @@ -2,11 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import time from unittest.mock import AsyncMock, MagicMock, patch import pytest -from aiperf.common.enums import CreditPhase +from aiperf.common.enums import CreditPhase, make_result_producer_capability +from aiperf.common.environment import Environment from aiperf.common.exceptions import PostProcessorDisabled from aiperf.common.messages.inference_messages import ( MetricRecordsData, @@ -16,6 +18,7 @@ BranchStats, CreditPhaseStats, MetricResult, + PhaseRecordsStats, ProcessRecordsResult, ProfileResults, ) @@ -35,6 +38,12 @@ from tests.harness import mock_plugin +def test_records_manager_advertises_profile_result_producer() -> None: + assert ( + make_result_producer_capability("profile") in RecordsManager.extra_capabilities + ) + + # Helper functions def create_mock_records_manager( start_time_ns: int, @@ -321,16 +330,25 @@ def _create_manager_for_timing_dispatch() -> RecordsManager: manager._complete_credit_phases = set() manager._phase_branch_stats = {} manager._latest_branch_stats = None + manager._credits_complete_received = False + manager._credits_complete_fallback_task = None manager._timing_results_processors = [] manager._send_timing_to_results_processors = AsyncMock() manager._send_results_to_results_processors = AsyncMock() manager.info = MagicMock() manager.notice = MagicMock() + manager.warning = MagicMock() manager.debug = MagicMock() manager.trace = MagicMock() manager.exception = MagicMock() manager.is_enabled_for = MagicMock(return_value=False) manager._handle_all_records_received = AsyncMock() + + def close_background_task(coro): + coro.close() + return MagicMock() + + manager.execute_async = MagicMock(side_effect=close_background_task) return manager @@ -445,6 +463,13 @@ async def test_on_metric_records_records_complete_before_phase_complete_defers_f ) ) + manager._records_tracker.check_and_set_all_records_received_for_phase.assert_not_called() + manager._handle_all_records_received.assert_not_awaited() + + await manager._on_credits_complete( + CreditsCompleteMessage(service_id="timing-manager") + ) + manager._records_tracker.check_and_set_all_records_received_for_phase.assert_called_once_with( CreditPhase.PROFILING ) @@ -525,6 +550,13 @@ async def _record_branch_stats_at_finalization(phase: CreditPhase) -> None: await manager._on_metric_records(_metric_records_message()) + manager._records_tracker.check_and_set_all_records_received_for_phase.assert_not_called() + manager._handle_all_records_received.assert_not_awaited() + + await manager._on_credits_complete( + CreditsCompleteMessage(service_id="timing-manager") + ) + manager._records_tracker.check_and_set_all_records_received_for_phase.assert_called_once_with( CreditPhase.PROFILING ) @@ -605,9 +637,76 @@ async def _block_timing_fanout(stats: CreditPhaseStats) -> None: release_timing_fanout.set() await phase_complete_task + manager._handle_all_records_received.assert_not_awaited() + + await manager._on_credits_complete( + CreditsCompleteMessage(service_id="timing-manager") + ) + + manager._handle_all_records_received.assert_awaited_once_with( + CreditPhase.PROFILING + ) + + @pytest.mark.asyncio + async def test_finalization_falls_back_when_credits_complete_never_arrives( + self, + ) -> None: + manager = _create_manager_for_timing_dispatch() + manager._records_tracker = RecordsTracker() + manager.execute_async = lambda coro: asyncio.create_task(coro) + + original_timeout = Environment.RECORD.CREDITS_COMPLETE_FALLBACK_TIMEOUT + Environment.RECORD.CREDITS_COMPLETE_FALLBACK_TIMEOUT = 0.0 + try: + await manager._on_credit_phase_complete( + CreditPhaseCompleteMessage( + service_id="timing-manager", + stats=_create_credit_phase_stats().model_copy( + update={"final_requests_completed": 1} + ), + ) + ) + await manager._on_metric_records(_metric_records_message()) + await manager._credits_complete_fallback_task + finally: + Environment.RECORD.CREDITS_COMPLETE_FALLBACK_TIMEOUT = original_timeout + manager._handle_all_records_received.assert_awaited_once_with( CreditPhase.PROFILING ) + manager.warning.assert_called() + + @pytest.mark.asyncio + async def test_finalize_processes_results_without_server_metrics_flush_wait( + self, + ) -> None: + manager = RecordsManager.__new__(RecordsManager) + stats = PhaseRecordsStats( + phase=CreditPhase.PROFILING, + start_ns=1_000_000_000, + requests_end_ns=time.time_ns(), + ) + manager.service_id = "records-manager" + manager._records_tracker = MagicMock() + manager._records_tracker.create_stats_for_phase.return_value = stats + manager.publish = AsyncMock() + manager.send_command_and_wait_for_response = AsyncMock(return_value=object()) + manager._process_results = AsyncMock() + manager.debug = MagicMock() + manager.warning = MagicMock() + manager.info = MagicMock() + + with patch( + "aiperf.records.records_manager.asyncio.sleep", new_callable=AsyncMock + ) as mock_sleep: + await manager._finalize_and_process_results( + CreditPhase.PROFILING, cancelled=False + ) + + mock_sleep.assert_not_awaited() + manager._process_results.assert_awaited_once_with( + phase=CreditPhase.PROFILING, cancelled=False + ) @pytest.mark.asyncio async def test_send_timing_to_results_processors_ignores_empty_processor_list( diff --git a/tests/unit/server_metrics/test_baseline_integration.py b/tests/unit/server_metrics/test_baseline_integration.py new file mode 100644 index 000000000..cc931b5ab --- /dev/null +++ b/tests/unit/server_metrics/test_baseline_integration.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""ServerMetricsManager adopts BaselineCollectorMixin for the phase baseline handshake.""" + +from aiperf.common.enums import ServiceCapability, make_result_producer_capability +from aiperf.common.mixins.baseline_collector_mixin import BaselineCollectorMixin +from aiperf.server_metrics.manager import ServerMetricsManager + + +def test_server_metrics_uses_mixin() -> None: + assert issubclass(ServerMetricsManager, BaselineCollectorMixin) + + +def test_server_metrics_advertises_baseline_capability() -> None: + assert ( + ServiceCapability.BASELINE_COLLECTOR in ServerMetricsManager.extra_capabilities + ) + + +def test_server_metrics_advertises_server_metrics_result_producer() -> None: + assert ( + make_result_producer_capability("server_metrics") + in ServerMetricsManager.extra_capabilities + ) + + +def test_server_metrics_implements_collect_baseline() -> None: + assert "collect_baseline" in ServerMetricsManager.__dict__ diff --git a/tests/unit/server_metrics/test_server_metrics_manager.py b/tests/unit/server_metrics/test_server_metrics_manager.py index cd8789ca0..f4f2fde99 100644 --- a/tests/unit/server_metrics/test_server_metrics_manager.py +++ b/tests/unit/server_metrics/test_server_metrics_manager.py @@ -5,7 +5,7 @@ import pytest -from aiperf.common.enums import CommandType +from aiperf.common.enums import BaselineKind, CommandType from aiperf.common.messages import ProfileConfigureCommand, ProfileStartCommand from aiperf.common.messages.server_metrics_messages import ServerMetricsRecordMessage from aiperf.common.models import ErrorDetails @@ -139,6 +139,7 @@ async def test_configure_with_reachable_endpoints( ) assert len(manager._collectors) > 0 + assert mock_collector.initialize.await_count == len(manager._collectors) @pytest.mark.asyncio async def test_configure_with_unreachable_endpoints( @@ -462,12 +463,11 @@ async def test_exception_during_reachability_check( assert len(manager._collectors) == 0 @pytest.mark.asyncio - async def test_exception_during_baseline_capture( + async def test_configure_does_not_capture_baseline( self, cli_config: CLIConfig, cfg_with_endpoint: CLIConfig, ): - """Test that exceptions during baseline capture are logged but don't fail configuration.""" manager = ServerMetricsManager( run=make_run_from_cli(cfg_with_endpoint), ) @@ -477,10 +477,6 @@ async def test_exception_during_baseline_capture( ) as mock_collector_class: mock_collector = AsyncMock() mock_collector.is_url_reachable = AsyncMock(return_value=True) - mock_collector.initialize = AsyncMock() - mock_collector.collect_and_process_metrics.side_effect = Exception( - "Baseline failed" - ) mock_collector_class.return_value = mock_collector await manager._profile_configure_command( @@ -491,8 +487,8 @@ async def test_exception_during_baseline_capture( ) ) - # Collector should still be added despite baseline failure assert len(manager._collectors) > 0 + mock_collector.collect_and_process_metrics.assert_not_called() class TestPartialStartup: @@ -536,12 +532,11 @@ class TestProfileCompleteAndCancel: """Test profile completion and cancellation scenarios.""" @pytest.mark.asyncio - async def test_profile_complete_triggers_final_scrape( + async def test_profile_complete_stops_without_final_scrape( self, cli_config: CLIConfig, cfg_with_endpoint: CLIConfig, ): - """Test that profile complete triggers final metrics scrape.""" from aiperf.common.messages import ProfileCompleteCommand manager = ServerMetricsManager( @@ -557,38 +552,30 @@ async def test_profile_complete_triggers_final_scrape( ) ) - # Should call final scrape - mock_collector.collect_and_process_metrics.assert_called_once() - # Should stop collector after final scrape + mock_collector.collect_and_process_metrics.assert_not_called() mock_collector.stop.assert_called_once() @pytest.mark.asyncio - async def test_profile_complete_handles_final_scrape_failure( + async def test_collect_baseline_reports_scrape_failures( self, cli_config: CLIConfig, cfg_with_endpoint: CLIConfig, ): - """Test that profile complete handles final scrape failures gracefully.""" - from aiperf.common.messages import ProfileCompleteCommand - manager = ServerMetricsManager( run=make_run_from_cli(cfg_with_endpoint), ) + manager.info = MagicMock() + manager.warning = MagicMock() + manager.debug = MagicMock() mock_collector = AsyncMock() - mock_collector.collect_and_process_metrics.side_effect = Exception( - "Final scrape failed" + mock_collector.collect_and_process_metrics.side_effect = RuntimeError( + "baseline failed" ) manager._collectors = {"endpoint1": mock_collector} - await manager._handle_profile_complete_command( - ProfileCompleteCommand( - service_id=manager.id, command=CommandType.PROFILE_COMPLETE - ) - ) - - # Should still stop collector even if final scrape fails - mock_collector.stop.assert_called_once() + with pytest.raises(RuntimeError, match="baseline failed"): + await manager.collect_baseline(BaselineKind.START, "phase-1", "profiling") @pytest.mark.asyncio async def test_profile_complete_when_already_stopped( diff --git a/tests/unit/timing/phase/test_phase_gate_client.py b/tests/unit/timing/phase/test_phase_gate_client.py new file mode 100644 index 000000000..af1cc6690 --- /dev/null +++ b/tests/unit/timing/phase/test_phase_gate_client.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +import pytest + +from aiperf.common.messages import ( + CommandMessage, + PhaseEndGateCommand, + PhaseGateGrantedResponse, + PhaseStartGateCommand, +) +from aiperf.timing.phase.phase_gate import PhaseGateClient + + +class _StubSender: + def __init__(self) -> None: + self.sent: list[CommandMessage] = [] + + async def send_command_and_wait_for_response( + self, cmd: CommandMessage, timeout: float | None = None + ) -> PhaseGateGrantedResponse: + self.sent.append(cmd) + return PhaseGateGrantedResponse( + command_id=cmd.command_id, + service_id="system_controller", + command=cmd.command, + phase_id=cmd.phase_id, + ) + + +@pytest.mark.asyncio +async def test_before_phase_sends_start_gate() -> None: + sender = _StubSender() + gate = PhaseGateClient( + sender=sender, + service_id="timing_manager_test", + enabled=True, + timeout_s=5.0, + ) + await gate.before_phase("p1", "profiling") + assert len(sender.sent) == 1 + assert isinstance(sender.sent[0], PhaseStartGateCommand) + assert sender.sent[0].phase_id == "p1" + assert sender.sent[0].phase_name == "profiling" + assert sender.sent[0].service_id == "timing_manager_test" + + +@pytest.mark.asyncio +async def test_after_phase_sends_end_gate() -> None: + sender = _StubSender() + gate = PhaseGateClient( + sender=sender, + service_id="timing_manager_test", + enabled=True, + timeout_s=5.0, + ) + await gate.after_phase("p1", "profiling") + assert len(sender.sent) == 1 + assert isinstance(sender.sent[0], PhaseEndGateCommand) + assert sender.sent[0].phase_id == "p1" + assert sender.sent[0].phase_name == "profiling" + assert sender.sent[0].service_id == "timing_manager_test" + + +@pytest.mark.asyncio +async def test_disabled_short_circuits() -> None: + sender = _StubSender() + gate = PhaseGateClient( + sender=sender, + service_id="timing_manager_test", + enabled=False, + timeout_s=5.0, + ) + await gate.before_phase("p1", "profiling") + await gate.after_phase("p1", "profiling") + assert sender.sent == [] diff --git a/tests/unit/timing/phase/test_runner_gate_ordering.py b/tests/unit/timing/phase/test_runner_gate_ordering.py new file mode 100644 index 000000000..e838130c8 --- /dev/null +++ b/tests/unit/timing/phase/test_runner_gate_ordering.py @@ -0,0 +1,293 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Ordering invariants for the phase baseline handshake. + +Proves at the unit level: + - PhaseGateClient.before_phase awaits BEFORE phase_publisher.publish_phase_start + is called (i.e. baseline collectors get their reading slot before any credit + is published). + - PhaseGateClient.after_phase fires AFTER _wait_for_returning_complete settles, + once the phase has been finalised. + - When AIPERF_BASELINE_GATE_ENABLED=false, PhaseGateClient is a strict no-op + and never touches the command bus. + +Uses the existing make_runner / cfg test scaffolding from test_runner.py for +fixtures, plus a recording sender to capture the ordering of gate vs. publish +calls via time.monotonic_ns() (the autouse no_sleep fixture only mocks +asyncio.sleep, not time.monotonic). +""" + +from __future__ import annotations + +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from aiperf.common.messages import ( + CommandMessage, + PhaseEndGateCommand, + PhaseGateGrantedResponse, + PhaseStartGateCommand, +) +from aiperf.timing.phase.phase_gate import PhaseGateClient +from tests.unit.timing.phase.test_runner import MockStrategy, cfg, make_runner + + +class _RecordingSender: + """Captures gate command sends with monotonic timestamps.""" + + def __init__(self, events: list[tuple[str, int]]) -> None: + self._events = events + + async def send_command_and_wait_for_response( + self, message: CommandMessage, timeout: float | None = None + ) -> PhaseGateGrantedResponse: + if isinstance(message, PhaseStartGateCommand): + label = "before_phase" + elif isinstance(message, PhaseEndGateCommand): + label = "after_phase" + else: + label = type(message).__name__ + self._events.append((label, time.monotonic_ns())) + return PhaseGateGrantedResponse( + command_id=message.command_id, + service_id="system_controller", + command=message.command, + phase_id=message.phase_id, + ) + + +def _stamping_async_mock(events: list[tuple[str, int]], label: str) -> AsyncMock: + async def _impl(*_a: object, **_k: object) -> None: + events.append((label, time.monotonic_ns())) + + return AsyncMock(side_effect=_impl) + + +@pytest.fixture +def recording_sender_events() -> list[tuple[str, int]]: + return [] + + +@pytest.fixture +def recording_sender( + recording_sender_events: list[tuple[str, int]], +) -> _RecordingSender: + return _RecordingSender(recording_sender_events) + + +@pytest.fixture +def stamping_pub(recording_sender_events: list[tuple[str, int]]) -> MagicMock: + m = MagicMock() + m.publish_phase_start = _stamping_async_mock( + recording_sender_events, "publish_phase_start" + ) + m.publish_phase_sending_complete = _stamping_async_mock( + recording_sender_events, "publish_phase_sending_complete" + ) + m.publish_phase_complete = _stamping_async_mock( + recording_sender_events, "publish_phase_complete" + ) + m.publish_progress = AsyncMock() + m.publish_credits_complete = _stamping_async_mock( + recording_sender_events, "publish_credits_complete" + ) + return m + + +@pytest.fixture +def conv_src() -> MagicMock: + m = MagicMock() + m.next = MagicMock() + return m + + +@pytest.fixture +def router() -> MagicMock: + m = MagicMock() + m.send_credit = m.cancel_all_credits = AsyncMock() + m.mark_credits_complete = MagicMock() + return m + + +@pytest.fixture +def conc() -> MagicMock: + m = MagicMock() + m.configure_for_phase = MagicMock() + m.acquire_session_slot = AsyncMock(return_value=True) + m.acquire_prefill_slot = AsyncMock(return_value=True) + m.release_session_slot = m.release_prefill_slot = MagicMock() + m.set_session_limit = m.set_prefill_limit = MagicMock() + m.release_stuck_slots = MagicMock(return_value=(0, 0)) + return m + + +@pytest.fixture +def cancel_pol() -> MagicMock: + m = MagicMock() + m.next_cancellation_delay_ns = MagicMock(return_value=None) + return m + + +@pytest.fixture +def callback_handler() -> MagicMock: + m = MagicMock() + m.register_phase = m.unregister_phase = MagicMock() + m.on_credit_return = m.on_first_token = AsyncMock() + return m + + +@pytest.mark.asyncio +async def test_before_phase_fires_before_publish_phase_start( + conv_src: MagicMock, + stamping_pub: MagicMock, + router: MagicMock, + conc: MagicMock, + cancel_pol: MagicMock, + callback_handler: MagicMock, + recording_sender: _RecordingSender, + recording_sender_events: list[tuple[str, int]], +) -> None: + """The START gate must release before publish_phase_start is awaited.""" + gate = PhaseGateClient( + sender=recording_sender, + service_id="timing_manager_test", + enabled=True, + timeout_s=1.0, + ) + runner = make_runner( + cfg(), + conv_src, + stamping_pub, + router, + conc, + cancel_pol, + callback_handler, + ) + runner._phase_gate = gate + + with patch( + "aiperf.timing.phase.runner.plugins.get_class", + return_value=lambda **_kw: MockStrategy(), + ): + runner._progress.all_credits_sent_event.set() + runner._progress.all_credits_returned_event.set() + await runner.run(is_final_phase=True) + + labels = [e[0] for e in recording_sender_events] + assert "before_phase" in labels, f"no before_phase event recorded; got {labels}" + assert "publish_phase_start" in labels, ( + f"no publish_phase_start event; got {labels}" + ) + + first_before = labels.index("before_phase") + first_publish_start = labels.index("publish_phase_start") + assert first_before < first_publish_start, ( + f"before_phase must precede publish_phase_start; got order: {labels}" + ) + + +@pytest.mark.asyncio +async def test_after_phase_fires_after_publish_phase_complete( + conv_src: MagicMock, + stamping_pub: MagicMock, + router: MagicMock, + conc: MagicMock, + cancel_pol: MagicMock, + callback_handler: MagicMock, + recording_sender: _RecordingSender, + recording_sender_events: list[tuple[str, int]], +) -> None: + """The END gate must fire on the synchronous (final, non-seamless) path.""" + gate = PhaseGateClient( + sender=recording_sender, + service_id="timing_manager_test", + enabled=True, + timeout_s=1.0, + ) + runner = make_runner( + cfg(), + conv_src, + stamping_pub, + router, + conc, + cancel_pol, + callback_handler, + ) + runner._phase_gate = gate + + with patch( + "aiperf.timing.phase.runner.plugins.get_class", + return_value=lambda **_kw: MockStrategy(), + ): + runner._progress.all_credits_sent_event.set() + runner._progress.all_credits_returned_event.set() + await runner.run(is_final_phase=True) + + labels = [e[0] for e in recording_sender_events] + assert "after_phase" in labels, f"END gate never fired; events were {labels}" + assert "publish_phase_complete" in labels, ( + f"phase never completed publish; events were {labels}" + ) + + last_complete = max( + idx for idx, lbl in enumerate(labels) if lbl == "publish_phase_complete" + ) + last_after = max(idx for idx, lbl in enumerate(labels) if lbl == "after_phase") + assert last_after > last_complete, ( + f"after_phase must follow publish_phase_complete; got order: {labels}" + ) + + +@pytest.mark.asyncio +async def test_disabled_gate_does_not_send_commands( + conv_src: MagicMock, + stamping_pub: MagicMock, + router: MagicMock, + conc: MagicMock, + cancel_pol: MagicMock, + callback_handler: MagicMock, + recording_sender: _RecordingSender, + recording_sender_events: list[tuple[str, int]], +) -> None: + """When AIPERF_BASELINE_GATE_ENABLED=false the gate is a strict no-op.""" + gate = PhaseGateClient( + sender=recording_sender, + service_id="timing_manager_test", + enabled=False, + timeout_s=1.0, + ) + runner = make_runner( + cfg(), + conv_src, + stamping_pub, + router, + conc, + cancel_pol, + callback_handler, + ) + runner._phase_gate = gate + + with patch( + "aiperf.timing.phase.runner.plugins.get_class", + return_value=lambda **_kw: MockStrategy(), + ): + runner._progress.all_credits_sent_event.set() + runner._progress.all_credits_returned_event.set() + await runner.run(is_final_phase=True) + + gate_labels = [ + lbl + for lbl, _ts in recording_sender_events + if lbl in ("before_phase", "after_phase") + ] + assert gate_labels == [], ( + f"disabled gate sent commands; recorded gate events: {gate_labels}" + ) + # Sanity: phase still ran end-to-end via the regular publish path. + pub_labels = [ + lbl for lbl, _ts in recording_sender_events if lbl.startswith("publish_") + ] + assert "publish_phase_start" in pub_labels + assert "publish_phase_complete" in pub_labels diff --git a/tests/unit/timing/phase/test_runner_gates.py b/tests/unit/timing/phase/test_runner_gates.py new file mode 100644 index 000000000..24b69d225 --- /dev/null +++ b/tests/unit/timing/phase/test_runner_gates.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Unit-level test: PhaseRunner accepts a PhaseGateClient via constructor.""" + +import inspect + +from aiperf.timing.phase.runner import PhaseRunner + + +def test_phase_runner_accepts_phase_gate_kwarg() -> None: + """PhaseRunner.__init__ accepts a `phase_gate` kwarg (None-acceptable default).""" + params = inspect.signature(PhaseRunner.__init__).parameters + assert "phase_gate" in params + # default should be None (kwarg, not required) + assert params["phase_gate"].default is None