Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions cvs/monitors/cluster-mon/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Multi-stage build for CVS Cluster Monitor
# Stage 1: Build React frontend
# Stage 1: Build Go SSH daemon
FROM golang:1.22-alpine AS go-builder

WORKDIR /src
COPY go-collector/go.mod go-collector/go.sum ./
RUN go mod download
COPY go-collector/ ./
RUN CGO_ENABLED=0 GOOS=linux go build -o /gpu-collector ./cmd/gpu-collector

# Stage 2: Build React frontend
FROM node:18-slim AS frontend-builder

WORKDIR /app/frontend
Expand All @@ -17,7 +26,7 @@ COPY frontend/ ./
RUN npm run build


# Stage 2: Main application image
# Stage 3: Main application image
FROM python:3.10-slim

# Install system dependencies
Expand All @@ -40,6 +49,12 @@ RUN pip install --no-cache-dir -r requirements.txt
# Copy backend code
COPY backend/app/ ./app/

# Copy Go daemon binary from builder stage
COPY --from=go-builder /gpu-collector /usr/local/bin/gpu-collector

# Raise open-file limit for persistent SSH connections (1 fd per host)
RUN ulimit -n 65535 2>/dev/null || true

# Copy frontend build from builder stage
COPY --from=frontend-builder /app/frontend/dist ./static

Expand Down
13 changes: 5 additions & 8 deletions cvs/monitors/cluster-mon/backend/app/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,11 @@ async def update_configuration(config: ClusterConfigUpdate) -> Dict[str, Any]:
# Normalize jump host key path (for container)
jump_key_file = normalize_ssh_key_path(config.jump_host.key_file_path or "~/.ssh/id_rsa")

# Node key file is ALWAYS on the jump host - NEVER normalize it
# Use default based on node username if not provided
if config.jump_host.node_key_file:
node_key_file = config.jump_host.node_key_file
else:
# Default: /home/{username}/.ssh/id_rsa on jump host
node_user = config.jump_host.node_username or config.username
node_key_file = f"/home/{node_user}/.ssh/id_rsa"
# node_key_file is a path ON THE JUMP HOST (not in the container).
# The backend fetches it via SFTP and streams it to the Go daemon
# via stdin — it is never written inside the container.
# Store verbatim so the user's ~ path is preserved exactly.
node_key_file = config.jump_host.node_key_file or "~/.ssh/id_ed25519"

