Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion third_party/tsl/tsl/platform/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,11 @@ cc_library(
srcs = ["numbers.cc"],
hdrs = ["numbers.h"],
deps = [
":str_util",
":stringpiece",
":stringprintf",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@xla//xla/tsl/platform:logging",
"@xla//xla/tsl/platform:macros",
"@xla//xla/tsl/platform:types",
Expand Down Expand Up @@ -1036,6 +1036,7 @@ cc_library(
deps = [
":str_util",
":stringpiece",
"@com_google_absl//absl/strings",
"@xla//xla/tsl/platform:macros",
],
)
Expand Down
264 changes: 93 additions & 171 deletions third_party/tsl/tsl/platform/numbers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,22 @@ limitations under the License.

#include "tsl/platform/numbers.h"

#include <ctype.h>
#include <float.h>
#include <stdio.h>
#include <stdlib.h>

#include <algorithm>
#include <charconv>
#include <cmath>
#include <cstdint>
#include <locale>
#include <limits>
#include <optional>
#include <string>
#include <system_error> // NOLINT
#include <unordered_map>
#include <type_traits>

#include "absl/strings/charconv.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/macros.h"
Expand All @@ -40,102 +41,23 @@ namespace tsl {
namespace {

template <typename T>
const std::unordered_map<std::string, T>* GetSpecialNumsSingleton() {
static const std::unordered_map<std::string, T>* special_nums =
CHECK_NOTNULL((new const std::unordered_map<std::string, T>{
{"inf", std::numeric_limits<T>::infinity()},
{"+inf", std::numeric_limits<T>::infinity()},
{"-inf", -std::numeric_limits<T>::infinity()},
{"infinity", std::numeric_limits<T>::infinity()},
{"+infinity", std::numeric_limits<T>::infinity()},
{"-infinity", -std::numeric_limits<T>::infinity()},
{"nan", std::numeric_limits<T>::quiet_NaN()},
{"+nan", std::numeric_limits<T>::quiet_NaN()},
{"-nan", -std::numeric_limits<T>::quiet_NaN()},
}));
return special_nums;
}

template <typename T>
T locale_independent_strtonum(const char* str, const char** endptr) {
auto special_nums = GetSpecialNumsSingleton<T>();
std::stringstream s(str);

// Check if str is one of the special numbers.
std::string special_num_str;
s >> special_num_str;

for (size_t i = 0; i < special_num_str.length(); ++i) {
special_num_str[i] =
std::tolower(special_num_str[i], std::locale::classic());
}

auto entry = special_nums->find(special_num_str);
if (entry != special_nums->end()) {
*endptr = str + (s.eof() ? static_cast<std::iostream::pos_type>(strlen(str))
: s.tellg());
return entry->second;
} else {
// Perhaps it's a hex number
if (special_num_str.compare(0, 2, "0x") == 0 ||
special_num_str.compare(0, 3, "-0x") == 0) {
return strtol(str, const_cast<char**>(endptr), 16);
}
}
// Reset the stream
s.str(str);
s.clear();
// Use the "C" locale
s.imbue(std::locale::classic());

T result;
s >> result;

// Set to result to what strto{f,d} functions would have returned. If the
// number was outside the range, the stringstream sets the fail flag, but
// returns the +/-max() value, whereas strto{f,d} functions return +/-INF.
if (s.fail()) {
if (result == std::numeric_limits<T>::max() ||
result == std::numeric_limits<T>::infinity()) {
result = std::numeric_limits<T>::infinity();
s.clear(s.rdstate() & ~std::ios::failbit);
} else if (result == -std::numeric_limits<T>::max() ||
result == -std::numeric_limits<T>::infinity()) {
result = -std::numeric_limits<T>::infinity();
s.clear(s.rdstate() & ~std::ios::failbit);
}
std::optional<T> AsciiToFp(absl::string_view str) {
T value;
absl::from_chars_result result =
absl::from_chars(str.data(), str.data() + str.size(), value);
if (result.ec != std::errc{}) {
return std::nullopt;
}

if (endptr) {
*endptr =
str +
(s.fail() ? static_cast<std::iostream::pos_type>(0)
: (s.eof() ? static_cast<std::iostream::pos_type>(strlen(str))
: s.tellg()));
if (result.ptr != str.data() + str.size()) {
// Not all characters consumed.
return std::nullopt;
}
return result;
return value;
}

} // namespace

namespace strings {

size_t FastInt32ToBufferLeft(int32_t i, char* buffer) {
uint32_t u = i;
size_t length = 0;
if (i < 0) {
*buffer++ = '-';
++length;
// We need to do the negation in modular (i.e., "unsigned")
// arithmetic; MSVC++ apparently warns for plain "-u", so
// we write the equivalent expression "0 - u" instead.
u = 0 - u;
}
length += FastUInt32ToBufferLeft(u, buffer);
return length;
}

size_t FastUInt32ToBufferLeft(uint32_t i, char* buffer) {
template <typename T>
size_t FastUIntToBufferLeft(T i, char* buffer) {
static_assert(std::is_unsigned_v<T>);
char* start = buffer;
do {
*buffer++ = ((i % 10) + '0');
Expand All @@ -146,103 +68,107 @@ size_t FastUInt32ToBufferLeft(uint32_t i, char* buffer) {
return buffer - start;
}

size_t FastInt64ToBufferLeft(int64_t i, char* buffer) {
uint64_t u = i;
template <typename T>
size_t FastIntToBufferLeft(T i, char* buffer) {
static_assert(std::is_signed_v<T>);
std::make_unsigned_t<T> u = i;
size_t length = 0;
if (i < 0) {
*buffer++ = '-';
++length;
// We need to do the negation in modular (i.e., "unsigned")
// arithmetic; MSVC++ apparently warns for plain "-u", so
// we write the equivalent expression "0 - u" instead.
u = 0 - u;
}
length += FastUInt64ToBufferLeft(u, buffer);
length += FastUIntToBufferLeft(u, buffer);
return length;
}
} // namespace

size_t FastUInt64ToBufferLeft(uint64_t i, char* buffer) {
char* start = buffer;
do {
*buffer++ = ((i % 10) + '0');
i /= 10;
} while (i > 0);
*buffer = 0;
std::reverse(start, buffer);
return buffer - start;
}

static const double kDoublePrecisionCheckMax = DBL_MAX / 1.000000000000001;

size_t DoubleToBuffer(double value, char* buffer) {
// DBL_DIG is 15 for IEEE-754 doubles, which are used on almost all
// platforms these days. Just in case some system exists where DBL_DIG
// is significantly larger -- and risks overflowing our buffer -- we have
// this assert.
static_assert(DBL_DIG < 20, "DBL_DIG is too big");

if (std::isnan(value)) {
int snprintf_result = snprintf(buffer, kFastToBufferSize, "%snan",
std::signbit(value) ? "-" : "");
// Paranoid check to ensure we don't overflow the buffer.
DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
return snprintf_result;
}
namespace strings {

if (std::abs(value) <= kDoublePrecisionCheckMax) {
int snprintf_result =
snprintf(buffer, kFastToBufferSize, "%.*g", DBL_DIG, value);
size_t FastInt32ToBufferLeft(int32_t i, char* buffer) {
return FastIntToBufferLeft(i, buffer);
}

// The snprintf should never overflow because the buffer is significantly
// larger than the precision we asked for.
DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
size_t FastUInt32ToBufferLeft(uint32_t i, char* buffer) {
return FastUIntToBufferLeft(i, buffer);
}

if (locale_independent_strtonum<double>(buffer, nullptr) == value) {
// Round-tripping the string to double works; we're done.
return snprintf_result;
}
// else: full precision formatting needed. Fall through.
}
size_t FastInt64ToBufferLeft(int64_t i, char* buffer) {
return FastIntToBufferLeft(i, buffer);
}

int snprintf_result =
snprintf(buffer, kFastToBufferSize, "%.*g", DBL_DIG + 2, value);
size_t FastUInt64ToBufferLeft(uint64_t i, char* buffer) {
return FastUIntToBufferLeft(i, buffer);
}

// Should never overflow; see above.
DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
namespace {

return snprintf_result;
constexpr int NumDecimalDigits(int n) {
int count = 0;
do {
++count;
n /= 10;
} while (n != 0);
return count;
}

size_t FloatToBuffer(float value, char* buffer) {
// FLT_DIG is 6 for IEEE-754 floats, which are used on almost all
// platforms these days. Just in case some system exists where FLT_DIG
// is significantly larger -- and risks overflowing our buffer -- we have
// this assert.
static_assert(FLT_DIG < 10, "FLT_DIG is too big");

template <typename T>
size_t FpToBuffer(T value, char* buffer) {
// Out of an abundance of caution, we ensure that the buffer is large enough
// to hold the worst-case formatting of any floating-point number.
constexpr size_t kMaxExponentDigits10 =
std::max(NumDecimalDigits(std::numeric_limits<T>::max_exponent10),
NumDecimalDigits(std::numeric_limits<T>::min_exponent10));
constexpr size_t kMaxCharsWritten =
1 + // sign bit
std::numeric_limits<T>::max_digits10 + // decimal digits
1 + // decimal point
1 + // exponent character
1 + // exponent sign
kMaxExponentDigits10; // exponent digits
static_assert(kMaxCharsWritten < kFastToBufferSize);
if (std::isnan(value)) {
int snprintf_result = snprintf(buffer, kFastToBufferSize, "%snan",
std::signbit(value) ? "-" : "");
int snprintf_result = absl::SNPrintF(buffer, kFastToBufferSize, "%snan",
std::signbit(value) ? "-" : "");
// Paranoid check to ensure we don't overflow the buffer.
DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
return snprintf_result;
}

int snprintf_result =
snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG, value);
int snprintf_result = absl::SNPrintF(buffer, kFastToBufferSize, "%.*g",
std::numeric_limits<T>::digits10, value);

// The snprintf should never overflow because the buffer is significantly
// larger than the precision we asked for.
DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
DCHECK(snprintf_result > 0 && snprintf_result <= kMaxCharsWritten);

float parsed_value;
if (!absl::SimpleAtof(buffer, &parsed_value) || parsed_value != value) {
if (auto parsed_value = AsciiToFp<T>(buffer); parsed_value != value) {
// Round-trip conversion failed, so we need to use full precision
// formatting.
snprintf_result =
snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG + 3, value);
absl::SNPrintF(buffer, kFastToBufferSize, "%.*g",
std::numeric_limits<T>::max_digits10, value);

// Should never overflow; see above.
DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
DCHECK(snprintf_result > 0 && snprintf_result <= kMaxCharsWritten);
}

return snprintf_result;
}

} // namespace

size_t DoubleToBuffer(double value, char* buffer) {
return FpToBuffer(value, buffer);
}

size_t FloatToBuffer(float value, char* buffer) {
return FpToBuffer(value, buffer);
}

strings_internal::AlphaNumBuffer LegacyPrecision(double d) {
strings_internal::AlphaNumBuffer result;
result.size = DoubleToBuffer(d, result.data.data());
Expand Down Expand Up @@ -304,33 +230,29 @@ std::string HumanReadableNumBytes(int64_t num_bytes) {
return "-8E";
}

const char* neg_str = (num_bytes < 0) ? "-" : "";
static absl::string_view kNegSign = "-";
absl::string_view neg_str = (num_bytes < 0) ? kNegSign : "";
if (num_bytes < 0) {
num_bytes = -num_bytes;
}

// Special case for bytes.
if (num_bytes < 1024) {
// No fractions for bytes.
char buf[8]; // Longest possible string is '-XXXXB'
snprintf(buf, sizeof(buf), "%s%lldB", neg_str,
static_cast<long long>(num_bytes));
return std::string(buf);
return absl::StrCat(neg_str, num_bytes, "B");
}

static const char units[] = "KMGTPE"; // int64 only goes up to E.
const char* unit = units;
static absl::string_view kUnits = "KMGTPE"; // int64 only goes up to E.
auto unit = kUnits.begin();
while (num_bytes >= static_cast<int64_t>(1024) * 1024) {
num_bytes /= 1024;
++unit;
CHECK(unit < units + TF_ARRAYSIZE(units));
CHECK(unit < kUnits.end());
}

// We use SI prefixes.
char buf[16];
snprintf(buf, sizeof(buf), ((*unit == 'K') ? "%s%.1f%ciB" : "%s%.2f%ciB"),
neg_str, num_bytes / 1024.0, *unit);
return std::string(buf);
return absl::StrFormat("%s%.*f%ciB", neg_str, *unit == 'K' ? 1 : 2,
num_bytes / 1024.0, *unit);
}

std::string HumanReadableElapsedTime(double seconds) {
Expand Down
Loading
Loading