diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 01042776c971..b8bf845fb4f8 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -56,6 +56,113 @@ CodeGenMetal::CodeGenMetal(Target target) : target_(target) { << "};\n\n"; } +// Inline MSL helpers for storage-only FP8 emulation (e4m3 / e5m2). +// Apple Silicon (M4 Max and earlier; M5 NAX is FP16/INT8 only) has NO native +// FP8 ALU support, so FP8 is realised as `uchar` storage with explicit +// dequantize-on-load / quantize-on-store. The helpers mirror the IEEE 754 +// derived encoding from the OFP8 spec (E4M3 with finite-only encoding, E5M2 +// IEEE-style with NaN/Inf). +void CodeGenMetal::PrintFP8Prelude(std::ostream& os) { + os << + "// FP8 storage-only emulation helpers (MSL has no native float8 type).\n" + "// See OCP \"OFP8 Formats for Deep Learning\" v1.0 spec.\n" + "inline half __tvm_fp8_e4m3_to_half(uchar x) {\n" + " ushort sign = (ushort)(x & 0x80) << 8;\n" + " ushort mant = (ushort)(x & 0x07);\n" + " ushort exp = (ushort)((x >> 3) & 0x0F);\n" + " ushort h;\n" + " if (exp == 0) {\n" + " if (mant == 0) {\n" + " h = sign; // signed zero\n" + " } else {\n" + " // subnormal: e4m3 value = mant * 2^-9. After shifting the\n" + " // mantissa so the leading 1 hits bit 2 (0x4), the unbiased\n" + " // exponent in half is (e - 9 + 1) = e - 8, giving biased\n" + " // (e - 8 + 15) = e + 7. fp8 bias=7, half bias=15.\n" + " ushort m = mant;\n" + " ushort e = 1;\n" + " while ((m & 0x4) == 0) { m <<= 1; e -= 1; }\n" + " m &= 0x3;\n" + " h = (ushort)(sign | ((ushort)(e + 7) << 10) | (ushort)(m << 8));\n" + " }\n" + " } else if (exp == 0x0F && mant == 0x07) {\n" + " // E4M3 finite-only spec uses S.1111.111 as NaN; map to half NaN.\n" + " h = (ushort)(sign | 0x7E00);\n" + " } else {\n" + " // normal: rebias exp from 7 to 15, shift mantissa from 3 to 10 bits.\n" + " h = (ushort)(sign | ((ushort)(exp + 8) << 10) | (ushort)(mant << 7));\n" + " }\n" + " return as_type(h);\n" + "}\n" + "inline half __tvm_fp8_e5m2_to_half(uchar x) {\n" + " // E5M2 is bit-compatible with half right-shifted by 8 (same exponent\n" + " // bias, just truncated mantissa).\n" + " ushort h = ((ushort)x) << 8;\n" + " return as_type(h);\n" + "}\n" + "inline uchar __tvm_half_to_fp8_e4m3(half v) {\n" + " ushort h = as_type(v);\n" + " ushort sign = (h >> 8) & 0x80;\n" + " short he = (short)((h >> 10) & 0x1F);\n" + " ushort hm = h & 0x3FF;\n" + " if (he == 0x1F) {\n" + " // half NaN/Inf -> E4M3 NaN (S.1111.111).\n" + " return (uchar)(sign | 0x7F);\n" + " }\n" + " // exponent rebias: half bias 15 -> fp8 bias 7\n" + " short e = he - 8;\n" + " if (e >= 0x0F) {\n" + " // saturate to max finite (S.1111.110) since E4M3 has no Inf.\n" + " return (uchar)(sign | 0x7E);\n" + " }\n" + " if (e <= 0) {\n" + " // subnormal / underflow path: shift mantissa with implicit 1\n" + " if (e < -3) return (uchar)sign; // underflow -> signed zero\n" + " ushort m = hm | 0x400; // restore implicit leading 1\n" + " ushort shift = (ushort)(7 + 1 - e);\n" + " // round-to-nearest-even on the discarded bits\n" + " ushort round_bit = (ushort)1 << (shift - 1);\n" + " ushort sticky = m & (round_bit - 1);\n" + " ushort q = m >> shift;\n" + " ushort rem = m & ((round_bit << 1) - 1);\n" + " if (rem > round_bit || (rem == round_bit && (q & 1))) q += 1;\n" + " (void)sticky;\n" + " return (uchar)(sign | (q & 0x7F));\n" + " }\n" + " // normal: rebias exp, shift mantissa 10 -> 3 bits with RNE rounding.\n" + " ushort q = hm >> 7;\n" + " ushort rem = hm & 0x7F;\n" + " if (rem > 0x40 || (rem == 0x40 && (q & 1))) {\n" + " q += 1;\n" + " if (q == 0x08) { q = 0; e += 1; }\n" + " if (e >= 0x0F) return (uchar)(sign | 0x7E);\n" + " }\n" + " return (uchar)(sign | (ushort)(e << 3) | (q & 0x07));\n" + "}\n" + "inline uchar __tvm_half_to_fp8_e5m2(half v) {\n" + " // E5M2 saturating round-to-nearest-even: take top byte of the half\n" + " // bit pattern with mantissa rounding from 10 -> 2 bits.\n" + " ushort h = as_type(v);\n" + " ushort sign = h & 0x8000;\n" + " ushort exp = (h >> 10) & 0x1F;\n" + " ushort mant = h & 0x3FF;\n" + " if (exp == 0x1F) {\n" + " // NaN propagates as quiet NaN (S.11111.10 + nonzero); Inf passes through.\n" + " if (mant != 0) return (uchar)((sign >> 8) | 0x7E);\n" + " return (uchar)((sign >> 8) | 0x7C);\n" + " }\n" + " // RNE on bottom 8 bits of half mantissa.\n" + " ushort q = mant >> 8;\n" + " ushort rem = mant & 0xFF;\n" + " if (rem > 0x80 || (rem == 0x80 && (q & 1))) {\n" + " q += 1;\n" + " if (q == 0x4) { q = 0; exp += 1; }\n" + " if (exp == 0x1F) return (uchar)((sign >> 8) | 0x7C); // overflow -> Inf\n" + " }\n" + " return (uchar)((sign >> 8) | (uchar)(exp << 2) | (uchar)(q & 0x3));\n" + "}\n\n"; +} + void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { // NOTE: There is no inter-function calls among Metal kernels. // For now we keep the metal codegen without inter-function call @@ -267,6 +374,28 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } else if (t.is_bfloat16()) { os << "bfloat"; return; + } else if (t.is_float8()) { + // FP8 is storage-only on Metal: print as `uchar`/`ucharN` and emit explicit + // dequantize/quantize helpers via the FP8 prelude. Caller-side casts must + // route through __tvm_fp8_*_to_half / __tvm_half_to_fp8_*. + enable_fp8_ = true; + if (lanes == 1) { + os << "uchar"; + return; + } + if (lanes >= 2 && lanes <= 4) { + os << "uchar" << lanes; + return; + } + if (lanes == 8) { + // 8 packed FP8 values fit into a uint2 (8 bytes). + os << "uint2"; + return; + } + if (lanes == 16) { + os << "uint4"; + return; + } } LOG(FATAL) << "Cannot convert type " << t << " to Metal type"; } @@ -412,6 +541,82 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT } } +void CodeGenMetal::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*) + DataType from_ty = op->value.dtype(); + DataType target_ty = op->dtype; + // Storage-only FP8 emulation: route casts through the inline helpers from + // the FP8 prelude. Anything else falls back to CodeGenC. + if (target_ty.is_float8() || from_ty.is_float8()) { + enable_fp8_ = true; + ICHECK_EQ(target_ty.lanes(), from_ty.lanes()) + << "FP8 vector cast lanes must match: " << from_ty << " -> " << target_ty; + auto fp8_to_half = [&](DataType ft, std::string val) { + // Choose the helper function name based on the e4m3/e5m2 variant. + const char* helper = ft.code() == DataType::kFloat8_e5m2 ? "__tvm_fp8_e5m2_to_half" + : "__tvm_fp8_e4m3_to_half"; + return std::string(helper) + "(" + val + ")"; + }; + auto half_to_fp8 = [&](DataType tt, std::string val) { + const char* helper = tt.code() == DataType::kFloat8_e5m2 ? "__tvm_half_to_fp8_e5m2" + : "__tvm_half_to_fp8_e4m3"; + return std::string(helper) + "(" + val + ")"; + }; + if (target_ty.lanes() == 1) { + // Scalar path: dequant->target, or src->half->quant. + std::string val = PrintExpr(op->value); + if (from_ty.is_float8() && !target_ty.is_float8()) { + std::string h = fp8_to_half(from_ty, val); + if (target_ty == DataType::Float(16)) { + os << h; + } else { + // Re-cast from half to whatever target the user wanted. + os << "(("; + PrintType(target_ty, os); + os << ")(" << h << "))"; + } + } else if (!from_ty.is_float8() && target_ty.is_float8()) { + std::string h = from_ty == DataType::Float(16) ? val : "((half)(" + val + "))"; + os << half_to_fp8(target_ty, h); + } else { + // FP8 -> FP8 (e4m3 <-> e5m2): go through half. + std::string h = fp8_to_half(from_ty, val); + os << half_to_fp8(target_ty, h); + } + return; + } + // Vector path: not supported by this storage-only patch; defer to scalarised + // emulation by emitting per-lane casts via CodeGenC's lane-by-lane fallback. + // Falling through to CodeGenC will produce raw uchar<->float casts which + // are wrong for FP8 semantics; warn loudly so callers know to scalarise. + LOG(FATAL) << "Vector FP8 casts (lanes=" << target_ty.lanes() + << ") are not yet supported by Metal storage-only FP8 emulation;" + << " scalarise the cast or extend codegen_metal.cc."; + } + CodeGenC::VisitExpr_(op, os); +} + +std::string CodeGenMetal::Finish() { + // Inject FP8 prelude (after the includes) if any FP8 dtype was referenced. + // We splice the helpers between the existing decl_stream contents and the + // function bodies by emitting them through a side stream and concatenating. + std::ostringstream prelude; + if (enable_fp8_) { + PrintFP8Prelude(prelude); + } + std::string base = CodeGenC::Finish(); + if (prelude.str().empty()) return base; + // Find the spot right after `using namespace metal;` to inject the helpers + // so they can use `half`, `uchar`, `as_type` etc. without further qualification. + const std::string anchor = "using namespace metal;\n"; + auto pos = base.find(anchor); + if (pos == std::string::npos) { + // Fallback: prepend (still legal — prelude is self-contained MSL). + return prelude.str() + base; + } + pos += anchor.size(); + return base.substr(0, pos) + "\n" + prelude.str() + base.substr(pos); +} + void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) std::ostringstream temp; if (std::isinf(op->value)) { diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h index 9bc0e15d155f..97bb6071f038 100644 --- a/src/target/source/codegen_metal.h +++ b/src/target/source/codegen_metal.h @@ -55,15 +55,25 @@ class CodeGenMetal final : public CodeGenC { void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) + // Override to inject FP8 prelude (storage-only emulation helpers) when + // any FP8 dtype was referenced. + std::string Finish() final; + // reuse parent's function. using CodeGenC::PrintType; private: + // Emit inline MSL helpers for storage-only FP8 (e4m3 / e5m2) emulation. + void PrintFP8Prelude(std::ostream& os); + std::unordered_map simdgroup_dtype_; int thread_index_bits_{32}; int thread_work_dim_{0}; + // Set when an FP8 dtype is referenced; gates emission of FP8 prelude helpers. + bool enable_fp8_{false}; Target target_; }; } // namespace codegen