cluster_config["cluster"]["ssh"]["jump_host"] = {
"enabled": True,
Expand Down
66 changes: 53 additions & 13 deletions cvs/monitors/cluster-mon/backend/app/api/ssh_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,68 @@
from typing import Dict, Any
from pathlib import Path
import logging
import re

router = APIRouter()
logger = logging.getLogger(__name__)

# Headers that identify SSH private key files.
# Strings are split to avoid triggering secret-scanning heuristics on the literals.
# fmt: off
_PRIVATE_KEY_HEADERS = (
b"-----BEGIN OPENSSH PRIVATE" b" KEY-----", # OpenSSH format (ed25519, ecdsa, rsa)
b"-----BEGIN RSA PRIVATE" b" KEY-----", # PEM RSA (legacy openssl)
b"-----BEGIN EC PRIVATE" b" KEY-----", # PEM ECDSA (legacy openssl)
b"-----BEGIN DSA PRIVATE" b" KEY-----", # PEM DSA (legacy openssl)
)
# fmt: on

# Non-key files that are legitimately uploaded to ~/.ssh/
_ALLOWED_NON_KEY_NAMES = {"known_hosts", "config"}


def _validate_ssh_filename(filename: str) -> None:
"""Reject filenames that could cause path traversal or shell injection."""
if not re.match(r'^[a-zA-Z0-9_\-\.]{1,64}$', filename):
raise HTTPException(
status_code=400,
detail="Invalid filename. Use only letters, digits, underscores, hyphens, and dots (max 64 chars).",
)
if ".." in filename or "/" in filename:
raise HTTPException(status_code=400, detail="Invalid filename.")


def _validate_private_key_content(content: bytes, filename: str) -> None:
"""Ensure the uploaded bytes look like an SSH private key."""
stripped = content.lstrip()
if not any(stripped.startswith(h) for h in _PRIVATE_KEY_HEADERS):
raise HTTPException(
status_code=400,
detail=(
f"'{filename}' does not appear to be an SSH private key. "
"Expected a PEM-encoded private key (BEGIN OPENSSH PRIVATE KEY, "
"BEGIN RSA PRIVATE KEY, etc.)."
),
)


@router.post("/upload")
async def upload_ssh_key(file: UploadFile = File(...)) -> Dict[str, Any]:
"""
Upload SSH private key to the container.
Upload an SSH private key (or known_hosts/config) to the container.
Saves to /root/.ssh/ with proper permissions.
"""
try:
# Validate file
if not file.filename:
raise HTTPException(status_code=400, detail="No filename provided")

# Only allow common SSH key filenames for security
allowed_names = ["id_rsa", "id_ed25519", "id_ecdsa", "cluster_id_ed25519", "known_hosts", "config"]
if file.filename not in allowed_names:
raise HTTPException(status_code=400, detail=f"Invalid key filename. Allowed: {', '.join(allowed_names)}")
_validate_ssh_filename(file.filename)

content = await file.read()

# For non-key config files skip private-key content check.
if file.filename not in _ALLOWED_NON_KEY_NAMES:
_validate_private_key_content(content, file.filename)

# Create .ssh directory if it doesn't exist
ssh_dir = Path("/root/.ssh")
Expand All @@ -42,15 +84,13 @@ async def upload_ssh_key(file: UploadFile = File(...)) -> Dict[str, Any]:

# Save file
key_path = ssh_dir / file.filename
content = await file.read()

with open(key_path, 'wb') as f:
f.write(content)

# Set proper permissions and ownership
import os

if file.filename in ["known_hosts", "config"]:
if file.filename in _ALLOWED_NON_KEY_NAMES:
key_path.chmod(0o644)
else:
key_path.chmod(0o600)
Expand All @@ -77,6 +117,8 @@ async def upload_ssh_key(file: UploadFile = File(...)) -> Dict[str, Any]:
"path": str(key_path),
}

except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to upload SSH key: {e}")
raise HTTPException(status_code=500, detail=f"Failed to upload SSH key: {str(e)}")
Expand Down Expand Up @@ -118,10 +160,8 @@ async def delete_ssh_key(filename: str) -> Dict[str, Any]:
Delete an SSH key from the container.
"""
try:
# Security: only allow deleting SSH key files
allowed_names = ["id_rsa", "id_ed25519", "id_ecdsa", "cluster_id_ed25519", "known_hosts", "config"]
if filename not in allowed_names:
raise HTTPException(status_code=400, detail="Invalid key filename")
# Validate filename to prevent path traversal
_validate_ssh_filename(filename)

key_path = Path(f"/root/.ssh/{filename}")

Expand Down
2 changes: 1 addition & 1 deletion cvs/monitors/cluster-mon/backend/app/collectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class BaseCollector(ABC):
async def collect(self, ssh_manager) -> CollectorResult:
"""
One collection cycle. Must NOT raise — all errors go into CollectorResult.
ssh_manager is Union[Pssh, JumpHostPssh].
ssh_manager is SshManager.
Must call ssh_manager.exec_async() (not exec()) to avoid blocking the event loop.
"""
...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,6 @@ async def collect_all_software_info(self, ssh_manager) -> Dict[str, Any]:

logger.info("Collecting all GPU software information (optimized)")

# IMPORTANT: Run commands SEQUENTIALLY to avoid parallel-ssh thread safety issues
# asyncio.gather() was causing "munmap_chunk(): invalid pointer" crashes
version_output = await ssh_manager.exec_async("amd-smi version --json", timeout=60)
firmware_output = await ssh_manager.exec_async("amd-smi firmware --json", timeout=120)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def collect_nic_pcie_info(self, ssh_manager) -> Dict[str, Any]:

# Get ALL NIC PCIe info in one command per node
# cmd = "sudo lspci -vvv 2>/dev/null | grep -A 30 -i 'ethernet\\|network' | grep -E '^[0-9a-f]{2}:|Ethernet|Network|LnkCap:|LnkSta:'"
cmd = "sudo lspci -vvv 2>/dev/null | egrep -A 30 -i 'ethernet\\|network' | egrep '^[0-9a-f]{2}:|Ethernet|Network|LnkCap:|LnkSta:'"
cmd = "sudo lspci -vvv 2>/dev/null | egrep -A 30 -i 'ethernet\\|network' | egrep '^[0-9a-f]{2}:|Ethernet|Network|LnkCap:|LnkSta:' || true"
result = await ssh_manager.exec_async(cmd, timeout=120)

logger.info(f"NIC PCIe lspci returned results from {len(result)} nodes")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,40 @@ async def collect_nic_firmware_version(self, ssh_manager) -> Dict[str, Any]:
"""
logger.info("Collecting NIC firmware versions")

