Skip to content
Open
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions cpp/include/tensorrt_llm/runtime/memoryCounters.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <atomic>
#include <cstddef>
#include <limits>
#include <string>

namespace tensorrt_llm::runtime
Expand Down Expand Up @@ -84,9 +85,18 @@ class MemoryCounters
return mPinnedPoolDiff;
}

template <MemoryType T>
struct always_false : std::false_type
{
};

template <MemoryType T>
void allocate(SizeType32 size)
{
if (size > static_cast<SizeType32>(std::numeric_limits<DiffType>::max()))
{
TLLM_THROW("Memory size too large for diff type: %zu", size);
}
auto const sizeDiff = static_cast<DiffType>(size);
if constexpr (T == MemoryType::kGPU)
{
Expand Down Expand Up @@ -115,7 +125,7 @@ class MemoryCounters
}
else
{
TLLM_THROW("Unknown memory type: %s", MemoryTypeString<T>::value);
static_assert(always_false<T>::value, "Unknown memory type!");
}
}

Expand All @@ -124,6 +134,10 @@ class MemoryCounters
template <MemoryType T>
void deallocate(SizeType32 size)
{
if (size > static_cast<SizeType32>(std::numeric_limits<DiffType>::max()))
{
TLLM_THROW("Memory size too large for diff type: %zu", size);
}
auto const sizeDiff = -static_cast<DiffType>(size);
if constexpr (T == MemoryType::kGPU)
{
Expand Down Expand Up @@ -152,14 +166,19 @@ class MemoryCounters
}
else
{
TLLM_THROW("Unknown memory type: %s", MemoryTypeString<T>::value);
static_assert(always_false<T>::value, "Unknown memory type!");
}
}

void deallocate(MemoryType memoryType, SizeType32 size);

static MemoryCounters& getInstance();

MemoryCounters(MemoryCounters const&) = delete;
MemoryCounters& operator=(MemoryCounters const&) = delete;
MemoryCounters(MemoryCounters&&) = delete;
MemoryCounters& operator=(MemoryCounters&&) = delete;

static std::string bytesToString(SizeType32 bytes, int precision = 2);

static std::string bytesToString(DiffType bytes, int precision = 2);
Expand Down