Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/benchmarks/lib/test_endpoint_request_func_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Test cases for endpoint_request_func.py
"""
Expand Down
14 changes: 14 additions & 0 deletions tests/benchmarks/lib/test_utils_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import tempfile
Expand Down
14 changes: 14 additions & 0 deletions tests/benchmarks/test_datasets_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import json
from argparse import ArgumentParser, Namespace
Expand Down
255 changes: 173 additions & 82 deletions tests/cache_manager/test_cache_transfer_manager.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import time
import unittest
from unittest.mock import MagicMock, patch

import fastdeploy.cache_manager.cache_transfer_manager as cache_transfer_manager
import fastdeploy.cache_manager.transfer_factory.rdma_cache_transfer as rdma_module
from fastdeploy.cache_manager.cache_transfer_manager import CacheTransferManager


# ==========================
# 测试用 Args
# ==========================
# Test Configuration
class Args:
"""Test configuration class to simulate input arguments for CacheTransferManager."""

rank = 0
local_data_parallel_id = 0
mp_num = 1
Expand All @@ -27,137 +43,212 @@ class Args:
create_cache_tensor = False


# ==========================
# 测试类
# ==========================
# RDMA Test Utilities
def create_rdma_manager(rdma_comm, splitwise_role="prefill"):
"""Factory function to create RDMACommManager instance with default test parameters.

Args:
rdma_comm: Mocked rdma_comm module or None
splitwise_role (str): Splitwise role, default to "prefill"

Returns:
rdma_module.RDMACommManager: Initialized RDMACommManager instance
"""
return rdma_module.RDMACommManager(
splitwise_role=splitwise_role,
rank=0,
gpu_id=0,
cache_k_ptr_list=[1, 2],
cache_v_ptr_list=[3, 4],
max_block_num=10,
block_bytes=1024,
rdma_port=20000,
)


# CacheTransferManager Test Cases
class TestCacheTransferManager(unittest.TestCase):
"""Unit test suite for CacheTransferManager class."""

def setUp(self):
# --------------------------
# mock logger
# --------------------------
"""Set up test fixtures before each test method.

Mocks dependencies, initializes test objects, and configures test environment.
"""
# Mock logger
cache_transfer_manager.logger = MagicMock()

# --------------------------
# mock current_platform
# --------------------------
# Mock current platform detection
class DummyPlatform:
"""Mock platform class to disable specific hardware checks in tests."""

@staticmethod
def is_iluvatar():
return False

@staticmethod
def is_xpu():
# 测试环境下不使用 XPU,返回 False
return False
return False # Disable XPU in test environment

@staticmethod
def is_cuda():
# 测试环境下不使用 CUDA,返回 False
return False
return False # Disable CUDA in test environment

cache_transfer_manager.current_platform = DummyPlatform()

# --------------------------
# mock EngineCacheQueue
# --------------------------
patcher1 = patch("fastdeploy.cache_manager.cache_transfer_manager.EngineCacheQueue", new=MagicMock())
patcher1.start()
self.addCleanup(patcher1.stop)

# --------------------------
# mock IPCSignal
# --------------------------
patcher2 = patch("fastdeploy.cache_manager.cache_transfer_manager.IPCSignal", new=MagicMock())
patcher2.start()
self.addCleanup(patcher2.stop)

# --------------------------
# mock _init_cpu_cache 和 _init_gpu_cache
# --------------------------
patcher3 = patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None)
patcher4 = patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None)
patcher3.start()
patcher4.start()
self.addCleanup(patcher3.stop)
self.addCleanup(patcher4.stop)

# --------------------------
# 创建 manager
# --------------------------
# Mock EngineCacheQueue class
self.engine_cache_queue_patcher = patch(
"fastdeploy.cache_manager.cache_transfer_manager.EngineCacheQueue", new=MagicMock()
)
self.engine_cache_queue_patcher.start()

# Mock IPCSignal class
self.ipc_signal_patcher = patch("fastdeploy.cache_manager.cache_transfer_manager.IPCSignal", new=MagicMock())
self.ipc_signal_patcher.start()

# Mock cache initialization methods to avoid actual resource allocation
self.init_cpu_cache_patcher = patch.object(CacheTransferManager, "_init_cpu_cache", lambda self, args: None)
self.init_gpu_cache_patcher = patch.object(CacheTransferManager, "_init_gpu_cache", lambda self, args: None)
self.init_cpu_cache_patcher.start()
self.init_gpu_cache_patcher.start()

# Initialize CacheTransferManager with test configuration
self.manager = CacheTransferManager(Args())

# --------------------------
# mock worker_healthy_live_signal
# --------------------------
# Mock worker health check signal
class DummySignal:
"""Mock signal class to simulate worker health status."""

def __init__(self):
self.value = [0]
self.value = [0] # Default to unhealthy initial state

self.manager.worker_healthy_live_signal = DummySignal()

# --------------------------
# mock swap thread pools
# --------------------------
# Mock thread pools for swap operations
self.manager.swap_to_cpu_thread_pool = MagicMock()
self.manager.swap_to_gpu_thread_pool = MagicMock()

# --------------------------
# mock cache_task_queue
# --------------------------
# Mock cache task queue with test data
self.manager.cache_task_queue = MagicMock()
self.manager.cache_task_queue.empty.return_value = False
self.manager.cache_task_queue.get_transfer_task.return_value = (([0], 0, 0, MagicMock(value=0), 0), True)
self.manager.cache_task_queue.barrier1 = MagicMock()
self.manager.cache_task_queue.barrier2 = MagicMock()
self.manager.cache_task_queue.barrier3 = MagicMock()

# --------------------------
# 避免 sleep 阻塞测试
# --------------------------
self.sleep_patch = patch("time.sleep", lambda x: None)
self.sleep_patch.start()
self.addCleanup(self.sleep_patch.stop)
# Mock time.sleep to prevent test blocking
self.sleep_patcher = patch("time.sleep", lambda x: None)
self.sleep_patcher.start()

def tearDown(self):
"""Clean up test fixtures after each test method."""
self.engine_cache_queue_patcher.stop()
self.ipc_signal_patcher.stop()
self.init_cpu_cache_patcher.stop()
self.init_gpu_cache_patcher.stop()
self.sleep_patcher.stop()

# ==========================
# check_work_status 测试
# ==========================
def test_check_work_status_no_signal(self):
"""Test check_work_status when no health signal is set.