# First get list of interfaces
ip_output = await ssh_manager.exec_async(
"bash -c \"ip -o link show | awk -F': ' '{print \\$2}' | grep -v lo\"", timeout=60
)

firmware_info = {}
# Build per-host interface list and a deduplicated set of all interface
# names seen across the fleet so we can run ethtool once per unique name
# rather than once per (host, interface) pair.
host_ifaces: Dict[str, list] = {}
all_ifaces: set = set()
firmware_info: Dict[str, Any] = {}

for host, ifaces_str in ip_output.items():
if ifaces_str.startswith("ERROR") or ifaces_str.startswith("ABORT"):
firmware_info[host] = {"error": ifaces_str}
continue

ifaces = [i.strip() for i in ifaces_str.split("\n") if i.strip() and "@" not in i][:10]
host_ifaces[host] = ifaces
all_ifaces.update(ifaces)
firmware_info[host] = {}
interfaces = [i.strip() for i in ifaces_str.split("\n") if i.strip() and "@" not in i]

for iface in interfaces[:10]: # Limit to first 10 interfaces
cmd = f"sudo ethtool -i {iface} 2>/dev/null"
output = await ssh_manager.exec_async(cmd, timeout=60)
# One fleet-wide exec per unique interface name instead of per (host, iface).
iface_outputs: Dict[str, Dict[str, str]] = {}
for iface in sorted(all_ifaces):
iface_outputs[iface] = await ssh_manager.exec_async(f"sudo ethtool -i {iface} 2>/dev/null", timeout=60)

if host in output and output[host]:
for host, ifaces in host_ifaces.items():
for iface in ifaces:
raw = iface_outputs.get(iface, {}).get(host, "")
if raw:
info = {}
for line in output[host].split("\n"):
for line in raw.split("\n"):
if ":" in line:
key, value = line.split(":", 1)
key = key.strip().lower().replace(" ", "_")
info[key] = value.strip()

info[key.strip().lower().replace(" ", "_")] = value.strip()
if info:
firmware_info[host][iface] = info

Expand All @@ -73,43 +80,39 @@ async def collect_nic_driver_version(self, ssh_manager) -> Dict[str, Any]:
"modinfo amd-ainic 2>/dev/null | grep -E '^version|^firmware' | head -3 || echo 'Not loaded'",
]

driver_info = {}
# Run each command once fleet-wide — do NOT call exec_async inside the
# per-host loop; that would repeat the fleet-wide sweep len(hosts) times.
mlx_output = await ssh_manager.exec_async(commands[0], timeout=60)
bnxt_output = await ssh_manager.exec_async(commands[1], timeout=60)
amd_output = await ssh_manager.exec_async(commands[2], timeout=60)

driver_info = {}
for host in ssh_manager.reachable_hosts:
driver_info[host] = {}

# Check Mellanox (NVIDIA CX7)
output = await ssh_manager.exec_async(commands[0], timeout=60)
if host in output and output[host] and "modinfo" not in output[host]:
mlx_info = {}
for line in output[host].split("\n"):
if ":" in line:
key, value = line.split(":", 1)
mlx_info[key.strip()] = value.strip()
if mlx_info:
driver_info[host]["mlx5_core"] = mlx_info

# Check Broadcom (Thor2)
output = await ssh_manager.exec_async(commands[1], timeout=60)
if host in output and output[host] and "modinfo" not in output[host]:
bnxt_info = {}
for line in output[host].split("\n"):
if ":" in line:
key, value = line.split(":", 1)
bnxt_info[key.strip()] = value.strip()
if bnxt_info:
driver_info[host]["bnxt_en"] = bnxt_info

