Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] [TIR] variable b has been used before definition! #17572

Open
hmz0412 opened this issue Dec 24, 2024 · 2 comments
Open

[Bug] [TIR] variable b has been used before definition! #17572

hmz0412 opened this issue Dec 24, 2024 · 2 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@hmz0412
Copy link

hmz0412 commented Dec 24, 2024

I want to create my own prefill tir function, but when i built this module to test, the error messages above were appeared.
I don't know the reason why the bug will be created.

Actual situation
Traceback (most recent call last): File "/home/octal/mlc-llm/3rdparty/tvm/python/tvm/relax/frontend/nn/llm/test_equal.py", line 554, in <module> lib_cpu = tvm.build(IR_cpu, target="llvm") File "/home/octal/mlc-llm/3rdparty/tvm/python/tvm/driver/build_module.py", line 297, in build rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host) File "/home/octal/mlc-llm/3rdparty/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 245, in __call__ raise_last_ffi_error() File "/home/octal/mlc-llm/3rdparty/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error raise py_err File "/home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc", line 531, in operator() return TIRToRuntime(inputs_arg, host_target); File "/home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc", line 492, in tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&) auto pair = SplitMixedModule(ir_module, target, target_host); File "/home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc", line 418, in tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&) mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target)); File "/home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc", line 291, in tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential) mod = seq(std::move(mod)); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/transforms/make_packed_api.cc", line 458, in operator() func = MakePackedAPI(std::move(func)); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/transforms/make_packed_api.cc", line 420, in tvm::tir::MakePackedAPI(tvm::tir::PrimFunc) Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc", line 186, in tvm::tir::UndefinedVars(tvm::tir::Stmt const&, tvm::runtime::Array<tvm::tir::Var, void> const&) m(stmt); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > fvisit(arr[i]); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in operator() VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > fvisit(arr[i]); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in operator() VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > fvisit(arr[i]); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in operator() VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > fvisit(arr[i]); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in operator() VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > fvisit(arr[i]); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in operator() VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > fvisit(arr[i]); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in operator() VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > fvisit(arr[i]); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in operator() VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > fvisit(arr[i]); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in operator() VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > fvisit(arr[i]); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in operator() VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > fvisit(arr[i]); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 119, in operator() VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*) this->VisitStmt(op->body); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*) this->VisitStmt(op->body); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*) this->VisitStmt(op->body); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*) this->VisitStmt(op->body); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*) this->VisitStmt(op->body); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*) this->VisitStmt(op->body); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*) this->VisitStmt(op->body); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc", line 61, in tvm::tir::VarUseDefAnalyzer::VisitStmt_(tvm::tir::ForNode const*) this->HandleDef(op->loop_var); File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc", line 136, in tvm::tir::VarUseDefAnalyzer::HandleDef(tvm::tir::Var const&) ICHECK(!def_count_.count(v)) << "variable " << v->name_hint tvm.error.InternalError: Traceback (most recent call last): 45: operator() at /home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc:531 44: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&) at /home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc:492 43: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&) at /home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc:418 42: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential) at /home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc:291 41: operator() at /home/octal/mlc-llm/3rdparty/tvm/src/tir/transforms/make_packed_api.cc:458 40: tvm::tir::MakePackedAPI(tvm::tir::PrimFunc) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/transforms/make_packed_api.cc:420 39: tvm::tir::UndefinedVars(tvm::tir::Stmt const&, tvm::runtime::Array<tvm::tir::Var, void> const&) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc:186 38: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 37: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35 36: operator() at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 35: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 34: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35 33: operator() at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 32: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 31: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35 30: operator() at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 29: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 28: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35 27: operator() at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 26: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 25: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35 24: operator() at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 23: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 22: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35 21: operator() at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 20: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 19: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35 18: operator() at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 17: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 16: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35 15: operator() at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 14: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 13: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35 12: operator() at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 11: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 10: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> > at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35 9: operator() at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119 8: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58 7: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58 6: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58 5: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58 4: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58 3: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58 2: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58 1: tvm::tir::VarUseDefAnalyzer::VisitStmt_(tvm::tir::ForNode const*) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc:61 0: tvm::tir::VarUseDefAnalyzer::HandleDef(tvm::tir::Var const&) at /home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc:136 File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc", line 138 InternalError: Check failed: (!use_count_.count(v)) is false: variable b has been used before definition!

Reproduce

import math
from typing import Any, Dict, Tuple

import tvm
from tvm import relax as rx
from tvm import tir
from tvm.relax.frontend.nn import Object, Tensor
from tvm.runtime import DataType
from tvm.script import tir as T
from tvm.script import ir as I
from tvm.target import Target

