1313#define FLOAT4 (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
1414#define HALF2 (value ) (reinterpret_cast <half2*>(&(value))[0 ])
1515#define BFLOAT2 (value ) (reinterpret_cast <__nv_bfloat162*>(&(value))[0 ])
16+ #define LDST128BITS (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
1617
1718// -------------------------------------- FP32 --------------------------------------
1819// Relu x: N, y: N y=max(0,x)
@@ -81,6 +82,24 @@ __global__ void relu_f16x8_kernel(half* x, half* y, int N) {
8182 if ((idx + 6 ) < N) { HALF2 (y[idx + 6 ]) = reg_y_3; }
8283}
8384
85+ __global__ void relu_f16x8_pack_kernel (half* x, half* y, int N) {
86+ int idx = 8 * (blockIdx .x * blockDim .x + threadIdx .x );
87+ const half2 z2 = {__float2half (0 .0f ), __float2half (0 .0f )};
88+ // temporary register(memory), .local space in ptx, addressable
89+ half pack_x[8 ], pack_y[8 ]; // 8x16 bits=128 bits.
90+ // reinterpret as float4 and load 128 bits in 1 memory issue.
91+ LDST128BITS (pack_x[0 ]) = LDST128BITS (x[idx]); // load 128 bits
92+
93+ #pragma unroll
94+ for (int i = 0 ; i < 8 ; i += 2 ) {
95+ // __hmax2 for half2 x 4
96+ HALF2 (pack_y[i]) = __hmax2 (HALF2 (pack_x[i]), z2);
97+ }
98+ // reinterpret as float4 and store 128 bits in 1 memory issue.
99+ if ((idx + 7 ) < N) { LDST128BITS (y[idx]) = LDST128BITS (pack_y[0 ]); }
100+ }
101+
102+
84103// --------------------- PyTorch bindings for custom kernel -----------------------
85104#define STRINGFY (str ) #str
86105#define TORCH_BINDING_COMMON_EXTENSION (func ) \
@@ -92,61 +111,54 @@ if(((T).options().dtype() != (th_type))) { \
92111 throw std::runtime_error (" values must be " #th_type); \
93112}
94113
95- #define CHECK_TORCH_TENSOR_SHAPE (T, S0 ) \
96- if (((T).size(0 ) != (S0))) { throw std::runtime_error (" Tensor size mismatch!" ); }
97-
98114#define TORCH_BINDING_RELU (packed_type, th_type, element_type, n_elements ) \
99- torch::Tensor relu_##packed_type(torch::Tensor x) { \
100- CHECK_TORCH_TENSOR_DTYPE (x, (th_type)) \
101- auto options = torch::TensorOptions ().dtype ((th_type)).device ( \
102- torch::kCUDA , 0 ); \
103- const int N = x.size (0 ); \
104- auto y = torch::zeros ({N}, options); \
105- static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
106- const int NUM_BLOCKS = (N + 256 - 1 ) / 256 ; \
107- dim3 block (NUM_THREADS_PER_BLOCK); \
108- dim3 grid (NUM_BLOCKS); \
109- relu_##packed_type##_kernel<<<grid, block>>> ( \
110- reinterpret_cast <element_type*>(x.data_ptr ()), \
111- reinterpret_cast <element_type*>(y.data_ptr ()), N); \
112- return y; \
113- }
114-
115- #define TORCH_BINDING_RELU_V2 (packed_type, th_type, element_type, n_elements ) \
116- void relu_##packed_type##_v2(torch::Tensor x, torch::Tensor y) { \
115+ void relu_##packed_type(torch::Tensor x, torch::Tensor y) { \
117116 CHECK_TORCH_TENSOR_DTYPE (x, (th_type)) \
118117 CHECK_TORCH_TENSOR_DTYPE (y, (th_type)) \
119- const int N = x.size ( 0 ); \
120- CHECK_TORCH_TENSOR_SHAPE (y, N) \
121- static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
122- const int NUM_BLOCKS = (N + 256 - 1 ) / 256 ; \
123- dim3 block (NUM_THREADS_PER_BLOCK); \
124- dim3 grid (NUM_BLOCKS); \
125- relu_##packed_type##_kernel<<<grid, block>>> ( \
118+ const int ndim = x.dim (); \
119+ if (ndim != 2 ) { \
120+ int N = 1 ; \
121+ for ( int i = 0 ; i < ndim; ++i) { N *= x. size (i); } \
122+ dim3 block (256 / (n_elements)); \
123+ dim3 grid ((N + 256 - 1 ) / 256 ); \
124+ relu_##packed_type##_kernel<<<grid, block>>> ( \
126125 reinterpret_cast <element_type*>(x.data_ptr ()), \
127126 reinterpret_cast <element_type*>(y.data_ptr ()), N); \
127+ } else { \
128+ const int S = x.size (0 ); \
129+ const int K = x.size (1 ); \
130+ const int N = S * K; \
131+ if ((K/(n_elements)) <= 1024 ) { \
132+ dim3 block (K/(n_elements)); \
133+ dim3 grid (S); \
134+ relu_##packed_type##_kernel<<<grid, block>>> ( \
135+ reinterpret_cast <element_type*>(x.data_ptr ()), \
136+ reinterpret_cast <element_type*>(y.data_ptr ()), N); \
137+ } else { \
138+ int N = 1 ; \
139+ for (int i = 0 ; i < ndim; ++i) { N *= x.size (i); } \
140+ dim3 block (256 / (n_elements)); \
141+ dim3 grid ((N + 256 - 1 ) / 256 ); \
142+ relu_##packed_type##_kernel<<<grid, block>>> ( \
143+ reinterpret_cast <element_type*>(x.data_ptr ()), \
144+ reinterpret_cast <element_type*>(y.data_ptr ()), N); \
145+ } \
146+ } \
128147}
129148
130- TORCH_BINDING_RELU (f32 , torch::kFloat32 , float , 1 )
131- TORCH_BINDING_RELU(f32x4, torch::kFloat32 , float , 4 )
132- TORCH_BINDING_RELU(f16 , torch::kHalf , half, 1 )
133- TORCH_BINDING_RELU(f16x2, torch::kHalf , half, 2 )
134- TORCH_BINDING_RELU(f16x8, torch::kHalf , half, 8 )
135- TORCH_BINDING_RELU_V2(f32 , torch::kFloat32 , float , 1 )
136- TORCH_BINDING_RELU_V2(f32x4, torch::kFloat32 , float , 4 )
137- TORCH_BINDING_RELU_V2(f16 , torch::kHalf , half, 1 )
138- TORCH_BINDING_RELU_V2(f16x2, torch::kHalf , half, 2 )
139- TORCH_BINDING_RELU_V2(f16x8, torch::kHalf , half, 8 )
149+
150+ TORCH_BINDING_RELU (f32 , torch::kFloat32 , float , 1 )
151+ TORCH_BINDING_RELU(f32x4, torch::kFloat32 , float , 4 )
152+ TORCH_BINDING_RELU(f16 , torch::kHalf , half, 1 )
153+ TORCH_BINDING_RELU(f16x2, torch::kHalf , half, 2 )
154+ TORCH_BINDING_RELU(f16x8, torch::kHalf , half, 8 )
155+ TORCH_BINDING_RELU(f16x8_pack, torch::kHalf , half, 8 )
140156
141157PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
142158 TORCH_BINDING_COMMON_EXTENSION (relu_f32)
143159 TORCH_BINDING_COMMON_EXTENSION (relu_f32x4)
144160 TORCH_BINDING_COMMON_EXTENSION (relu_f16)
145161 TORCH_BINDING_COMMON_EXTENSION (relu_f16x2)
146162 TORCH_BINDING_COMMON_EXTENSION (relu_f16x8)
147- TORCH_BINDING_COMMON_EXTENSION (relu_f32_v2)
148- TORCH_BINDING_COMMON_EXTENSION (relu_f32x4_v2)
149- TORCH_BINDING_COMMON_EXTENSION (relu_f16_v2)
150- TORCH_BINDING_COMMON_EXTENSION (relu_f16x2_v2)
151- TORCH_BINDING_COMMON_EXTENSION (relu_f16x8_v2)
163+ TORCH_BINDING_COMMON_EXTENSION (relu_f16x8_pack)
152164}
0 commit comments