@@ -271,13 +271,17 @@ def run(
271271 if r == 0 :
272272 store_all_hashes (hashes [:batch_size ])
273273
274- w_bw_list .append (w_bw )
275- w_time_list .append (w_time )
276- w_size_sum += w_size
274+ if r != 0 :
275+ w_bw_list .append (w_bw )
276+ w_time_list .append (w_time )
277+ w_size_sum += w_size
277278
278279 if operation_mode == "write_only" :
279280 del kvcaches , hashes
280- torch .cuda .empty_cache ()
281+ if torch .cuda .is_available ():
282+ torch .cuda .empty_cache ()
283+ elif hasattr (torch , "npu" ) and torch .npu .is_available ():
284+ torch .npu .empty_cache ()
281285
282286 if operation_mode in ["read_only" , "both" ]:
283287 if operation_mode == "read_only" :
@@ -310,16 +314,23 @@ def run(
310314 mla ,
311315 )
312316
313- r_bw_list .append (r_bw )
314- r_time_list .append (r_time )
315- r_size_sum += r_size
317+ if r != 0 :
318+ r_bw_list .append (r_bw )
319+ r_time_list .append (r_time )
320+ r_size_sum += r_size
316321
317322 if operation_mode == "read_only" :
318323 del kvcaches
319- torch .cuda .empty_cache ()
324+ if torch .cuda .is_available ():
325+ torch .cuda .empty_cache ()
326+ elif hasattr (torch , "npu" ) and torch .npu .is_available ():
327+ torch .npu .empty_cache ()
320328 else :
321329 del kvcaches , hashes
322- torch .cuda .empty_cache ()
330+ if torch .cuda .is_available ():
331+ torch .cuda .empty_cache ()
332+ elif hasattr (torch , "npu" ) and torch .npu .is_available ():
333+ torch .npu .empty_cache ()
323334
324335 del store
325336 avg_w_bw = sum (w_bw_list ) / len (w_bw_list ) if w_bw_list else 0.0
@@ -330,3 +341,33 @@ def run(
330341 avg_r_size = r_size_sum / (1024 ** 3 ) / len (r_time_list ) if r_time_list else 0.0
331342
332343 return avg_w_size , avg_w_time , avg_w_bw , avg_r_time , avg_r_bw , avg_r_size
344+
345+
346+ if __name__ == "__main__" :
347+ os .environ ["UC_LOGGER_LEVEL" ] = "debug"
348+
349+ try :
350+ result = run (
351+ storage_backends = "." ,
352+ device_id = 1 ,
353+ repeat = 1 ,
354+ num_head = 1 ,
355+ block_len = 128 ,
356+ transferStreamNumber = 32 ,
357+ num_tokens = 4096 ,
358+ block_layer = 61 ,
359+ head_size = 576 ,
360+ block_elem_size = 2 ,
361+ kv = 1 ,
362+ mla = True ,
363+ transferIoDirect = False ,
364+ operation_mode = "both" ,
365+ )
366+
367+ avg_w_size , avg_w_time , avg_w_bw , avg_r_time , avg_r_bw , avg_r_size = result
368+
369+ except Exception as e :
370+ print (f"Error: { e } " )
371+ import traceback
372+
373+ traceback .print_exc ()
0 commit comments