Skip to content
Open
392 changes: 316 additions & 76 deletions src/op/reduce.cc

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion src/op/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ class ReduceOpNode : public TileOperatorNode {
/// Generate initial value for reduction
PrimExpr MakeInitValue() const;
/// Generate reduction expression
PrimExpr MakeReduce(const PrimExpr &acc, const PrimExpr &b) const;
/// pack_lanes = 1 for scalar, 2 for add2/max2/min2, etc.
PrimExpr MakeReduce(const PrimExpr &acc, const PrimExpr &b,
int pack_lanes = 1) const;
/// Generate codegen reducer string
std::string MakeCodegenReducer() const;
};
Expand Down
89 changes: 89 additions & 0 deletions src/runtime/logging.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
#include <ctime>
#include <iomanip>
#include <iostream>
#include <sstream>
#include <string>
#include <unordered_map>

namespace tvm {
namespace runtime {
Expand All @@ -17,6 +19,24 @@ const char *level_strings[] = {
": Error: ", // TVM_LOG_LEVEL_ERROR = 3
": Fatal: ", // TVM_LOG_LEVEL_FATAL = 4
};

constexpr const char *kSrcPrefix = "/src/";
constexpr const size_t kSrcPrefixLength = 5;
constexpr const char *kDefaultKeyword = "DEFAULT";

std::string FileToVLogMapKey(const std::string &filename) {
size_t last_src =
filename.rfind(kSrcPrefix, std::string::npos, kSrcPrefixLength);
if (last_src == std::string::npos) {
std::string no_slash_src{kSrcPrefix + 1};
if (filename.substr(0, no_slash_src.size()) == no_slash_src) {
return filename.substr(no_slash_src.size());
}
}
return (last_src == std::string::npos)
? filename
: filename.substr(last_src + kSrcPrefixLength);
}
} // namespace

void LogMessageImpl(const std::string &file, int lineno, int level,
Expand All @@ -39,6 +59,75 @@ void LogMessageImpl(const std::string &file, int lineno, int level,
throw InternalError(file, lineno, message);
}

TvmLogDebugSettings TvmLogDebugSettings::ParseSpec(const char *opt_spec) {
TvmLogDebugSettings settings;
if (opt_spec == nullptr) {
return settings;
}
std::string spec(opt_spec);
if (spec.empty() || spec == "0") {
return settings;
}
settings.dlog_enabled_ = true;
if (spec == "1") {
return settings;
}
std::istringstream spec_stream(spec);
auto tell_pos = [&](const std::string &last_read) {
int pos = spec_stream.tellg();
if (pos == -1) {
pos = spec.size() - last_read.size();
}
return pos;
};
while (spec_stream) {
std::string name;
if (!std::getline(spec_stream, name, '=')) {
break;
}
if (name.empty()) {
LOG(FATAL) << "TVM_LOG_DEBUG ill-formed at position " << tell_pos(name)
<< ": empty filename";
}
name = FileToVLogMapKey(name);
std::string level;
if (!std::getline(spec_stream, level, ',')) {
LOG(FATAL) << "TVM_LOG_DEBUG ill-formed at position " << tell_pos(level)
<< ": expecting \"=<level>\" after \"" << name << "\"";
return settings;
}
if (level.empty()) {
LOG(FATAL) << "TVM_LOG_DEBUG ill-formed at position " << tell_pos(level)
<< ": empty level after \"" << name << "\"";
return settings;
}
char *end_of_level = nullptr;
int level_val = static_cast<int>(strtol(level.c_str(), &end_of_level, 10));
if (end_of_level != level.c_str() + level.size()) {
Comment on lines +120 to +122
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

cat -n src/runtime/logging.cc | sed -n '90,120p'

Repository: tile-ai/tilelang

Length of output: 1541


🏁 Script executed:

rg -A 5 -B 5 "vlog_level_map_" src/runtime/logging.cc

Repository: tile-ai/tilelang

Length of output: 761


🏁 Script executed:

rg "VLOG\|vlog.*level" src/runtime/logging.cc | head -20

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check the broader function to understand the scope
sed -n '60,116p' src/runtime/logging.cc | cat -n

Repository: tile-ai/tilelang

Length of output: 2293


Add range validation before narrowing strtol result to int.

On lines 105–106, strtol parses the level string as a long and directly casts to int without checking whether the value is in range. If a user provides a value outside [INT_MIN, INT_MAX], the cast silently wraps to an incorrect VLOG level instead of rejecting the input as malformed.

Suggested fix
+    errno = 0;
+    long parsed_level = std::strtol(level.c_str(), &end_of_level, 10);
+    if (errno == ERANGE || parsed_level < std::numeric_limits<int>::min() ||
+        parsed_level > std::numeric_limits<int>::max()) {
+      LOG(FATAL) << "TVM_LOG_DEBUG ill-formed at position " << tell_pos(level)
+                 << ": level out of range: \"" << level << "\"";
+    }
-    int level_val = static_cast<int>(strtol(level.c_str(), &end_of_level, 10));
+    int level_val = static_cast<int>(parsed_level);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/runtime/logging.cc` around lines 104 - 106, The code currently calls
strtol(level.c_str(), &end_of_level, 10) and narrows the result into int
level_val without range checking; change the logic in the parsing branch around
end_of_level/level_val to parse into a long (keep strtol), then check that
end_of_level points to the end of the string, errno != ERANGE, and that the
returned long is between INT_MIN and INT_MAX before static_cast<int>ing to
level_val; if any check fails, treat the input as malformed (reject/log error)
instead of silently casting to an incorrect VLOG level.

LOG(FATAL) << "TVM_LOG_DEBUG ill-formed at position " << tell_pos(level)
<< ": invalid level: \"" << level << "\"";
return settings;
}
LOG(INFO) << "TVM_LOG_DEBUG enables VLOG statements in '" << name
<< "' up to level " << level;
Comment on lines +128 to +129
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid unconditional INFO output while parsing TVM_LOG_DEBUG.

These lines print to stderr for every valid spec entry, even when the caller only wanted to configure VLOG gating. That makes the flag itself user-visible and can pollute tests or tooling that assert on stderr. Prefer VLOG/DLOG here, or drop the log entirely.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/runtime/logging.cc` around lines 111 - 112, The unconditional LOG(INFO)
used when parsing TVM_LOG_DEBUG (the statement emitting "TVM_LOG_DEBUG enables
VLOG statements in '...'" in src/runtime/logging.cc) should not print to stderr
for every spec entry; replace it with a gated log such as VLOG(1) (or
DLOG(INFO)) or remove it entirely so the message only appears when verbose
logging is enabled; update the single LOG(INFO) call accordingly to use VLOG(1)
or drop the emission.

settings.vlog_level_map_.emplace(name, level_val);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Handle duplicate file entries explicitly.

Line 113 uses emplace, so repeated keys keep the first level silently. That makes duplicates behave incorrectly and also leaves the preceding log message lying about the effective level. Either overwrite (insert_or_assign / operator[]) or fail fast on duplicates.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/runtime/logging.cc` at line 113, The code uses
settings.vlog_level_map_.emplace(name, level_val) which silently keeps the first
value on duplicate keys; change this to explicitly handle duplicates by either
overwriting with settings.vlog_level_map_.insert_or_assign(name, level_val) (or
operator[]) so later entries take effect, or detect an existing key (find on
vlog_level_map_) and fail fast/log an error before inserting; update the
surrounding log message to reflect the actual effective level based on the
chosen behavior (overwrite or error).

}
return settings;
}

bool TvmLogDebugSettings::VerboseEnabledImpl(const std::string &filename,
int level) const {
auto itr = vlog_level_map_.find(FileToVLogMapKey(filename));
if (itr != vlog_level_map_.end()) {
return level <= itr->second;
}
itr = vlog_level_map_.find(kDefaultKeyword);
if (itr != vlog_level_map_.end()) {
return level <= itr->second;
}
return false;
}

} // namespace detail
} // namespace runtime
} // namespace tvm
27 changes: 23 additions & 4 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4359,10 +4359,29 @@ void CodeGenTileLangCUDA::VisitExpr_(const ShuffleNode *op,
os << "uint1{__pack_nv_bfloat162(" << e0 << ", " << e1 << ")}";
} else {
enable_fp16_ = true;
// __pack_half2 returns __half2 which is 32-bit.
// Reinterpret via aggregate initialisation.
os << "uint1{*(unsigned*)&(__pack_half2((__half)(" << e0 << "), (__half)("
<< e1 << ")))}";
os << "uint1{tl::pack_half2(" << e0 << ", " << e1 << ")}";
}
return;
}
// Handle ExtractElement: extract a scalar lane from a bfloat16x2 / float16x2
// vector (produced by packed reduction, etc.). The vector is stored as an
// opaque uint1 in the lowered code, but semantically it is a packed pair.
DataType vec_t =
op->vectors.size() == 1 ? op->vectors[0].dtype() : DataType();
bool vec_is_bf16x2 = vec_t.is_bfloat16() && vec_t.lanes() == 2;
bool vec_is_fp16x2 = vec_t.is_float16() && vec_t.lanes() == 2;
if ((vec_is_bf16x2 || vec_is_fp16x2) && op->vectors.size() == 1 &&
op->indices.size() == 1) {
int lane = Downcast<IntImm>(op->indices[0])->value;
std::string vec = PrintExpr(op->vectors[0]);
if (vec_is_bf16x2) {
enable_bf16_ = true;
os << "bfloat16_t(((nv_bfloat162*)(&(" << vec << ")))->"
<< (lane == 0 ? "x" : "y") << ")";
} else {
enable_fp16_ = true;
os << "half_t(((half2*)(&(" << vec << ")))->" << (lane == 0 ? "x" : "y")
<< ")";
}
return;
}
Expand Down
29 changes: 29 additions & 0 deletions src/tl_templates/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,13 @@ template <typename T> TL_DEVICE uint1 to_uint1(T v) {
return r;
}

// Pack two half_t into a uint1.
TL_DEVICE uint1 pack_half2(half_t a, half_t b) {
unsigned packed =
__pack_half2(static_cast<__half>(a), static_cast<__half>(b));
return uint1{packed};
}

// --- add2 ----------------------------------------------------------------

TL_DEVICE float2 add2(float2 a, float2 b) {
Expand Down Expand Up @@ -959,4 +966,26 @@ TL_DEVICE bfloat16_t shfl_sync(unsigned mask, bfloat16_t val, int srcLane) {
return reinterpret_cast<bfloat16_t &>(ret16);
}

// Specializations for uint1 (packed bfloat16x2 / float16x2).
// uint1 is a 32-bit struct { unsigned x; } used to represent packed pairs.
// __shfl_xor_sync operates on native 32-bit types, so we pass the raw unsigned.

template <>
TL_DEVICE uint1 shfl_xor_sync(unsigned mask, uint1 val, int laneMask) {
return uint1{__shfl_xor_sync(mask, val.x, laneMask)};
}

template <>
TL_DEVICE uint1 shfl_down_sync(unsigned mask, uint1 val, int delta) {
return uint1{__shfl_down_sync(mask, val.x, delta)};
}

template <> TL_DEVICE uint1 shfl_up_sync(unsigned mask, uint1 val, int delta) {
return uint1{__shfl_up_sync(mask, val.x, delta)};
}

template <> TL_DEVICE uint1 shfl_sync(unsigned mask, uint1 val, int srcLane) {
return uint1{__shfl_sync(mask, val.x, srcLane)};
}

} // namespace tl
42 changes: 42 additions & 0 deletions src/tl_templates/cuda/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,48 @@ struct MinOpNan {
}
};

struct SumOp_bf16x2 {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return tl::to_uint1(tl::add2(tl::from_uint1<__nv_bfloat162>(x),
tl::from_uint1<__nv_bfloat162>(y)));
}
};

struct MaxOp_bf16x2 {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return tl::to_uint1(tl::max2(tl::from_uint1<__nv_bfloat162>(x),
tl::from_uint1<__nv_bfloat162>(y)));
}
};

struct MinOp_bf16x2 {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return tl::to_uint1(tl::min2(tl::from_uint1<__nv_bfloat162>(x),
tl::from_uint1<__nv_bfloat162>(y)));
}
};

struct SumOp_fp16x2 {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return tl::to_uint1(
tl::add2(tl::from_uint1<__half2>(x), tl::from_uint1<__half2>(y)));
}
};

struct MaxOp_fp16x2 {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return tl::to_uint1(
tl::max2(tl::from_uint1<__half2>(x), tl::from_uint1<__half2>(y)));
}
};

struct MinOp_fp16x2 {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return tl::to_uint1(
tl::min2(tl::from_uint1<__half2>(x), tl::from_uint1<__half2>(y)));
}
};

struct BitAndOp {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return x & y;
Expand Down
3 changes: 2 additions & 1 deletion testing/python/language/test_tilelang_language_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

tilelang.testing.set_random_seed()


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -76,6 +75,8 @@ def _reduce_op(T, op, src, dst, dim, batch=1):
("sum", T.float32, 128, 64, "shared", "fragment", 256, 2),
("sum", T.float32, 128, 64, "shared", "fragment", 256, 4),
("sum", T.float16, 64, 128, "fragment", "fragment", 256, 4),
("sum", T.bfloat16, 128, 128, "fragment", "fragment", 32, 1),
("sum", T.bfloat16, 64, 128, "fragment", "fragment", 256, 4),
("max", T.bfloat16, 128, 64, "shared", "fragment", 256, 2),
("max", T.float32, 128, 128, "fragment", "fragment", 256, 4),
("min", T.float32, 64, 128, "shared", "fragment", 128, 2),
Expand Down
Loading