diff --git a/src/TiledArray/tensor/print.h b/src/TiledArray/tensor/print.h index a2497bff15..18d09be7a7 100644 --- a/src/TiledArray/tensor/print.h +++ b/src/TiledArray/tensor/print.h @@ -40,12 +40,45 @@ namespace detail { class NDArrayPrinter { public: NDArrayPrinter(int width = 10, int precision = 6) - : width(width), precision(precision) {} + : width(width), + precision(precision), + truncate_(0.5 * std::pow(10., -precision)) {} private: int width = 10; int precision = 10; + /// truncates (=sets to zero) small floating-point numbers + class FloatTruncate { + public: + /// truncates numbers smaller than @p threshold + FloatTruncate(double threshold) noexcept : threshold_{threshold} {} + + [[nodiscard]] auto operator()(std::floating_point auto val) const noexcept { + return std::abs(val) < threshold_ ? decltype(val){0} : val; + } + + template + requires detail::is_complex_v && + std::floating_point + [[nodiscard]] auto operator()(T const& val) const noexcept { + using std::imag; + using std::real; + return T{(*this)(real(val)), (*this)(imag(val))}; + } + + template + requires(!(std::floating_point || detail::is_complex_v)) + [[nodiscard]] auto operator()(T const& val) const noexcept { + return val; + } + + private: + double threshold_; + }; + + FloatTruncate truncate_; + // Helper function to recursively print the array template > diff --git a/src/TiledArray/tensor/print.ipp b/src/TiledArray/tensor/print.ipp index 8634418138..c3535409b3 100644 --- a/src/TiledArray/tensor/print.ipp +++ b/src/TiledArray/tensor/print.ipp @@ -50,9 +50,9 @@ void NDArrayPrinter::printArray(const T* data, const std::size_t order, for (size_t i = 0; i < extents[level]; ++i) { if (level == order - 1) { + auto value = truncate_(data[offset + i * strides[level]]); // At the deepest level, print the actual values - os << std::fixed << std::setprecision(precision) << std::setw(width) << std::setfill(Char(' ')) - << data[offset + i * strides[level]]; + os << std::fixed << std::setprecision(precision) << std::setw(width) << std::setfill(Char(' ')) << value; if (i < extents[level] - 1) { os << ", "; }