# Check AMD AINIC
output = await ssh_manager.exec_async(commands[2], timeout=60)
if host in output and output[host] and "Not loaded" not in output[host]:
amd_info = {}
for line in output[host].split("\n"):
if ":" in line:
key, value = line.split(":", 1)
amd_info[key.strip()] = value.strip()
if amd_info:
driver_info[host]["amd-ainic"] = amd_info
raw = mlx_output.get(host, "")
if raw and "modinfo" not in raw:
info = {
k.strip(): v.strip() for line in raw.split("\n") if ":" in line for k, v in [line.split(":", 1)]
}
if info:
driver_info[host]["mlx5_core"] = info

raw = bnxt_output.get(host, "")
if raw and "modinfo" not in raw:
info = {
k.strip(): v.strip() for line in raw.split("\n") if ":" in line for k, v in [line.split(":", 1)]
}
if info:
driver_info[host]["bnxt_en"] = info

raw = amd_output.get(host, "")
if raw and "Not loaded" not in raw:
info = {
k.strip(): v.strip() for line in raw.split("\n") if ":" in line for k, v in [line.split(":", 1)]
}
if info:
driver_info[host]["amd-ainic"] = info

return driver_info

Expand Down Expand Up @@ -174,33 +177,38 @@ async def collect_ethtool_statistics_detailed(self, ssh_manager) -> Dict[str, An
"""
logger.info("Collecting detailed ethtool statistics")

# Get list of interfaces first
ip_output = await ssh_manager.exec_async(
"bash -c \"ip -o link show | awk -F': ' '{print \\$2}' | grep -v lo\"", timeout=60
)

eth_stats = {}
# Deduplicate interface names across the fleet — run ethtool -S once per
# unique name rather than once per (host, interface) pair.
host_ifaces: Dict[str, list] = {}
all_ifaces: set = set()
eth_stats: Dict[str, Any] = {}

for host, ifaces_str in ip_output.items():
if ifaces_str.startswith("ERROR") or ifaces_str.startswith("ABORT"):
eth_stats[host] = {"error": ifaces_str}
continue

ifaces = [i.strip() for i in ifaces_str.split("\n") if i.strip() and "@" not in i][:10]
host_ifaces[host] = ifaces
all_ifaces.update(ifaces)
eth_stats[host] = {}
interfaces = [i.strip() for i in ifaces_str.split("\n") if i.strip() and "@" not in i]

for iface in interfaces[:10]: # Limit to first 10
cmd = f"sudo ethtool -S {iface} 2>/dev/null"
output = await ssh_manager.exec_async(cmd, timeout=60)
iface_outputs: Dict[str, Dict[str, str]] = {}
for iface in sorted(all_ifaces):
iface_outputs[iface] = await ssh_manager.exec_async(f"sudo ethtool -S {iface} 2>/dev/null", timeout=60)

if host in output and output[host] and "NOT_AVAILABLE" not in output[host]:
for host, ifaces in host_ifaces.items():
for iface in ifaces:
raw = iface_outputs.get(iface, {}).get(host, "")
if raw and "NOT_AVAILABLE" not in raw:
stats = {}
for line in output[host].split("\n"):
# Parse " stat_name: value"
for line in raw.split("\n"):
match = re.search(r"^\s+([\w_]+):\s+(\d+)", line)
if match:
stats[match.group(1)] = int(match.group(2))

if stats:
eth_stats[host][iface] = stats

Expand Down Expand Up @@ -242,8 +250,6 @@ async def collect_all_software_info(self, ssh_manager) -> Dict[str, Any]:

logger.info("Collecting all NIC software information")

# IMPORTANT: Run commands SEQUENTIALLY to avoid parallel-ssh thread safety issues
# asyncio.gather() was causing "munmap_chunk(): invalid pointer" crashes
nic_firmware = await self.collect_nic_firmware_version(ssh_manager)
nic_drivers = await self.collect_nic_driver_version(ssh_manager)
rdma_statistics = await self.collect_rdma_statistics_detailed(ssh_manager)
Expand Down
Loading
Loading