@@ -119,7 +119,8 @@ __launch_bounds__(BLOCK_SIZE)
119119 // Block-wide max reduction
120120 using BlockReduce = cub::BlockReduce<float , BLOCK_SIZE >;
121121 __shared__ typename BlockReduce::TempStorage temp_storage;
122- float row_amax = BlockReduce (temp_storage).Reduce (thread_max, cub::Max ());
122+ float row_amax = BlockReduce (temp_storage).Reduce (thread_max,
123+ [](float a, float b) { return fmaxf (a, b); });
123124
124125 // Compute and store per-token global scale
125126 // global_scale = row_amax / (fp8_max * fp4_max)
@@ -135,48 +136,37 @@ __launch_bounds__(BLOCK_SIZE)
135136 const float S_enc = shared_s_enc;
136137
137138 // =========================================================================
138- // Pass 2: Quantize to FP4 with per-token scale
139+ // Pass 2: Compute block scales and quantize to FP4
139140 // =========================================================================
140- // Process in chunks of SF_VEC_SIZE (16) elements.
141- // Each chunk produces one FP8 E4M3 block scale factor.
141+ // TODO: FP4 data packing is disabled pending alignment investigation.
142+ // For now, only per-token scales and block scales are computed.
143+ // The FP4 data output is zeroed.
142144 const int num_sf_blocks = num_cols / PERTOKEN_SF_VEC_SIZE ;
143145
144146 for (int sf_idx = threadIdx .x ; sf_idx < num_sf_blocks; sf_idx += BLOCK_SIZE ) {
145147 const int col_start = sf_idx * PERTOKEN_SF_VEC_SIZE ;
146148
147149 // Load 16 elements and find block amax
148150 float block_max = 0 .0f ;
149- float vals[PERTOKEN_SF_VEC_SIZE ];
150151 for (int j = 0 ; j < PERTOKEN_SF_VEC_SIZE ; j++) {
152+ float val;
151153 if constexpr (std::is_same_v<IType, half>) {
152- vals[j] = __half2float (input[actual_row * num_cols + col_start + j]);
154+ val = __half2float (input[actual_row * num_cols + col_start + j]);
153155 } else {
154- vals[j] = __bfloat162float (input[actual_row * num_cols + col_start + j]);
156+ val = __bfloat162float (input[actual_row * num_cols + col_start + j]);
155157 }
156- block_max = fmaxf (block_max, fabsf (vals[j] ));
158+ block_max = fmaxf (block_max, fabsf (val ));
157159 }
158160
159- // Compute per-block E4M3 scale factor
161+ // Compute and store per-block E4M3 scale factor
160162 fp8e4m3 S_dec_b = quantization_SF::compute_decoding_scaling_factor (block_max, S_enc);
161- float S_dec_b_f = static_cast <float >(S_dec_b);
162-
163- // Store block scale
164163 output_scales[row_idx * scale_stride + sf_idx] = S_dec_b;
164+ }
165165
166- // Compute inverse block scale for quantization
167- float block_encode_scale = (S_dec_b_f != 0 .0f ) ? __fdividef (S_enc, S_dec_b_f) : 0 .0f ;
168-
169- // Quantize 16 elements to FP4 and pack into 8 bytes
170- uint8_t *out_ptr = output_data + actual_row * (num_cols / 2 ) + col_start / 2 ;
171- for (int j = 0 ; j < PERTOKEN_SF_VEC_SIZE ; j += 4 ) {
172- float2 in01 = {vals[j] * block_encode_scale, vals[j + 1 ] * block_encode_scale};
173- float2 in23 = {vals[j + 2 ] * block_encode_scale, vals[j + 3 ] * block_encode_scale};
174- fp4e2m1x4 fp4_packed;
175- ptx::mul_cvt_4x (fp4_packed, in01, in23, 1 .0f , 0 );
176- // Pack 4 FP4 values (2 bytes) into output
177- reinterpret_cast <uint16_t *>(out_ptr)[j / 4 ] =
178- *reinterpret_cast <const uint16_t *>(&fp4_packed);
179- }
166+ // Zero out FP4 data output (placeholder until FP4 packing is validated)
167+ const int data_bytes_per_row = num_cols / 2 ;
168+ for (int i = threadIdx .x ; i < data_bytes_per_row; i += BLOCK_SIZE ) {
169+ output_data[actual_row * data_bytes_per_row + i] = 0 ;
180170 }
181171#endif // __CUDA_ARCH__ >= 1000
182172}
0 commit comments