From 6bc5f30a8eb6f08fb3e8121fd17a4d02c8a6704d Mon Sep 17 00:00:00 2001 From: "Lu,Chengjun" Date: Mon, 17 Nov 2025 19:30:32 +0000 Subject: [PATCH] [LoadStoreOpToLLVM] Unify the 2D block IO lowering code for both regular pointer and block pointer. And clean up duplicate code. Signed-off-by: Lu,Chengjun --- .../LoadStoreOpToLLVM.cpp | 1187 ++++------------- 1 file changed, 291 insertions(+), 896 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 9062aef7ad..0ebfba83c7 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -83,6 +83,40 @@ unsigned getCanonicalIndex(unsigned index, unsigned freeVarMask) { return index & ~freeVarMask; } +struct BlockPointerStruct { + Value base; + SmallVector shape; + SmallVector stride; + SmallVector offsets; +}; + +BlockPointerStruct unpackLLBlockPointer(Value blockPointerStruct, + ConversionPatternRewriter &rewriter) { + const SmallVector &blockPtr = unpackLLElements( + blockPointerStruct.getLoc(), blockPointerStruct, rewriter); + // The block pointer struct is expected to have the following layout: + // Struct { + // Value offset[rank]; + // Value shape[rank]; + // Value stride[rank]; + // Value base; + // } + assert((blockPtr.size() - 1) % 3 == 0 && + "unexpected number of values unpacked from a block pointer"); + unsigned rank = (blockPtr.size() - 1) / 3; + unsigned blockOffset = 0, blockShape = 1 * rank, blockStride = 2 * rank, + blockBase = 3 * rank; + SmallVector offset(blockPtr.begin() + blockOffset, + blockPtr.begin() + blockShape); + SmallVector shape(blockPtr.begin() + blockShape, + blockPtr.begin() + blockStride); + SmallVector stride(blockPtr.begin() + blockStride, + blockPtr.begin() + blockBase); + + return {blockPtr[blockBase], std::move(shape), std::move(stride), + std::move(offset)}; +} + /// Holds the values related to a block pointer. /// It includes the base pointer, base width and height, row and column /// stride, and offset base for X and Y. @@ -102,14 +136,17 @@ struct BlockPointerValues { BlockPointerValues getValuesFromBlockPointerStruct(Value blockPointerStruct, ConversionPatternRewriter &rewriter) { - const SmallVector &elems = unpackLLElements( - blockPointerStruct.getLoc(), blockPointerStruct, rewriter); - assert(elems.size() == sizeof(BlockPointerValues) / sizeof(Value) && + const BlockPointerStruct &blockPtr = + unpackLLBlockPointer(blockPointerStruct, rewriter); + assert(blockPtr.shape.size() == 2 && "unexpected number of values unpacked from a block pointer"); - return {/*base=*/elems[6], /*baseWidth=*/elems[3], - /*baseHeight=*/elems[2], /*rowStride=*/elems[4], - /*colStride=*/elems[5], /*offsetBaseX=*/elems[1], - /*offsetBaseY=*/elems[0]}; + return {/*base=*/blockPtr.base, + /*baseWidth=*/blockPtr.shape[1], + /*baseHeight=*/blockPtr.shape[0], + /*rowStride=*/blockPtr.stride[0], + /*colStride=*/blockPtr.stride[1], + /*offsetBaseX=*/blockPtr.offsets[1], + /*offsetBaseY=*/blockPtr.offsets[0]}; } /// Compute the 2D prefetch shape for each warp given an input 2D tensor. @@ -433,29 +470,131 @@ struct BlockIOConversionBase : public LoadStoreConversionBase { : getDotEncoding(tensorTy).value().getParent()); } + // Unpack the base pointers from regular pointer or block pointer. + SmallVector getBases(ConversionPatternRewriter &rewriter, Value ptr, + const SmallVector &unpackedPtrs, + unsigned numElems) const { + SmallVector ptrElems; + if (isTensorPointerType(ptr.getType())) { + // The block pointer struct is expected to have the following layout: + // Struct { + // Value offset[rank]; + // Value shape[rank]; + // Value stride[rank]; + // Value base; + // } + assert((unpackedPtrs.size() - 1) % 3 == 0 && + "unexpected number of values unpacked from a block pointer"); + unsigned rank = (unpackedPtrs.size() - 1) / 3; + unsigned blockBase = 3 * rank; + ptrElems.assign(numElems, unpackedPtrs[blockBase]); + } else { + ptrElems = unpackedPtrs; + } + + return ptrElems; + } + // Returns the pitch (stride in bytes) of \p ptr. Value getPitch(ConversionPatternRewriter &rewriter, Value ptr, + const SmallVector &unpackedPtrs, unsigned elemSizeInBits, unsigned dim) const { Location loc = ptr.getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); - int stride = getStride(ptr, dim); - // If the stride is 0, we assume a minimum pitch of 64 bytes. - constexpr int MIN_PITCH = 64; - if (stride == 0) - return b.i32_val(MIN_PITCH); - - if (stride > 0) { - unsigned pitch = (unsigned)stride * elemSizeInBits / 8; - if (pitch < MIN_PITCH) - return nullptr; // unsupported pitch - return b.i32_val(pitch); + if (isTensorPointerType(ptr.getType())) { + // The block pointer struct is expected to have the following layout: + // Struct { + // Value offset[rank]; + // Value shape[rank]; + // Value stride[rank]; + // Value base; + // } + assert((unpackedPtrs.size() - 1) % 3 == 0 && + "unexpected number of values unpacked from a block pointer"); + unsigned rank = (unpackedPtrs.size() - 1) / 3; + unsigned blockStride = 2 * rank; + Value stride = unpackedPtrs[blockStride + dim]; + return b.mul(b.trunc(i32_ty, stride), b.i32_val(elemSizeInBits / 8)); + } else { + // Regular pointer. + int stride = getStride(ptr, dim); + // If the stride is 0, we assume a minimum pitch of 64 bytes. + constexpr int MIN_PITCH = 64; + if (stride == 0) + return b.i32_val(MIN_PITCH); + + if (stride > 0) { + unsigned pitch = (unsigned)stride * elemSizeInBits / 8; + if (pitch < MIN_PITCH) + return nullptr; // unsupported pitch + return b.i32_val(pitch); + } + assert(stride == -1 && "invalid stride < 0"); } - assert(stride == -1 && "invalid stride < 0"); return nullptr; } +#if 0 + // Returns the width and height of the matrix referred by the ptr. + std::tuple getBaseWidthAndHeight(ConversionPatternRewriter &rewriter, Value ptr, + Value llPtr, unsigned elemSizeInBits, unsigned rowDim) const { + Location loc = ptr.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value baseWidth, baseHeight; + if (isTensorPointerType(ptr.getType())) { + auto [base, _baseWidth, _baseHeight, rowStride, colStride, _offsetBaseX, + _offsetBaseY] = + getValuesFromBlockPointerStruct(llPtr, rewriter); + if (rowDim == 1) { + baseWidth = b.trunc(i32_ty, _baseWidth); + baseHeight = b.trunc(i32_ty, _baseHeight); + } else { + baseWidth = b.trunc(i32_ty, _baseHeight); + baseHeight = b.trunc(i32_ty, _baseWidth); + } + baseWidth = b.mul(baseWidth, b.i32_val(elemSizeInBits / 8)); + } else { + // If the stride is 0, we want to load only the first row. + int stride = getStride(ptr, rowDim); + unsigned baseHeightInt = (stride == 0 ? 1 : tileHeight); + baseHeight = b.i32_val(baseHeightInt); + baseWidth = b.i32_val( + std::max(64u, vBlocks * tileWidth * (elemSizeInBits / 8))); + } + } +#endif + + // Returns the offsets of the block in the ptr. + SmallVector + getBaseOffsets(ConversionPatternRewriter &rewriter, Value ptr, + const SmallVector &unpackedPtrs) const { + Location loc = ptr.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (isTensorPointerType(ptr.getType())) { + // The block pointer struct is expected to have the following layout: + // Struct { + // Value offset[rank]; + // Value shape[rank]; + // Value stride[rank]; + // Value base; + // } + assert((unpackedPtrs.size() - 1) % 3 == 0 && + "unexpected number of values unpacked from a block pointer"); + unsigned rank = (unpackedPtrs.size() - 1) / 3; + unsigned blockOffset = 0, blockShape = 1 * rank; + return SmallVector(unpackedPtrs.begin() + blockOffset, + unpackedPtrs.begin() + blockShape); + } else { + // For the regular pointers, the offsets has already been added into + // bases. + RankedTensorType tensorType = cast(ptr.getType()); + unsigned rank = tensorType.getShape().size(); + return SmallVector(rank, b.i32_val(0)); + } + } + struct BlockIOTileSizeInfo { BlockIOTileSizeInfo() = delete; BlockIOTileSizeInfo(int tileHeight, int tileWidth, int numElemPerPackedVal, @@ -959,8 +1098,8 @@ struct PrefetchOpConversion masks[offset] = maskElems[i]; } - Value rowStrideInBytes = - getPitch(rewriter, op.getPtr(), elemSizeInBits, memoryRowMajor ? 0 : 1); + Value rowStrideInBytes = getPitch(rewriter, op.getPtr(), ptrElems, + elemSizeInBits, memoryRowMajor ? 0 : 1); if (!rowStrideInBytes) return failure(); @@ -1096,804 +1235,6 @@ struct LoadOpToBlockIOConversion : ConvertTritonGPUOpToLLVMPattern(converter, benefit), BlockIOConversionBase(targetInfo, axisAnalysisPass) {} - LogicalResult - rewriteTensorPointerLoad(triton::LoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // FIXME: Remove once IGC can split large 2D block loads. - std::optional oneMatrixPerLoadForBT = - mlir::triton::tools::isEnvValueBool(mlir::triton::tools::getStrEnv( - "TRITON_INTEL_ONE_MATRIX_PER_LOAD_BT")); - if (!oneMatrixPerLoadForBT.has_value()) - oneMatrixPerLoadForBT = - op->hasAttr(triton::gpu::intel::TritonIntelGPUDialect:: - getOneMatrixPerLoadAttrName()); - - Value ptr = op.getPtr(); - assert(isTensorPointerType(ptr.getType()) && - "Expecting tensor pointer type"); - - Location loc = op.getLoc(); - auto b = TritonLLVMOpBuilder(loc, rewriter); - Value mask = op.getMask(); - Value other = op.getOther(); - Type resultType = op.getType(); - auto tensorType = cast(resultType); - - const bool memoryRowMajor = isMemoryRowMajor(op); - DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorType); - - LLVM_DEBUG(llvm::dbgs() << "Tensor type for op " << int(opIdx) << ": " - << tensorType << "\n"); - - Attribute encoding = tensorType.getEncoding(); - std::optional llEncoding = - cast(encoding).toLinearLayout( - tensorType.getShape()); - assert(llEncoding.has_value() && "invalid dot layout to linear layout"); - LinearEncodingAttr llAttr = - LinearEncodingAttr::get(rewriter.getContext(), *llEncoding); - SmallVector threadOrder = llAttr.getThreadOrder(); - size_t rank = threadOrder.size(); - const bool valueRowMajor = - (threadOrder[rank - 2] == 1 && threadOrder[rank - 1] == 0); - assert((valueRowMajor || - (threadOrder[rank - 2] == 0 && threadOrder[rank - 1] == 1)) && - "Only row_major or column_major is allowed"); - const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor; - - Type eltTy = getTypeConverter()->convertType(tensorType.getElementType()); - unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth(); - - auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( - cast(encoding), tensorType.getShape(), - memoryRowMajor, elemSizeInBits / 8, rewriter.getContext()); - unsigned tileHeight = tileParams[0]; - const unsigned tileWidth = tileParams[1]; - const unsigned vBlocks = tileParams[2]; - - DpasEncodingAttr dpasLayout = getDpasLayout(tensorType); - const ArrayRef tensorShape = tensorType.getShape(); - unsigned numElems = getTotalElemsPerThread(resultType); - SmallVector numReps = - dpasLayout.getDPASRepetitions(tensorShape, opIdx); - auto warpsPerCTA = dpasLayout.getWarpsPerCTA(); - SmallVector dpasWarpsOrder = - getMatrixOrder(warpsPerCTA.size(), /*rowMajor*/ true); - unsigned threadsPerWarp = - product(getThreadsPerWarp(dpasLayout, tensorShape)); - - Value warpId = rewriter.create( - loc, i32_ty, - rewriter.create(loc, /*upperBound=*/nullptr)); - - SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder); - - if (opIdx == DpasEncodingAttr::OpIdx::OperandC) { - // A block load with the DPAS layout but without the DotDpasLayout is - // expected to follow the ordering of the DPAS output. For a 2D block - // load, the rows are distributed across work items/SIMD lanes and the - // column vectors are available for each work item to process. This layout - // aligns to the DPAS layout as the DPAS operation output layout - // distributes rows across work items. - - if (isTransposeRequired) { - // TODO: this would likely require a shuffle to match the expected - // ordering coming out of the DPAS layout and requires more - // investigation - return failure(); - } - - MLIRContext *ctx = rewriter.getContext(); - - Value elemSizeInBytes = b.i32_val(elemSizeInBits / 8); - - const unsigned elemsPerLane = tileWidth * tileHeight / threadsPerWarp; - Type load2DGenXType = - LLVM::getVectorType(IntegerType::get(ctx, elemSizeInBits), - elemsPerLane); // make it opaque type. - - auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX, - offsetBaseY] = - getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter); - baseWidth = b.trunc(i32_ty, baseWidth); - baseHeight = b.trunc(i32_ty, baseHeight); - - auto pitch = b.trunc(i32_ty, rowStride); - - SmallVector repClusterShape = dpasLayout.getShapeC(); - unsigned outerDimWarpNum = - std::min(warpsPerCTA[rank - 2], - mlir::ceil(tensorShape[rank - 2], - repClusterShape[rank - 2])); - unsigned innerDimWarpNum = - std::min(warpsPerCTA[rank - 1], - mlir::ceil(tensorShape[rank - 1], - repClusterShape[rank - 1])); - Value outerDimWarpId = - b.urem(multiDimWarpId[rank - 2], b.i32_val(outerDimWarpNum)); - Value innerDimWarpId = - b.urem(multiDimWarpId[rank - 1], b.i32_val(innerDimWarpNum)); - int64_t numRepOuter = numReps[1]; - int64_t numRepInner = numReps[2]; - - std::array replicaStride = { - outerDimWarpNum * repClusterShape[rank - 2], - innerDimWarpNum * repClusterShape[rank - 1]}; - std::array warpStride = {repClusterShape[rank - 2], - repClusterShape[rank - 1]}; - - Value dimWarpId0 = b.mul(outerDimWarpId, b.i32_val(warpStride[0])); - Value dimWarpId1 = b.mul(innerDimWarpId, b.i32_val(warpStride[1])); - Value warpId0Offset = b.add(dimWarpId0, offsetBaseY); - Value warpId1Offset = b.add(dimWarpId1, offsetBaseX); - - ArrayRef repCluster = dpasLayout.getRepCluster(); - unsigned valOffset = 0; - - SmallVector unpackedLoadedVals; - - for (int m = 0; m < numRepOuter; ++m) { - for (int n = 0; n < numRepInner; ++n) { - for (int repM = 0; repM < repCluster[0]; ++repM) { - - Value offsetY = - b.urem(b.add(warpId0Offset, b.i32_val(m * replicaStride[0] + - repM * tileHeight)), - baseHeight); - for (int repN = 0; repN < repCluster[1]; ++repN) { - Value offsetX = - b.add(warpId1Offset, - b.i32_val(n * replicaStride[1] + repN * tileWidth)); - - auto load2dOp = rewriter.create( - loc, load2DGenXType, - /*ptr*/ base, - /*base_width*/ b.mul(baseWidth, elemSizeInBytes), - /*base_height*/ baseHeight, - /*base_pitch*/ b.mul(pitch, elemSizeInBytes), - /*x*/ offsetX, - /*y*/ offsetY, - /*elem_size_in_bits*/ elemSizeInBits, - /*tile_width*/ tileWidth, - /*tile_height*/ tileHeight, - /*v_blocks*/ vBlocks, - /*transpose*/ false, - /*vnni_transform*/ false); - if (failed(load2dOp.verify())) { - // delete the op so that the verifier will not abort the pass - // pipeline later, as we can fail this path and try a different - // approach. - rewriter.eraseOp(load2dOp); - return failure(); - } - - Value ret = - b.bitcast(load2dOp, LLVM::getVectorType(eltTy, elemsPerLane)); - - for (size_t i = 0; i < elemsPerLane; ++i) { - Value loaded = b.extract_element( - eltTy, ret, b.urem(b.i32_val(i), baseHeight)); - unpackedLoadedVals.push_back(loaded); - } - } - } - } - } - - LLVMTypeConverter *typeConverter = getTypeConverter(); - Type llvmResultStructTy = typeConverter->convertType(op.getType()); - Value resultStruct = packLLElements( - loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy); - rewriter.replaceOp(op, {resultStruct}); - - return success(); - } - - const bool isOperandA = (opIdx == DpasEncodingAttr::OpIdx::OperandA); - const SmallVector dpasInstShape = - isOperandA ? dpasLayout.getDPASInstShapeA() - : dpasLayout.getDPASInstShapeB(); - const SmallVector elemsPerDPASInst = {dpasInstShape[0], - dpasInstShape[1]}; - LLVM_DEBUG(llvm::dbgs() - << "Elements per DPAS Instruction: " << elemsPerDPASInst[0] - << ", " << elemsPerDPASInst[1] << "\n"); - unsigned elemsPerLanePerDPASInst = - product(elemsPerDPASInst) / threadsPerWarp; - LLVMTypeConverter *typeConverter = getTypeConverter(); - Type unpackedDPASOperandType = LLVM::getVectorType( - typeConverter->convertType(eltTy), elemsPerLanePerDPASInst); - - // By default, use the unpacked type for the 2D load result type. - Type loadResultElemType = typeConverter->convertType(eltTy); - bool usePackedType = false; - unsigned packedElemsPerLanePerDPASInst = elemsPerLanePerDPASInst; - - // The tensor values are distributed as DotOp layout of DPAS. - // If the element size of the tensor matches the DPAS packed layout, then - // use the packed type for the 2D load result type. For example, - // The intermediate ops generated by ConvertTritonGPUToLLVM: - // %0 = load_2d %ptr : vector<8 x i32> - // %1 = bitcast %0 : vector<8 x i32> -> vector<16 x f16> - // %2 = bitcast %1 : vector<16 x f16> -> vector<8 x i32> - // %3 = dpas %2 - // And the LLVM dialect optimization pass can eliminate the duplicated - // bitcast. Then there is a shortcut to use the load result directly as the - // input operands to DPAS. - // TODO: add support for int4 and int2. - unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); - if ((opsPerChannel == 4 && elemSizeInBits == 8) || - (opsPerChannel == 2 && elemSizeInBits == 16) || - (opsPerChannel == 1 && elemSizeInBits == 32)) { - loadResultElemType = - (isOperandA && elemSizeInBits != 32) ? i16_ty : i32_ty; - packedElemsPerLanePerDPASInst = - isOperandA ? elemsPerLanePerDPASInst / (opsPerChannel == 4 ? 2 : 1) - : elemsPerLanePerDPASInst / opsPerChannel; - usePackedType = true; - } - - Type packedDPASOperandType = - LLVM::getVectorType(loadResultElemType, packedElemsPerLanePerDPASInst); - - // Outer dim: Dim M or N. Inner dim: Dim K. - auto repCluster = dpasLayout.getRepCluster(); - SmallVector warpShape = - isOperandA ? dpasLayout.getShapeA() : dpasLayout.getShapeB(); - - unsigned dimOuter = bool(opIdx) ? rank - 1 : rank - 2; - unsigned dimInner = bool(opIdx) ? rank - 2 : rank - 1; - - LLVM_DEBUG({ - llvm::dbgs() << "warpsPerCTA: " << warpsPerCTA[dimOuter] << ", " - << warpsPerCTA[dimInner] << "\n"; - llvm::dbgs() << "tensorShape: " << tensorShape[dimOuter] << ", " - << tensorShape[dimInner] << "\n"; - llvm::dbgs() << "repCluster: " << repCluster[dimOuter] << ", " - << repCluster[dimInner] << "\n"; - llvm::dbgs() << "warpShape: " << warpShape[dimOuter] << ", " - << warpShape[dimInner] << "\n"; - }); - - // Round the warp id fit into the tensor shape. - unsigned outerDimRequiredWarpNum = mlir::ceil( - tensorShape[dimOuter], warpShape[dimOuter]); // ceil of ratio - LLVM_DEBUG(llvm::dbgs() << "tensor to warp shape ratio = " - << outerDimRequiredWarpNum << "\n"); - unsigned outerDimWarpNum = - std::min(warpsPerCTA[dimOuter], outerDimRequiredWarpNum); - LLVM_DEBUG(llvm::dbgs() - << "outerDimWarpNum = " << outerDimRequiredWarpNum << "\n"); - Value outerDimWarpId = - b.urem(multiDimWarpId[dimOuter], b.i32_val(outerDimWarpNum)); - - auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX, - offsetBaseY] = - getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter); - - MLIRContext *ctx = rewriter.getContext(); - const StringAttr dimOuterStr = S("dim" + std::to_string(dimOuter)); - const StringAttr dimInnerStr = S("dim" + std::to_string(dimInner)); - LLVM_DEBUG({ - llvm::dbgs() << "dimOuterStr: " << dimOuterStr << "\n"; - llvm::dbgs() << "dimInnerStr: " << dimInnerStr << "\n"; - }); - - unsigned dpasTileToPackedIndicesRatio = - elemsPerDPASInst[0] / packedElemsPerLanePerDPASInst; - // if the number of elements in the DPAS tile is less than the number of - // packed elems per lane set the ratio to 1 - dpasTileToPackedIndicesRatio = std::max(dpasTileToPackedIndicesRatio, 1u); - LLVM_DEBUG(llvm::dbgs() << "dpasTileToPackedIndicesRatio = " - << dpasTileToPackedIndicesRatio << "\n"); - - // Create the linear layout for the load. - // First, we create a tile layout corresponding to a single invocation of - // the DPAS instruction across all threads/work-items in a sub-group. The - // layout will later be expanded to cover multiple DPAS invocations - // (iteration) and multiple loads (load). - StringAttr kOffset = S("offset"); - StringAttr kIteration = S("iteration"); - StringAttr kLoad = S("load"); - - auto createTileLayout = [&](const SmallVectorImpl &threadOrder, - SmallVector tileShape) { - auto outDimNames = standardOutDimNames(ctx, tensorShape.size()); - LinearLayout layout = LinearLayout::empty(); - SmallVector kOffsetDims; - unsigned totalOffsets = 1; - assert(tileShape.size() == 2); // only support 2D layouts for now - - if (isTransposeRequired && opIdx == DpasEncodingAttr::OpIdx::OperandB) { - const unsigned widthDim = threadOrder[rank - 2]; - const unsigned origTileWidth = tileShape[widthDim]; - tileShape[widthDim] = origTileWidth / (32 / elemSizeInBits); - } - - for (int i = 0; i < tileShape.size(); ++i) { - int dim = threadOrder[i]; - StringAttr kOffset = S("offset" + std::to_string(dim)); - - kOffsetDims.push_back(kOffset); - - assert(llvm::isPowerOf2_32(tileShape[dim])); - // reduce the offset dimension size by the number of elements packed in - // a single slot for the row wise dimension - const unsigned offsetDimSize = - (!isTransposeRequired && dim == 0) - ? tileShape[dim] / dpasTileToPackedIndicesRatio - : tileShape[dim]; - layout *= - LinearLayout::identity1D(offsetDimSize, kOffset, outDimNames[dim]); - totalOffsets *= offsetDimSize; - } - SmallVector newDims; - newDims.append(kOffsetDims.begin(), kOffsetDims.end()); - auto ret = layout.transposeIns(newDims); - ret = ret.transposeOuts(outDimNames); - return ret.reshapeIns({{kOffset, totalOffsets}}); - }; - auto tileLayout = createTileLayout(threadOrder, elemsPerDPASInst); - - LLVM_DEBUG({ - llvm::dbgs() << "Block load tile layout: " << tileLayout << "\n"; - for (size_t i = 0; i < tileLayout.getOutDimSize(dimOuterStr) * - tileLayout.getOutDimSize(dimInnerStr); - i += tileLayout.getOutDimSize(S("dim1"))) { - auto tensorVals = tileLayout.apply({{kOffset, i}}); - assert(tensorVals.size() == 2); - llvm::dbgs() << i << " : " << tensorVals[0].second << ", " - << tensorVals[1].second << "\n"; - } - llvm::dbgs() << "tile layout done\n"; - }); - - unsigned numOperandsOuterDimPerLoad = 1; - unsigned numOperandsInnerDimPerLoad = 1; - - unsigned numOperandsPer2DLoadM, numOperandsPer2DloadN; - if (!isTransposeRequired) { - numOperandsPer2DLoadM = - isOperandA ? repCluster[dimOuter] : numReps[unsigned(opIdx) ? 1 : 2]; - numOperandsPer2DloadN = - isOperandA ? numReps[unsigned(opIdx) ? 1 : 2] : repCluster[dimOuter]; - } else { - if (isOperandA) - return failure(); - - if (!usePackedType) - return failure(); - - if (*oneMatrixPerLoadForBT) { - // Only load 1 operand per inst on row. - numOperandsPer2DLoadM = 1; - tileHeight = elemsPerDPASInst[threadOrder[rank - 2]]; - } else { - // We can decompose the matrix returned by transposed large 2d load - // when threads per warp < column size. Otherwise we have to load one - // operand per inst. - // Note: the tileHeight and numOperandsPer2DLoadM are the column size - // now. - numOperandsPer2DLoadM = - (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1; - } - // The transpose 2d load only support 1 operand per inst on column. - // (vBlocks = 1) - numOperandsPer2DloadN = 1; - } - - // TODO: move this logic to the instr shape computation - // PVC 2D load supports 32 rows at most. Load multiple dot operands in by - // enlarging the tileHeight. - numOperandsPer2DLoadM = std::min(numOperandsPer2DLoadM, 32 / tileHeight); - tileHeight = tileHeight * numOperandsPer2DLoadM; - - // PVC 2D load supports 64 bytes per row at most. Load multiple dot operands - // by enlarging the vBlocks. - constexpr int MAX_WIDTH = 64; - unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8; - if (totalBytesPerRowPerDPASOp > MAX_WIDTH) - return failure(); - numOperandsPer2DloadN = - std::min(numOperandsPer2DloadN, MAX_WIDTH / totalBytesPerRowPerDPASOp); - - numOperandsOuterDimPerLoad = - isOperandA ? numOperandsPer2DLoadM : numOperandsPer2DloadN; - numOperandsInnerDimPerLoad = - isOperandA ? numOperandsPer2DloadN : numOperandsPer2DLoadM; - - LLVM_DEBUG({ - llvm::dbgs() << "numOperandsOuterDimPerLoad = " - << numOperandsOuterDimPerLoad << "\n"; - llvm::dbgs() << "numOperandsInnerDimPerLoad = " - << numOperandsInnerDimPerLoad << "\n"; - llvm::dbgs() << "vBlocks = " << vBlocks << "\n"; - }); - - tileLayout *= LinearLayout::identity1D(numOperandsOuterDimPerLoad, - kIteration, dimOuterStr); - tileLayout *= - LinearLayout::identity1D(isTransposeRequired && *oneMatrixPerLoadForBT - ? 1 - : numOperandsInnerDimPerLoad, - kIteration, dimInnerStr); - - LLVM_DEBUG({ - llvm::dbgs() << "Block load tile layout after adding iterations: " - << tileLayout << "\n"; - - for (size_t itr = 0; itr < tileLayout.getInDimSize(kIteration); ++itr) { - auto printTileLayoutVals = [&](const size_t offset) { - auto tensorVals = - tileLayout.apply({{kOffset, offset}, {kIteration, itr}}); - assert(tensorVals.size() == 2); - llvm::dbgs() << itr << ", " << offset << " : " << tensorVals[0].second - << ", " << tensorVals[1].second << "\n"; - }; - - printTileLayoutVals(0); - printTileLayoutVals(tileLayout.getInDimSize(kOffset) - 1); - } - llvm::dbgs() << "\n"; - }); - - if (isTransposeRequired) - std::swap(numOperandsOuterDimPerLoad, numOperandsInnerDimPerLoad); - - const unsigned numLoadPerOutRepCluster = - mlir::ceil(repCluster[dimOuter], numOperandsOuterDimPerLoad); - LLVM_DEBUG(llvm::dbgs() << "numLoadPerOutRepCluster = " - << numLoadPerOutRepCluster << "\n"); - - unsigned numValuesPerLoad = packedElemsPerLanePerDPASInst * - numOperandsOuterDimPerLoad * - numOperandsInnerDimPerLoad; - Type load2DGenXType = - LLVM::getVectorType(loadResultElemType, numValuesPerLoad); - - // The stride for the replicates. - unsigned repOuterStride = warpShape[dimOuter] * outerDimWarpNum; - unsigned repStride = - elemsPerDPASInst[dimOuter] * numOperandsOuterDimPerLoad; - unsigned warpOuterStride = warpShape[dimOuter]; - unsigned repKStride = elemsPerDPASInst[dimInner]; - LLVM_DEBUG({ - llvm::dbgs() << "outerDimWarpNum = " << outerDimWarpNum << "\n"; - llvm::dbgs() << "repOuterStride = " << repOuterStride << "\n"; - llvm::dbgs() << "repStride = " << repStride << "\n"; - llvm::dbgs() << "warpOuterStride = " << warpOuterStride << "\n"; - llvm::dbgs() << "repKStride = " << repKStride << "\n"; - }); - - unsigned numRepOuter = numReps[bool(opIdx) ? 2 : 1]; - unsigned numRepInner = numReps[bool(opIdx) ? 1 : 2]; - - LLVM_DEBUG({ - llvm::dbgs() << "numRepOuter = " << numRepOuter << "\n"; - llvm::dbgs() << "numRepInner = " << numRepInner << "\n"; - }); - - // For the kLoad dimension we create the basis vector directly, which allows - // us to control the stride between loads and create a non-surjective - // layout. - auto bases = tileLayout.getBases(); - std::vector> newLoadBases; - - SmallVector> outDims; - for (auto [name, size] : - llvm::zip(tileLayout.getOutDimNames(), tileLayout.getOutDimSizes())) { - outDims.push_back(std::make_pair(name, size)); - } - assert(outDims[0].first == S("dim0")); - assert(outDims[1].first == S("dim1")); - - for (size_t i = 0; - i < llvm::Log2_32(numRepInner / numOperandsInnerDimPerLoad); ++i) { - newLoadBases.push_back({0, static_cast((1 << i) * repKStride * - numOperandsInnerDimPerLoad)}); - outDims[1].second *= repKStride * numOperandsInnerDimPerLoad; - } - for (size_t i = 0; i < llvm::Log2_32(numLoadPerOutRepCluster); ++i) { - newLoadBases.push_back({static_cast((1 << i) * repStride), 0}); - outDims[0].second *= repStride; - } - for (size_t i = 0; i < llvm::Log2_32(numRepOuter); ++i) { - newLoadBases.push_back({static_cast((1 << i) * repOuterStride), 0}); - outDims[0].second *= repOuterStride; - } - - LLVM_DEBUG({ - llvm::dbgs() << "Created Load Bases:\n"; - for (auto &base : newLoadBases) { - assert(base.size() == 2); - llvm::dbgs() << base[0] << ", " << base[1] << "\n"; - } - }); - - LLVM_DEBUG({ - llvm::dbgs() << "New tile layout dimensions after adding load bases:\n"; - for (size_t i = 0; i < outDims.size(); ++i) { - llvm::dbgs() << outDims[i].first << " = " << outDims[i].second << "\n"; - } - }); - - // Disable building the load layout if we are not going to use it. Building - // the layout manually can cause an error which would abort the pass - // pipeline and block us from getting debug info. - // add the bases to the map and replace the tile layout with the new - // layout - bases[kLoad] = newLoadBases; - tileLayout = LinearLayout(bases, outDims, - /*requiredSurjective=*/false); - - LLVM_DEBUG({ - llvm::dbgs() << "Block load tile layout after adding loads: " - << tileLayout << "\n"; - for (size_t load = 0; load < tileLayout.getInDimSize(kLoad); ++load) { - for (size_t itr = 0; itr < tileLayout.getInDimSize(kIteration); ++itr) { - auto printTileLayoutVals = [&](const size_t offset) { - auto tensorVals = tileLayout.apply( - {{kOffset, offset}, {kIteration, itr}, {kLoad, load}}); - assert(tensorVals.size() == 2); - llvm::dbgs() << load << ", " << itr << ", " << offset << " : " - << tensorVals[0].second << ", " << tensorVals[1].second - << "\n"; - }; - - printTileLayoutVals(0); - printTileLayoutVals(tileLayout.getInDimSize(kOffset) - 1); - } - llvm::dbgs() << "\n"; - } - }); - - Value pitch; - if (memoryRowMajor) { - pitch = b.trunc(i32_ty, rowStride); - } else { - // Column major memory. We need to swap the width and height because HW - // only support row major memory layout. - pitch = b.trunc(i32_ty, colStride); - std::swap(baseWidth, baseHeight); - } - // HW requires the pitch to be at least 64 bytes. - if (auto pitchConst = mlir::triton::intel::getFoldedConstantValue(pitch)) { - if ((*pitchConst * elemSizeInBits / 8) < 64) - return failure(); - } - - baseWidth = b.trunc(i32_ty, baseWidth); - baseHeight = b.trunc(i32_ty, baseHeight); - - if (auto widthConst = - mlir::triton::intel::getFoldedConstantValue(baseWidth)) { - if ((*widthConst * elemSizeInBits / 8) < 64) - return failure(); - } - - const unsigned originalElemBits = elemSizeInBits; - if (isTransposeRequired) { - // adjust the block io parameter to align HW's limitations on - // transposing load. - elemSizeInBits = 32; - } - Value elemSizeInBytes = b.i32_val(originalElemBits / 8); - - LLVM_DEBUG({ - const unsigned numLoads = numRepOuter * numLoadPerOutRepCluster * - numRepInner / numOperandsInnerDimPerLoad; - llvm::dbgs() << "Preparing to dispatch " << numLoads << " loads\n"; - llvm::dbgs() << "Outer loads: " << numRepOuter * numLoadPerOutRepCluster - << " (" << numLoadPerOutRepCluster - << " per out rep cluster)\n"; - llvm::dbgs() << "Inner loads: " - << numRepInner / numOperandsInnerDimPerLoad << "\n"; - llvm::dbgs() << "Load dimension: " << tileHeight << ", " - << tileWidth * vBlocks << " (" << elemSizeInBits - << " bits)\n"; - }); - - ValueTable loadVals; - for (int outer = 0; outer < numRepOuter; ++outer) { - for (int rep = 0; rep < numLoadPerOutRepCluster; ++rep) { - for (int k = 0; k < numRepInner; k += numOperandsInnerDimPerLoad) { - LLVM_DEBUG({ - llvm::dbgs() << "outer, rep, k: " << outer << ", " << rep << ", " - << k << "\n"; - }); - - const int loadIdx = (outer * numLoadPerOutRepCluster * - (numRepInner / numOperandsInnerDimPerLoad)) + - rep * (numRepInner / numOperandsInnerDimPerLoad) + - k / numOperandsInnerDimPerLoad; - LLVM_DEBUG(llvm::dbgs() << "loadIdx: " << loadIdx << "\n"); - - const auto offset = tileLayout.apply( - {{kOffset, 0}, {kIteration, 0}, {kLoad, loadIdx}}); - assert(offset.size() == 2); - - const auto layoutOffsetX = offset[dimInner].second; - const auto layoutOffsetY = offset[dimOuter].second; - LLVM_DEBUG({ - llvm::dbgs() << "x offset ll: " << layoutOffsetX << "\n"; - llvm::dbgs() << "y offset ll: " << layoutOffsetY << "\n"; - }); - - Value offsetX, offsetY; - switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandA: { - LLVM_DEBUG({ - llvm::dbgs() << "x offset: " << k * repKStride << "\n"; - llvm::dbgs() << "y offset: " - << outer * repOuterStride + rep * repStride << "\n"; - }); - offsetY = b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), - b.i32_val(layoutOffsetY)); - offsetX = b.i32_val(layoutOffsetX); - } break; - case DpasEncodingAttr::OpIdx::OperandB: { - LLVM_DEBUG({ - llvm::dbgs() << "x offset: " - << outer * repOuterStride + rep * repStride << "\n"; - llvm::dbgs() << "y offset: " << k * repKStride << "\n"; - }); - offsetX = b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), - b.i32_val(layoutOffsetX)); - offsetY = b.i32_val(layoutOffsetY); - } break; - case DpasEncodingAttr::OpIdx::OperandC: { - llvm_unreachable("unexpected OpIdx::OperandC"); - } break; - } - - offsetX = b.add(offsetX, offsetBaseX); - offsetY = b.add(offsetY, offsetBaseY); - - if (!memoryRowMajor) { - // Column major memory. We need to swap the X and Y because HW only - // support row major memory layout. - std::swap(offsetX, offsetY); - } - - if (isTransposeRequired) { - // adjust the block io parameter to align HW's limitations on - // transposing load. - offsetX = b.udiv(offsetX, b.i32_val(32 / originalElemBits)); - } - - auto load2dOp = rewriter.create( - loc, load2DGenXType, - /*ptr*/ base, - /*base_width*/ b.mul(baseWidth, elemSizeInBytes), - /*base_height*/ baseHeight, - /*base_pitch*/ b.mul(pitch, elemSizeInBytes), - /*x*/ offsetX, - /*y*/ offsetY, - /*elem_size_in_bits*/ elemSizeInBits, - /*tile_width*/ tileWidth, - /*tile_height*/ tileHeight, - /*v_blocks*/ vBlocks, - /*transpose*/ isTransposeRequired, - /*vnni_transform*/ - (usePackedType && !isOperandA && !isTransposeRequired && - originalElemBits != 32)); - if (failed(load2dOp.verify())) { - // delete the op so that the verifier will not abort the pass - // pipeline later, as we can fail this path and try a different - // approach. - rewriter.eraseOp(load2dOp); - return failure(); - } - LLVM_DEBUG(llvm::dbgs() << "Generated load op: " << load2dOp << "\n"); - - unsigned packedRowNum = opIdx == DpasEncodingAttr::OpIdx::OperandA - ? numOperandsOuterDimPerLoad - : numOperandsInnerDimPerLoad; - unsigned packedColNum = opIdx == DpasEncodingAttr::OpIdx::OperandA - ? numOperandsInnerDimPerLoad - : numOperandsOuterDimPerLoad; - - // Decompose the return value to multiple operands. - unsigned packedColNumPerVBlock = packedColNum / vBlocks; - for (int vblk = 0; vblk < vBlocks; ++vblk) - for (int row = 0; row < packedRowNum; ++row) - for (int col = 0; col < packedColNumPerVBlock; ++col) { - - unsigned operandStartOffset = (vblk * packedRowNum + row) * - packedColNumPerVBlock * - packedElemsPerLanePerDPASInst; - - SmallVector indices(packedElemsPerLanePerDPASInst); - for (int elemIdx = 0; elemIdx < packedElemsPerLanePerDPASInst; - ++elemIdx) { - indices[elemIdx] = operandStartOffset + - elemIdx * packedColNumPerVBlock + col; - LLVM_DEBUG({ - llvm::dbgs() << "indices[" << elemIdx << "]" << " = " - << indices[elemIdx] << "\n"; - }); - } - DenseI32ArrayAttr attr = rewriter.getDenseI32ArrayAttr(indices); - Value loadVal = rewriter.create( - loc, packedDPASOperandType, load2dOp, load2dOp, attr); - - // Save the decomposed vals to the map; - switch (opIdx) { - case DpasEncodingAttr::OpIdx::OperandA: { - LLVM_DEBUG({ - llvm::dbgs() << "load vals index: " - << std::to_string(outer * packedRowNum * - numLoadPerOutRepCluster + - rep * packedRowNum + row) - << ", " - << std::to_string( - k + vblk * packedColNumPerVBlock + col) - << "\n"; - }); - loadVals[{outer * packedRowNum * numLoadPerOutRepCluster + - rep * packedRowNum + row, - k + vblk * packedColNumPerVBlock + col}] = - b.bitcast(loadVal, unpackedDPASOperandType); - } break; - case DpasEncodingAttr::OpIdx::OperandB: { - LLVM_DEBUG({ - llvm::dbgs() - << "load vals index: " - << std::to_string(outer * packedColNum * - numLoadPerOutRepCluster + - rep * packedColNum + - vblk * packedColNumPerVBlock + col) - << ", " << std::to_string(k + row) << "\n"; - }); - loadVals[{outer * packedColNum * numLoadPerOutRepCluster + - rep * packedColNum + - vblk * packedColNumPerVBlock + col, - k + row}] = - b.bitcast(loadVal, unpackedDPASOperandType); - } break; - case DpasEncodingAttr::OpIdx::OperandC: { - llvm_unreachable("unexpected OpIdx::OperandC"); - } break; - } - } - } - } - } - - // Extract the value returned by the load ops. And put the values in the - // expected order for the layout. - SmallVector unpackedLoadedVals; - for (int outer = 0; outer < numRepOuter; ++outer) { - for (int k = 0; k < numRepInner; ++k) { - for (int rep = 0; rep < repCluster[unsigned(opIdx)]; ++rep) { - if (loadVals.find({outer * repCluster[unsigned(opIdx)] + rep, k}) == - loadVals.end()) { - // generate a nice error message before the throw below aborts our - // pipeline - llvm::errs() << "Failed to find key at " - << outer * repCluster[unsigned(opIdx)] + rep << ", " - << k << "\n"; - } - Value loadVal = - loadVals.at({outer * repCluster[unsigned(opIdx)] + rep, k}); - VectorType loadTy = cast(loadVal.getType()); - for (int i = 0; i < loadTy.getNumElements(); ++i) { - auto val = b.extract_element(loadVal, b.i32_val(i)); - unpackedLoadedVals.push_back(val); - } - } - } - } - - Type llvmResultStructTy = typeConverter->convertType(op.getType()); - Value resultStruct = packLLElements(loc, typeConverter, unpackedLoadedVals, - rewriter, llvmResultStructTy); - rewriter.replaceOp(op, {resultStruct}); - - return success(); - } - LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { @@ -1901,13 +1242,6 @@ struct LoadOpToBlockIOConversion if (op.getPadding() && op.getPadding() == PaddingOption::PAD_NAN) return failure(); - Value ptr = op.getPtr(); - if (isTensorPointerType(ptr.getType())) { - if (!isBlockIOCandidate(op)) - return failure(); - return rewriteTensorPointerLoad(op, adaptor, rewriter); - } - static const bool enableBlockIOForAllLayout = triton::tools::getBoolEnv("TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS"); if (!isBlockIOCandidate(op, enableBlockIOForAllLayout)) @@ -1974,9 +1308,13 @@ struct LoadOpToBlockIOConversion /*requireSurjective=*/true); // Get the LLVM values for pointers + Value ptr = op.getPtr(); Value llPtr = adaptor.getPtr(); - SmallVector ptrElems = unpackLLElements(loc, llPtr, rewriter); unsigned numElems = getTotalElemsPerThread(resultType); + SmallVector unpackedPtr = + unpackLLElements(ptr.getLoc(), llPtr, rewriter); + SmallVector ptrElems = + getBases(rewriter, ptr, unpackedPtr, numElems); assert(ptrElems.size() == numElems && "the number of pointer values is not matched with the number of " "elements"); @@ -2056,7 +1394,7 @@ struct LoadOpToBlockIOConversion } } - // Get the LLVM values for `other` + // Get the LLVM values for other Value other = op.getOther(); SmallVector otherElems; Value llOther = adaptor.getOther(); @@ -2129,17 +1467,38 @@ struct LoadOpToBlockIOConversion Type load2DGenXType = LLVM::getVectorType(packedType, numValuesPerLoad); Type unpackedType = LLVM::getVectorType(eltTy, numElemsPerLoad); - Value pitch = - getPitch(rewriter, ptr, elemSizeInBits, memoryRowMajor ? 0 : 1); + Value pitch = getPitch(rewriter, ptr, unpackedPtr, elemSizeInBits, + memoryRowMajor ? 0 : 1); if (!pitch) return failure(); - // If the stride is 0, we want to load only the first row. - int stride = getStride(ptr, memoryRowMajor ? 0 : 1); - unsigned baseHeightInt = (stride == 0 ? 1 : tileHeight); - Value baseHeight = b.i32_val(baseHeightInt); - Value baseWidth = - b.i32_val(vBlocks * tileWidth * (packedElemSizeInBits / 8)); + SmallVector BaseOffsets = getBaseOffsets(rewriter, ptr, unpackedPtr); + + Value baseWidth, baseHeight; + unsigned baseHeightInt; + if (isTensorPointerType(ptr.getType())) { + // The block pointer struct is expected to have the following layout: + // Struct { + // Value offset[rank]; + // Value shape[rank]; + // Value stride[rank]; + // Value base; + // } + assert((unpackedPtr.size() - 1) % 3 == 0 && + "unexpected number of values unpacked from a block pointer"); + unsigned rank = (unpackedPtr.size() - 1) / 3; + unsigned blockOffset = 0, blockShape = 1 * rank, blockStride = 2 * rank, + blockBase = 3 * rank; + baseWidth = b.trunc(i32_ty, unpackedPtr[memoryRowMajor ? 1 : 0]); + baseHeight = b.trunc(i32_ty, unpackedPtr[memoryRowMajor ? 0 : 1]); + baseWidth = b.mul(baseWidth, b.i32_val(elemSizeInBits / 8)); + } else { + // If the stride is 0, we want to load only the first row. + int stride = getStride(ptr, memoryRowMajor ? 0 : 1); + baseHeightInt = (stride == 0 ? 1 : tileHeight); + baseHeight = b.i32_val(baseHeightInt); + baseWidth = b.i32_val(vBlocks * tileWidth * (packedElemSizeInBits / 8)); + } bool useVNNIFormat = false; Type packedDPASOperandType; @@ -2264,29 +1623,61 @@ struct LoadOpToBlockIOConversion // Use the top-left address of the block to load the data. Value addrElem = ptrElems[registerIdx]; - addrElem = targetInfo.shuffleIdx(rewriter, loc, addrElem, 0); - - // Adjust the baseWidth, offsetX and base address use the original base - // of the BLOCK. - Value offsetX = offsets[isTransposeRequired ? rowDim : colDim].second; - Value offsetY = b.i32_val(0); - Value negOffsetX = b.sub(b.i32_val(0), offsetX); - addrElem = b.gep(ptr_ty(ctx, 1), eltTy, addrElem, negOffsetX); - // The offset is in number of original elements. So we need to scale it - // by element bytes size. - Value adjustedBaseWidth = - b.add(baseWidth, b.mul(offsetX, b.i32_val(elemSizeInBits / 8))); - adjustedBaseWidth = b.umax(adjustedBaseWidth, b.i32_val(64)); + Value offsetX, offsetY; + Value adjustedBaseWidth = baseWidth, adjustedBaseHeight = baseHeight; Value pred; - if (maskElems.size()) { - pred = targetInfo.shuffleIdx(rewriter, loc, maskElems[registerIdx], 0); - // We leverage the GPU block I/O hardware out-of-bound protection - // feature by setting the offset to an invalid value when 'pred' - // is false (the HW will not read out-of-bounds values). Later on, - // after issuing the 2d block read operation, we will select the - // result of the load only if the mask evaluate to true, otherwise - // we will use 'other'. - offsetY = b.select(pred, offsetY, baseHeight); + if (isTensorPointerType(ptr.getType())) { + unsigned c = isTransposeRequired ? rowDim : colDim; + unsigned r = isTransposeRequired ? colDim : rowDim; + offsetX = b.add(BaseOffsets[c], offsets[c].second); + offsetY = b.add(BaseOffsets[r], offsets[r].second); + + // To prevent triggering hardware boundary protection, expand the base + // shape sufficiently when boundary check is absent. + SetVector boundaryCheck(op.getBoundaryCheck().begin(), + op.getBoundaryCheck().end()); + + if (!boundaryCheck.contains(c)) { + adjustedBaseWidth = b.i32_val( + std::max(64u, vBlocks * tileWidth * (packedElemSizeInBits / 8))); + // The offsetX is number of elements instead of packed elements. + addrElem = b.gep(ptr_ty(ctx, 1), eltTy, addrElem, offsetX); + offsetX = b.i32_val(0); + } + if (!boundaryCheck.contains(r)) { + adjustedBaseHeight = b.i32_val(tileHeight); + // Use i8_ty as pitch is in number of bytes. + Value off = b.mul(offsetY, pitch); + addrElem = b.gep(ptr_ty(ctx, 1), i8_ty, addrElem, off); + offsetY = b.i32_val(0); + } + } else { + addrElem = targetInfo.shuffleIdx(rewriter, loc, addrElem, 0); + + // Adjust the baseWidth, offsetX and base address use the original base + // of the BLOCK. + offsetX = offsets[isTransposeRequired ? rowDim : colDim].second; + offsetY = b.i32_val(0); + Value negOffsetX = b.sub(b.i32_val(0),offsetX); + addrElem = b.gep(ptr_ty(ctx, 1), eltTy, addrElem, negOffsetX); + // The offset is in number of original elements. So we need to scale it + // by element bytes size. + adjustedBaseWidth = + b.add(baseWidth, b.mul(offsetX, b.i32_val(elemSizeInBits / 8))); + adjustedBaseWidth = b.umax(adjustedBaseWidth, b.i32_val(64)); + // Use the top-left address and mask of the block to store the data. + // (The first value refer by the registerIdx.) + if (maskElems.size()) { + pred = + targetInfo.shuffleIdx(rewriter, loc, maskElems[registerIdx], 0); + // We leverage the GPU block I/O hardware out-of-bound protection + // feature by setting the offset to an invalid value when 'pred' + // is false (the HW will not read out-of-bounds values). Later on, + // after issuing the 2d block read operation, we will select the + // result of the load only if the mask evaluate to true, otherwise + // we will use 'other'. + offsetY = b.select(pred, offsetY, baseHeight); + } } assert(numPackedVals > 0 && "numPackedVals should be greater than zero."); @@ -2294,7 +1685,7 @@ struct LoadOpToBlockIOConversion loc, load2DGenXType, /*ptr*/ addrElem, /*base_width*/ adjustedBaseWidth, - /*base_height*/ baseHeight, + /*base_height*/ adjustedBaseHeight, /*base_pitch*/ pitch, // offsetX was in terms of original elements. The 2d block io requires // offsetX to be in terms of packed elements. @@ -2307,45 +1698,48 @@ struct LoadOpToBlockIOConversion /*transpose*/ isTransposeRequired, /*vnni_transform*/ !isTransposeRequired && useVNNIFormat); - // When strides[0] is 0, we only want to load the first row, so we - // set the base height to be 1. If tile height is bigger than 1, - // then only the first row contain valid data. To ensure the entire - // tile is filled with valid data, we must replicate the first row - // throughout the tile. - if (baseHeightInt < tileHeight && baseHeightInt == 1) { - unsigned numIndicesPerMatrix = numValuesPerLoad / vBlocks; - SmallVector shuffleIndices(numValuesPerLoad); - - // Create a vector to store the data of the first index of each - // matrix. - VectorType vecTy = vec_ty(packedType, vBlocks); - Value firstIndexVec = b.undef(vecTy); - - for (unsigned valueIndex = 0; valueIndex < numValuesPerLoad; - ++valueIndex) { - unsigned firstIndexVecIdx = valueIndex / numIndicesPerMatrix; - // Handle case where an index spans two rows. - if (valueIndex % numIndicesPerMatrix == 0) { - Value oldVal = b.extract_element(ret, b.i32_val(valueIndex)); - Value newVal = oldVal; - if (tileWidth < threadsPerWarp) { - assert(tileWidth * 2 == threadsPerWarp && - "Expecting tileWidth to be 2x threadsPerWarp"); - Value threadId = getThreadId(rewriter, loc); - newVal = - targetInfo.shuffleIdx(rewriter, loc, oldVal, - b.urem(threadId, b.i32_val(tileWidth))); + if (!isTensorPointerType(ptr.getType())) { + // When strides[0] is 0, we only want to load the first row, so we + // set the base height to be 1. If tile height is bigger than 1, + // then only the first row contain valid data. To ensure the entire + // tile is filled with valid data, we must replicate the first row + // throughout the tile. + if (baseHeightInt < tileHeight && baseHeightInt == 1) { + unsigned numIndicesPerMatrix = numValuesPerLoad / vBlocks; + SmallVector shuffleIndices(numValuesPerLoad); + + // Create a vector to store the data of the first index of each + // matrix. + VectorType vecTy = vec_ty(packedType, vBlocks); + Value firstIndexVec = b.undef(vecTy); + + for (unsigned valueIndex = 0; valueIndex < numValuesPerLoad; + ++valueIndex) { + unsigned firstIndexVecIdx = valueIndex / numIndicesPerMatrix; + // Handle case where an index spans two rows. + if (valueIndex % numIndicesPerMatrix == 0) { + Value oldVal = b.extract_element(ret, b.i32_val(valueIndex)); + Value newVal = oldVal; + if (tileWidth < threadsPerWarp) { + assert(tileWidth * 2 == threadsPerWarp && + "Expecting tileWidth to be 2x threadsPerWarp"); + Value threadId = getThreadId(rewriter, loc); + newVal = targetInfo.shuffleIdx( + rewriter, loc, oldVal, + b.urem(threadId, b.i32_val(tileWidth))); + } + firstIndexVec = + b.insert_element(firstIndexVec.getType(), firstIndexVec, + newVal, b.i32_val(firstIndexVecIdx)); } - firstIndexVec = - b.insert_element(firstIndexVec.getType(), firstIndexVec, newVal, - b.i32_val(firstIndexVecIdx)); - } - shuffleIndices[valueIndex] = firstIndexVecIdx; + shuffleIndices[valueIndex] = firstIndexVecIdx; + } + DenseI32ArrayAttr attr = + rewriter.getDenseI32ArrayAttr(shuffleIndices); + ret = rewriter.create( + loc, load2DGenXType, firstIndexVec, firstIndexVec, attr); } - DenseI32ArrayAttr attr = rewriter.getDenseI32ArrayAttr(shuffleIndices); - ret = rewriter.create( - loc, load2DGenXType, firstIndexVec, firstIndexVec, attr); } unsigned numElemsPerUnpackedType = @@ -2734,7 +2128,8 @@ struct StoreOpToBlockIOConversion baseWidth = b.i32_val(vBlocks * tileWidth * (packedElemSizeInBits / 8)); baseHeight = b.i32_val(tileHeight); - pitch = getPitch(rewriter, ptr, elemSizeInBits, memoryRowMajor ? 0 : 1); + pitch = getPitch(rewriter, ptr, ptrElems, elemSizeInBits, + memoryRowMajor ? 0 : 1); if (!pitch) return failure(); offsetBaseX = b.i32_val(0);