Skip to content

Commit

Permalink
[*] improve log message for storage view content
Browse files Browse the repository at this point in the history
  • Loading branch information
tirivo committed Jun 4, 2024
1 parent 5eb5d5a commit b3c01d5
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 35 deletions.
32 changes: 32 additions & 0 deletions include/ctranslate2/storage_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "types.h"
#include "utils.h"

#define PRINT_MAX_VALUES 6

namespace ctranslate2 {

#define ASSERT_DTYPE(DTYPE) \
Expand Down Expand Up @@ -238,6 +240,36 @@ namespace ctranslate2 {

friend std::ostream& operator<<(std::ostream& os, const StorageView& storage);

template <typename T>
void print_tensor(std::ostream& os, const T* data, const std::vector<dim_t>& shape, size_t dim, size_t offset, int indent) const {
std::string indentation(indent, ' ');
if (dim == shape.size() - 1) {
os << indentation << "[";
for (dim_t i = 0; i < shape[dim]; ++i) {
if (i > 0) os << ", ";
if (i < PRINT_MAX_VALUES / 2 || i >= shape[dim] - PRINT_MAX_VALUES / 2) {
os << data[offset + i];
} else if (i == PRINT_MAX_VALUES / 2) {
os << "...";
i = shape[dim] - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part
}
}
os << "]";
} else {
os << indentation << "[\n";
for (dim_t i = 0; i < shape[dim]; ++i) {
if (i > 0) os << ",\n";
if (i < PRINT_MAX_VALUES / 2 || i >= shape[dim] - PRINT_MAX_VALUES / 2) {
print_tensor(os, data, shape, dim + 1, offset + i * shape[dim + 1], indent + 2);
} else if (i == PRINT_MAX_VALUES / 2) {
os << indentation << " ...";
i = shape[dim] - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part
}
}
os << "\n" << indentation << "]";
}
}

protected:
DataType _dtype = DataType::FLOAT32;
Device _device = Device::CPU;
Expand Down
104 changes: 69 additions & 35 deletions src/storage_view.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include "dispatch.h"

#define PRINT_MAX_VALUES 6

namespace ctranslate2 {

Expand Down Expand Up @@ -440,44 +439,79 @@ namespace ctranslate2 {
return os;
}

std::ostream& operator<<(std::ostream& os, const StorageView& storage) {
StorageView printable(storage.dtype());
printable.copy_from(storage);
TYPE_DISPATCH(
printable.dtype(),
const auto* values = printable.data<T>();
if (printable.size() <= PRINT_MAX_VALUES) {
for (dim_t i = 0; i < printable.size(); ++i) {
os << ' ';
print_value(os, values[i]);
}
}
std::ostream& operator<<(std::ostream& os, const StorageView& storage) {
// Create a printable copy of the storage
StorageView printable(storage.dtype());
printable.copy_from(storage);

// Check the data type and print accordingly
TYPE_DISPATCH(
printable.dtype(),
const auto* values = printable.data<T>();
const auto& shape = printable.shape();

// Print tensor contents based on dimensionality
if (shape.empty()) { // Scalar case
os << "Data (Scalar): " << values[0] << std::endl;
} else if (shape.size() == 1) { // Vector case
os << "Data (1D Vector):" << std::endl;
os << "[";
for (dim_t i = 0; i < printable.size(); ++i) {
if (i > 0) os << ", ";
if (i < PRINT_MAX_VALUES / 2 || i >= printable.size() - PRINT_MAX_VALUES / 2) {
os << values[i];
} else if (i == PRINT_MAX_VALUES / 2) {
os << "...";
i = printable.size() - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part
}
}
os << "]\n";
} else if (shape.size() == 2) { // 2D Matrix case
os << "Data (2D Matrix):" << std::endl;
os << "[\n";
for (dim_t i = 0; i < shape[0]; ++i) {
if (i > 0) os << ",\n";
if (i < PRINT_MAX_VALUES / 2 || i >= shape[0] - PRINT_MAX_VALUES / 2) {
os << " [";
for (dim_t j = 0; j < shape[1]; ++j) {
if (j > 0) os << ", ";
if (j < PRINT_MAX_VALUES / 2 || j >= shape[1] - PRINT_MAX_VALUES / 2) {
os << values[i * shape[1] + j];
} else if (j == PRINT_MAX_VALUES / 2) {
os << "...";
j = shape[1] - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part
}
}
os << "]";
} else if (i == PRINT_MAX_VALUES / 2) {
os << " ...";
i = shape[0] - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part
}
}
os << "\n]\n";
} else { // Higher-dimensional tensors
os << "Data (" << shape.size() << "D Tensor):" << std::endl;
storage.print_tensor(os, values, shape, 0, 0, 0);
os << std::endl;
}
);

// Print metadata
os << "[device:" << device_to_str(storage.device(), storage.device_index())
<< ", dtype:" << dtype_name(storage.dtype()) << ", storage viewed as ";
if (storage.is_scalar())
os << "scalar";
else {
for (dim_t i = 0; i < PRINT_MAX_VALUES / 2; ++i) {
os << ' ';
print_value(os, values[i]);
for (dim_t i = 0; i < storage.rank(); ++i) {
if (i > 0)
os << 'x';
os << storage.dim(i);
}
os << " ...";
for (dim_t i = printable.size() - (PRINT_MAX_VALUES / 2); i < printable.size(); ++i) {
os << ' ';
print_value(os, values[i]);
}
}
os << std::endl);
os << "[" << device_to_str(storage.device(), storage.device_index())
<< " " << dtype_name(storage.dtype()) << " storage viewed as ";
if (storage.is_scalar())
os << "scalar";
else {
for (dim_t i = 0; i < storage.rank(); ++i) {
if (i > 0)
os << 'x';
os << storage.dim(i);
}
os << ']';
return os;
}
os << ']';
return os;
}


#define DECLARE_IMPL(T) \
template \
Expand Down

0 comments on commit b3c01d5

Please sign in to comment.