Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions src/aiperf/common/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import contextlib
import multiprocessing
import os
import platform
import signal
import sys
import uuid
import warnings
from typing import TYPE_CHECKING

from aiperf.common.constants import IS_MACOS, IS_WINDOWS
from aiperf.common.enums import LifecycleState
from aiperf.common.environment import Environment
from aiperf.plugin.enums import ServiceType
Expand Down Expand Up @@ -121,7 +121,7 @@ async def _run_service():
# processes inherit terminal file descriptors and interfere with Textual's
# terminal management, causing ASCII garbage and freezing when mouse events occur.
# Only apply this in spawned child processes, NOT in the main process where Textual runs.
if platform.system() == "Darwin" and is_child_process:
if IS_MACOS and is_child_process:
_redirect_stdio_to_devnull()

# Initialize global RandomGenerator for reproducible random number generation
Expand All @@ -143,6 +143,8 @@ async def _run_service():

_exit_if_service_failed(service)

_configure_event_loop_policy_for_platform()

with contextlib.suppress(asyncio.CancelledError):
if not Environment.SERVICE.DISABLE_UVLOOP:
import uvloop
Expand All @@ -152,6 +154,23 @@ async def _run_service():
asyncio.run(_run_service())


def _configure_event_loop_policy_for_platform() -> None:
"""On Windows, switch to ``WindowsSelectorEventLoopPolicy`` before the
event loop is created.

pyzmq's async sockets call ``loop.add_reader()`` / ``loop.add_writer()``,
which the default ``ProactorEventLoop`` on Windows does not implement.
The selector policy must be set before ``asyncio.run()``/``uvloop.run()``
constructs the loop.

uvloop is already auto-disabled on Windows via ``environment.py``, so on
Windows this only matters for the asyncio path. On non-Windows platforms
this is a no-op — the default policy is already correct.
"""
if IS_WINDOWS:
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())


def _exit_if_service_failed(service) -> None:
"""Surface accumulated service failures as a non-zero SystemExit.

Expand Down
7 changes: 7 additions & 0 deletions src/aiperf/common/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import platform as _platform

# Platform detection — evaluated once at import time.
IS_WINDOWS: bool = _platform.system() == "Windows"
IS_MACOS: bool = _platform.system() == "Darwin"
IS_LINUX: bool = _platform.system() == "Linux"

NANOS_PER_SECOND = 1_000_000_000
NANOS_PER_MILLIS = 1_000_000
MILLIS_PER_SECOND = 1000
Expand Down
21 changes: 15 additions & 6 deletions src/aiperf/common/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
print(f"Workers: {Environment.WORKER.CPU_UTILIZATION_FACTOR}")
"""

import platform
from pathlib import Path
from typing import Annotated, Literal

Expand All @@ -46,6 +45,7 @@
from typing_extensions import Self