def _attention_prefill_ragged_cpu(h_kv, h_q, d, dtype):

    group_size = h_q // h_kv
    sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1))

    @I.ir_module
    class cpu_module:

        @T.prim_func
        def batch_prefill_ragged_kv(  # pylint: disable=too-many-branches
            var_q: T.handle, # [total_len, h_q, d]
            var_q_indptr: T.handle, # [batch_size + 1]
            var_k: T.handle, # [total_len, h_kv, d]
            var_v: T.handle, # [total_len, h_kv, d]
            var_kv_indptr: T.handle, # [batch_size + 1]
            var_q_rope_position: T.handle, # [total_q_len]
            var_k_rope_pos_offset: T.handle, # [b]
            var_output: T.handle, # [total_len, h_q, d]
            var_lse: T.handle, # [total_len, h_q]
            causal: T.int32,
            rotary_mode: T.int32,
            rope_scale: T.float32,
            rope_theta: T.float32,
            attn_score_scaling_factor: T.float32
        ):
            batch_size = T.int32(is_size_var=True)
            qo_len = T.int32(is_size_var=True)
            kv_len = T.int32(is_size_var=True)
            q_indptr_elem_offset = T.int32(is_size_var=True)
            kv_indptr_elem_offset = T.int32(is_size_var=True)
            q_rope_position_elem_offset = T.int32(is_size_var=True)
            k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)

            q = T.match_buffer(var_q, (qo_len, h_q, d), dtype)
            q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset)
            k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype)
            v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype)
            kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset)
            q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset)
            k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset)
            output = T.match_buffer(var_output, (qo_len, h_q, d), dtype)
            lse = T.match_buffer(var_lse, (qo_len, h_q), "float32")  # pylint: disable=unused-variable


            for b in T.serial(batch_size):
                with T.block("attn"):
                    
                    # q_token_start = T.alloc_buffer([1,], "uint32")
                    # q_num = T.alloc_buffer([1,], "uint32")
                    # k_token_start = T.alloc_buffer([1,], "int32")
                    # k_num = T.alloc_buffer([1,], "int32")

                    softmax_sum = T.alloc_buffer([h_q], "float32")
                    m_prev = T.alloc_buffer([h_q], "float32")
                    m_new = T.alloc_buffer([h_q], "float32")
                    d_prev = T.alloc_buffer([h_q], "float32")
                    d_new = T.alloc_buffer([h_q], "float32")
                    sum = T.alloc_buffer([d], "float32")

                    max_score = T.alloc_buffer([h_q], "float32")
                    attention_scores = T.alloc_buffer([kv_indptr[b + 1] - kv_indptr[b], h_q], "float32")
                    exp_scores = T.alloc_buffer([kv_indptr[b + 1] - kv_indptr[b], h_q], "float32")
                    attention_score = T.alloc_buffer([1,], "float32")
                    
                    for q_idx in T.serial(q_indptr[b + 1] - q_indptr[b]):
                    
                        for i in T.serial(h_q):
                            max_score[i] = -5e4
                            m_prev[i] = -5e4
                            d_prev[i] = 1.0

                        for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
                            for h in T.serial(h_q):
                                h_kv_idx = h // group_size

                                if _causal_mask(causal,
                                                row=q_idx,
                                                col=k_idx,
                                                kv_len=kv_indptr[b + 1] - kv_indptr[b],
                                                qo_len=q_indptr[b + 1] - q_indptr[b]):
                                    result = 0.0
                                    for d_idx in T.serial(d):
                                        result += q[q_indptr[b] + q_idx, h, d_idx] * k[kv_indptr[b] + k_idx, h_kv_idx, d_idx]
                                    attention_score[0] = result * sm_scale
                                            
                                else:
                                    attention_score[0] = -5e4 * sm_scale
                                attention_scores[k_idx, h] = attention_score[0]
                                max_score[h] = T.max(max_score[h], attention_score[0])
                                m_new[h] = T.max(m_prev[h], max_score[h])
                                

                        for h in T.serial(h_q):
                            d_new[h] = d_prev[h] * T.exp2(m_prev[h] - m_new[h])                

                        for h in T.serial(h_q):
                            softmax_sum[h] = 0.0
                            for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
                                exp_scores[k_idx, h] = T.exp(attention_scores[k_idx, h] - m_new[h])
                                softmax_sum[h] += exp_scores[k_idx, h]
                            d_new[h]+=softmax_sum[h]
                                
                        d_prev = d_new

                        for h in T.serial(h_q):
                            h_kv_idx = h // group_size
                            
                            for i in T.serial(d):
                                sum[i] = 0.0
                            for v_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]):
                                weight = exp_scores[v_idx, h] / softmax_sum[h]
                                for i in T.serial(d):
                                    sum[i] += v[kv_indptr[b] + v_idx, h_kv_idx, i] * weight
                            for i in T.serial(d):
                                output[q_indptr[b] + q_idx, h, i] = sum[i]
    return cpu_module
IR_cpu = _attention_prefill_ragged_cpu(2, 16, 256, "float32")
lib_cpu = tvm.build(IR_cpu, target="llvm")

cc @Hzfengsy @junrushao @quic-sanirudh @shingjan

@hmz0412 hmz0412 added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Dec 24, 2024
@Hzfengsy
Copy link
Member

exp_scores = T.alloc_buffer([kv_indptr[b + 1] - kv_indptr[b], h_q], "float32")

Dynamic allocation (depends on loop var) is not allowed :)

@hmz0412
Copy link
Author

hmz0412 commented Dec 25, 2024

I want to achieve the effect of allocating a buffer that is related to the number of KV tokens in each batch. However, since dynamic allocation is not possible, how can I achieve this while complying with TIR constraints?

For example: Is it possible to allocate the buffer enough to each batch outside the main loop and then use it later? Or are there other approaches?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

2 participants