From a0d0511d2bd6f5875c3451d48045afd7a412712b Mon Sep 17 00:00:00 2001 From: fanyunfan <2569658856@qq.com> Date: Fri, 25 Jul 2025 08:03:06 +0800 Subject: [PATCH 1/5] [FIX] fix bugs caused by None attention_bias during qwen3 model conversion Signed-off-by: fanyunfan <2569548856@qq.com> --- tensorrt_llm/models/qwen/config.py | 4 +++- tensorrt_llm/models/qwen/convert.py | 21 ++++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/models/qwen/config.py b/tensorrt_llm/models/qwen/config.py index 47d1e15baea..c9e57ecdf69 100644 --- a/tensorrt_llm/models/qwen/config.py +++ b/tensorrt_llm/models/qwen/config.py @@ -109,7 +109,9 @@ def from_hugging_face(cls, assert qwen_type in valid_types, f"Unsupported Qwen type: {qwen_type}, only {valid_types} are acceptable." num_key_value_heads = getattr(hf_config, "num_key_value_heads", hf_config.num_attention_heads) - head_dim = hf_config.hidden_size // hf_config.num_attention_heads + head_dim = getattr( + hf_config, "head_dim", + hf_config.hidden_size // hf_config.num_attention_heads) head_size = getattr(hf_config, "kv_channels", head_dim) hidden_act = getattr(hf_config, "hidden_act", "silu") if qwen_type == "qwen2_moe": diff --git a/tensorrt_llm/models/qwen/convert.py b/tensorrt_llm/models/qwen/convert.py index dc2bc355683..0bcc762ba17 100644 --- a/tensorrt_llm/models/qwen/convert.py +++ b/tensorrt_llm/models/qwen/convert.py @@ -537,19 +537,26 @@ def convert_hf_qwen(hf_model, tensor_parallel) assert (k_weight.shape[0] % (mapping.tp_size * head_size)) == 0 assert (v_weight.shape[0] % (mapping.tp_size * head_size)) == 0 - assert (k_bias.shape[0] % (mapping.tp_size * head_size)) == 0 - assert (v_bias.shape[0] % (mapping.tp_size * head_size)) == 0 + + if k_bias is not None and v_bias is not None: + assert (k_bias.shape[0] % + (mapping.tp_size * head_size)) == 0 + assert (v_bias.shape[0] % + (mapping.tp_size * head_size)) == 0 wq = split(q_weight, mapping.tp_size, mapping.tp_rank) wk = split(k_weight, mapping.tp_size, mapping.tp_rank) wv = split(v_weight, mapping.tp_size, mapping.tp_rank) - bq = split(q_bias, mapping.tp_size, mapping.tp_rank) - bk = split(k_bias, mapping.tp_size, mapping.tp_rank) - bv = split(v_bias, mapping.tp_size, mapping.tp_rank) - qkv_w = torch.concat((wq, wk, wv)) - qkv_b = torch.concat((bq, bk, bv)) + + if q_bias is not None and k_bias is not None and v_bias is not None: + bq = split(q_bias, mapping.tp_size, mapping.tp_rank) + bk = split(k_bias, mapping.tp_size, mapping.tp_rank) + bv = split(v_bias, mapping.tp_size, mapping.tp_rank) + qkv_b = torch.concat((bq, bk, bv)) + else: + qkv_b = None else: qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0) From bbb21981906e93b115ec76361306c0e3c9b97f56 Mon Sep 17 00:00:00 2001 From: fanyunfan <2569548856@qq.com> Date: Sat, 4 Oct 2025 08:53:37 +0800 Subject: [PATCH 2/5] [None][fix] Enhance memory counters with compile-time safety and bounds checking Signed-off-by: fanyunfan <2569548856@qq.com> --- cpp/include/tensorrt_llm/runtime/memoryCounters.h | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/cpp/include/tensorrt_llm/runtime/memoryCounters.h b/cpp/include/tensorrt_llm/runtime/memoryCounters.h index 42ccdc13d62..59d1c277014 100644 --- a/cpp/include/tensorrt_llm/runtime/memoryCounters.h +++ b/cpp/include/tensorrt_llm/runtime/memoryCounters.h @@ -87,6 +87,10 @@ class MemoryCounters template void allocate(SizeType32 size) { + if (size > static_cast(std::numeric_limits::max())) + { + TLLM_THROW("Memory size too large for diff type: %zu", size); + } auto const sizeDiff = static_cast(size); if constexpr (T == MemoryType::kGPU) { @@ -115,7 +119,7 @@ class MemoryCounters } else { - TLLM_THROW("Unknown memory type: %s", MemoryTypeString::value); + static_assert(!std::is_same_v, "Unknown memory type!"); } } @@ -124,6 +128,10 @@ class MemoryCounters template void deallocate(SizeType32 size) { + if (size > static_cast(std::numeric_limits::max())) + { + TLLM_THROW("Memory size too large for diff type: %zu", size); + } auto const sizeDiff = -static_cast(size); if constexpr (T == MemoryType::kGPU) { @@ -152,7 +160,7 @@ class MemoryCounters } else { - TLLM_THROW("Unknown memory type: %s", MemoryTypeString::value); + static_assert(!std::is_same_v, "Unknown memory type!"); } } From 4511c590818f2a27acbe7fa130253951285222df Mon Sep 17 00:00:00 2001 From: fanyunfan <2569548856@qq.com> Date: Sat, 4 Oct 2025 11:29:13 +0800 Subject: [PATCH 3/5] [None][fix] Explicitly prohibit all copy and assignment operations for MemoryCounters Singleton. Signed-off-by: fanyunfan <2569548856@qq.com> --- cpp/include/tensorrt_llm/runtime/memoryCounters.h | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/cpp/include/tensorrt_llm/runtime/memoryCounters.h b/cpp/include/tensorrt_llm/runtime/memoryCounters.h index 59d1c277014..e96c3f8729c 100644 --- a/cpp/include/tensorrt_llm/runtime/memoryCounters.h +++ b/cpp/include/tensorrt_llm/runtime/memoryCounters.h @@ -84,6 +84,11 @@ class MemoryCounters return mPinnedPoolDiff; } + template + struct always_false : std::false_type + { + }; + template void allocate(SizeType32 size) { @@ -119,7 +124,7 @@ class MemoryCounters } else { - static_assert(!std::is_same_v, "Unknown memory type!"); + static_assert(always_false::value, "Unknown memory type!"); } } @@ -160,7 +165,7 @@ class MemoryCounters } else { - static_assert(!std::is_same_v, "Unknown memory type!"); + static_assert(always_false::value, "Unknown memory type!"); } } @@ -168,6 +173,11 @@ class MemoryCounters static MemoryCounters& getInstance(); + MemoryCounters(MemoryCounters const&) = delete; + MemoryCounters& operator=(MemoryCounters const&) = delete; + MemoryCounters(MemoryCounters&&) = delete; + MemoryCounters& operator=(MemoryCounters&&) = delete; + static std::string bytesToString(SizeType32 bytes, int precision = 2); static std::string bytesToString(DiffType bytes, int precision = 2); From a6761061617ec0daf7e128f2f5a9be088e93c3ad Mon Sep 17 00:00:00 2001 From: fanyunfan <2569548856@qq.com> Date: Sat, 1 Nov 2025 00:07:37 +0800 Subject: [PATCH 4/5] Add header file for std::numeric_limits Signed-off-by: fanyunfan <2569548856@qq.com> --- cpp/include/tensorrt_llm/runtime/memoryCounters.h | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/include/tensorrt_llm/runtime/memoryCounters.h b/cpp/include/tensorrt_llm/runtime/memoryCounters.h index e96c3f8729c..eb59bef90ef 100644 --- a/cpp/include/tensorrt_llm/runtime/memoryCounters.h +++ b/cpp/include/tensorrt_llm/runtime/memoryCounters.h @@ -21,6 +21,7 @@ #include #include +#include #include namespace tensorrt_llm::runtime From 8c43e5526aa3a2863df92d881787c0764e09b26e Mon Sep 17 00:00:00 2001 From: fanyunfan <2569548856@qq.com> Date: Thu, 6 Nov 2025 09:16:41 +0800 Subject: [PATCH 5/5] [None][fix] Fix compile-time check and remove unnecessary checks and constraints Signed-off-by: fanyunfan <2569548856@qq.com> --- .../tensorrt_llm/runtime/memoryCounters.h | 38 ++++++------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/cpp/include/tensorrt_llm/runtime/memoryCounters.h b/cpp/include/tensorrt_llm/runtime/memoryCounters.h index eb59bef90ef..6fd48e66be1 100644 --- a/cpp/include/tensorrt_llm/runtime/memoryCounters.h +++ b/cpp/include/tensorrt_llm/runtime/memoryCounters.h @@ -21,7 +21,6 @@ #include #include -#include #include namespace tensorrt_llm::runtime @@ -30,32 +29,32 @@ namespace tensorrt_llm::runtime class MemoryCounters { public: - using SizeType32 = std::size_t; + using SizeType = std::size_t; using DiffType = std::ptrdiff_t; MemoryCounters() = default; - [[nodiscard]] SizeType32 getGpu() const + [[nodiscard]] SizeType getGpu() const { return mGpu; } - [[nodiscard]] SizeType32 getCpu() const + [[nodiscard]] SizeType getCpu() const { return mCpu; } - [[nodiscard]] SizeType32 getPinned() const + [[nodiscard]] SizeType getPinned() const { return mPinned; } - [[nodiscard]] SizeType32 getUVM() const + [[nodiscard]] SizeType getUVM() const { return mUVM; } - [[nodiscard]] SizeType32 getPinnedPool() const + [[nodiscard]] SizeType getPinnedPool() const { return mPinnedPool; } @@ -91,12 +90,8 @@ class MemoryCounters }; template - void allocate(SizeType32 size) + void allocate(SizeType size) { - if (size > static_cast(std::numeric_limits::max())) - { - TLLM_THROW("Memory size too large for diff type: %zu", size); - } auto const sizeDiff = static_cast(size); if constexpr (T == MemoryType::kGPU) { @@ -129,15 +124,11 @@ class MemoryCounters } } - void allocate(MemoryType memoryType, SizeType32 size); + void allocate(MemoryType memoryType, SizeType size); template - void deallocate(SizeType32 size) + void deallocate(SizeType size) { - if (size > static_cast(std::numeric_limits::max())) - { - TLLM_THROW("Memory size too large for diff type: %zu", size); - } auto const sizeDiff = -static_cast(size); if constexpr (T == MemoryType::kGPU) { @@ -170,23 +161,18 @@ class MemoryCounters } } - void deallocate(MemoryType memoryType, SizeType32 size); + void deallocate(MemoryType memoryType, SizeType size); static MemoryCounters& getInstance(); - MemoryCounters(MemoryCounters const&) = delete; - MemoryCounters& operator=(MemoryCounters const&) = delete; - MemoryCounters(MemoryCounters&&) = delete; - MemoryCounters& operator=(MemoryCounters&&) = delete; - - static std::string bytesToString(SizeType32 bytes, int precision = 2); + static std::string bytesToString(SizeType bytes, int precision = 2); static std::string bytesToString(DiffType bytes, int precision = 2); [[nodiscard]] std::string toString() const; private: - std::atomic mGpu{}, mCpu{}, mPinned{}, mUVM{}, mPinnedPool{}; + std::atomic mGpu{}, mCpu{}, mPinned{}, mUVM{}, mPinnedPool{}; std::atomic mGpuDiff{}, mCpuDiff{}, mPinnedDiff{}, mUVMDiff{}, mPinnedPoolDiff{}; };