Skip to content

Commit b36dfdb

Browse files
authored
[Fix]Set the multiprocessing start method of the test tool to 'spawn'. (#447)
Set the multiprocessing start method of the test tool to 'spawn' and add NPU cleanup
1 parent d613e22 commit b36dfdb

File tree

2 files changed

+57
-10
lines changed

2 files changed

+57
-10
lines changed

ucm/store/test/e2e/nfsstore_embed_fetch.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -271,13 +271,17 @@ def run(
271271
if r == 0:
272272
store_all_hashes(hashes[:batch_size])
273273

274-
w_bw_list.append(w_bw)
275-
w_time_list.append(w_time)
276-
w_size_sum += w_size
274+
if r != 0:
275+
w_bw_list.append(w_bw)
276+
w_time_list.append(w_time)
277+
w_size_sum += w_size
277278

278279
if operation_mode == "write_only":
279280
del kvcaches, hashes
280-
torch.cuda.empty_cache()
281+
if torch.cuda.is_available():
282+
torch.cuda.empty_cache()
283+
elif hasattr(torch, "npu") and torch.npu.is_available():
284+
torch.npu.empty_cache()
281285

282286
if operation_mode in ["read_only", "both"]:
283287
if operation_mode == "read_only":
@@ -310,16 +314,23 @@ def run(
310314
mla,
311315
)
312316

313-
r_bw_list.append(r_bw)
314-
r_time_list.append(r_time)
315-
r_size_sum += r_size
317+
if r != 0:
318+
r_bw_list.append(r_bw)
319+
r_time_list.append(r_time)
320+
r_size_sum += r_size
316321

317322
if operation_mode == "read_only":
318323
del kvcaches
319-
torch.cuda.empty_cache()
324+
if torch.cuda.is_available():
325+
torch.cuda.empty_cache()
326+
elif hasattr(torch, "npu") and torch.npu.is_available():
327+
torch.npu.empty_cache()
320328
else:
321329
del kvcaches, hashes
322-
torch.cuda.empty_cache()
330+
if torch.cuda.is_available():
331+
torch.cuda.empty_cache()
332+
elif hasattr(torch, "npu") and torch.npu.is_available():
333+
torch.npu.empty_cache()
323334

324335
del store
325336
avg_w_bw = sum(w_bw_list) / len(w_bw_list) if w_bw_list else 0.0
@@ -330,3 +341,33 @@ def run(
330341
avg_r_size = r_size_sum / (1024**3) / len(r_time_list) if r_time_list else 0.0
331342

332343
return avg_w_size, avg_w_time, avg_w_bw, avg_r_time, avg_r_bw, avg_r_size
344+
345+
346+
if __name__ == "__main__":
347+
os.environ["UC_LOGGER_LEVEL"] = "debug"
348+
349+
try:
350+
result = run(
351+
storage_backends=".",
352+
device_id=1,
353+
repeat=1,
354+
num_head=1,
355+
block_len=128,
356+
transferStreamNumber=32,
357+
num_tokens=4096,
358+
block_layer=61,
359+
head_size=576,
360+
block_elem_size=2,
361+
kv=1,
362+
mla=True,
363+
transferIoDirect=False,
364+
operation_mode="both",
365+
)
366+
367+
avg_w_size, avg_w_time, avg_w_bw, avg_r_time, avg_r_bw, avg_r_size = result
368+
369+
except Exception as e:
370+
print(f"Error: {e}")
371+
import traceback
372+
373+
traceback.print_exc()

ucm/store/test/e2e/nfsstore_embed_fetch_run.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,15 @@ def get_user_input(prompt, default=None):
4747

4848

4949
def main():
50+
51+
try:
52+
multiprocessing.set_start_method("spawn", force=True)
53+
except RuntimeError:
54+
pass
55+
5056
storage_backends = "."
5157
device_id = 1
52-
repeat = 3
58+
repeat = 3 # This parameter must be greater than 1; the results from the first round of testing are not included in the bandwidth calculation.
5359
num_tokens_list = [2048, 4096, 8192, 16384, 32768]
5460
transferStreamNumbers = [32, 64, 128]
5561

0 commit comments

Comments
 (0)