diff --git a/projects/hipblaslt/tensilelite/client/cpu_gemm_driver.cpp b/projects/hipblaslt/tensilelite/client/cpu_gemm_driver.cpp index df43b9ff579a..3709106b2cd2 100644 --- a/projects/hipblaslt/tensilelite/client/cpu_gemm_driver.cpp +++ b/projects/hipblaslt/tensilelite/client/cpu_gemm_driver.cpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include "ProgramOptions.hpp" @@ -118,6 +119,14 @@ namespace }; #endif +#ifndef _WIN32 + template <> + struct TypeTraits + { + static constexpr rocisa::DataType value = rocisa::DataType::Float4; + }; +#endif + // A slow, easy to understand, golden reference implementation of GEMM. // Used strictly for validating the correctness of the optimized path. // Calculates, for each element (i, j): @@ -129,7 +138,12 @@ namespace // * scaleAlphaVec[factorDim == 0 ? i : j] (if scaleAlphaVec != nullptr) // // scaleA is always indexed by row (M), scaleB always by col (N). - // factorDim only affects scaleAlphaVec: 0 = row-dim (length M), 1 = col-dim (length N). + // factorDim only affects scaleAlphaVec: 0 = row-dim (length M), 1 = col-dim (length N). + // + // When mxBlockA/B > 0 and mxScaleA/mxScaleB are non-null, the K-reduction + // is block-structured: accumulate min(mxBlockA, mxBlockB) products, then + // scale by the MX scale factors before adding to the running sum. + // Quantize a value through `Narrow`, then return as float. This mirrors // what the GPU MFMA path does when storage is wider than the MAC input // type (e.g. Half stored, F8 used in MFMA). @@ -170,14 +184,26 @@ namespace bool transB, float alpha, float beta, - const float* biasVec = nullptr, - const float* scaleAlphaVec = nullptr, - ActivationType activation = ActivationType::None, - const float* scaleAVec = nullptr, - const float* scaleBVec = nullptr, - int factorDim = 0, - QuantizeFn quantizeA = nullptr, - QuantizeFn quantizeB = nullptr) + const float* biasVec = nullptr, + const float* scaleAlphaVec = nullptr, + ActivationType activation = ActivationType::None, + const float* scaleAVec = nullptr, + const float* scaleBVec = nullptr, + int factorDim = 0, + QuantizeFn quantizeA = nullptr, + QuantizeFn quantizeB = nullptr +#ifndef _WIN32 + , + const E8* mxScaleA = nullptr, + const E8* mxScaleB = nullptr, + int mxBlockA = 0, + int mxBlockB = 0, + size_t mxScaleAStrideM = 0, + size_t mxScaleAStrideKBlk = 0, + size_t mxScaleBStrideN = 0, + size_t mxScaleBStrideKBlk = 0 +#endif + ) { switch(activation) @@ -200,14 +226,61 @@ namespace for(size_t j = 0; j < n; j++) { float sum = 0.0f; - for(size_t l = 0; l < k; l++) + +#ifndef _WIN32 + // MX scale tensor layout follows the ContractionProblem tensor + // descriptors. Even with padScaleTensor=false, setMXScaleA/B + // can pad the scale-K dimension, so using k/mxBlock as the + // free-dimension stride would address the wrong scale element. + // + // Both mxBlockA and mxBlockB are required to be > 0 (#1) and + // powers of 2 (validated in runGemm). When they differ, one + // divides the other; step the inner reduction by the smaller + // of the two so each inner segment has constant (sa, sb) and + // can pick the correct scale index for each side. + if(mxBlockA > 0 && mxBlockB > 0 && mxScaleA && mxScaleB) { - float aVal = a[i * strideAM + l * strideAK]; - float bVal = b[l * strideBK + j * strideBN]; - if(quantizeA) aVal = quantizeA(aVal); - if(quantizeB) bVal = quantizeB(bVal); - sum += aVal * bVal; + size_t step = static_cast( + std::min(mxBlockA, mxBlockB)); + for(size_t lBase = 0; lBase < k; lBase += step) + { + float blockSum = 0.0f; + for(size_t t = 0; t < step; t++) + { + size_t l = lBase + t; + float aVal = a[i * strideAM + l * strideAK]; + float bVal = b[l * strideBK + j * strideBN]; + if(quantizeA) aVal = quantizeA(aVal); + if(quantizeB) bVal = quantizeB(bVal); + blockSum += aVal * bVal; + } + + size_t blkA = lBase / static_cast(mxBlockA); + size_t blkB = lBase / static_cast(mxBlockB); + + size_t mxsaIdx = i * mxScaleAStrideM + + blkA * mxScaleAStrideKBlk; + size_t mxsbIdx = j * mxScaleBStrideN + + blkB * mxScaleBStrideKBlk; + + float mxScale = static_cast(mxScaleA[mxsaIdx]) + * static_cast(mxScaleB[mxsbIdx]); + sum += blockSum * mxScale; + } } + else +#endif + { + for(size_t l = 0; l < k; l++) + { + float aVal = a[i * strideAM + l * strideAK]; + float bVal = b[l * strideBK + j * strideBN]; + if(quantizeA) aVal = quantizeA(aVal); + if(quantizeB) bVal = quantizeB(bVal); + sum += aVal * bVal; + } + } + float effectiveAlpha = alpha; if(scaleAVec) effectiveAlpha *= scaleAVec[i]; @@ -253,8 +326,16 @@ int runGemm(size_t m, const std::string& useScaleAB, int factorDim, rocisa::DataType computeInputA = rocisa::DataType::None, - rocisa::DataType computeInputB = rocisa::DataType::None) + rocisa::DataType computeInputB = rocisa::DataType::None, + int mxBlockA = 0, + int mxBlockB = 0, + size_t batchCount = 1) { + if(batchCount < 1) + { + std::cerr << "Error: batchCount (" << batchCount << ") must be >= 1" << std::endl; + return 1; + } constexpr rocisa::DataType dtypeEnumA = TypeTraits::value; constexpr rocisa::DataType dtypeEnumB = TypeTraits::value; if(computeInputA == rocisa::DataType::None) computeInputA = dtypeEnumA; @@ -262,11 +343,78 @@ int runGemm(size_t m, static_assert(std::is_same::value, "Currently only float accumulation is supported"); +#ifndef _WIN32 + constexpr bool isInputAFP4 = std::is_same_v; + constexpr bool isInputBFP4 = std::is_same_v; + static_assert(isInputAFP4 == isInputBFP4, + "FP4 input storage must be used for both A and B, or neither."); + constexpr bool isFP4 = isInputAFP4; +#else + constexpr bool isFP4 = false; +#endif + + if constexpr(!isFP4) + { + mxBlockA = 0; + mxBlockB = 0; + } + + if constexpr(isFP4) + { + // One-sided MX (only A or only B scaled) is not supported by either + // reference path; they would disagree about what one-sided MX means. + if((mxBlockA > 0) != (mxBlockB > 0)) + { + std::cerr << "Error: one-sided MX is not supported " + << "(mxBlockA=" << mxBlockA << ", mxBlockB=" << mxBlockB + << "); set both > 0 or both 0." << std::endl; + return 1; + } + auto checkSide = [&](const char* name, int b) -> int { + if(b <= 0) return 0; + if((b & (b - 1)) != 0) + { + std::cerr << "Error: " << name << " (" << b << ") must be a power of 2" + << std::endl; + return 1; + } + if(k < static_cast(b)) + { + std::cerr << "Error: K (" << k << ") must be >= " << name << " (" << b << ")" + << std::endl; + return 1; + } + if(k % static_cast(b) != 0) + { + std::cerr << "Error: K (" << k << ") must be a multiple of " << name + << " (" << b << ")" << std::endl; + return 1; + } + return 0; + }; + if(int rc = checkSide("mxBlockA", mxBlockA)) return rc; + if(int rc = checkSide("mxBlockB", mxBlockB)) return rc; + + // Asymmetric MX (mxBlockA != mxBlockB) is only supported on the fast + // path. The production slow path's MX inner loop uses a single scale + // per max(mxBlockA, mxBlockB)-sized segment, which collapses the + // smaller-blocked side's per-segment scales onto the first one and + // produces wrong results. Reject the combination at the driver rather + // than ship a known-wrong slow path. + if(mxBlockA != mxBlockB && !tryFastPath) + { + std::cerr << "Error: asymmetric MX (mxBlockA=" << mxBlockA + << " != mxBlockB=" << mxBlockB + << ") is only supported on the fast path " + << "(use --tryFastPath)." << std::endl; + return 1; + } + } + // Calculate strides assuming standard column-major packed storage size_t lda = transA ? k : m; size_t ldb = transB ? n : k; size_t ldc = m; - size_t batchCount = 1; // Define the contraction problem (geometry, strides, types) ContractionProblemGemm contraction @@ -275,7 +423,7 @@ int runGemm(size_t m, dtypeEnumA, dtypeEnumB, rocisa::DataType::Float, - rocisa::DataType::Float, // A, B, C, D types + rocisa::DataType::Float, m, n, k, @@ -295,11 +443,30 @@ int runGemm(size_t m, contraction.setAlphaType(rocisa::DataType::Float); contraction.setBetaType(rocisa::DataType::Float); - // Allocate host memory for inputs and outputs - std::vector a(m * k); - std::vector b(k * n); - std::vector c(m * n); - std::vector d(m * n); + // Allocate host memory for inputs and outputs. Each batch slice is packed. + size_t numA = m * k; + size_t numB = k * n; + size_t numC = m * n; + + size_t storageA, storageB; +#ifndef _WIN32 + if constexpr(isFP4) + { + // Packed batch stride: 2 nibbles per byte, packed per batch slice. + storageA = ((numA + 1) / 2) * batchCount; + storageB = ((numB + 1) / 2) * batchCount; + } + else +#endif + { + storageA = numA * batchCount; + storageB = numB * batchCount; + } + + std::vector a(storageA); + std::vector b(storageB); + std::vector c(numC * batchCount); + std::vector d(numC * batchCount); // Initialize inputs with random values. We use ±1 (binary) for A and B by // default because it is exactly representable in every supported storage @@ -311,6 +478,9 @@ int runGemm(size_t m, // are NOT on the F8 grid - otherwise the quantization step has nothing to // do and the bug being tested for can't be reproduced. We give an operand // such values when its storage type is wider than its computeInput type. + // + // For FP4 with mxBlockA/B>0 (mxfp4), inputs are drawn from the discrete + // E2M1-representable value set so the MX-scale logic is exercised. size_t seed = 42; std::mt19937 gen(seed); std::uniform_int_distribution<> binary_distribution(0, 1); @@ -318,27 +488,77 @@ int runGemm(size_t m, auto randomGen = [&]() { return binary_distribution(gen) ? 1.0f : -1.0f; }; - auto initOperand = [&](auto& vec, bool quantizes) { - using T = typename std::decay_t::value_type; - if(quantizes) - { - // Values representable in storage but not on the compute-input grid - - // for storage=Half/compute=F8N, values like 0.7 that Half holds - // exactly but F8N rounds to 0.625 or 0.75. - std::generate(vec.begin(), vec.end(), - [&]() { return static_cast(realDist(gen)); }); - } - else - { - std::generate(vec.begin(), vec.end(), - [&]() { return static_cast(randomGen()); }); - } - }; +#ifndef _WIN32 + if constexpr(isFP4) + { + // Full E2M1-representable value set: ±0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6. + // Drawing from the entire grid (not just the powers of two near zero) + // exercises the MX-scale path with values whose products span more of + // the FP4 range, while still being exactly representable. + constexpr float fp4Values[] + = {-6.0f, -4.0f, -3.0f, -2.0f, -1.5f, -1.0f, -0.5f, 0.0f, + 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f}; + constexpr int fp4ValueCount = sizeof(fp4Values) / sizeof(fp4Values[0]); + std::uniform_int_distribution<> fp4Dist(0, fp4ValueCount - 1); + auto randomFp4 = [&]() { return fp4Values[fp4Dist(gen)]; }; + + // Pack 2 logical FP4 values per byte (Float4x2). When the logical + // element count is odd, the second slot of the last byte has no + // element behind it — it's padding. We must still initialize that + // slot to a valid FP4 value (we use 0), because the fast-path + // ShadowBuffer FP4 decoder unconditionally reads both slots of every + // byte (the guard is on the *write* back, not the read), and reading + // uninitialized memory would be UB. + auto initFp4Operand = [&](auto& vec, size_t numLogical) { + const size_t storage = vec.size(); + const bool hasOddTail = (numLogical % 2 != 0); + for(size_t i = 0; i < storage; ++i) + { + const bool isPaddingSlot = hasOddTail && (i == storage - 1); + float slot0 = randomFp4(); + float slot1 = isPaddingSlot ? 0.0f : randomFp4(); + vec[i] = Float4x2(slot0, slot1); + } + }; + initFp4Operand(a, numA); + initFp4Operand(b, numB); + } + else +#endif + { + auto initOperand = [&](auto& vec, bool quantizes) { + using T = typename std::decay_t::value_type; +#ifndef _WIN32 + if constexpr(std::is_same_v) + { + // FP4 mixed-input init unsupported in this branch; the FP4-only + // path above handles the pure FP4 case. Mixed FP4/non-FP4 + // dispatch is rejected at the dispatcher level. + throw std::runtime_error( + "Mixed FP4 / non-FP4 input is not supported."); + } + else +#endif + if(quantizes) + { + // Values representable in storage but not on the compute-input grid - + // for storage=Half/compute=F8N, values like 0.7 that Half holds + // exactly but F8N rounds to 0.625 or 0.75. + std::generate(vec.begin(), vec.end(), + [&]() { return static_cast(realDist(gen)); }); + } + else + { + std::generate(vec.begin(), vec.end(), + [&]() { return static_cast(randomGen()); }); + } + }; - bool quantizesA = (sizeof(InputAT) > 1) && (computeInputA != dtypeEnumA); - bool quantizesB = (sizeof(InputBT) > 1) && (computeInputB != dtypeEnumB); - initOperand(a, quantizesA); - initOperand(b, quantizesB); + bool quantizesA = (sizeof(InputAT) > 1) && (computeInputA != dtypeEnumA); + bool quantizesB = (sizeof(InputBT) > 1) && (computeInputB != dtypeEnumB); + initOperand(a, quantizesA); + initOperand(b, quantizesB); + } std::generate(c.begin(), c.end(), [&]() { return static_cast(randomGen()); }); // Optional feature buffers @@ -403,12 +623,59 @@ int runGemm(size_t m, contraction.setParams().setActivationEnum(activation); } +#ifndef _WIN32 + // MX scale setup (FP4 with mxBlockA/B > 0 only) + [[maybe_unused]] std::vector mxsa, mxsb; + + if constexpr(isFP4) + { + if(mxBlockA > 0 || mxBlockB > 0) + { + // Use unpadded MX scale tensors so the columnMajorGemm reference + // indexing matches: mxsa = {m, k/mxBlockA} with m as leading + // stride (and analogous for B). Default padScaleTensor=true would + // round M up to next 32 and K/mxBlockA/B up to next 8, breaking + // the index math below. + contraction.setMXScaleA(rocisa::DataType::E8, mxBlockA, /*saStride=*/{}, /*padScaleTensor=*/false); + contraction.setMXScaleB(rocisa::DataType::E8, mxBlockB, /*sbStride=*/{}, /*padScaleTensor=*/false); + + size_t nmxsa = contraction.mxsa().totalLogicalElements(); + size_t nmxsb = contraction.mxsb().totalLogicalElements(); + + if(nmxsa == 0 || nmxsb == 0) + { + std::cerr << "Error: MX scale tensor has zero elements (nmxsa=" << nmxsa + << ", nmxsb=" << nmxsb << ")" << std::endl; + return 1; + } + + mxsa.resize(nmxsa); + mxsb.resize(nmxsb); + + // Distinct exponents in [0..7] so wrong indexing breaks validation + std::uniform_int_distribution<> expDist(0, 7); + for(size_t i = 0; i < nmxsa; i++) + mxsa[i] = E8(std::ldexp(1.0f, expDist(gen))); + for(size_t i = 0; i < nmxsb; i++) + mxsb[i] = E8(std::ldexp(1.0f, expDist(gen))); + } + } +#endif + ContractionInputs inputs(a.data(), b.data(), c.data(), d.data(), alpha, beta); inputs.bias = useBias ? biasVec.data() : nullptr; inputs.scaleAlphaVec = useScaleAlphaVec ? scaleAlphaVecBuf.data() : nullptr; inputs.scaleA = (useScaleAB != "none") ? scaleABuf.data() : nullptr; inputs.scaleB = (useScaleAB != "none") ? scaleBBuf.data() : nullptr; +#ifndef _WIN32 + if constexpr(isFP4) + { + inputs.mxsa = (mxBlockA > 0) ? mxsa.data() : nullptr; + inputs.mxsb = (mxBlockB > 0) ? mxsb.data() : nullptr; + } +#endif + auto start = std::chrono::high_resolution_clock::now(); if(tryFastPath && !TensileLite::Client::isFastPathEligible(contraction)) @@ -432,9 +699,40 @@ int runGemm(size_t m, { std::cout << "Validating..." << std::endl; - // Convert inputs to f32 for the golden reference comparison - std::vector aF32(a.begin(), a.end()); - std::vector bF32(b.begin(), b.end()); + // Convert inputs to f32 for the golden reference comparison. + // For batched problems, A/B/C are batchCount slices of size numA/numB/numC + // (column-major packed; batch stride = numA / numB / numC). + size_t totalA = numA * batchCount; + size_t totalB = numB * batchCount; + size_t totalC = numC * batchCount; + + std::vector aF32, bF32; + +#ifndef _WIN32 + if constexpr(isFP4) + { + aF32.resize(totalA); + for(size_t i = 0; i < totalA; i++) + aF32[i] = a[i / 2].getElement(i % 2); + bF32.resize(totalB); + for(size_t i = 0; i < totalB; i++) + bF32[i] = b[i / 2].getElement(i % 2); + } + else if constexpr(std::is_same_v || std::is_same_v) + { + throw std::runtime_error("Mixed FP4 / non-FP4 input is not supported."); + } + else +#endif + { + aF32.resize(totalA); + for(size_t i = 0; i < totalA; i++) + aF32[i] = static_cast(a[i]); + bF32.resize(totalB); + for(size_t i = 0; i < totalB; i++) + bF32[i] = static_cast(b[i]); + } + std::vector cF32(c.begin(), c.end()); std::vector dRef(d.size()); @@ -444,38 +742,83 @@ int runGemm(size_t m, QuantizeFn quantA = (computeInputA != dtypeEnumA) ? quantizerFor(computeInputA) : nullptr; QuantizeFn quantB = (computeInputB != dtypeEnumB) ? quantizerFor(computeInputB) : nullptr; - // Run the golden reference - columnMajorGemm(aF32.data(), - bF32.data(), - cF32.data(), - dRef.data(), - m, - n, - k, - transA, - transB, - (useScaleAB == "Scalar") ? alpha * scaleABuf[0] * scaleBBuf[0] : alpha, - beta, - useBias ? biasVec.data() : nullptr, - useScaleAlphaVec ? scaleAlphaVecBuf.data() : nullptr, - activation, - (useScaleAB == "Vector") ? scaleABuf.data() : nullptr, - (useScaleAB == "Vector") ? scaleBBuf.data() : nullptr, - factorDim, - quantA, - quantB); - - // Compare results +#ifndef _WIN32 + size_t mxsaBatchStride = 0, mxsbBatchStride = 0; + size_t mxsaStrideM = 0, mxsaStrideKBlk = 0; + size_t mxsbStrideN = 0, mxsbStrideKBlk = 0; + if constexpr(isFP4) + { + if(mxBlockA > 0) + { + auto const& mxsaTensor = contraction.mxsa(); + mxsaStrideM = mxsaTensor.strides()[contraction.freeIndicesA()[0].i]; + mxsaStrideKBlk = mxsaTensor.strides()[contraction.boundIndices()[0].a]; + mxsaBatchStride = mxsaTensor.strides()[contraction.batchIndices()[0].a]; + } + if(mxBlockB > 0) + { + auto const& mxsbTensor = contraction.mxsb(); + mxsbStrideN = mxsbTensor.strides()[contraction.freeIndicesB()[0].i]; + mxsbStrideKBlk = mxsbTensor.strides()[contraction.boundIndices()[0].b]; + mxsbBatchStride = mxsbTensor.strides()[contraction.batchIndices()[0].b]; + } + } +#endif + + // Run the golden reference per-batch. + for(size_t batch = 0; batch < batchCount; ++batch) + { + const float* aPtr = aF32.data() + batch * numA; + const float* bPtr = bF32.data() + batch * numB; + const float* cPtr = cF32.data() + batch * numC; + float* dPtr = dRef.data() + batch * numC; + + columnMajorGemm(aPtr, + bPtr, + cPtr, + dPtr, + m, + n, + k, + transA, + transB, + (useScaleAB == "Scalar") ? alpha * scaleABuf[0] * scaleBBuf[0] : alpha, + beta, + useBias ? biasVec.data() : nullptr, + useScaleAlphaVec ? scaleAlphaVecBuf.data() : nullptr, + activation, + (useScaleAB == "Vector") ? scaleABuf.data() : nullptr, + (useScaleAB == "Vector") ? scaleBBuf.data() : nullptr, + factorDim, + quantA, + quantB +#ifndef _WIN32 + , + (isFP4 && mxBlockA > 0) ? mxsa.data() + batch * mxsaBatchStride : nullptr, + (isFP4 && mxBlockB > 0) ? mxsb.data() + batch * mxsbBatchStride : nullptr, + mxBlockA, + mxBlockB, + mxsaStrideM, + mxsaStrideKBlk, + mxsbStrideN, + mxsbStrideKBlk +#endif + ); + } + + // Compare results — FP4 with MX scales needs wider tolerance + float tolerance = isFP4 ? 0.5f : 0.05f; + bool allClose = true; float maxDiff = 0.0f; - for(size_t i = 0; i < m * n; i++) + for(size_t i = 0; i < totalC; i++) { float valDut = static_cast(d[i]); float valRef = dRef[i]; float diff = std::abs(valDut - valRef); - if(diff > 0.05f) + if(diff > tolerance) { allClose = false; maxDiff = std::max(maxDiff, diff); @@ -513,7 +856,7 @@ int main(int argc, char* argv[]) "transB", po::value()->default_value(false), "Transpose B")( "alpha", po::value()->default_value(1.0f), "Alpha scalar")( "beta", po::value()->default_value(0.0f), "Beta scalar")( - "type", po::value()->default_value("f32"), "Data type for A and B (f32, f16, bf16, f8, bf8, f8fnuz, bf8fnuz)")( + "type", po::value()->default_value("f32"), "Data type for A and B (f32, f16, bf16, f8, bf8, f8fnuz, bf8fnuz, f4)")( "typeA", po::value()->default_value(""), "Override A storage type (defaults to --type)")( "typeB", po::value()->default_value(""), "Override B storage type (defaults to --type)")( "computeInputA", po::value()->default_value(""), "Override A compute-input type for MAC (defaults to --typeA). Set smaller than storage to mimic kernels that quantize A.")( @@ -524,7 +867,10 @@ int main(int argc, char* argv[]) "activation", po::value()->default_value("none"), "Activation (none, relu)")( "scaleAlphaVec", po::value()->default_value(false), "Enable per-row alpha scaling")( "factorDim", po::value()->default_value(0), "ScaleAlphaVec dimension: 0=row(M), 1=col(N)")( - "useScaleAB", po::value()->default_value("none"), "ScaleAB mode (none, Scalar, Vector)"); + "useScaleAB", po::value()->default_value("none"), "ScaleAB mode (none, Scalar, Vector)")( + "mxBlockA", po::value()->default_value(0), "MX block size for the A side (FP4 only, must be power of 2; both --mxBlockA and --mxBlockB must be set together)")( + "mxBlockB", po::value()->default_value(0), "MX block size for the B side (FP4 only, must be power of 2; both --mxBlockA and --mxBlockB must be set together)")( + "batchCount", po::value()->default_value(1), "Batch count (default 1)"); po::variables_map vm; try @@ -563,6 +909,8 @@ int main(int argc, char* argv[]) auto strToDataType = [](const std::string& s, rocisa::DataType& out) -> bool { if(s == "f32") { out = rocisa::DataType::Float; return true; } + if(s == "f64") { out = rocisa::DataType::Double; return true; } + if(s == "tf32") { out = rocisa::DataType::Float; return true; } if(s == "f16") { out = rocisa::DataType::Half; return true; } if(s == "bf16") { out = rocisa::DataType::BFloat16; return true; } #ifdef TENSILE_USE_FP8_BF8 @@ -570,6 +918,9 @@ int main(int argc, char* argv[]) if(s == "bf8") { out = rocisa::DataType::BFloat8; return true; } if(s == "f8fnuz") { out = rocisa::DataType::Float8_fnuz; return true; } if(s == "bf8fnuz") { out = rocisa::DataType::BFloat8_fnuz; return true; } +#endif +#ifndef _WIN32 + if(s == "f4") { out = rocisa::DataType::Float4; return true; } #endif return false; }; @@ -590,6 +941,39 @@ int main(int argc, char* argv[]) bool useScaleAlphaVec = vm["scaleAlphaVec"].as(); int factorDim = vm["factorDim"].as(); std::string useScaleAB = vm["useScaleAB"].as(); + int mxBlockA = vm["mxBlockA"].as(); + int mxBlockB = vm["mxBlockB"].as(); + size_t batchCount = vm["batchCount"].as(); + + if(mxBlockA < 0 || mxBlockB < 0) + { + std::cerr << "Error: mxBlockA/mxBlockB must be non-negative" << std::endl; + return 1; + } + // One-sided MX is rejected (see review #1). When either per-side flag is + // given, require both > 0. + if((mxBlockA > 0) != (mxBlockB > 0)) + { + std::cerr << "Error: --mxBlockA and --mxBlockB must both be > 0 " + << "(mxBlockA=" << mxBlockA << ", mxBlockB=" << mxBlockB << ")" + << std::endl; + return 1; + } + + if((mxBlockA > 0 || mxBlockB > 0) && typeStr != "f4") + { + std::cerr << "Error: mxBlockA/mxBlockB is only supported for type f4, not " + << typeStr << std::endl; + return 1; + } + +#ifndef _WIN32 + if((typeAStr == "f4") != (typeBStr == "f4")) + { + std::cerr << "Error: mixed FP4 / non-FP4 input is not supported." << std::endl; + return 1; + } +#endif if(useScaleAB != "none" && useScaleAB != "Scalar" && useScaleAB != "Vector") { @@ -624,10 +1008,21 @@ int main(int argc, char* argv[]) using AT = decltype(aTag); auto callB = [&](auto bTag) -> int { using BT = decltype(bTag); +#ifndef _WIN32 + constexpr bool isMixedFP4 = std::is_same_v + != std::is_same_v; + if constexpr(isMixedFP4) + { + std::cerr << "Error: mixed FP4 / non-FP4 input is not supported." + << std::endl; + return 1; + } + else +#endif return runGemm( m, n, k, transA, transB, alpha, beta, validate, tryFastPath, useBias, activation, useScaleAlphaVec, useScaleAB, factorDim, - computeInputA, computeInputB); + computeInputA, computeInputB, mxBlockA, mxBlockB, batchCount); }; if(typeBStr == "f32") return callB(float{}); if(typeBStr == "f16") return callB(Half{}); @@ -637,6 +1032,9 @@ int main(int argc, char* argv[]) if(typeBStr == "bf8") return callB(BFloat8{}); if(typeBStr == "f8fnuz") return callB(Float8_fnuz{}); if(typeBStr == "bf8fnuz") return callB(BFloat8_fnuz{}); +#endif +#ifndef _WIN32 + if(typeBStr == "f4") return callB(Float4x2{}); #endif std::cerr << "Unknown typeB: " << typeBStr << std::endl; return 1; @@ -652,6 +1050,9 @@ int main(int argc, char* argv[]) if(typeAStr == "bf8") return dispatchB(BFloat8{}); if(typeAStr == "f8fnuz") return dispatchB(Float8_fnuz{}); if(typeAStr == "bf8fnuz") return dispatchB(BFloat8_fnuz{}); +#endif +#ifndef _WIN32 + if(typeAStr == "f4") return dispatchB(Float4x2{}); #endif std::cerr << "Unknown typeA: " << typeAStr << std::endl; return 1; diff --git a/projects/hipblaslt/tensilelite/client/src/Reference.cpp b/projects/hipblaslt/tensilelite/client/src/Reference.cpp index d984e9875104..08ec4c537050 100644 --- a/projects/hipblaslt/tensilelite/client/src/Reference.cpp +++ b/projects/hipblaslt/tensilelite/client/src/Reference.cpp @@ -194,6 +194,33 @@ namespace TensileLite m_storage = loadToFloat(ptr, N); m_ptr = m_storage.data(); } +#endif +#ifndef _WIN32 + else if(type == rocisa::DataType::Float4) + { + // Dense, lane-contiguous FP4 packing: N logical elements + // stored in (N+1)/2 Float4x2 words. One intrinsic call per word. + // + // NOTE: when N is odd, the trailing Float4x2 word is read + // unconditionally (both nibbles), but only the low nibble + // is written to m_storage (the high nibble's `v.y` is + // discarded by the `if(2*w+1 < N)` guard). Callers must + // therefore ensure the upper nibble of that trailing word + // is initialized (zero-padded or otherwise valid storage); + // it must NOT be out-of-bounds memory, since the intrinsic + // still reads the full byte. + m_storage.resize(N); + const Float4x2* fp4 = static_cast(ptr); + for(size_t w = 0; w < (N + 1) / 2; ++w) + { + auto v = __amd_cvt_fp4x2_to_floatx2_scale( + fp4[w].data, __AMD_OCP_E2M1, 0); + m_storage[2 * w] = v.x; + if(2 * w + 1 < N) + m_storage[2 * w + 1] = v.y; + } + m_ptr = m_storage.data(); + } #endif else { @@ -219,20 +246,6 @@ namespace TensileLite return m_ptr; } - // Return a writable copy of the shadow buffer. Copies from the - // direct pointer when the buffer aliases user-owned float storage. - float* ensureMutable(size_t count) - { - if(m_ptr == nullptr) - return nullptr; - if(m_storage.empty()) - m_storage.assign(m_ptr, m_ptr + count); - else if(m_storage.size() < count) - throw std::runtime_error("ShadowBuffer ensureMutable: count exceeds storage"); - m_ptr = m_storage.data(); - return m_storage.data(); - } - explicit operator bool() const { return m_ptr != nullptr; @@ -245,62 +258,6 @@ namespace TensileLite } }; - // Pre-multiply shadow A/B by MX block scales (rocroller ScaledCPUMM style) - // so the fast f32 GEMM path matches the slow per-element reference. - void applyMXScaleToShadow(ShadowBuffer& shadow, - ContractionProblemGemm const& problem, - TensorDescriptor const& dataTensor, - TensorDescriptor const& scaleTensor, - void const* scaleBase, - rocisa::DataType mxType, - int mxBlock, - bool forMatrixA) - { - if(mxBlock <= 0 || scaleBase == nullptr) - return; - - float* dataPtr = shadow.ensureMutable(dataTensor.totalAllocatedElements()); - if(dataPtr == nullptr) - return; - - auto const& boundIndices = problem.boundIndices(); - - std::vector boundSize(boundIndices.size()); - for(size_t i = 0; i < boundIndices.size(); ++i) - boundSize[i] = problem.boundSize(i); - -#pragma omp parallel for - for(ptrdiff_t elemNumber = 0; - elemNumber < static_cast(dataTensor.totalLogicalElements()); - ++elemNumber) - { - std::vector dataCoord(dataTensor.dimensions()); - std::vector scaleCoord(scaleTensor.dimensions()); - - CoordNumbered(static_cast(elemNumber), - dataCoord.begin(), - dataCoord.end(), - dataTensor.sizes().begin(), - dataTensor.sizes().end()); - - for(size_t d = 0; d < dataCoord.size(); ++d) - scaleCoord[d] = dataCoord[d]; - - for(size_t i = 0; i < boundIndices.size(); ++i) - { - auto const& bi = boundIndices[i]; - size_t pos = forMatrixA ? bi.a : bi.b; - size_t val = scaleCoord[pos]; - if(forMatrixA ? bi.aMirror : bi.bMirror) - val = boundSize[i] - val - 1; - scaleCoord[pos] = val / static_cast(mxBlock); - } - - size_t dataIdx = dataTensor.index(dataCoord); - size_t scaleIdx = scaleTensor.index(scaleCoord); - dataPtr[dataIdx] *= mxScaleElementAsFloat(mxType, scaleBase, scaleIdx); - } - } } namespace Client @@ -1187,6 +1144,10 @@ namespace TensileLite || t == rocisa::DataType::Float8_fnuz || t == rocisa::DataType::BFloat8_fnuz) return true; +#endif +#ifndef _WIN32 + if(t == rocisa::DataType::Float4) + return true; #endif return false; }; @@ -1204,6 +1165,41 @@ namespace TensileLite return rejectFast(detail.c_str()); } + constexpr size_t FAST_BLOCK_K = 8; + size_t mxBlockA = problem.mxBlockA(); + size_t mxBlockB = problem.mxBlockB(); + +#ifndef _WIN32 + if(isMXFP4Problem(problem) && problem.a().dataType() != problem.b().dataType()) + return rejectFast("mixed_mxfp4_input_types"); +#endif + + if(mxBlockA > 0 || mxBlockB > 0) + { + // One-sided MX (only A or only B scaled) is not supported. The + // slow path's columnMajorGemm reference also rejects this case, + // so the two paths agree on what "MX" means. + if((mxBlockA > 0) != (mxBlockB > 0)) + return rejectFast("one_sided_mx_not_supported"); + + if(mxBlockA > 0 && mxBlockA % FAST_BLOCK_K != 0) + return rejectFast("mxBlockA_not_aligned_to_BLOCK_K"); + if(mxBlockB > 0 && mxBlockB % FAST_BLOCK_K != 0) + return rejectFast("mxBlockB_not_aligned_to_BLOCK_K"); + + size_t sizeK = problem.boundSize(0); + if(mxBlockA > 0 && sizeK % mxBlockA != 0) + return rejectFast("K_not_multiple_of_mxBlockA"); + if(mxBlockB > 0 && sizeK % mxBlockB != 0) + return rejectFast("K_not_multiple_of_mxBlockB"); + } + + if(problem.boundIndices().size() >= 1) + { + if(problem.boundIndices()[0].aMirror || problem.boundIndices()[0].bMirror) + return rejectFast("mirror_indices"); + } + if(problem.batchIndices().empty()) { return rejectFast("no_batch_indices"); @@ -1320,29 +1316,6 @@ namespace TensileLite ShadowBuffer shadowC( inputs.c, problem.c().dataType(), problem.c().totalAllocatedElements()); - if(problem.mxBlockA() > 0 && inputs.mxsa != nullptr) - { - applyMXScaleToShadow(shadowA, - problem, - problem.a(), - problem.mxsa(), - inputs.mxsa, - problem.mxTypeA(), - problem.mxBlockA(), - /*forMatrixA=*/true); - } - if(problem.mxBlockB() > 0 && inputs.mxsb != nullptr) - { - applyMXScaleToShadow(shadowB, - problem, - problem.b(), - problem.mxsb(), - inputs.mxsb, - problem.mxTypeB(), - problem.mxBlockB(), - /*forMatrixA=*/false); - } - std::vector shadowD; float* ptrD = nullptr; if(problem.d().dataType() == rocisa::DataType::Float) @@ -1391,6 +1364,35 @@ namespace TensileLite } } + // MX block-scaling metadata (FP4 with MX) + size_t mxBlockA = problem.mxBlockA(); + size_t mxBlockB = problem.mxBlockB(); + bool hasMX = (mxBlockA > 0) || (mxBlockB > 0); + const E8* mxsaPtr + = (mxBlockA > 0) ? static_cast(inputs.mxsa) : nullptr; + const E8* mxsbPtr + = (mxBlockB > 0) ? static_cast(inputs.mxsb) : nullptr; + size_t strideMxsaM = 0, strideMxsaBlk = 0; + size_t strideMxsbN = 0, strideMxsbBlk = 0; + size_t strideBatchMxsa = 0, strideBatchMxsb = 0; + if(hasMX) + { + if(mxBlockA > 0) + { + strideMxsaM = problem.mxsa().strides()[indexMA]; + strideMxsaBlk = problem.mxsa().strides()[indexKA]; + strideBatchMxsa + = problem.mxsa().strides()[problem.batchIndices()[0].a]; + } + if(mxBlockB > 0) + { + strideMxsbN = problem.mxsb().strides()[indexNB]; + strideMxsbBlk = problem.mxsb().strides()[indexKB]; + strideBatchMxsb + = problem.mxsb().strides()[problem.batchIndices()[0].b]; + } + } + constexpr size_t BLOCK_M = 32; constexpr size_t BLOCK_N = 32; constexpr size_t BLOCK_K = 8; @@ -1409,6 +1411,12 @@ namespace TensileLite const float* curBatchB = shadowB.data() + (b * strideBatchB); const float* curBatchC = shadowC.data() + (b * strideBatchC); float* curBatchD = ptrD + (b * strideBatchD); + + const E8* mxsaBatch + = mxsaPtr ? mxsaPtr + b * strideBatchMxsa : nullptr; + const E8* mxsbBatch + = mxsbPtr ? mxsbPtr + b * strideBatchMxsb : nullptr; + for(size_t m = 0; m < mTiles; ++m) { auto m0 = m * BLOCK_M; @@ -1419,11 +1427,8 @@ namespace TensileLite std::array aReg = {0}; std::array bReg = {0}; std::array cReg = {0}; - for(size_t k = 0; k < kTiles; ++k) - { - auto k0 = k * BLOCK_K; - // Populate A 'registers': + auto loadATile = [&](size_t k0) { for(size_t km = 0; km < BLOCK_K; ++km) { for(size_t mm = 0; mm < BLOCK_M; ++mm) @@ -1432,7 +1437,8 @@ namespace TensileLite size_t global_m = m0 + mm; if(global_k < sizeK && global_m < sizeM) { - auto offset = global_m * strideMA + global_k * strideKA; + auto offset + = global_m * strideMA + global_k * strideKA; aReg[km * BLOCK_M + mm] = curBatchA[offset]; } else @@ -1441,8 +1447,9 @@ namespace TensileLite } } } + }; - // Populate B 'registers': + auto loadBTile = [&](size_t k0) { for(size_t kn = 0; kn < BLOCK_K; ++kn) { for(size_t nn = 0; nn < BLOCK_N; ++nn) @@ -1451,8 +1458,8 @@ namespace TensileLite size_t global_n = n0 + nn; if(global_k < sizeK && global_n < sizeN) { - bReg[kn * BLOCK_N + nn] - = curBatchB[global_n * strideNB + global_k * strideKB]; + bReg[kn * BLOCK_N + nn] = curBatchB + [global_n * strideNB + global_k * strideKB]; } else { @@ -1460,29 +1467,102 @@ namespace TensileLite } } } + }; - // Perform matrix multiplication accumulation with k as inner-most (fastest) - // dimension for both A and B. A, B, and C of sizes defined by BLOCK_M, BLOCK_N, BLOCK_K. - // Store result in row-major order. - auto innerReduction = [BLOCK_M, BLOCK_N, BLOCK_K]( - const float* A, const float* B, float* C) { - for(size_t k_i = 0; k_i < BLOCK_K; ++k_i) + auto innerReduction = [BLOCK_M, BLOCK_N, BLOCK_K]( + const float* A, + const float* B, + float* C) { + for(size_t k_i = 0; k_i < BLOCK_K; ++k_i) + { + for(size_t m_i = 0; m_i < BLOCK_M; ++m_i) { - for(size_t m_i = 0; m_i < BLOCK_M; ++m_i) + for(size_t n_i = 0; n_i < BLOCK_N; ++n_i) { - for(size_t n_i = 0; n_i < BLOCK_N; ++n_i) - { - auto b_index = k_i * BLOCK_N + n_i; - auto a_index = k_i * BLOCK_M + m_i; - auto c_index = m_i * BLOCK_N + n_i; - float valB = B[b_index]; - float valA = A[a_index]; - C[c_index] += valA * valB; - } + auto b_index = k_i * BLOCK_N + n_i; + auto a_index = k_i * BLOCK_M + m_i; + auto c_index = m_i * BLOCK_N + n_i; + float valB = B[b_index]; + float valA = A[a_index]; + C[c_index] += valA * valB; + } + } + } + }; + + if(hasMX) + { + // Iterate per BLOCK_K tile and apply MX scales + // per tile. Each mxBlock is a multiple of BLOCK_K + // (enforced by isFastPathEligible), so all k values + // within a tile share the same scale. This correctly + // handles asymmetric mxBlockA != mxBlockB. + // + // Invariant: K is a multiple of BLOCK_K. This holds + // transitively from mxBlockA/B % FAST_BLOCK_K == 0 + // plus K % mxBlockA/B == 0 (both checked in + // isFastPathEligible). Assert it here so a future + // edit to the eligibility check can't silently + // produce a partial-tile MX block with the wrong + // scale alignment. + assert(sizeK % BLOCK_K == 0 + && "MX fast path requires K divisible by BLOCK_K"); + for(size_t k = 0; k < kTiles; ++k) + { + auto k0 = k * BLOCK_K; + loadATile(k0); + loadBTile(k0); + + std::array tilePartial = {0}; + innerReduction( + aReg.data(), bReg.data(), tilePartial.data()); + + size_t mxsaI + = (mxBlockA > 0) ? k0 / mxBlockA : 0; + size_t mxsbI + = (mxBlockB > 0) ? k0 / mxBlockB : 0; + + for(size_t mm = 0; mm < BLOCK_M; ++mm) + { + size_t global_m = m0 + mm; + if(global_m >= sizeM) + continue; + + float sa = (mxBlockA > 0 && mxsaBatch) + ? static_cast(mxsaBatch + [global_m * strideMxsaM + + mxsaI * strideMxsaBlk]) + : 1.0f; + + for(size_t nn = 0; nn < BLOCK_N; ++nn) + { + size_t global_n = n0 + nn; + if(global_n >= sizeN) + continue; + + float sb = (mxBlockB > 0 && mxsbBatch) + ? static_cast(mxsbBatch + [global_n * strideMxsbN + + mxsbI * strideMxsbBlk]) + : 1.0f; + + cReg[mm * BLOCK_N + nn] + += tilePartial[mm * BLOCK_N + nn] + * sa * sb; } } - }; - innerReduction(aReg.data(), bReg.data(), cReg.data()); + } + } + else + { + for(size_t k = 0; k < kTiles; ++k) + { + auto k0 = k * BLOCK_K; + loadATile(k0); + loadBTile(k0); + innerReduction( + aReg.data(), bReg.data(), cReg.data()); + } } // Copy from cReg back. diff --git a/projects/hipblaslt/tensilelite/tests/CMakeLists.txt b/projects/hipblaslt/tensilelite/tests/CMakeLists.txt index 2e3296d9b2d3..c022273e5b86 100644 --- a/projects/hipblaslt/tensilelite/tests/CMakeLists.txt +++ b/projects/hipblaslt/tensilelite/tests/CMakeLists.txt @@ -297,3 +297,161 @@ if(NOT WIN32) ) endforeach() endif() + +################################################# +# FP4 / MX Scale tests (fast & slow) # +################################################# +if(NOT WIN32) + # M/N similar to f32 tests; K must be a multiple of mxBlock (32). + set(F4_SIZES_32 -M 133 -N 147 -K 160) + set(F4_SIZES_64 -M 133 -N 147 -K 192) + + foreach(PATH fast slow) + if("${PATH}" STREQUAL "fast") + set(PATH_FLAGS --tryFastPath) + else() + set(PATH_FLAGS "") + endif() + + # Basic F4 with MX scaling, all transpose combos + foreach(TRANS NN TN NT TT) + get_trans_flags(${TRANS} TFLAGS) + add_cpu_gemm_test(${PATH}_f4_MX32_${TRANS} + ARGS ${F4_SIZES_32} ${TFLAGS} --type f4 --mxBlockA 32 --mxBlockB 32 --validate ${PATH_FLAGS} + ) + add_cpu_gemm_test(${PATH}_f4_MX32_K192_${TRANS} + ARGS ${F4_SIZES_64} ${TFLAGS} --type f4 --mxBlockA 32 --mxBlockB 32 --validate ${PATH_FLAGS} + ) + endforeach() + + # F4 without MX (mxBlock=0, the default) + add_cpu_gemm_test(${PATH}_f4_NoMX_NN + ARGS -M 16 -N 16 -K 16 --type f4 --validate ${PATH_FLAGS} + ) + + # F4 + MX + epilogue features + add_cpu_gemm_test(${PATH}_f4_MX32_Bias + ARGS ${F4_SIZES_32} --type f4 --mxBlockA 32 --mxBlockB 32 --validate --bias ${PATH_FLAGS} + ) + add_cpu_gemm_test(${PATH}_f4_MX32_Relu + ARGS ${F4_SIZES_32} --type f4 --mxBlockA 32 --mxBlockB 32 --validate --activation relu ${PATH_FLAGS} + ) + add_cpu_gemm_test(${PATH}_f4_MX32_ScaleAlphaVec + ARGS ${F4_SIZES_32} --type f4 --mxBlockA 32 --mxBlockB 32 --validate --scaleAlphaVec ${PATH_FLAGS} + ) + add_cpu_gemm_test(${PATH}_f4_MX32_AllFeatures + ARGS ${F4_SIZES_32} --type f4 --mxBlockA 32 --mxBlockB 32 --validate --bias --activation relu --scaleAlphaVec ${PATH_FLAGS} + ) + add_cpu_gemm_test(${PATH}_f4_MX32_Beta + ARGS ${F4_SIZES_32} --type f4 --mxBlockA 32 --mxBlockB 32 --beta 1.0 --validate ${PATH_FLAGS} + ) + + # F4 + MX + ScaleAB + foreach(MODE Scalar Vector) + add_cpu_gemm_test(${PATH}_f4_MX32_ScaleAB_${MODE} + ARGS ${F4_SIZES_32} --type f4 --mxBlockA 32 --mxBlockB 32 --validate --useScaleAB ${MODE} ${PATH_FLAGS} + ) + endforeach() + endforeach() + + # Fast-path-only: mxBlock=16 (tests non-hardcoded mxBlock; 16 % BLOCK_K=8 == 0) + add_cpu_gemm_test(fast_f4_MX16_NN + ARGS ${F4_SIZES_32} --type f4 --mxBlockA 16 --mxBlockB 16 --validate --tryFastPath + ) + + # Batched mxfp4 (batchCount > 1) — exercises strideBatchMxsa/strideBatchMxsb + # in the fast path's MX loop (the site of a previous .a / .b typo bug). + foreach(PATH fast slow) + if("${PATH}" STREQUAL "fast") + set(PATH_FLAGS --tryFastPath) + else() + set(PATH_FLAGS "") + endif() + foreach(TRANS NN TT) + get_trans_flags(${TRANS} TFLAGS) + add_cpu_gemm_test(${PATH}_f4_MX32_Batched_${TRANS} + ARGS ${F4_SIZES_32} ${TFLAGS} --type f4 --mxBlockA 32 --mxBlockB 32 --validate --batchCount 2 ${PATH_FLAGS} + ) + endforeach() + endforeach() + + # Asymmetric MX (mxBlockA != mxBlockB). K must be a multiple of + # lcm(mxBlockA, mxBlockB) — here lcm(32, 64) = 64, which F4_SIZES_64 + # already satisfies (K=192). Fast path only: the production slow path + # collapses the smaller side's per-segment scales (see driver rejection + # in runGemm). + add_cpu_gemm_test(fast_f4_MX_A32_B64_NN + ARGS ${F4_SIZES_64} --type f4 --mxBlockA 32 --mxBlockB 64 --validate --tryFastPath + ) + add_cpu_gemm_test(fast_f4_MX_A64_B32_NN + ARGS ${F4_SIZES_64} --type f4 --mxBlockA 64 --mxBlockB 32 --validate --tryFastPath + ) + + # Negative test: slow path must reject asymmetric MX. + add_test( + NAME "CPUGemm.slow_f4_MX_A32_B64_Rejected" + COMMAND cpu-gemm-driver ${F4_SIZES_64} --type f4 --mxBlockA 32 --mxBlockB 64 --validate + ) + set_tests_properties("CPUGemm.slow_f4_MX_A32_B64_Rejected" PROPERTIES WILL_FAIL TRUE) + + # Negative test: one-sided MX (only A or only B) must be rejected. + add_test( + NAME "CPUGemm.f4_MX_OneSided_Rejected" + COMMAND cpu-gemm-driver ${F4_SIZES_32} --type f4 --mxBlockA 32 --validate + ) + set_tests_properties("CPUGemm.f4_MX_OneSided_Rejected" PROPERTIES WILL_FAIL TRUE) + + # Edge-size mxfp4 coverage: K equal to mxBlock (single MX block per row), + # M=1 / N=1 strips, and alpha=0 short-circuit. Both fast and slow paths. + foreach(PATH fast slow) + if("${PATH}" STREQUAL "fast") + set(PATH_FLAGS --tryFastPath) + else() + set(PATH_FLAGS "") + endif() + add_cpu_gemm_test(${PATH}_f4_MX32_K_eq_mxBlock + ARGS -M 16 -N 16 -K 32 --type f4 --mxBlockA 32 --mxBlockB 32 --validate ${PATH_FLAGS} + ) + add_cpu_gemm_test(${PATH}_f4_MX32_M1 + ARGS -M 1 -N 147 -K 160 --type f4 --mxBlockA 32 --mxBlockB 32 --validate ${PATH_FLAGS} + ) + add_cpu_gemm_test(${PATH}_f4_MX32_N1 + ARGS -M 133 -N 1 -K 160 --type f4 --mxBlockA 32 --mxBlockB 32 --validate ${PATH_FLAGS} + ) + add_cpu_gemm_test(${PATH}_f4_MX32_Alpha0 + ARGS ${F4_SIZES_32} --type f4 --mxBlockA 32 --mxBlockB 32 --alpha 0 --validate ${PATH_FLAGS} + ) + endforeach() + + # F4 with odd dimensions (M*K is odd) to exercise ShadowBuffer edge cases + add_cpu_gemm_test(fast_f4_NoMX_OddDims + ARGS -M 3 -N 5 -K 7 --type f4 --validate --tryFastPath + ) + + # Negative tests: invalid K/mxBlock combos should fail with clear error + add_test( + NAME "CPUGemm.slow_f4_MX32_BadK_Unaligned" + COMMAND cpu-gemm-driver -M 16 -N 16 -K 33 --type f4 --mxBlockA 32 --mxBlockB 32 --validate + ) + set_tests_properties("CPUGemm.slow_f4_MX32_BadK_Unaligned" PROPERTIES WILL_FAIL TRUE) + + add_test( + NAME "CPUGemm.slow_f4_MX32_BadK_TooSmall" + COMMAND cpu-gemm-driver -M 16 -N 16 -K 16 --type f4 --mxBlockA 32 --mxBlockB 32 --validate + ) + set_tests_properties("CPUGemm.slow_f4_MX32_BadK_TooSmall" PROPERTIES WILL_FAIL TRUE) + + # Negative test: non-power-of-2 mxBlock should be rejected + add_test( + NAME "CPUGemm.slow_f4_MX6_BadPow2" + COMMAND cpu-gemm-driver -M 4 -N 4 -K 12 --type f4 --mxBlockA 6 --mxBlockB 6 --validate + ) + set_tests_properties("CPUGemm.slow_f4_MX6_BadPow2" PROPERTIES WILL_FAIL TRUE) + + # Negative test: mxBlock with non-FP4 types should be rejected + add_test( + NAME "CPUGemm.slow_f32_MX32_Rejected" + COMMAND cpu-gemm-driver -M 16 -N 16 -K 32 --type f32 --mxBlockA 32 --mxBlockB 32 --validate + ) + set_tests_properties("CPUGemm.slow_f32_MX32_Rejected" PROPERTIES WILL_FAIL TRUE) +endif() diff --git a/projects/hipblaslt/tensilelite/tests/ReferenceMXFastPath_test.cpp b/projects/hipblaslt/tensilelite/tests/ReferenceMXFastPath_test.cpp index 23c8d24f8c54..f86d73182668 100644 --- a/projects/hipblaslt/tensilelite/tests/ReferenceMXFastPath_test.cpp +++ b/projects/hipblaslt/tensilelite/tests/ReferenceMXFastPath_test.cpp @@ -16,12 +16,17 @@ using namespace TensileLite::Client; namespace { - ContractionProblemGemm makeMXFP8Problem(size_t M, size_t N, size_t K, int mxBlock) + ContractionProblemGemm makeMXProblem(rocisa::DataType typeA, + rocisa::DataType typeB, + size_t M, + size_t N, + size_t K, + int mxBlock) { auto problem = ContractionProblemGemm::GEMM_Strides(false, false, - rocisa::DataType::Float8, - rocisa::DataType::Float8, + typeA, + typeB, rocisa::DataType::Float, rocisa::DataType::Float, M, @@ -40,8 +45,8 @@ namespace problem.setMXScaleA(rocisa::DataType::E8, mxBlock, {}, /*padScaleTensor=*/false); problem.setMXScaleB(rocisa::DataType::E8, mxBlock, {}, /*padScaleTensor=*/false); - problem.setComputeInputTypeA(rocisa::DataType::Float8); - problem.setComputeInputTypeB(rocisa::DataType::Float8); + problem.setComputeInputTypeA(typeA); + problem.setComputeInputTypeB(typeB); problem.setAlphaType(rocisa::DataType::Float); problem.setBetaType(rocisa::DataType::Float); return problem; @@ -71,6 +76,30 @@ namespace } } // namespace +#ifndef _WIN32 + +TEST(ReferenceMXFastPath, RejectsMixedInputTypesWithMXFP4) +{ + const size_t M = 64; + const size_t N = 64; + const size_t K = 128; + const int mxBlock = 32; + + auto problemA = makeMXProblem( + rocisa::DataType::Float4, rocisa::DataType::Float, M, N, K, mxBlock); + EXPECT_FALSE(isFastPathEligible(problemA)); + + auto problemB = makeMXProblem( + rocisa::DataType::Float, rocisa::DataType::Float4, M, N, K, mxBlock); + EXPECT_FALSE(isFastPathEligible(problemB)); + + auto problemBoth = makeMXProblem( + rocisa::DataType::Float4, rocisa::DataType::Float4, M, N, K, mxBlock); + EXPECT_TRUE(isFastPathEligible(problemBoth)); +} + +#endif + #ifdef TENSILE_USE_FP8_BF8 TEST(ReferenceMXFastPath, MatchesSlowPathForScaledFP8Gemm) @@ -80,7 +109,8 @@ TEST(ReferenceMXFastPath, MatchesSlowPathForScaledFP8Gemm) const size_t K = 128; const int mxBlock = 32; - auto problem = makeMXFP8Problem(M, N, K, mxBlock); + auto problem = makeMXProblem( + rocisa::DataType::Float8, rocisa::DataType::Float8, M, N, K, mxBlock); ASSERT_TRUE(isFastPathEligible(problem)); std::vector a(M * K); @@ -118,7 +148,8 @@ TEST(ReferenceMXFastPath, MatchesSlowPathWithBetaAndBias) const size_t K = 96; const int mxBlock = 32; - auto problem = makeMXFP8Problem(M, N, K, mxBlock); + auto problem = makeMXProblem( + rocisa::DataType::Float8, rocisa::DataType::Float8, M, N, K, mxBlock); problem.setUseBias(1); problem.setBias(rocisa::DataType::Float, M, M); ASSERT_TRUE(isFastPathEligible(problem));