from aiperf.common.aiperf_logger import AIPerfLogger
from aiperf.common.constants import IS_WINDOWS
from aiperf.config.loader.parsing import (
parse_service_types,
parse_str_or_csv_list,
Expand Down Expand Up @@ -878,11 +878,20 @@ class _ServiceSettings(BaseSettings):

@model_validator(mode="after")
def auto_disable_uvloop_on_windows(self) -> Self:
"""Automatically disable uvloop on Windows as it's not supported."""
if platform.system() == "Windows" and not self.DISABLE_UVLOOP:
_logger.info(
"Windows detected: automatically disabling uvloop (not supported on Windows)"
)
"""Automatically disable uvloop on Windows as it's not supported.

Validator fires on every ``_ServiceSettings()`` construction, which
runs once in the main process AND once per spawned child service.
Gate the log line to the main process so the user sees it once, not
~9 times per aiperf run on Windows.
"""
if IS_WINDOWS and not self.DISABLE_UVLOOP:
import multiprocessing

if multiprocessing.parent_process() is None:
_logger.info(
"Windows detected: automatically disabling uvloop (not supported on Windows)"
)
self.DISABLE_UVLOOP = True
return self

Expand Down
6 changes: 3 additions & 3 deletions src/aiperf/common/tokenizer_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,10 @@ def log_tokenizer_validation_results(
for entry in results:
if entry.was_resolved:
logger.info(
f" Tokenizer {entry.resolved_name} detected for {entry.original_name}"
f"[OK] Tokenizer {entry.resolved_name} detected for {entry.original_name}"
)
else:
logger.info(f" Tokenizer {entry.resolved_name} detected")
logger.info(f"[OK] Tokenizer {entry.resolved_name} detected")

total = len(results)
resolved = sum(1 for e in results if e.was_resolved)
Expand All @@ -260,7 +260,7 @@ def log_tokenizer_validation_results(
parts.append(f"{resolved} resolved")
if elapsed_seconds is not None:
parts.append(f"{elapsed_seconds:.1f}s")
logger.info(" ".join(parts))
logger.info(" | ".join(parts))


def _display_panel(title: str, content: str, console: Console | None = None) -> None:
Expand Down
11 changes: 5 additions & 6 deletions src/aiperf/config/comm/dual_bind.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing_extensions import Self

from aiperf.config.comm.base import BaseZMQCommunicationConfig, BaseZMQProxyConfig
from aiperf.config.comm.ipc import _build_socket_address
from aiperf.plugin.enums import CommunicationBackend


Expand Down Expand Up @@ -61,10 +62,8 @@ class ZMQDualBindProxyConfig(BaseZMQProxyConfig):
enable_capture: bool = Field(default=False, description="Enable capture socket")

def _ipc_addr(self, endpoint: str) -> str:
"""Build an IPC address for the given endpoint."""
if self.ipc_path is None:
raise ValueError("IPC path is required for dual-bind transport")
return f"ipc://{self.ipc_path / self.name}_{endpoint}.ipc"
"""Build an address for the given endpoint (ipc:// on POSIX, tcp:// on Windows)."""
return _build_socket_address(self.ipc_path, f"{self.name}_{endpoint}.ipc")

def _tcp_addr(self, port: int) -> str:
"""Build a TCP address for the given port (bind-side)."""
Expand Down Expand Up @@ -228,13 +227,13 @@ def validate_paths(self) -> Self:
)

def _ipc_addr(self, name: str) -> str:
"""Build an IPC address for the given endpoint name."""
"""Build an address for the given endpoint (ipc:// on POSIX, tcp:// on Windows)."""
if not self.ipc_path:
raise ValueError(
f"Dual-bind IPC address for endpoint {name!r} requires comm.ipc_path; "
"set comm.ipc_path or configure controller_host for TCP addresses."
)
return f"ipc://{self.ipc_path / name}.ipc"
return _build_socket_address(self.ipc_path, f"{name}.ipc")

@property
def records_push_pull_address(self) -> str:
Expand Down
76 changes: 52 additions & 24 deletions src/aiperf/config/comm/ipc.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,56 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import hashlib
import tempfile
from pathlib import Path
from typing import Annotated, ClassVar

from pydantic import Field, model_validator
from typing_extensions import Self

from aiperf.common.constants import IS_WINDOWS
from aiperf.config.comm.base import BaseZMQCommunicationConfig, BaseZMQProxyConfig
from aiperf.plugin.enums import CommunicationBackend

# Windows fallback: ZMQ does not support ipc:// on Windows. Use TCP loopback
# with a deterministic port derived from a hash of the would-be IPC path, so
# bind and connect sides agree without explicit coordination.
#
# Range chosen to:
# - stay below the OS ephemeral-port range (49152+ on Linux/macOS/Win10+)
# - sit above the cluster of common service ports (HTTP/Prometheus/vLLM/
# Ollama/OTLP/etc.) so a co-running service on localhost is unlikely to
# have already bound a port we hash to
# - keep birthday-paradox collision probability low for AIPerf's ~15 sockets:
# P(collision) ≈ 1 - exp(-n^2 / (2 * RANGE)). At RANGE=20000, n=15 → ~0.56%.
#
# Per-aiperf-run uniqueness is provided by ``tempfile.mkdtemp()`` randomness
# in ``ZMQIPCConfig.validate_path`` — two concurrent aiperf processes get
# different ipc paths, which feed into the salt, which produces different
# port distributions.
_WINDOWS_TCP_BASE_PORT = 28000
_WINDOWS_TCP_PORT_RANGE = 20000


def _build_socket_address(path: Path | None, ipc_filename: str) -> str:
"""Build a ZMQ socket address for an inter-service connection.

On Linux/macOS: returns ipc://{path}/{ipc_filename} (Unix domain socket).
On Windows: returns tcp://127.0.0.1:<port> with a deterministic port
derived from sha256(path/ipc_filename), since Windows ZMQ does not
support ipc://. Path is required on every platform so callers maintain
a consistent contract and the hash inputs are stable.
"""
if path is None:
raise ValueError("IPC path is required for socket address derivation")
if IS_WINDOWS:
salt = f"{path}/{ipc_filename}"
digest = hashlib.sha256(salt.encode()).hexdigest()
port_offset = int(digest[:8], 16) % _WINDOWS_TCP_PORT_RANGE
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hashing each Windows IPC endpoint into a fixed 20,000-port range can map two production sockets in the same run to the same TCP port, causing intermittent address already in use startup failures. Fix: allocate and share a collision-free per-run port map, or resolve collisions by probing/reserving ports before services start.

return f"tcp://127.0.0.1:{_WINDOWS_TCP_BASE_PORT + port_offset}"
return f"ipc://{path / ipc_filename}"


class ZMQIPCProxyConfig(BaseZMQProxyConfig):
"""Configuration for IPC proxy."""
Expand All @@ -21,10 +61,8 @@ class ZMQIPCProxyConfig(BaseZMQProxyConfig):
enable_capture: bool = Field(default=False, description="Enable capture socket")

def _addr(self, endpoint: str) -> str:
"""Build an IPC address for the given endpoint."""
if self.path is None:
raise ValueError("Path is required for IPC transport")
return f"ipc://{self.path / self.name}_{endpoint}.ipc"
"""Build an address for the given endpoint (ipc:// on POSIX, tcp:// on Windows)."""
return _build_socket_address(self.path, f"{self.name}_{endpoint}.ipc")

@property
def frontend_address(self) -> str:
Expand Down Expand Up @@ -97,35 +135,25 @@ def validate_path(self) -> Self:

@property
def records_push_pull_address(self) -> str:
"""Get the records push/pull address based on protocol configuration."""
if not self.path:
raise ValueError("Path is required for IPC transport")
return f"ipc://{self.path / 'records_push_pull.ipc'}"
"""Get the records push/pull address (ipc:// on POSIX, tcp:// on Windows)."""
return _build_socket_address(self.path, "records_push_pull.ipc")

@property
def credit_router_address(self) -> str:
"""Get the credit router address for streaming ROUTER-DEALER."""
if not self.path:
raise ValueError("Path is required for IPC transport")
return f"ipc://{self.path / 'credit_router.ipc'}"
"""Get the credit router address (ipc:// on POSIX, tcp:// on Windows)."""
return _build_socket_address(self.path, "credit_router.ipc")

@property
def credit_return_router_address(self) -> str:
"""Get the credit return router address for dedicated return channel."""
if not self.path:
raise ValueError("Path is required for IPC transport")
return f"ipc://{self.path / 'credit_return_router.ipc'}"
"""Get the credit return router address (ipc:// on POSIX, tcp:// on Windows)."""
return _build_socket_address(self.path, "credit_return_router.ipc")

@property
def control_address(self) -> str:
"""Get the control channel address."""
if not self.path:
raise ValueError("Path is required for IPC transport")
return f"ipc://{self.path / 'control.ipc'}"
"""Get the control channel address (ipc:// on POSIX, tcp:// on Windows)."""
return _build_socket_address(self.path, "control.ipc")

@property
def group_lifecycle_address(self) -> str:
"""Get the group-local lifecycle channel address."""
if not self.path:
raise ValueError("Path is required for IPC transport")
return f"ipc://{self.path / 'group_lifecycle.ipc'}"
"""Get the group-local lifecycle channel address (ipc:// on POSIX, tcp:// on Windows)."""
return _build_socket_address(self.path, "group_lifecycle.ipc")
3 changes: 1 addition & 2 deletions src/aiperf/controller/multiprocess_service_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import multiprocessing
import uuid
from multiprocessing import Process
from multiprocessing.context import ForkProcess, SpawnProcess

from pydantic import BaseModel, ConfigDict, Field

Expand All @@ -21,7 +20,7 @@ class MultiProcessRunInfo(BaseModel):

model_config = ConfigDict(arbitrary_types_allowed=True)

process: Process | SpawnProcess | ForkProcess | None = Field(default=None)
process: Process | None = Field(default=None)
service_type: ServiceTypeT = Field(
...,
description="Type of service running in the process",
Expand Down
44 changes: 38 additions & 6 deletions src/aiperf/controller/system_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import signal
from collections.abc import Callable, Coroutine

from aiperf.common.constants import IS_WINDOWS
from aiperf.common.mixins import AIPerfLoggerMixin


Expand All @@ -24,11 +25,42 @@ def setup_signal_handlers(self, callback: Callable[[int], Coroutine]) -> None:
loop = asyncio.get_running_loop()
self.debug(f"Setting up SIGINT handler on loop {loop}")

def signal_handler(sig: int) -> None:
self.warning(f"Signal {sig} received, initiating graceful shutdown")
task = asyncio.create_task(callback(sig))
self._signal_tasks.add(task)
task.add_done_callback(self._signal_tasks.discard)
# Windows ProactorEventLoop does not implement add_signal_handler.
# Fall back to signal.signal(), which Windows supports for SIGINT.
# The handler is dispatched on the main thread and re-enters the loop
# via run_coroutine_threadsafe to invoke the async callback. The
# scheduled task is held by the loop, so no _signal_tasks tracking
# is needed on Windows.
#
# Limitation: on Windows the handler runs in main-thread Python-level
# interrupt context. CPython on Windows cannot interrupt a blocking C
# extension call (e.g. zmq.poll, aiohttp blocking I/O) until that
# call returns control to Python, so the user may observe sub-second
# Ctrl+C lag during heavy I/O. This is a CPython-on-Windows constraint,
# not a bug in the handler.
if IS_WINDOWS:

loop.add_signal_handler(signal.SIGINT, signal_handler, signal.SIGINT)
def windows_signal_handler(sig: int, _frame: object) -> None:
self.warning(f"Signal {sig} received, initiating graceful shutdown")
asyncio.run_coroutine_threadsafe(callback(sig), loop)

signal.signal(signal.SIGINT, windows_signal_handler)
# SIGBREAK (Ctrl+Break / CTRL_BREAK_EVENT) is distinct from SIGINT
# on Windows. Some scripted runs, CI runners, and remote consoles
# send CTRL_BREAK_EVENT instead of CTRL_C_EVENT — without this
# handler the process dies without graceful shutdown. Guarded
# by getattr because signal.SIGBREAK is undefined on POSIX and
# this branch is unit-tested on Linux/macOS via IS_WINDOWS mock.
sigbreak = getattr(signal, "SIGBREAK", None)
if sigbreak is not None:
signal.signal(sigbreak, windows_signal_handler)
else:

def signal_handler(sig: int) -> None:
self.warning(f"Signal {sig} received, initiating graceful shutdown")
task = asyncio.create_task(callback(sig))
self._signal_tasks.add(task)
task.add_done_callback(self._signal_tasks.discard)

loop.add_signal_handler(signal.SIGINT, signal_handler, signal.SIGINT)
self.debug("SIGINT handler installed successfully")
4 changes: 2 additions & 2 deletions src/aiperf/plugin/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,13 @@ def run_validate() -> None:
all_passed = True
for label, errors, fmt in checks:
if errors:
console.print(f"[red][/red] {label}")
console.print(f"[red](FAIL)[/red] {label}")
for cat, items in errors.items():
for line in fmt(cat, items):
console.print(f" {line}")
all_passed = False
else:
console.print(f"[green][/green] {label}")
console.print(f"[green](OK)[/green] {label}")

color = "green" if all_passed else "red"
msg = "All checks passed" if all_passed else "Validation failed"
Expand Down
Loading
Loading