1+ #pragma once
2+
3+ __device__ uint4 dequantize_s4_to_fp16x2 (uint32_t const &source) {
4+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
5+ assert (false );
6+ #else
7+ uint4 result;
8+
9+ uint32_t *h = reinterpret_cast <uint32_t *>(&result);
10+ uint32_t const i4s = reinterpret_cast <uint32_t const &>(source);
11+
12+ // First, we extract the i4s and construct an intermediate fp16 number.
13+ static constexpr uint32_t immLut = (0xf0 & 0xcc ) | 0xaa ;
14+ static constexpr uint32_t BOTTOM_MASK = 0x000f000f ;
15+ static constexpr uint32_t TOP_MASK = 0x00f000f0 ;
16+ static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400 ;
17+
18+ // Note that the entire sequence only requires 1 shift instruction. This is
19+ // thanks to the register packing format and the fact that we force our
20+ // integers to be unsigned, and account for this in the fp16 subtractions. In
21+ // addition, I exploit the fact that sub and fma have the same throughput in
22+ // order to convert elt_23 and elt_67 to fp16 without having to shift them to
23+ // the bottom bits before hand.
24+
25+ // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
26+ // dependency if we issue immediately before required.
27+ const uint32_t top_i4s = i4s >> 8 ;
28+ // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
29+ asm volatile (" lop3.b32 %0, %1, %2, %3, %4;\n "
30+ : " =r" (h[0 ])
31+ : " r" (i4s), " n" (BOTTOM_MASK), " n" (I4s_TO_F16s_MAGIC_NUM),
32+ " n" (immLut));
33+ // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
34+ asm volatile (" lop3.b32 %0, %1, %2, %3, %4;\n "
35+ : " =r" (h[1 ])
36+ : " r" (i4s), " n" (TOP_MASK), " n" (I4s_TO_F16s_MAGIC_NUM),
37+ " n" (immLut));
38+ // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
39+ asm volatile (" lop3.b32 %0, %1, %2, %3, %4;\n "
40+ : " =r" (h[2 ])
41+ : " r" (top_i4s), " n" (BOTTOM_MASK), " n" (I4s_TO_F16s_MAGIC_NUM),
42+ " n" (immLut));
43+ // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
44+ asm volatile (" lop3.b32 %0, %1, %2, %3, %4;\n "
45+ : " =r" (h[3 ])
46+ : " r" (top_i4s), " n" (TOP_MASK), " n" (I4s_TO_F16s_MAGIC_NUM),
47+ " n" (immLut));
48+
49+ // I use inline PTX below because I am not sure if the compiler will emit
50+ // float2half instructions if I use the half2 ctor. In this case, I chose
51+ // performance reliability over code readability.
52+
53+ // This is the half2 {1032, 1032} represented as an integer.
54+ // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
55+ // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
56+ static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400 ;
57+ // This is the half2 {1 / 16, 1 / 16} represented as an integer.
58+ static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00 ;
59+ // This is the half2 {-72, -72} represented as an integer.
60+ // static constexpr uint32_t NEG_72 = 0xd480d480;
61+ // Haotian: Let's use {-64, -64}.
62+ static constexpr uint32_t NEG_64 = 0xd400d400 ;
63+
64+ // Finally, we construct the output numbers.
65+ // Convert elt_01
66+ asm volatile (" sub.f16x2 %0, %1, %2;\n "
67+ : " =r" (h[0 ])
68+ : " r" (h[0 ]), " r" (FP16_TOP_MAGIC_NUM));
69+ // Convert elt_23
70+ asm volatile (" fma.rn.f16x2 %0, %1, %2, %3;\n "
71+ : " =r" (h[1 ])
72+ : " r" (h[1 ]), " r" (ONE_SIXTEENTH), " r" (NEG_64));
73+ // Convert elt_45
74+ asm volatile (" sub.f16x2 %0, %1, %2;\n "
75+ : " =r" (h[2 ])
76+ : " r" (h[2 ]), " r" (FP16_TOP_MAGIC_NUM));
77+ // Convert elt_67
78+ asm volatile (" fma.rn.f16x2 %0, %1, %2, %3;\n "
79+ : " =r" (h[3 ])
80+ : " r" (h[3 ]), " r" (ONE_SIXTEENTH), " r" (NEG_64));
81+
82+ return result;
83+ #endif
84+ __builtin_unreachable (); // Suppress missing return statement warning
85+ }
0 commit comments