@@ -385,28 +385,41 @@ void performTest(const ProcessingMethod processing_method,
385385
386386 NVTEBasicTensor grad_data_tensor = {grad_data_d, static_cast <NVTEDType>(itype), logical_shape_};
387387 NVTEBasicTensor in_data_tensor = {in_data_d, static_cast <NVTEDType>(itype), logical_shape_};
388- nvte_set_grouped_tensor_param (&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData , &in_data_tensor);
389- nvte_set_grouped_tensor_param (&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData , &grad_data_tensor);
388+ nvte_set_grouped_tensor_param (in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData ,
389+ &in_data_tensor, sizeof (in_data_tensor));
390+ nvte_set_grouped_tensor_param (grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData ,
391+ &grad_data_tensor, sizeof (grad_data_tensor));
390392
391393 if ((shape_rep == VARYING_FIRST_DIM ) || (shape_rep == VARYING_BOTH_DIMS )) {
392394 NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64 , first_dims_shape_};
393- nvte_set_grouped_tensor_param (&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims , &first_dims_tensor);
394- nvte_set_grouped_tensor_param (&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims , &first_dims_tensor);
395- nvte_set_grouped_tensor_param (&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims , &first_dims_tensor);
395+ nvte_set_grouped_tensor_param (grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims ,
396+ &first_dims_tensor, sizeof (first_dims_tensor));
397+ nvte_set_grouped_tensor_param (in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims ,
398+ &first_dims_tensor, sizeof (first_dims_tensor));
399+ nvte_set_grouped_tensor_param (out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims ,
400+ &first_dims_tensor, sizeof (first_dims_tensor));
396401 }
397402
398403 if ((shape_rep == VARYING_LAST_DIM ) || (shape_rep == VARYING_BOTH_DIMS )) {
399404 NVTEBasicTensor last_dims_tensor = {last_dims_d, kNVTEInt64 , last_dims_shape_};
400- nvte_set_grouped_tensor_param (&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims , &last_dims_tensor);
401- nvte_set_grouped_tensor_param (&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims , &last_dims_tensor);
402- nvte_set_grouped_tensor_param (&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims , &last_dims_tensor);
405+ nvte_set_grouped_tensor_param (grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims ,
406+ &last_dims_tensor, sizeof (last_dims_tensor));
407+ nvte_set_grouped_tensor_param (in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims ,
408+ &last_dims_tensor, sizeof (last_dims_tensor));
409+ nvte_set_grouped_tensor_param (out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims ,
410+ &last_dims_tensor, sizeof (last_dims_tensor));
403411 }
404412
405413 if (shape_rep != SAME_BOTH_DIMS ) {
406414 NVTEBasicTensor offsets_tensor = {offsets_d, kNVTEInt64 , offsets_shape_};
407- nvte_set_grouped_tensor_param (&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets , &offsets_tensor);
408- nvte_set_grouped_tensor_param (&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets , &offsets_tensor);
409- nvte_set_grouped_tensor_param (&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets , &offsets_tensor);
415+ nvte_set_grouped_tensor_param (grad_group_tensor,
416+ NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets ,
417+ &offsets_tensor, sizeof (offsets_tensor));
418+ nvte_set_grouped_tensor_param (in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets ,
419+ &offsets_tensor, sizeof (offsets_tensor));
420+ nvte_set_grouped_tensor_param (out_group_tensor,
421+ NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets ,
422+ &offsets_tensor, sizeof (offsets_tensor));
410423 }
411424
412425 if (rowwise) {
@@ -417,8 +430,11 @@ void performTest(const ProcessingMethod processing_method,
417430 NVTEBasicTensor out_data_rowwise_tensor = {out_data_rowwise_d, static_cast <NVTEDType>(otype), logical_shape_};
418431 NVTEShape scales_rowwise_shape_ = nvte_make_shape (scales_rowwise_shape.data (), scales_rowwise_shape.size ());
419432 NVTEBasicTensor out_scales_rowwise_tensor = {out_scales_rowwise_d, NVTEDType::kNVTEFloat8E8M0 , scales_rowwise_shape_};
420- nvte_set_grouped_tensor_param (&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData , &out_data_rowwise_tensor);
421- nvte_set_grouped_tensor_param (&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv , &out_scales_rowwise_tensor);
433+ nvte_set_grouped_tensor_param (out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData ,
434+ &out_data_rowwise_tensor, sizeof (out_data_rowwise_tensor));
435+ nvte_set_grouped_tensor_param (out_group_tensor,
436+ NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv ,
437+ &out_scales_rowwise_tensor, sizeof (out_scales_rowwise_tensor));
422438 }
423439
424440 if (colwise) {
@@ -429,8 +445,12 @@ void performTest(const ProcessingMethod processing_method,
429445 NVTEBasicTensor out_data_colwise_tensor = {out_data_colwise_d, static_cast <NVTEDType>(otype), logical_shape_};
430446 NVTEShape scales_colwise_shape_ = nvte_make_shape (scales_colwise_shape.data (), scales_colwise_shape.size ());
431447 NVTEBasicTensor out_scales_colwise_tensor = {out_scales_colwise_d, NVTEDType::kNVTEFloat8E8M0 , scales_colwise_shape_};
432- nvte_set_grouped_tensor_param (&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData , &out_data_colwise_tensor);
433- nvte_set_grouped_tensor_param (&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv , &out_scales_colwise_tensor);
448+ nvte_set_grouped_tensor_param (out_group_tensor,
449+ NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData ,
450+ &out_data_colwise_tensor, sizeof (out_data_colwise_tensor));
451+ nvte_set_grouped_tensor_param (out_group_tensor,
452+ NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv ,
453+ &out_scales_colwise_tensor, sizeof (out_scales_colwise_tensor));
434454 }
435455
436456 Tensor output_dbias (" output_dbias" , std::vector<size_t >{ cols }, itype);
@@ -695,7 +715,10 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) {
695715 }
696716 offsets[t+1 ] = offsets[t] + first_dims[t] * last_dims[t];
697717 // Skips tests if tensor shape is not as required by the kernel
698- if ((first_dims[t] % 128 != 0 ) || (last_dims[t] % 32 != 0 )) {
718+ if (first_dims[t] % 128 != 0 ) {
719+ GTEST_SKIP ();
720+ }
721+ if (!is_single_tensor && (last_dims[t] % 128 != 0 )) {
699722 GTEST_SKIP ();
700723 }
701724 }
0 commit comments