@@ -14,21 +14,22 @@ limitations under the License.
1414
1515#include " tsl/platform/numbers.h"
1616
17- #include < ctype.h>
18- #include < float.h>
1917#include < stdio.h>
2018#include < stdlib.h>
2119
2220#include < algorithm>
2321#include < charconv>
2422#include < cmath>
2523#include < cstdint>
26- #include < locale>
24+ #include < limits>
25+ #include < optional>
2726#include < string>
2827#include < system_error> // NOLINT
29- #include < unordered_map >
28+ #include < type_traits >
3029
30+ #include " absl/strings/charconv.h"
3131#include " absl/strings/str_cat.h"
32+ #include " absl/strings/str_format.h"
3233#include " absl/strings/string_view.h"
3334#include " xla/tsl/platform/logging.h"
3435#include " xla/tsl/platform/macros.h"
@@ -40,102 +41,23 @@ namespace tsl {
4041namespace {
4142
4243template <typename T>
43- const std::unordered_map<std::string, T>* GetSpecialNumsSingleton () {
44- static const std::unordered_map<std::string, T>* special_nums =
45- CHECK_NOTNULL ((new const std::unordered_map<std::string, T>{
46- {" inf" , std::numeric_limits<T>::infinity ()},
47- {" +inf" , std::numeric_limits<T>::infinity ()},
48- {" -inf" , -std::numeric_limits<T>::infinity ()},
49- {" infinity" , std::numeric_limits<T>::infinity ()},
50- {" +infinity" , std::numeric_limits<T>::infinity ()},
51- {" -infinity" , -std::numeric_limits<T>::infinity ()},
52- {" nan" , std::numeric_limits<T>::quiet_NaN ()},
53- {" +nan" , std::numeric_limits<T>::quiet_NaN ()},
54- {" -nan" , -std::numeric_limits<T>::quiet_NaN ()},
55- }));
56- return special_nums;
57- }
58-
59- template <typename T>
60- T locale_independent_strtonum (const char * str, const char ** endptr) {
61- auto special_nums = GetSpecialNumsSingleton<T>();
62- std::stringstream s (str);
63-
64- // Check if str is one of the special numbers.
65- std::string special_num_str;
66- s >> special_num_str;
67-
68- for (size_t i = 0 ; i < special_num_str.length (); ++i) {
69- special_num_str[i] =
70- std::tolower (special_num_str[i], std::locale::classic ());
71- }
72-
73- auto entry = special_nums->find (special_num_str);
74- if (entry != special_nums->end ()) {
75- *endptr = str + (s.eof () ? static_cast <std::iostream::pos_type>(strlen (str))
76- : s.tellg ());
77- return entry->second ;
78- } else {
79- // Perhaps it's a hex number
80- if (special_num_str.compare (0 , 2 , " 0x" ) == 0 ||
81- special_num_str.compare (0 , 3 , " -0x" ) == 0 ) {
82- return strtol (str, const_cast <char **>(endptr), 16 );
83- }
84- }
85- // Reset the stream
86- s.str (str);
87- s.clear ();
88- // Use the "C" locale
89- s.imbue (std::locale::classic ());
90-
91- T result;
92- s >> result;
93-
94- // Set to result to what strto{f,d} functions would have returned. If the
95- // number was outside the range, the stringstream sets the fail flag, but
96- // returns the +/-max() value, whereas strto{f,d} functions return +/-INF.
97- if (s.fail ()) {
98- if (result == std::numeric_limits<T>::max () ||
99- result == std::numeric_limits<T>::infinity ()) {
100- result = std::numeric_limits<T>::infinity ();
101- s.clear (s.rdstate () & ~std::ios::failbit);
102- } else if (result == -std::numeric_limits<T>::max () ||
103- result == -std::numeric_limits<T>::infinity ()) {
104- result = -std::numeric_limits<T>::infinity ();
105- s.clear (s.rdstate () & ~std::ios::failbit);
106- }
44+ std::optional<T> AsciiToFp (absl::string_view str) {
45+ T value;
46+ absl::from_chars_result result =
47+ absl::from_chars (str.data (), str.data () + str.size (), value);
48+ if (result.ec != std::errc{}) {
49+ return std::nullopt ;
10750 }
108-
109- if (endptr) {
110- *endptr =
111- str +
112- (s.fail () ? static_cast <std::iostream::pos_type>(0 )
113- : (s.eof () ? static_cast <std::iostream::pos_type>(strlen (str))
114- : s.tellg ()));
51+ if (result.ptr != str.data () + str.size ()) {
52+ // Not all characters consumed.
53+ return std::nullopt ;
11554 }
116- return result ;
55+ return value ;
11756}
11857
119- } // namespace
120-
121- namespace strings {
122-
123- size_t FastInt32ToBufferLeft (int32_t i, char * buffer) {
124- uint32_t u = i;
125- size_t length = 0 ;
126- if (i < 0 ) {
127- *buffer++ = ' -' ;
128- ++length;
129- // We need to do the negation in modular (i.e., "unsigned")
130- // arithmetic; MSVC++ apparently warns for plain "-u", so
131- // we write the equivalent expression "0 - u" instead.
132- u = 0 - u;
133- }
134- length += FastUInt32ToBufferLeft (u, buffer);
135- return length;
136- }
137-
138- size_t FastUInt32ToBufferLeft (uint32_t i, char * buffer) {
58+ template <typename T>
59+ size_t FastUIntToBufferLeft (T i, char * buffer) {
60+ static_assert (std::is_unsigned_v<T>);
13961 char * start = buffer;
14062 do {
14163 *buffer++ = ((i % 10 ) + ' 0' );
@@ -146,103 +68,107 @@ size_t FastUInt32ToBufferLeft(uint32_t i, char* buffer) {
14668 return buffer - start;
14769}
14870
149- size_t FastInt64ToBufferLeft (int64_t i, char * buffer) {
150- uint64_t u = i;
71+ template <typename T>
72+ size_t FastIntToBufferLeft (T i, char * buffer) {
73+ static_assert (std::is_signed_v<T>);
74+ std::make_unsigned_t <T> u = i;
15175 size_t length = 0 ;
15276 if (i < 0 ) {
15377 *buffer++ = ' -' ;
15478 ++length;
79+ // We need to do the negation in modular (i.e., "unsigned")
80+ // arithmetic; MSVC++ apparently warns for plain "-u", so
81+ // we write the equivalent expression "0 - u" instead.
15582 u = 0 - u;
15683 }
157- length += FastUInt64ToBufferLeft (u, buffer);
84+ length += FastUIntToBufferLeft (u, buffer);
15885 return length;
15986}
87+ } // namespace
16088
161- size_t FastUInt64ToBufferLeft (uint64_t i, char * buffer) {
162- char * start = buffer;
163- do {
164- *buffer++ = ((i % 10 ) + ' 0' );
165- i /= 10 ;
166- } while (i > 0 );
167- *buffer = 0 ;
168- std::reverse (start, buffer);
169- return buffer - start;
170- }
171-
172- static const double kDoublePrecisionCheckMax = DBL_MAX / 1.000000000000001 ;
173-
174- size_t DoubleToBuffer (double value, char * buffer) {
175- // DBL_DIG is 15 for IEEE-754 doubles, which are used on almost all
176- // platforms these days. Just in case some system exists where DBL_DIG
177- // is significantly larger -- and risks overflowing our buffer -- we have
178- // this assert.
179- static_assert (DBL_DIG < 20 , " DBL_DIG is too big" );
180-
181- if (std::isnan (value)) {
182- int snprintf_result = snprintf (buffer, kFastToBufferSize , " %snan" ,
183- std::signbit (value) ? " -" : " " );
184- // Paranoid check to ensure we don't overflow the buffer.
185- DCHECK (snprintf_result > 0 && snprintf_result < kFastToBufferSize );
186- return snprintf_result;
187- }
89+ namespace strings {
18890
189- if ( std::abs (value) <= kDoublePrecisionCheckMax ) {
190- int snprintf_result =
191- snprintf (buffer, kFastToBufferSize , " %.*g " , DBL_DIG, value);
91+ size_t FastInt32ToBufferLeft ( int32_t i, char * buffer ) {
92+ return FastIntToBufferLeft (i, buffer);
93+ }
19294
193- // The snprintf should never overflow because the buffer is significantly
194- // larger than the precision we asked for.
195- DCHECK (snprintf_result > 0 && snprintf_result < kFastToBufferSize );
95+ size_t FastUInt32ToBufferLeft ( uint32_t i, char * buffer) {
96+ return FastUIntToBufferLeft (i, buffer);
97+ }
19698
197- if (locale_independent_strtonum<double >(buffer, nullptr ) == value) {
198- // Round-tripping the string to double works; we're done.
199- return snprintf_result;
200- }
201- // else: full precision formatting needed. Fall through.
202- }
99+ size_t FastInt64ToBufferLeft (int64_t i, char * buffer) {
100+ return FastIntToBufferLeft (i, buffer);
101+ }
203102
204- int snprintf_result =
205- snprintf (buffer, kFastToBufferSize , " %.*g" , DBL_DIG + 2 , value);
103+ size_t FastUInt64ToBufferLeft (uint64_t i, char * buffer) {
104+ return FastUIntToBufferLeft (i, buffer);
105+ }
206106
207- // Should never overflow; see above.
208- DCHECK (snprintf_result > 0 && snprintf_result < kFastToBufferSize );
107+ namespace {
209108
210- return snprintf_result;
109+ constexpr int NumDecimalDigits (int n) {
110+ int count = 0 ;
111+ do {
112+ ++count;
113+ n /= 10 ;
114+ } while (n != 0 );
115+ return count;
211116}
212117
213- size_t FloatToBuffer (float value, char * buffer) {
214- // FLT_DIG is 6 for IEEE-754 floats, which are used on almost all
215- // platforms these days. Just in case some system exists where FLT_DIG
216- // is significantly larger -- and risks overflowing our buffer -- we have
217- // this assert.
218- static_assert (FLT_DIG < 10 , " FLT_DIG is too big" );
219-
118+ template <typename T>
119+ size_t FpToBuffer (T value, char * buffer) {
120+ // Out of an abundance of caution, we ensure that the buffer is large enough
121+ // to hold the worst-case formatting of any floating-point number.
122+ constexpr size_t kMaxExponentDigits10 =
123+ std::max (NumDecimalDigits (std::numeric_limits<T>::max_exponent10),
124+ NumDecimalDigits (std::numeric_limits<T>::min_exponent10));
125+ constexpr size_t kMaxCharsWritten =
126+ 1 + // sign bit
127+ std::numeric_limits<T>::max_digits10 + // decimal digits
128+ 1 + // decimal point
129+ 1 + // exponent character
130+ 1 + // exponent sign
131+ kMaxExponentDigits10 ; // exponent digits
132+ static_assert (kMaxCharsWritten < kFastToBufferSize );
220133 if (std::isnan (value)) {
221- int snprintf_result = snprintf (buffer, kFastToBufferSize , " %snan" ,
222- std::signbit (value) ? " -" : " " );
134+ int snprintf_result = absl::SNPrintF (buffer, kFastToBufferSize , " %snan" ,
135+ std::signbit (value) ? " -" : " " );
223136 // Paranoid check to ensure we don't overflow the buffer.
224137 DCHECK (snprintf_result > 0 && snprintf_result < kFastToBufferSize );
225138 return snprintf_result;
226139 }
227140
228- int snprintf_result =
229- snprintf (buffer, kFastToBufferSize , " %.*g " , FLT_DIG , value);
141+ int snprintf_result = absl::SNPrintF (buffer, kFastToBufferSize , " %.*g " ,
142+ std::numeric_limits<T>::digits10 , value);
230143
231144 // The snprintf should never overflow because the buffer is significantly
232145 // larger than the precision we asked for.
233- DCHECK (snprintf_result > 0 && snprintf_result < kFastToBufferSize );
146+ DCHECK (snprintf_result > 0 && snprintf_result <= kMaxCharsWritten );
234147
235- float parsed_value;
236- if (!absl::SimpleAtof (buffer, &parsed_value) || parsed_value != value) {
148+ if (auto parsed_value = AsciiToFp<T>(buffer); parsed_value != value) {
149+ // Round-trip conversion failed, so we need to use full precision
150+ // formatting.
237151 snprintf_result =
238- snprintf (buffer, kFastToBufferSize , " %.*g" , FLT_DIG + 3 , value);
152+ absl::SNPrintF (buffer, kFastToBufferSize , " %.*g" ,
153+ std::numeric_limits<T>::max_digits10, value);
239154
240155 // Should never overflow; see above.
241- DCHECK (snprintf_result > 0 && snprintf_result < kFastToBufferSize );
156+ DCHECK (snprintf_result > 0 && snprintf_result <= kMaxCharsWritten );
242157 }
158+
243159 return snprintf_result;
244160}
245161
162+ } // namespace
163+
164+ size_t DoubleToBuffer (double value, char * buffer) {
165+ return FpToBuffer (value, buffer);
166+ }
167+
168+ size_t FloatToBuffer (float value, char * buffer) {
169+ return FpToBuffer (value, buffer);
170+ }
171+
246172strings_internal::AlphaNumBuffer LegacyPrecision (double d) {
247173 strings_internal::AlphaNumBuffer result;
248174 result.size = DoubleToBuffer (d, result.data .data ());
0 commit comments