Verifies that the method returns healthy status with empty message
when the health signal value is 0 (initial state).
"""
healthy, msg = self.manager.check_work_status()
self.assertTrue(healthy)
self.assertEqual(msg, "")

def test_check_work_status_healthy(self):
"""Test check_work_status with valid (recent) health signal.

Verifies that the method returns healthy status when the health signal
is set to current time (within threshold).
"""
self.manager.worker_healthy_live_signal.value[0] = int(time.time())
healthy, msg = self.manager.check_work_status()
self.assertTrue(healthy)
self.assertEqual(msg, "")

def test_check_work_status_unhealthy(self):
"""Test check_work_status with expired health signal.

Verifies that the method returns unhealthy status with appropriate
message when the health signal is older than the threshold.
"""
self.manager.worker_healthy_live_signal.value[0] = int(time.time()) - 1000
healthy, msg = self.manager.check_work_status(time_interval_threashold=10)
self.assertFalse(healthy)
self.assertIn("Not Healthy", msg)

# ==========================
# do_data_transfer 异常处理测试
# ==========================
def test_do_data_transfer_broken_pipe(self):
# mock get_transfer_task 抛出 BrokenPipeError
self.manager.cache_task_queue.get_transfer_task.side_effect = BrokenPipeError("mock broken pipe")

# mock check_work_status 返回 False,触发 break
self.manager.check_work_status = MagicMock(return_value=(False, "Not Healthy"))

# patch do_data_transfer 本身,避免死循环
with patch.object(self.manager, "do_data_transfer") as mock_transfer:
mock_transfer.side_effect = lambda: None # 直接返回,不执行死循环
self.manager.do_data_transfer()

# 验证 check_work_status 已被调用
self.assertTrue(self.manager.check_work_status.called or True)
# 验证 logger 调用
self.assertTrue(cache_transfer_manager.logger.error.called or True)
self.assertTrue(cache_transfer_manager.logger.critical.called or True)

# RDMACommManager Test Cases
class TestRDMACommManager(unittest.TestCase):
"""Unit test suite for RDMACommManager class."""

def test_init_with_rdma_comm(self):
"""Test RDMACommManager initialization with valid rdma_comm module.

Verifies that the messager is created using RDMACommunicator and
instance attributes are set correctly.
"""
mock_comm = MagicMock()
with patch.dict(sys.modules, {"rdma_comm": mock_comm}):
manager = create_rdma_manager(mock_comm)

mock_comm.RDMACommunicator.assert_called_once()
self.assertTrue(hasattr(manager, "messager"))
self.assertEqual(manager.splitwise_role, "prefill")

def test_connect_nominal(self):
"""Test connect method with valid prefill role.

Verifies that connect succeeds (returns True) when called with
prefill role and RDMA connection is successful.
"""
mock_comm = MagicMock()
mock_instance = MagicMock()
mock_instance.is_connected.return_value = False
mock_instance.connect.return_value = 0 # Simulate successful connection
mock_comm.RDMACommunicator.return_value = mock_instance

with patch.dict(sys.modules, {"rdma_comm": mock_comm}):
manager = create_rdma_manager(mock_comm, splitwise_role="prefill")

result = manager.connect("127.0.0.1", 5001)
self.assertTrue(result)
mock_instance.connect.assert_called_once_with("127.0.0.1", "5001")

def test_connect_invalid_role(self):
"""Test connect method with invalid role (decode).

Verifies that an AssertionError is raised when connect is called
with a role other than prefill.
"""
mock_comm = MagicMock()
mock_comm.RDMACommunicator.return_value = MagicMock()

with patch.dict(sys.modules, {"rdma_comm": mock_comm}):
manager = create_rdma_manager(mock_comm, splitwise_role="decode")

with self.assertRaises(AssertionError):
manager.connect("1.2.3.4", 1234)

def test_write_cache(self):
"""Test write_cache method parameter passing.

Verifies that write_cache correctly forwards all parameters to
the underlying messager's write_cache method.
"""
mock_comm = MagicMock()
mock_instance = MagicMock()
mock_comm.RDMACommunicator.return_value = mock_instance

with patch.dict(sys.modules, {"rdma_comm": mock_comm}):
manager = create_rdma_manager(mock_comm)

manager.write_cache("1.1.1.1", 9999, [1], [2], 3)

mock_instance.write_cache.assert_called_once_with("1.1.1.1", "9999", [1], [2], 3)


if __name__ == "__main__":
Expand Down
14 changes: 14 additions & 0 deletions tests/ce/deploy/deploy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ast
import json
import os
Expand Down
Loading
Loading