Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,113 @@ CodeGenMetal::CodeGenMetal(Target target) : target_(target) {
<< "};\n\n";
}

// Inline MSL helpers for storage-only FP8 emulation (e4m3 / e5m2).
// Apple Silicon (M4 Max and earlier; M5 NAX is FP16/INT8 only) has NO native
// FP8 ALU support, so FP8 is realised as `uchar` storage with explicit
// dequantize-on-load / quantize-on-store. The helpers mirror the IEEE 754
// derived encoding from the OFP8 spec (E4M3 with finite-only encoding, E5M2
// IEEE-style with NaN/Inf).
void CodeGenMetal::PrintFP8Prelude(std::ostream& os) {
os <<
"// FP8 storage-only emulation helpers (MSL has no native float8 type).\n"
"// See OCP \"OFP8 Formats for Deep Learning\" v1.0 spec.\n"
"inline half __tvm_fp8_e4m3_to_half(uchar x) {\n"
" ushort sign = (ushort)(x & 0x80) << 8;\n"
" ushort mant = (ushort)(x & 0x07);\n"
" ushort exp = (ushort)((x >> 3) & 0x0F);\n"
" ushort h;\n"
" if (exp == 0) {\n"
" if (mant == 0) {\n"
" h = sign; // signed zero\n"
" } else {\n"
" // subnormal: e4m3 value = mant * 2^-9. After shifting the\n"
" // mantissa so the leading 1 hits bit 2 (0x4), the unbiased\n"
" // exponent in half is (e - 9 + 1) = e - 8, giving biased\n"
" // (e - 8 + 15) = e + 7. fp8 bias=7, half bias=15.\n"
" ushort m = mant;\n"
" ushort e = 1;\n"
" while ((m & 0x4) == 0) { m <<= 1; e -= 1; }\n"
" m &= 0x3;\n"
" h = (ushort)(sign | ((ushort)(e + 7) << 10) | (ushort)(m << 8));\n"
" }\n"
" } else if (exp == 0x0F && mant == 0x07) {\n"
" // E4M3 finite-only spec uses S.1111.111 as NaN; map to half NaN.\n"
" h = (ushort)(sign | 0x7E00);\n"
" } else {\n"
" // normal: rebias exp from 7 to 15, shift mantissa from 3 to 10 bits.\n"
" h = (ushort)(sign | ((ushort)(exp + 8) << 10) | (ushort)(mant << 7));\n"
" }\n"
" return as_type<half>(h);\n"
"}\n"
"inline half __tvm_fp8_e5m2_to_half(uchar x) {\n"
" // E5M2 is bit-compatible with half right-shifted by 8 (same exponent\n"
" // bias, just truncated mantissa).\n"
" ushort h = ((ushort)x) << 8;\n"
" return as_type<half>(h);\n"
"}\n"
"inline uchar __tvm_half_to_fp8_e4m3(half v) {\n"
" ushort h = as_type<ushort>(v);\n"
" ushort sign = (h >> 8) & 0x80;\n"
" short he = (short)((h >> 10) & 0x1F);\n"
" ushort hm = h & 0x3FF;\n"
" if (he == 0x1F) {\n"
" // half NaN/Inf -> E4M3 NaN (S.1111.111).\n"
" return (uchar)(sign | 0x7F);\n"
" }\n"
" // exponent rebias: half bias 15 -> fp8 bias 7\n"
" short e = he - 8;\n"
" if (e >= 0x0F) {\n"
" // saturate to max finite (S.1111.110) since E4M3 has no Inf.\n"
" return (uchar)(sign | 0x7E);\n"
" }\n"
" if (e <= 0) {\n"
" // subnormal / underflow path: shift mantissa with implicit 1\n"
" if (e < -3) return (uchar)sign; // underflow -> signed zero\n"
" ushort m = hm | 0x400; // restore implicit leading 1\n"
" ushort shift = (ushort)(7 + 1 - e);\n"
" // round-to-nearest-even on the discarded bits\n"
" ushort round_bit = (ushort)1 << (shift - 1);\n"
" ushort sticky = m & (round_bit - 1);\n"
" ushort q = m >> shift;\n"
" ushort rem = m & ((round_bit << 1) - 1);\n"
" if (rem > round_bit || (rem == round_bit && (q & 1))) q += 1;\n"
" (void)sticky;\n"
" return (uchar)(sign | (q & 0x7F));\n"
" }\n"
" // normal: rebias exp, shift mantissa 10 -> 3 bits with RNE rounding.\n"
" ushort q = hm >> 7;\n"
" ushort rem = hm & 0x7F;\n"
" if (rem > 0x40 || (rem == 0x40 && (q & 1))) {\n"
" q += 1;\n"
" if (q == 0x08) { q = 0; e += 1; }\n"
" if (e >= 0x0F) return (uchar)(sign | 0x7E);\n"
" }\n"
" return (uchar)(sign | (ushort)(e << 3) | (q & 0x07));\n"
"}\n"
"inline uchar __tvm_half_to_fp8_e5m2(half v) {\n"
" // E5M2 saturating round-to-nearest-even: take top byte of the half\n"
" // bit pattern with mantissa rounding from 10 -> 2 bits.\n"
" ushort h = as_type<ushort>(v);\n"
" ushort sign = h & 0x8000;\n"
" ushort exp = (h >> 10) & 0x1F;\n"
" ushort mant = h & 0x3FF;\n"
" if (exp == 0x1F) {\n"
" // NaN propagates as quiet NaN (S.11111.10 + nonzero); Inf passes through.\n"
" if (mant != 0) return (uchar)((sign >> 8) | 0x7E);\n"
" return (uchar)((sign >> 8) | 0x7C);\n"
" }\n"
" // RNE on bottom 8 bits of half mantissa.\n"
" ushort q = mant >> 8;\n"
" ushort rem = mant & 0xFF;\n"
" if (rem > 0x80 || (rem == 0x80 && (q & 1))) {\n"
" q += 1;\n"
" if (q == 0x4) { q = 0; exp += 1; }\n"
" if (exp == 0x1F) return (uchar)((sign >> 8) | 0x7C); // overflow -> Inf\n"
" }\n"
" return (uchar)((sign >> 8) | (uchar)(exp << 2) | (uchar)(q & 0x3));\n"
"}\n\n";
}

