Skip to content

dtype_to_str: unsupported dtype raises ValueError for fp16/float8_e5m2 #4

@pandacooming

Description

@pandacooming

Problem

In tile_kernels/testing/bench.py, the dtype_to_str() function is incomplete:

def dtype_to_str(dtype: torch.dtype) -> str:
    mapping = {
        torch.float32: 'fp32',
        torch.bfloat16: 'bf16',
        torch.float8_e4m3fn: 'e4m3',
        torch.int8: 'e2m1',  # int8 represents FP4 e2m1 format
    }
    if dtype not in mapping:
        raise ValueError(f'Unsupported dtype: {dtype}. Only fp32, bf16, e4m3, and int8(e2m1) are supported')
    return mapping[dtype]

Missing mappings:

  • torch.float16'fp16'
  • torch.float8_e5m2'e5m2'

When torch.float16 or torch.float8_e5m2 is passed, it raises a ValueError.

Context

TileKernels includes quantization kernels (tile_kernels/quant/) that use float16 and float8_e5m2 dtypes. The dtype_to_str function is used in benchmark output formatting (via _format_valuemake_param_id). If these dtypes are used in a benchmark, the function will crash instead of producing a human-readable string.

Expected Fix

Add the two missing mappings to the mapping dict and update the error message accordingly:

mapping = {
    torch.float32: 'fp32',
    torch.float16: 'fp16',
    torch.bfloat16: 'bf16',
    torch.float8_e4m3fn: 'e4m3',
    torch.float8_e5m2: 'e5m2',
    torch.int8: 'e2m1',
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions