@@ -730,8 +730,20 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
730730 (0 * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
731731 lane_smem_b_n) * sizeof (half)
732732 );
733+ // TODO: may use .x4.trans to load 4 matrix for reg double buffers at once?
733734 LDMATRIX_X2_T (RB[reg_store_idx][j][0 ], RB[reg_store_idx][j][1 ],
734735 lane_smem_b_ptr);
736+ // int lane_smem_b_k = lane_id % 16; // 0~15, 0~15
737+ // int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
738+ // uint32_t lane_smem_b_ptr = (
739+ // smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) * (lane_id / 16) +
740+ // (0 * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
741+ // lane_smem_b_n) * sizeof(half)
742+ // );
743+ // // TRICK: I use .x4.trans to load 4 matrix for reg double buffers at once.
744+ // LDMATRIX_X4_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
745+ // RB[reg_load_idx][j][0], RB[reg_load_idx][j][1],
746+ // lane_smem_b_ptr);
735747 }
736748 }
737749
@@ -805,6 +817,7 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
805817 (smem_sel * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
806818 lane_smem_b_n) * sizeof (half)
807819 );
820+ // TODO: may use .x4.trans to load 4 matrix for reg double buffers at once?
808821 LDMATRIX_X2_T (RB[reg_store_idx][j][0 ], RB[reg_store_idx][j][1 ],
809822 lane_smem_b_ptr);
810823 }
@@ -841,7 +854,6 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
841854 }
842855 }
843856
844-
845857 CP_ASYNC_WAIT_GROUP (K_STAGE-2 );
846858 __syncthreads ();
847859
@@ -874,8 +886,20 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
874886 lane_smem_b_k * (BN + B_PAD) +
875887 lane_smem_b_n) * sizeof (half)
876888 );
889+ // TODO: may use .x4.trans to load 4 matrix for reg double buffers at once?
877890 LDMATRIX_X2_T (RB[reg_store_idx][j][0 ], RB[reg_store_idx][j][1 ],
878891 lane_smem_b_ptr);
892+ // int lane_smem_b_k = lane_id % 16; // 0~15, 0~15
893+ // int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
894+ // uint32_t lane_smem_b_ptr = (
895+ // smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) * (lane_id / 16) +
896+ // (smem_sel_reg * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
897+ // lane_smem_b_n) * sizeof(half)
898+ // );
899+ // // may use .x4.trans to load 4 matrix for reg double buffers at once?
900+ // LDMATRIX_X4_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
901+ // RB[reg_load_idx][j][0], RB[reg_load_idx][j][1],
902+ // lane_smem_b_ptr);
879903 }
880904 }
881905
@@ -920,6 +944,7 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
920944 (stage_sel * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
921945 lane_smem_b_n) * sizeof (half)
922946 );
947+ // TODO: may use .x4.trans to load 4 matrix for reg double buffers at once?
923948 LDMATRIX_X2_T (RB[reg_store_idx][j][0 ], RB[reg_store_idx][j][1 ],
924949 lane_smem_b_ptr);
925950 }
@@ -988,6 +1013,17 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_kernel(
9881013 );
9891014 LDMATRIX_X2_T (RB[reg_store_idx][j][0 ], RB[reg_store_idx][j][1 ],
9901015 lane_smem_b_ptr);
1016+ // int lane_smem_b_k = lane_id % 16; // 0~15, 0~15
1017+ // int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
1018+ // uint32_t lane_smem_b_ptr = (
1019+ // smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) * (lane_id / 16) +
1020+ // (stage_sel_reg * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
1021+ // lane_smem_b_n) * sizeof(half)
1022+ // );
1023+ // // may use .x4.trans to load 4 matrix for reg double buffers at once?
1024+ // LDMATRIX_X4_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
1025+ // RB[reg_load_idx][j][0], RB[reg_load_idx][j][1],
1026+ // lane_smem_b_ptr);
9911027 }
9921028 }
9931029 }
0 commit comments