Skip to content

Commit 9d4dfb5

Browse files
authored
[coro_io] add support for cancellation (#887)
* [coro_io] add support for cancellation * fix * fix ssl compile error * remove useless dispatch * fix timer user-after-free * f * f * 1 * fix mem order * fix * fix mem order * fix * fix format
1 parent 56fbbba commit 9d4dfb5

File tree

12 files changed

+615
-259
lines changed

12 files changed

+615
-259
lines changed

include/ylt/coro_io/coro_io.hpp

+306-173
Large diffs are not rendered by default.

include/ylt/coro_io/io_context_pool.hpp

+31-5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <vector>
3232

3333
#include "asio/dispatch.hpp"
34+
#include "async_simple/Signal.h"
3435
#ifdef __linux__
3536
#include <pthread.h>
3637
#include <sched.h>
@@ -104,11 +105,36 @@ class ExecutorWrapper : public async_simple::Executor {
104105
}
105106
void schedule(Func func, Duration dur, uint64_t hint,
106107
async_simple::Slot *slot = nullptr) override {
107-
auto timer = std::make_unique<asio::steady_timer>(executor_, dur);
108-
auto tm = timer.get();
109-
tm->async_wait([fn = std::move(func), timer = std::move(timer)](auto ec) {
110-
fn();
111-
});
108+
auto timer =
109+
std::make_shared<std::pair<asio::steady_timer, std::atomic<bool>>>(
110+
asio::steady_timer{executor_, dur}, false);
111+
if (!slot) {
112+
timer->first.async_wait([fn = std::move(func), timer](const auto &ec) {
113+
fn();
114+
});
115+
}
116+
else {
117+
if (!async_simple::signalHelper{async_simple::SignalType::Terminate}
118+
.tryEmplace(
119+
slot, [timer](auto signalType, auto *signal) mutable {
120+
if (bool expected = false;
121+
!timer->second.compare_exchange_strong(
122+
expected, true, std::memory_order_acq_rel)) {
123+
timer->first.cancel();
124+
}
125+
})) {
126+
asio::dispatch(timer->first.get_executor(), func);
127+
}
128+
else {
129+
timer->first.async_wait([fn = std::move(func), timer](const auto &ec) {
130+
fn();
131+
});
132+
if (bool expected = false; !timer->second.compare_exchange_strong(
133+
expected, true, std::memory_order_acq_rel)) {
134+
timer->first.cancel();
135+
}
136+
}
137+
}
112138
}
113139
};
114140

include/ylt/metric/summary_impl.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ class summary_impl {
135135
if (piece) {
136136
if constexpr (inc_order) {
137137
for (int j = 0; j < piece->size(); ++j) {
138+
// tsan check data race here is expected. stat dont need to be very
139+
// strict. we allow old value.
138140
auto value = (*piece)[j].load(std::memory_order_relaxed);
139141
if (value) {
140142
result.emplace_back(get_ordered_index(i * piece_size + j), value);

include/ylt/thirdparty/async_simple/Signal.h

+6-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#define ASYNC_SIMPLE_SIGNAL_H
1818

1919
#ifndef ASYNC_SIMPLE_USE_MODULES
20-
2120
#include <assert.h>
2221
#include <any>
2322
#include <atomic>
@@ -205,14 +204,16 @@ class Slot {
205204
"we dont allow emplace an empty signal handler");
206205
logicAssert(std::popcount(static_cast<uint64_t>(type)) == 1,
207206
"It's not allow to emplace for multiple signals");
208-
// trigger-once signal has already been triggered
207+
auto handler = std::make_unique<detail::SignalSlotSharedState::Handler>(
208+
std::forward<Args>(args)...);
209+
auto oldHandlerPtr = loadHandler<true>(type);
210+
// check trigger-once signal has already been triggered
211+
// if signal has already been triggered, return false
209212
if (!detail::SignalSlotSharedState::isMultiTriggerSignal(type) &&
210213
(signal()->state() & type)) {
211214
return false;
212215
}
213-
auto handler = std::make_unique<detail::SignalSlotSharedState::Handler>(
214-
std::forward<Args>(args)...);
215-
auto oldHandlerPtr = loadHandler<true>(type);
216+
// if signal triggered later, we will found it by cas failed.
216217
auto oldHandler = oldHandlerPtr->load(std::memory_order_acquire);
217218
if (oldHandler ==
218219
&detail::SignalSlotSharedState::HandlerManager::emittedTag) {

include/ylt/thirdparty/async_simple/coro/Collect.h

+41-46
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ struct CollectAnyAwaiter {
166166
_slot, [c = continuation, e = event, size = input.size()](
167167
SignalType type, Signal*) mutable {
168168
auto count = e->downCount();
169-
if (count > size + 1) {
170-
c.resume();
169+
if (count == size + 1) {
170+
c.resume();
171171
}
172172
})) { // has canceled
173173
return false;
@@ -186,14 +186,14 @@ struct CollectAnyAwaiter {
186186
assert(e != nullptr);
187187
auto count = e->downCount();
188188
// n+1: n coro + 1 cancel handler
189-
if (count > size + 1) {
190-
_result = std::make_unique<ResultType>();
191-
_result->_idx = i;
192-
_result->_value = std::move(result);
193-
if (auto ptr = local->getSlot(); ptr) {
194-
ptr->signal()->emit(_SignalType);
195-
}
196-
c.resume();
189+
if (count == size + 1) {
190+
_result = std::make_unique<ResultType>();
191+
_result->_idx = i;
192+
_result->_value = std::move(result);
193+
if (auto ptr = local->getSlot(); ptr) {
194+
ptr->signal()->emit(_SignalType);
195+
}
196+
c.resume();
197197
}
198198
});
199199
} // end for
@@ -268,8 +268,8 @@ struct CollectAnyVariadicAwaiter {
268268
_slot, [c = continuation, e = event](SignalType type,
269269
Signal*) mutable {
270270
auto count = e->downCount();
271-
if (count > std::tuple_size<InputType>() + 1) {
272-
c.resume();
271+
if (count == std::tuple_size<InputType>() + 1) {
272+
c.resume();
273273
}
274274
})) { // has canceled
275275
return false;
@@ -290,13 +290,13 @@ struct CollectAnyVariadicAwaiter {
290290
res) mutable {
291291
auto count = e->downCount();
292292
// n+1: n coro + 1 cancel handler
293-
if (count > std::tuple_size<InputType>() + 1) {
294-
_result = std::make_unique<ResultType>(
295-
std::in_place_index_t<index>(), std::move(res));
296-
if (auto ptr = local->getSlot(); ptr) {
297-
ptr->signal()->emit(_SignalType);
298-
}
299-
c.resume();
293+
if (count == std::tuple_size<InputType>() + 1) {
294+
_result = std::make_unique<ResultType>(
295+
std::in_place_index_t<index>(), std::move(res));
296+
if (auto ptr = local->getSlot(); ptr) {
297+
ptr->signal()->emit(_SignalType);
298+
}
299+
c.resume();
300300
}
301301
});
302302
}(),
@@ -388,15 +388,19 @@ struct CollectAllAwaiter {
388388
_slot->chainedSignal(_signal.get());
389389

390390
auto executor = promise_type._executor;
391-
for (size_t i = 0; i < _input.size(); ++i) {
392-
auto& exec = _input[i]._coro.promise()._executor;
393-
if (exec == nullptr) {
394-
exec = executor;
395-
}
396-
std::unique_ptr<LazyLocalBase> local;
397-
local = std::make_unique<LazyLocalBase>(_signal.get());
398-
_input[i]._coro.promise()._lazy_local = local.get();
399-
auto&& func = [this, i, local = std::move(local)]() mutable {
391+
392+
_event.setAwaitingCoro(continuation);
393+
auto size = _input.size();
394+
for (size_t i = 0; i < size; ++i) {
395+
auto& exec = _input[i]._coro.promise()._executor;
396+
if (exec == nullptr) {
397+
exec = executor;
398+
}
399+
std::unique_ptr<LazyLocalBase> local;
400+
local = std::make_unique<LazyLocalBase>(_signal.get());
401+
_input[i]._coro.promise()._lazy_local = local.get();
402+
auto&& func =
403+
[this, i, local = std::move(local)]() mutable {
400404
_input[i].start([this, i, local = std::move(local)](
401405
Try<ValueType>&& result) {
402406
_output[i] = std::move(result);
@@ -412,20 +416,15 @@ struct CollectAllAwaiter {
412416
awaitingCoro.resume();
413417
}
414418
});
415-
};
416-
if (Para == true && _input.size() > 1) {
417-
if (exec != nullptr)
418-
AS_LIKELY {
419-
exec->schedule_move_only(std::move(func));
420-
continue;
421-
}
422-
}
423-
func();
424-
}
425-
_event.setAwaitingCoro(continuation);
426-
auto awaitingCoro = _event.down();
427-
if (awaitingCoro) {
428-
awaitingCoro.resume();
419+
};
420+
if (Para == true && _input.size() > 1) {
421+
if (exec != nullptr)
422+
AS_LIKELY {
423+
exec->schedule_move_only(std::move(func));
424+
continue;
425+
}
426+
}
427+
func();
429428
}
430429
}
431430
inline auto await_resume() { return std::move(_output); }
@@ -602,10 +601,6 @@ struct CollectAllVariadicAwaiter {
602601
}
603602
}(std::get<index>(_inputs), std::get<index>(_results)),
604603
...);
605-
606-
if (auto awaitingCoro = _event.down(); awaitingCoro) {
607-
awaitingCoro.resume();
608-
}
609604
}
610605

611606
void await_suspend(std::coroutine_handle<> continuation) {

include/ylt/thirdparty/async_simple/coro/CountEvent.h

+8-8
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@ namespace detail {
3434
// The last 'down' will resume the awaiting coroutine on this event.
3535
class CountEvent {
3636
public:
37-
CountEvent(size_t count) : _count(count + 1) {}
38-
CountEvent(const CountEvent&) = delete;
39-
CountEvent(CountEvent&& other)
40-
: _count(other._count.exchange(0, std::memory_order_relaxed)),
41-
_awaitingCoro(std::exchange(other._awaitingCoro, nullptr)) {}
37+
CountEvent(size_t count) : _count(count) {}
38+
CountEvent(const CountEvent&) = delete;
39+
CountEvent(CountEvent&& other)
40+
: _count(other._count.exchange(0, std::memory_order_relaxed)),
41+
_awaitingCoro(std::exchange(other._awaitingCoro, nullptr)) {}
4242

43-
[[nodiscard]] CoroHandle<> down(size_t n = 1) {
44-
std::size_t oldCount;
45-
return down(oldCount, n);
43+
[[nodiscard]] CoroHandle<> down(size_t n = 1) {
44+
std::size_t oldCount;
45+
return down(oldCount, n);
4646
}
4747
[[nodiscard]] CoroHandle<> down(size_t& oldCount, std::size_t n) {
4848
// read acquire and write release, _awaitingCoro store can not be

src/coro_io/tests/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_executable(coro_io_test
55
test_client_pool.cpp
66
test_rate_limiter.cpp
77
test_coro_channel.cpp
8+
test_cancel.cpp
89
main.cpp
910
)
1011
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_SYSTEM_NAME MATCHES "Windows") # mingw-w64

0 commit comments

Comments
 (0)