|
25 | 25 | import numpy as np |
26 | 26 | import paddle |
27 | 27 |
|
| 28 | +from fastdeploy.cache_manager.ops import ( |
| 29 | + get_output_kv_signal, |
| 30 | + get_peer_mem_addr, |
| 31 | + memory_allocated, |
| 32 | + set_data_ipc, |
| 33 | + set_device, |
| 34 | +) |
28 | 35 | from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager |
29 | 36 | from fastdeploy.config import SpeculativeConfig |
30 | 37 | from fastdeploy.inter_communicator import ( |
31 | 38 | EngineWorkerQueue, |
32 | 39 | IPCSignal, |
33 | 40 | shared_memory_exists, |
34 | 41 | ) |
35 | | -from fastdeploy.model_executor.ops.gpu import get_output_kv_signal, set_data_ipc |
36 | 42 | from fastdeploy.utils import envs, get_logger |
37 | 43 |
|
38 | 44 | logger = get_logger("cache_messager", "cache_messager.log") |
@@ -157,16 +163,20 @@ def __init__( |
157 | 163 | val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"] |
158 | 164 | cache_k.append(key_cache) |
159 | 165 | cache_v.append(val_cache) |
160 | | - cache_k_ptr_list.append(key_cache.data_ptr()) |
161 | | - cache_v_ptr_list.append(val_cache.data_ptr()) |
| 166 | + if paddle.is_compiled_with_xpu(): |
| 167 | + cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr())) |
| 168 | + cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr())) |
| 169 | + else: |
| 170 | + cache_k_ptr_list.append(key_cache.data_ptr()) |
| 171 | + cache_v_ptr_list.append(val_cache.data_ptr()) |
162 | 172 | cache_k_ptr_list = np.array(cache_k_ptr_list) |
163 | 173 | cache_v_ptr_list = np.array(cache_v_ptr_list) |
164 | 174 |
|
165 | 175 | # 2. initialize the block_bytes |
166 | 176 | cache_shape = key_cache.shape |
167 | 177 | max_block_num = cache_shape[0] |
168 | 178 | block_bytes = math.prod(cache_shape[1:]) |
169 | | - if key_cache.dtype == paddle.bfloat16: |
| 179 | + if key_cache.dtype == paddle.bfloat16 or key_cache.dtype == paddle.float16: |
170 | 180 | block_bytes *= 2 |
171 | 181 | logger.info( |
172 | 182 | f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, " |
@@ -452,8 +462,12 @@ def __init__( |
452 | 462 | val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"] |
453 | 463 | cache_k.append(key_cache) |
454 | 464 | cache_v.append(val_cache) |
455 | | - cache_k_ptr_list.append(key_cache.data_ptr()) |
456 | | - cache_v_ptr_list.append(val_cache.data_ptr()) |
| 465 | + if paddle.is_compiled_with_xpu(): |
| 466 | + cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr())) |
| 467 | + cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr())) |
| 468 | + else: |
| 469 | + cache_k_ptr_list.append(key_cache.data_ptr()) |
| 470 | + cache_v_ptr_list.append(val_cache.data_ptr()) |
457 | 471 | cache_k_ptr_list = np.array(cache_k_ptr_list) |
458 | 472 | cache_v_ptr_list = np.array(cache_v_ptr_list) |
459 | 473 |
|
@@ -763,7 +777,7 @@ def _handle_connect_task(self): |
763 | 777 | def main(): |
764 | 778 | device = args.device_id |
765 | 779 | rank = args.rank |
766 | | - paddle.set_device(f"gpu:{device}") |
| 780 | + set_device(device) |
767 | 781 | cache_type = args.cache_dtype |
768 | 782 | speculative_config = SpeculativeConfig(args.speculative_config) |
769 | 783 | num_extra_layers = speculative_config.num_extra_cache_layer |
@@ -823,7 +837,7 @@ def main(): |
823 | 837 | cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()]) |
824 | 838 | logger.info(f"device :{device}") |
825 | 839 | logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}") |
826 | | - logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}") |
| 840 | + logger.info(f"done init cache (full) gmem alloc : {memory_allocated}") |
827 | 841 |
|
828 | 842 | if envs.ENABLE_V1_KVCACHE_SCHEDULER: |
829 | 843 | cache_messager = CacheMessagerV1( |
@@ -875,7 +889,6 @@ def main(): |
875 | 889 | args = parse_args() |
876 | 890 | rank_id = args.rank + args.local_data_parallel_id * args.mp_num |
877 | 891 | logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log") |
878 | | - |
879 | 892 | logger.info("create cache messager...") |
880 | 893 | logger.info(f"{args}") |
881 | 894 | main() |
0 commit comments