Skip to content

Commit 9ad23fa

Browse files
authored
[T2-2-3] blkmjsian
- dequantize awq - rope v2
1 parent b317033 commit 9ad23fa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+3227
-3
lines changed

include/infiniop.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,18 @@
77
#include "infiniop/ops/causal_softmax.h"
88
#include "infiniop/ops/clip.h"
99
#include "infiniop/ops/conv.h"
10+
#include "infiniop/ops/dequantize.h"
1011
#include "infiniop/ops/gemm.h"
1112
#include "infiniop/ops/mul.h"
1213
#include "infiniop/ops/random_sample.h"
1314
#include "infiniop/ops/rearrange.h"
1415
#include "infiniop/ops/relu.h"
1516
#include "infiniop/ops/rms_norm.h"
1617
#include "infiniop/ops/rope.h"
18+
#include "infiniop/ops/rope_v2.h"
1719
#include "infiniop/ops/sub.h"
1820
#include "infiniop/ops/swiglu.h"
21+
#include "infiniop/ops/topkrouter.h"
1922
#include "infiniop/tensor_descriptor.h"
2023

2124
#endif // __INFINIOP_API_H__

include/infiniop/ops/dequantize.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#ifndef __INFINIOP_DEQUANTIZE_API_H__
2+
#define __INFINIOP_DEQUANTIZE_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopDequantizeDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateDequantizeDescriptor(infiniopHandle_t handle,
9+
infiniopDequantizeDescriptor_t *desc_ptr,
10+
infiniopTensorDescriptor_t out_desc,
11+
infiniopTensorDescriptor_t qweight_desc,
12+
infiniopTensorDescriptor_t scales_desc,
13+
infiniopTensorDescriptor_t zeros_desc);
14+
15+
__C __export infiniStatus_t infiniopGetDequantizeWorkspaceSize(infiniopDequantizeDescriptor_t desc, size_t *size);
16+
17+
__C __export infiniStatus_t infiniopDequantize(infiniopDequantizeDescriptor_t desc,
18+
void *workspace,
19+
size_t workspace_size,
20+
void *out,
21+
const void *qweight,
22+
const void *scales,
23+
const void *zeros,
24+
size_t split_k_iters,
25+
size_t thx,
26+
size_t thy,
27+
void *stream);
28+
29+
__C __export infiniStatus_t infiniopDestroyDequantizeDescriptor(infiniopDequantizeDescriptor_t desc);
30+
31+
#endif

include/infiniop/ops/rope_v2.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef __INFINIOP_ROPE_V2_API_H__
2+
#define __INFINIOP_ROPE_V2_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopRoPEv2Descriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateRoPEv2Descriptor(
9+
infiniopHandle_t handle,
10+
infiniopRoPEv2Descriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t y,
12+
infiniopTensorDescriptor_t x,
13+
infiniopTensorDescriptor_t pos_ids,
14+
infiniopTensorDescriptor_t sin_table,
15+
infiniopTensorDescriptor_t cos_table);
16+
17+
__C __export infiniStatus_t infiniopGetRoPEv2WorkspaceSize(infiniopRoPEv2Descriptor_t desc, size_t *size);
18+
19+
__C __export infiniStatus_t infiniopRoPEv2(
20+
infiniopRoPEv2Descriptor_t desc,
21+
void *workspace,
22+
size_t workspace_size,
23+
void *y,
24+
const void *x,
25+
void const *pos_ids,
26+
void const *sin_table,
27+
void const *cos_table,
28+
void *stream);
29+
30+
__C __export infiniStatus_t infiniopDestroyRoPEv2Descriptor(infiniopRoPEv2Descriptor_t desc);
31+
32+
#endif

include/infiniop/ops/topkrouter.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifndef __INFINIOP_TOPKRouter_API_H__
2+
#define __INFINIOP_TOPKRouter_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopTopkrouterDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateTopkrouterDescriptor(
9+
infiniopHandle_t handle,
10+
infiniopTopkrouterDescriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t x_desc,
12+
infiniopTensorDescriptor_t correction_bias_desc);
13+
14+
__C __export infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescriptor_t desc, size_t *size);
15+
16+
__C __export infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void *workspace, size_t workspace_size,
17+
void *values, void *indices, void *x, void *correction_bias, float routed_scaling_factor, size_t topk, void *stream);
18+
19+
__C __export infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescriptor_t desc);
20+
21+
#endif
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#ifndef __DEQUANTIZE_H__
2+
#define __DEQUANTIZE_H__
3+
4+
#include "../../../utils.h"
5+
#include "../../operator.h"
6+
#include "../../tensor.h"
7+
#include "info.h"
8+
9+
#define DESCRIPTOR(NAMESPACE) \
10+
\
11+
namespace op::dequantize::NAMESPACE { \
12+
class Descriptor final : public InfiniopDescriptor { \
13+
struct Opaque; \
14+
Opaque *_opaque; \
15+
DequantizeInfo _info; \
16+
size_t _workspace_size; \
17+
\
18+
Descriptor( \
19+
size_t workspace_size_, \
20+
Opaque *opaque, \
21+
DequantizeInfo info, \
22+
infiniDevice_t device_type, \
23+
int device_id) \
24+
: InfiniopDescriptor{device_type, device_id}, \
25+
_opaque(opaque), \
26+
_info(info), \
27+
_workspace_size(workspace_size_) {} \
28+
\
29+
public: \
30+
~Descriptor(); \
31+
\
32+
size_t workspaceSize() const { return _workspace_size; } \
33+
\
34+
static infiniStatus_t create( \
35+
infiniopHandle_t handle, \
36+
Descriptor **desc_ptr, \
37+
infiniopTensorDescriptor_t out_desc, \
38+
infiniopTensorDescriptor_t qweight_desc, \
39+
infiniopTensorDescriptor_t scales_desc, \
40+
infiniopTensorDescriptor_t zeros_desc); \
41+
\
42+
infiniStatus_t calculate( \
43+
void *workspace, \
44+
size_t workspace_size, \
45+
void *out, \
46+
const void *qweight, \
47+
const void *scales, \
48+
const void *zeros, \
49+
int split_k_iters, \
50+
int thx, \
51+
int thy, \
52+
void *stream) const; \
53+
}; \
54+
}
55+
#endif

src/infiniop/ops/dequantize/info.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#ifndef __DEQUANTIZE_INFO_H__
2+
#define __DEQUANTIZE_INFO_H__
3+
4+
#include "../../../utils.h"
5+
#include "../../tensor.h"
6+
#include <vector>
7+
8+
namespace op::dequantize {
9+
10+
class DequantizeInfo {
11+
DequantizeInfo() = default;
12+
13+
public:
14+
int _in_c, _qout_c, _G;
15+
16+
int in_c() const { return _in_c; }
17+
int qout_c() const { return _qout_c; }
18+
int G() const { return _G; }
19+
20+
static utils::Result<DequantizeInfo> create(
21+
infiniopTensorDescriptor_t out_desc,
22+
infiniopTensorDescriptor_t qweight_desc,
23+
infiniopTensorDescriptor_t scales_desc,
24+
infiniopTensorDescriptor_t zeros_desc) {
25+
26+
int _in_c = qweight_desc->dim(0);
27+
int _qout_c = qweight_desc->dim(1);
28+
int _G = scales_desc->dim(0);
29+
30+
return utils::Result<DequantizeInfo>(DequantizeInfo{
31+
_in_c,
32+
_qout_c,
33+
_G});
34+
}
35+
};
36+
37+
} // namespace op::dequantize
38+
39+
#endif // __DEQUANTIZE_INFO_H__
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#pragma once
2+
3+
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const &source) {
4+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
5+
assert(false);
6+
#else
7+
uint4 result;
8+
9+
uint32_t *h = reinterpret_cast<uint32_t *>(&result);
10+
uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);
11+
12+
// First, we extract the i4s and construct an intermediate fp16 number.
13+
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
14+
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
15+
static constexpr uint32_t TOP_MASK = 0x00f000f0;
16+
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
17+
18+
// Note that the entire sequence only requires 1 shift instruction. This is
19+
// thanks to the register packing format and the fact that we force our
20+
// integers to be unsigned, and account for this in the fp16 subtractions. In
21+
// addition, I exploit the fact that sub and fma have the same throughput in
22+
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
23+
// the bottom bits before hand.
24+
25+
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
26+
// dependency if we issue immediately before required.
27+
const uint32_t top_i4s = i4s >> 8;
28+
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
29+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
30+
: "=r"(h[0])
31+
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
32+
"n"(immLut));
33+
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
34+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
35+
: "=r"(h[1])
36+
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
37+
"n"(immLut));
38+
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
39+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
40+
: "=r"(h[2])
41+
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
42+
"n"(immLut));
43+
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
44+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
45+
: "=r"(h[3])
46+
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
47+
"n"(immLut));
48+
49+
// I use inline PTX below because I am not sure if the compiler will emit
50+
// float2half instructions if I use the half2 ctor. In this case, I chose
51+
// performance reliability over code readability.
52+
53+
// This is the half2 {1032, 1032} represented as an integer.
54+
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
55+
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
56+
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
57+
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
58+
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
59+
// This is the half2 {-72, -72} represented as an integer.
60+
// static constexpr uint32_t NEG_72 = 0xd480d480;
61+
// Haotian: Let's use {-64, -64}.
62+
static constexpr uint32_t NEG_64 = 0xd400d400;
63+
64+
// Finally, we construct the output numbers.
65+
// Convert elt_01
66+
asm volatile("sub.f16x2 %0, %1, %2;\n"
67+
: "=r"(h[0])
68+
: "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
69+
// Convert elt_23
70+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
71+
: "=r"(h[1])
72+
: "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
73+
// Convert elt_45
74+
asm volatile("sub.f16x2 %0, %1, %2;\n"
75+
: "=r"(h[2])
76+
: "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
77+
// Convert elt_67
78+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
79+
: "=r"(h[3])
80+
: "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
81+
82+
return result;
83+
#endif
84+
__builtin_unreachable(); // Suppress missing return statement warning
85+
}

0 commit comments

Comments
 (0)