Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/utils/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,21 @@ fp16_t _f32_to_f16(float val) {
// Infinity
return fp16_t{static_cast<uint16_t>(sign | 0x7C00)};
} else if (exponent >= -14) { // Normalized case
return fp16_t{(uint16_t)(sign | ((exponent + 15) << 10) | (mantissa >> 13))};
} else if (exponent >= -24) {
// --- START MODIFICATION ---
// Add 0x1000 (2^12), which is half of the value of the 13th bit.
// This effectively rounds the 10-bit mantissa to the nearest value.
uint32_t rounded_mantissa = mantissa + 0x1000;

// Check for overflow in the mantissa after rounding
if (rounded_mantissa & 0x800000) {
// If mantissa overflows, we need to increment the exponent
// and reset mantissa. This is rare but important for correctness.
return fp16_t{(uint16_t)(sign | ((exponent + 15 + 1) << 10))};
}

return fp16_t{(uint16_t)(sign | ((exponent + 15) << 10) | (rounded_mantissa >> 13))};
// --- END MODIFICATION ---
} else if (exponent >= -24) {
mantissa |= 0x800000; // Add implicit leading 1
mantissa >>= (-14 - exponent);
return fp16_t{(uint16_t)(sign | (mantissa >> 13))};
Expand Down