Skip to content

[Feature Request] Vectorize float/bfloat16/float16 into float2/bfloat162/float162 in TileOPs #1847

@bucket-xv

Description

@bucket-xv

Required prerequisites

  • I have searched the Issue Tracker that this hasn't already been reported. (Comment there if it has.)

Motivation

Logically, T.float32, T.bfloat16, and T.float16 are basic operating units. However, after testing NVCC, the CUDA compiler, we find that the compiler does not inherently pack two bfloat16 into one register(all registers are 32 bits) as one operand.

As a result, they fail to leverage the habs2, hmax2, hmul2, hadd2 and so on(https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH____BFLOAT162__ARITHMETIC.html) vectorizing functions designed for two-way bfloat16/float16 operation handling and ffma2, fadd2(https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH__INTRINSIC__SINGLE.html) vectorizing functions designed for two-way float operations.

Solution

Maybe tile-ops for T.bfloat16/T.float16/T.float32 should consider the usage of bfloat162/float162/float2 (explicitly pack them). If explicit packing in a TileOP, we should take care of the layout, as the physical neighbours are more likely to be in one register. Hopefully, at the end of this optimization, all bfloat16 kernels should achieve nearly double the throughput for the same number of elements.

Alternatives

  • Use T.vectorize(2) to hint the Tilelang compiler? Add an optional flag of vec in T.Parrallel and T.reduce, suggesting that Tilelang use these two-way instructions? Or a passconfig to control the behavior?
  • Let the user pack T.bfloat16 into T.bfloat16_x2? However, this harms user experience and makes the language confusing. For example, what does the user mean when they call T.reduce on a T.bfloat16_x2 type?

Additional context

Below, an example is given to clarify the motivation and the desired result.

Example program:

import tilelang
from tilelang import language as T
import torch

@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        tilelang.PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT: True,
    }
)
def get_sample_kernel(len: int):
    num_threads = 32
    @T.prim_func
    def sample_kernel(a: T.Tensor[(len, ), T.bfloat16], out: T.Tensor[(1, ), T.bfloat16]):
        with T.Kernel(1, threads=num_threads) as block_idx:
            a_fragment = T.alloc_fragment((len,), T.bfloat16)
            b_fragment = T.alloc_fragment(1, T.bfloat16)
            T.copy(a, a_fragment)
            T.reduce_sum(a_fragment, b_fragment)
            T.copy(b_fragment, out)

    return sample_kernel

len = 2048
t = torch.randn((len, ), dtype=torch.bfloat16, device='cuda')
out = torch.empty((1, ), dtype=torch.bfloat16, device='cuda')
kernel = get_sample_kernel(len)
print(kernel.get_kernel_source())
kernel.export_ptx('kernel.ptx')
kernel.export_sass('kernel.sass')
kernel(t, out)

The output CUDA code:

#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>
#ifdef ENABLE_BF16
#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>
#endif

extern "C" __global__ void sample_kernel_kernel(const bfloat16_t* __restrict__ a, bfloat16_t* __restrict__ out);
extern "C" __global__ void __launch_bounds__(32, 1) sample_kernel_kernel(const bfloat16_t* __restrict__ a, bfloat16_t* __restrict__ out) {
  bfloat16_t a_fragment[64];
  bfloat16_t b_fragment[1];
  #pragma unroll
  for (int i = 0; i < 4; ++i) {
    *(ulonglong4*)(a_fragment + (i * 16)) = tl::load_global_256(&(*(ulonglong4*)(a + ((i * 512) + (((int)threadIdx.x) * 16)))));
  }
  b_fragment[0] = bfloat16_t(0x0p+0f/*0.000000e+00*/);
  #pragma unroll
  for (int rv = 0; rv < 64; ++rv) {
    b_fragment[0] = (b_fragment[0] + a_fragment[(((rv & 3) * 16) + (rv >> 2))]);
  }
  b_fragment[0] = tl::AllReduce<tl::SumOp, 32, 1, 0, 32>::run_hopper(b_fragment[0]);
  if (((int)threadIdx.x) == 0) {
    out[0] = b_fragment[0];
  }
}

PTX

For short:
{ add.bf16 %rs12,%rs9,%rs14; } * 64
then reduce across threads

What we want to see is add.bf16x2, from https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions