Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix](resource) Fix MemTableWriter attach resource context to thread context #47556

Merged
merged 3 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions be/src/olap/memtable_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,10 @@ Status MemTableWriter::_flush_memtable_async() {

Status MemTableWriter::flush_async() {
std::lock_guard<std::mutex> l(_lock);
// In order to avoid repeated ATTACH, use SWITCH here. have two calling paths:
// 1. call by local, from `VTabletWriterV2::_write_memtable`, has been ATTACH Load memory tracker
// into thread context, ATTACH cannot be repeated here.
// 2. call by remote, from `LoadChannelMgr::_get_load_channel`, no ATTACH because LoadChannelMgr
// not know Load context.
SCOPED_SWITCH_THREAD_MEM_TRACKER_LIMITER(_resource_ctx->memory_context()->mem_tracker());
// Two calling paths:
// 1. call by local, from `VTabletWriterV2::_write_memtable`.
// 2. call by remote, from `LoadChannelMgr::_get_load_channel`.
SCOPED_SWITCH_RESOURCE_CONTEXT(_resource_ctx);
if (!_is_init || _is_closed) {
// This writer is uninitialized or closed before flushing, do nothing.
// We return OK instead of NOT_INITIALIZED or ALREADY_CLOSED.
Expand Down
1 change: 0 additions & 1 deletion be/src/runtime/load_channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ Status LoadChannel::add_batch(const PTabletWriterAddBlockRequest& request,
PTabletWriterAddBlockResult* response) {
SCOPED_TIMER(_add_batch_timer);
COUNTER_UPDATE(_add_batch_times, 1);
SCOPED_ATTACH_TASK(_resource_ctx);
int64_t index_id = request.index_id();
// 1. get tablets channel
std::shared_ptr<BaseTabletsChannel> channel;
Expand Down
2 changes: 2 additions & 0 deletions be/src/runtime/load_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class LoadChannel {

bool is_high_priority() const { return _is_high_priority; }

std::shared_ptr<ResourceContext> resource_ctx() const { return _resource_ctx; }

RuntimeProfile::Counter* get_mgr_add_batch_timer() { return _mgr_add_batch_timer; }
RuntimeProfile::Counter* get_handle_mem_limit_timer() { return _handle_mem_limit_timer; }

Expand Down
1 change: 1 addition & 0 deletions be/src/runtime/load_channel_mgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ Status LoadChannelMgr::add_batch(const PTabletWriterAddBlockRequest& request,
return status;
}
SCOPED_TIMER(channel->get_mgr_add_batch_timer());
SCOPED_ATTACH_TASK(channel->resource_ctx());

if (!channel->is_high_priority()) {
// 2. check if mem consumption exceed limit
Expand Down
16 changes: 12 additions & 4 deletions be/src/runtime/memory/thread_mem_tracker_mgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ void ThreadMemTrackerMgr::attach_limiter_tracker(
DCHECK(mem_tracker);
CHECK(init());
flush_untracked_mem();
_last_attach_snapshots_stack.push_back({_reserved_mem, _consumer_tracker_stack});
_last_attach_snapshots_stack.push_back(
{_limiter_tracker, _wg_wptr, _reserved_mem, _consumer_tracker_stack});
if (_reserved_mem != 0) {
// _untracked_mem temporary store bytes that not synchronized to process reserved memory,
// but bytes have been subtracted from thread _reserved_mem.
Expand All @@ -41,16 +42,23 @@ void ThreadMemTrackerMgr::attach_limiter_tracker(
_limiter_tracker = mem_tracker;
}

void ThreadMemTrackerMgr::detach_limiter_tracker(
const std::shared_ptr<MemTrackerLimiter>& old_mem_tracker) {
void ThreadMemTrackerMgr::attach_limiter_tracker(
const std::shared_ptr<MemTrackerLimiter>& mem_tracker,
const std::weak_ptr<WorkloadGroup>& wg_wptr) {
attach_limiter_tracker(mem_tracker);
_wg_wptr = wg_wptr;
}

void ThreadMemTrackerMgr::detach_limiter_tracker() {
CHECK(init());
flush_untracked_mem();
shrink_reserved();
DCHECK(!_last_attach_snapshots_stack.empty());
_limiter_tracker = _last_attach_snapshots_stack.back().limiter_tracker;
_wg_wptr = _last_attach_snapshots_stack.back().wg_wptr;
_reserved_mem = _last_attach_snapshots_stack.back().reserved_mem;
_consumer_tracker_stack = _last_attach_snapshots_stack.back().consumer_tracker_stack;
_last_attach_snapshots_stack.pop_back();
_limiter_tracker = old_mem_tracker;
}

} // namespace doris
22 changes: 6 additions & 16 deletions be/src/runtime/memory/thread_mem_tracker_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,9 @@ class ThreadMemTrackerMgr {

// After attach, the current thread Memory Hook starts to consume/release task mem_tracker
void attach_limiter_tracker(const std::shared_ptr<MemTrackerLimiter>& mem_tracker);
void detach_limiter_tracker(const std::shared_ptr<MemTrackerLimiter>& old_mem_tracker =
ExecEnv::GetInstance()->orphan_mem_tracker());

void attach_task(const std::shared_ptr<MemTrackerLimiter>& mem_tracker,
const std::weak_ptr<WorkloadGroup>& wg_wptr) {
DCHECK(mem_tracker);
attach_limiter_tracker(mem_tracker);
_wg_wptr = wg_wptr;
enable_wait_gc();
}
void detach_task(const std::shared_ptr<MemTrackerLimiter>& old_mem_tracker) {
detach_limiter_tracker(old_mem_tracker);
_wg_wptr.reset();
disable_wait_gc();
}
void attach_limiter_tracker(const std::shared_ptr<MemTrackerLimiter>& mem_tracker,
const std::weak_ptr<WorkloadGroup>& wg_wptr);
void detach_limiter_tracker();

// Must be fast enough! Thread update_tracker may be called very frequently.
bool push_consumer_tracker(MemTracker* mem_tracker);
Expand Down Expand Up @@ -134,6 +122,8 @@ class ThreadMemTrackerMgr {

private:
struct LastAttachSnapshot {
std::shared_ptr<MemTrackerLimiter> limiter_tracker {nullptr};
std::weak_ptr<WorkloadGroup> wg_wptr;
int64_t reserved_mem = 0;
std::vector<MemTracker*> consumer_tracker_stack;
};
Expand All @@ -155,7 +145,7 @@ class ThreadMemTrackerMgr {
// A thread of query/load will only wait once during execution.
bool _wait_gc = false;

std::shared_ptr<MemTrackerLimiter> _limiter_tracker;
std::shared_ptr<MemTrackerLimiter> _limiter_tracker {nullptr};
std::vector<MemTracker*> _consumer_tracker_stack;
std::weak_ptr<WorkloadGroup> _wg_wptr;

Expand Down
46 changes: 28 additions & 18 deletions be/src/runtime/thread_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ class MemTracker;

void AttachTask::init(const std::shared_ptr<ResourceContext>& rc) {
ThreadLocalHandle::create_thread_local_if_not_exits();
signal::set_signal_task_id(rc->task_controller()->task_id());
thread_context()->attach_task(rc);
}

AttachTask::AttachTask(const std::shared_ptr<ResourceContext>& rc) {
signal::set_signal_task_id(rc->task_controller()->task_id());
init(rc);
}

Expand All @@ -50,37 +50,47 @@ AttachTask::AttachTask(QueryContext* query_ctx) {
}

AttachTask::~AttachTask() {
signal::set_signal_task_id(TUniqueId());
thread_context()->detach_task();
ThreadLocalHandle::del_thread_local_if_count_is_zero();
}

SwitchResourceContext::SwitchResourceContext(const std::shared_ptr<ResourceContext>& rc) {
DCHECK(rc != nullptr);
DCHECK(thread_context()->is_attach_task());
doris::ThreadLocalHandle::create_thread_local_if_not_exits();
if (rc != thread_context()->resource_ctx()) {
signal::set_signal_task_id(rc->task_controller()->task_id());
old_resource_ctx_ = thread_context()->resource_ctx();
thread_context()->resource_ctx_ = rc;
thread_context()->thread_mem_tracker_mgr->attach_limiter_tracker(
rc->memory_context()->mem_tracker(),
rc->workload_group_context()->workload_group());
}
}

SwitchResourceContext::~SwitchResourceContext() {
if (old_resource_ctx_ != nullptr) {
signal::set_signal_task_id(old_resource_ctx_->task_controller()->task_id());
thread_context()->resource_ctx_ = old_resource_ctx_;
thread_context()->thread_mem_tracker_mgr->detach_limiter_tracker();
}
doris::ThreadLocalHandle::del_thread_local_if_count_is_zero();
}

SwitchThreadMemTrackerLimiter::SwitchThreadMemTrackerLimiter(
const std::shared_ptr<doris::MemTrackerLimiter>& mem_tracker) {
DCHECK(mem_tracker);
doris::ThreadLocalHandle::create_thread_local_if_not_exits();
if (mem_tracker != thread_context()->thread_mem_tracker_mgr->limiter_mem_tracker()) {
_old_mem_tracker = thread_context()->thread_mem_tracker_mgr->limiter_mem_tracker();
thread_context()->thread_mem_tracker_mgr->attach_limiter_tracker(mem_tracker);
}
}

SwitchThreadMemTrackerLimiter::SwitchThreadMemTrackerLimiter(ResourceContext* rc) {
doris::ThreadLocalHandle::create_thread_local_if_not_exits();
// switch in the same task execution thread.
DCHECK(thread_context()->resource_ctx()->task_controller()->task_id() ==
rc->task_controller()->task_id());
DCHECK(rc->memory_context()->mem_tracker());
if (rc->memory_context()->mem_tracker() !=
thread_context()->thread_mem_tracker_mgr->limiter_mem_tracker()) {
_old_mem_tracker = thread_context()->thread_mem_tracker_mgr->limiter_mem_tracker();
thread_context()->thread_mem_tracker_mgr->attach_limiter_tracker(
rc->memory_context()->mem_tracker());
is_switched_ = true;
}
}

SwitchThreadMemTrackerLimiter::~SwitchThreadMemTrackerLimiter() {
if (_old_mem_tracker != nullptr) {
thread_context()->thread_mem_tracker_mgr->detach_limiter_tracker(_old_mem_tracker);
if (is_switched_) {
thread_context()->thread_mem_tracker_mgr->detach_limiter_tracker();
}
doris::ThreadLocalHandle::del_thread_local_if_count_is_zero();
}
Expand Down
32 changes: 24 additions & 8 deletions be/src/runtime/thread_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
// This will save some info about a working thread in the thread context.
// Looking forward to tracking memory during thread execution into MemTrackerLimiter.
#define SCOPED_ATTACH_TASK(arg1) auto VARNAME_LINENUM(attach_task) = AttachTask(arg1)
// Switch resource context in thread context, used after SCOPED_ATTACH_TASK.
#define SCOPED_SWITCH_RESOURCE_CONTEXT(arg1) \
auto VARNAME_LINENUM(switch_resource_context) = doris::SwitchResourceContext(arg1)

// Switch MemTrackerLimiter for count memory during thread execution.
// Used after SCOPED_ATTACH_TASK, in order to count the memory into another
Expand All @@ -61,6 +64,8 @@
// thread context need to be initialized, required by Allocator and elsewhere.
#define SCOPED_ATTACH_TASK(arg1, ...) \
auto VARNAME_LINENUM(scoped_tls_at) = doris::ScopedInitThreadContext()
#define SCOPED_SWITCH_RESOURCE_CONTEXT(arg1) \
auto VARNAME_LINENUM(switch_resource_context) = doris::ScopedInitThreadContext()
#define SCOPED_SWITCH_THREAD_MEM_TRACKER_LIMITER(arg1) \
auto VARNAME_LINENUM(scoped_tls_stmtl) = doris::ScopedInitThreadContext()
#define SCOPED_CONSUME_MEM_TRACKER(mem_tracker) \
Expand Down Expand Up @@ -161,6 +166,7 @@ namespace doris {
class ThreadContext;
class MemTracker;
class RuntimeState;
class SwitchResourceContext;

extern bthread_key_t btls_key;

Expand Down Expand Up @@ -189,15 +195,16 @@ class ThreadContext {
// will only attach_task at the beginning of the thread function, there should be no duplicate attach_task.
DCHECK(resource_ctx_ == nullptr);
resource_ctx_ = rc;
old_mem_tracker_ = thread_mem_tracker_mgr->limiter_mem_tracker();
thread_mem_tracker_mgr->attach_task(rc->memory_context()->mem_tracker(),
rc->workload_group_context()->workload_group());
thread_mem_tracker_mgr->attach_limiter_tracker(
rc->memory_context()->mem_tracker(),
rc->workload_group_context()->workload_group());
thread_mem_tracker_mgr->enable_wait_gc();
}

void detach_task() {
resource_ctx_.reset();
thread_mem_tracker_mgr->detach_task(old_mem_tracker_);
old_mem_tracker_.reset();
thread_mem_tracker_mgr->detach_limiter_tracker();
thread_mem_tracker_mgr->disable_wait_gc();
}

bool is_attach_task() { return resource_ctx_ != nullptr; }
Expand Down Expand Up @@ -233,8 +240,8 @@ class ThreadContext {
int thread_local_handle_count = 0;

private:
friend class SwitchResourceContext;
std::shared_ptr<ResourceContext> resource_ctx_;
std::shared_ptr<doris::MemTrackerLimiter> old_mem_tracker_ {nullptr};
};

class ThreadLocalHandle {
Expand Down Expand Up @@ -357,16 +364,25 @@ class AttachTask {
~AttachTask();
};

class SwitchResourceContext {
public:
explicit SwitchResourceContext(const std::shared_ptr<ResourceContext>& rc);

~SwitchResourceContext();

private:
std::shared_ptr<ResourceContext> old_resource_ctx_ {nullptr};
};

class SwitchThreadMemTrackerLimiter {
public:
explicit SwitchThreadMemTrackerLimiter(
const std::shared_ptr<doris::MemTrackerLimiter>& mem_tracker);
explicit SwitchThreadMemTrackerLimiter(ResourceContext* rc);

~SwitchThreadMemTrackerLimiter();

private:
std::shared_ptr<doris::MemTrackerLimiter> _old_mem_tracker {nullptr};
bool is_switched_ {false};
};

class AddThreadMemTrackerConsumer {
Expand Down
10 changes: 5 additions & 5 deletions be/test/runtime/memory/thread_mem_tracker_mgr_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ TEST_F(ThreadMemTrackerMgrTest, NestedSwitchMemTracker) {
EXPECT_EQ(t1->consumption(), size1 + size2 + size1); // not changed, now consume t2
EXPECT_EQ(t2->consumption(), size1 + size2);

thread_context->thread_mem_tracker_mgr->detach_limiter_tracker(t1); // detach
thread_context->thread_mem_tracker_mgr->detach_limiter_tracker(); // detach
EXPECT_EQ(t2->consumption(),
size1 + size2 + size1); // detach automatic call flush_untracked_mem.

Expand All @@ -149,7 +149,7 @@ TEST_F(ThreadMemTrackerMgrTest, NestedSwitchMemTracker) {
thread_context->thread_mem_tracker_mgr->consume(-size1);
EXPECT_EQ(t3->consumption(), size1);

thread_context->thread_mem_tracker_mgr->detach_limiter_tracker(t2); // detach
thread_context->thread_mem_tracker_mgr->detach_limiter_tracker(); // detach
EXPECT_EQ(t1->consumption(), size1 + size2 + size1 + size2 + size2);
EXPECT_EQ(t2->consumption(), size1 + size2);
EXPECT_EQ(t3->consumption(), 0);
Expand All @@ -160,7 +160,7 @@ TEST_F(ThreadMemTrackerMgrTest, NestedSwitchMemTracker) {
EXPECT_EQ(t1->consumption(), size1 + size2 + size1 + size2 + size2);
EXPECT_EQ(t2->consumption(), 0);

thread_context->thread_mem_tracker_mgr->detach_limiter_tracker(t1); // detach
thread_context->thread_mem_tracker_mgr->detach_limiter_tracker(); // detach
EXPECT_EQ(t1->consumption(), size1 + size2 + size1 + size2 + size2);
EXPECT_EQ(t2->consumption(), -size1);

Expand Down Expand Up @@ -439,14 +439,14 @@ TEST_F(ThreadMemTrackerMgrTest, NestedSwitchMemTrackerReserveMemory) {
EXPECT_EQ(doris::GlobalMemoryArbitrator::process_reserved_memory(),
size3 - size2 + size3 + size2);

thread_context->thread_mem_tracker_mgr->detach_limiter_tracker(t2); // detach
thread_context->thread_mem_tracker_mgr->detach_limiter_tracker(); // detach
EXPECT_EQ(t1->consumption(), size3);
EXPECT_EQ(t2->consumption(), size3 + size2);
EXPECT_EQ(t3->consumption(), -size1 - size2); // size3 - _reserved_mem
// size3 - size2 + size3 + size2 - (_reserved_mem + _untracked_mem)
EXPECT_EQ(doris::GlobalMemoryArbitrator::process_reserved_memory(), size3 - size2);

thread_context->thread_mem_tracker_mgr->detach_limiter_tracker(t1); // detach
thread_context->thread_mem_tracker_mgr->detach_limiter_tracker(); // detach
EXPECT_EQ(t1->consumption(), size3);
// not changed, reserved memory used done.
EXPECT_EQ(t2->consumption(), size3 + size2);
Expand Down
Loading