Problem
In tile_kernels/testing/bench.py, the dtype_to_str() function is incomplete:
def dtype_to_str(dtype: torch.dtype) -> str:
mapping = {
torch.float32: 'fp32',
torch.bfloat16: 'bf16',
torch.float8_e4m3fn: 'e4m3',
torch.int8: 'e2m1', # int8 represents FP4 e2m1 format
}
if dtype not in mapping:
raise ValueError(f'Unsupported dtype: {dtype}. Only fp32, bf16, e4m3, and int8(e2m1) are supported')
return mapping[dtype]
Missing mappings:
torch.float16 → 'fp16'
torch.float8_e5m2 → 'e5m2'
When torch.float16 or torch.float8_e5m2 is passed, it raises a ValueError.
Context
TileKernels includes quantization kernels (tile_kernels/quant/) that use float16 and float8_e5m2 dtypes. The dtype_to_str function is used in benchmark output formatting (via _format_value → make_param_id). If these dtypes are used in a benchmark, the function will crash instead of producing a human-readable string.
Expected Fix
Add the two missing mappings to the mapping dict and update the error message accordingly:
mapping = {
torch.float32: 'fp32',
torch.float16: 'fp16',
torch.bfloat16: 'bf16',
torch.float8_e4m3fn: 'e4m3',
torch.float8_e5m2: 'e5m2',
torch.int8: 'e2m1',
}
Problem
In
tile_kernels/testing/bench.py, thedtype_to_str()function is incomplete:Missing mappings:
torch.float16→'fp16'torch.float8_e5m2→'e5m2'When
torch.float16ortorch.float8_e5m2is passed, it raises aValueError.Context
TileKernels includes quantization kernels (
tile_kernels/quant/) that usefloat16andfloat8_e5m2dtypes. Thedtype_to_strfunction is used in benchmark output formatting (via_format_value→make_param_id). If these dtypes are used in a benchmark, the function will crash instead of producing a human-readable string.Expected Fix
Add the two missing mappings to the
mappingdict and update the error message accordingly: