diff --git a/src/utils/types.cpp b/src/utils/types.cpp index 4163c21..b5710e4 100644 --- a/src/utils/types.cpp +++ b/src/utils/types.cpp @@ -52,8 +52,21 @@ fp16_t _f32_to_f16(float val) { // Infinity return fp16_t{static_cast(sign | 0x7C00)}; } else if (exponent >= -14) { // Normalized case - return fp16_t{(uint16_t)(sign | ((exponent + 15) << 10) | (mantissa >> 13))}; - } else if (exponent >= -24) { + // --- START MODIFICATION --- + // Add 0x1000 (2^12), which is half of the value of the 13th bit. + // This effectively rounds the 10-bit mantissa to the nearest value. + uint32_t rounded_mantissa = mantissa + 0x1000; + + // Check for overflow in the mantissa after rounding + if (rounded_mantissa & 0x800000) { + // If mantissa overflows, we need to increment the exponent + // and reset mantissa. This is rare but important for correctness. + return fp16_t{(uint16_t)(sign | ((exponent + 15 + 1) << 10))}; + } + + return fp16_t{(uint16_t)(sign | ((exponent + 15) << 10) | (rounded_mantissa >> 13))}; + // --- END MODIFICATION --- +} else if (exponent >= -24) { mantissa |= 0x800000; // Add implicit leading 1 mantissa >>= (-14 - exponent); return fp16_t{(uint16_t)(sign | (mantissa >> 13))};