void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
// NOTE: There is no inter-function calls among Metal kernels.
// For now we keep the metal codegen without inter-function call
Expand Down Expand Up @@ -267,6 +374,28 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
} else if (t.is_bfloat16()) {
os << "bfloat";
return;
} else if (t.is_float8()) {
// FP8 is storage-only on Metal: print as `uchar`/`ucharN` and emit explicit
// dequantize/quantize helpers via the FP8 prelude. Caller-side casts must
// route through __tvm_fp8_*_to_half / __tvm_half_to_fp8_*.
enable_fp8_ = true;
if (lanes == 1) {
os << "uchar";
return;
}
if (lanes >= 2 && lanes <= 4) {
os << "uchar" << lanes;
return;
}
if (lanes == 8) {
// 8 packed FP8 values fit into a uint2 (8 bytes).
os << "uint2";
return;
}
if (lanes == 16) {
os << "uint4";
return;
}
}
LOG(FATAL) << "Cannot convert type " << t << " to Metal type";
}
Expand Down Expand Up @@ -412,6 +541,82 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT
}
}

void CodeGenMetal::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*)
DataType from_ty = op->value.dtype();
DataType target_ty = op->dtype;
// Storage-only FP8 emulation: route casts through the inline helpers from
// the FP8 prelude. Anything else falls back to CodeGenC.
if (target_ty.is_float8() || from_ty.is_float8()) {
enable_fp8_ = true;
ICHECK_EQ(target_ty.lanes(), from_ty.lanes())
<< "FP8 vector cast lanes must match: " << from_ty << " -> " << target_ty;
auto fp8_to_half = [&](DataType ft, std::string val) {
// Choose the helper function name based on the e4m3/e5m2 variant.
const char* helper = ft.code() == DataType::kFloat8_e5m2 ? "__tvm_fp8_e5m2_to_half"
: "__tvm_fp8_e4m3_to_half";
return std::string(helper) + "(" + val + ")";
};
auto half_to_fp8 = [&](DataType tt, std::string val) {
const char* helper = tt.code() == DataType::kFloat8_e5m2 ? "__tvm_half_to_fp8_e5m2"
: "__tvm_half_to_fp8_e4m3";
return std::string(helper) + "(" + val + ")";
};
if (target_ty.lanes() == 1) {
// Scalar path: dequant->target, or src->half->quant.
std::string val = PrintExpr(op->value);
if (from_ty.is_float8() && !target_ty.is_float8()) {
std::string h = fp8_to_half(from_ty, val);
if (target_ty == DataType::Float(16)) {
os << h;
} else {
// Re-cast from half to whatever target the user wanted.
os << "((";
PrintType(target_ty, os);
os << ")(" << h << "))";
}
} else if (!from_ty.is_float8() && target_ty.is_float8()) {
std::string h = from_ty == DataType::Float(16) ? val : "((half)(" + val + "))";
os << half_to_fp8(target_ty, h);
} else {
// FP8 -> FP8 (e4m3 <-> e5m2): go through half.
std::string h = fp8_to_half(from_ty, val);
os << half_to_fp8(target_ty, h);
}
return;
}
// Vector path: not supported by this storage-only patch; defer to scalarised
// emulation by emitting per-lane casts via CodeGenC's lane-by-lane fallback.
// Falling through to CodeGenC will produce raw uchar<->float casts which
// are wrong for FP8 semantics; warn loudly so callers know to scalarise.
LOG(FATAL) << "Vector FP8 casts (lanes=" << target_ty.lanes()
<< ") are not yet supported by Metal storage-only FP8 emulation;"
<< " scalarise the cast or extend codegen_metal.cc.";
}
CodeGenC::VisitExpr_(op, os);
}

std::string CodeGenMetal::Finish() {
// Inject FP8 prelude (after the includes) if any FP8 dtype was referenced.
// We splice the helpers between the existing decl_stream contents and the
// function bodies by emitting them through a side stream and concatenating.
std::ostringstream prelude;
if (enable_fp8_) {
PrintFP8Prelude(prelude);
}
std::string base = CodeGenC::Finish();
if (prelude.str().empty()) return base;
// Find the spot right after `using namespace metal;` to inject the helpers
// so they can use `half`, `uchar`, `as_type` etc. without further qualification.
const std::string anchor = "using namespace metal;\n";
auto pos = base.find(anchor);
if (pos == std::string::npos) {
// Fallback: prepend (still legal — prelude is self-contained MSL).
return prelude.str() + base;
}
pos += anchor.size();
return base.substr(0, pos) + "\n" + prelude.str() + base.substr(pos);
}

void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
std::ostringstream temp;
if (std::isinf(op->value)) {
Expand Down
10 changes: 10 additions & 0 deletions src/target/source/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,25 @@ class CodeGenMetal final : public CodeGenC {
void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)

// Override to inject FP8 prelude (storage-only emulation helpers) when
// any FP8 dtype was referenced.
std::string Finish() final;

// reuse parent's function.
using CodeGenC::PrintType;

private:
// Emit inline MSL helpers for storage-only FP8 (e4m3 / e5m2) emulation.
void PrintFP8Prelude(std::ostream& os);

std::unordered_map<const VarNode*, std::string> simdgroup_dtype_;
int thread_index_bits_{32};
int thread_work_dim_{0};
// Set when an FP8 dtype is referenced; gates emission of FP8 prelude helpers.
bool enable_fp8_{false};
Target target_;
};
} // namespace codegen
Expand Down