Skip to content

Commit bc6f87d

Browse files
authored
Fix task cancellation leak (pointfreeco#1418)
* Fix leak in withTaskCancellation(id:) * Add failing test that is fixed * wip * clean up * update tests * wip * wip * wip * wip * clean up * fix 13.4 * try to fix test
1 parent a27a7a5 commit bc6f87d

File tree

4 files changed

+41
-28
lines changed

4 files changed

+41
-28
lines changed

Sources/ComposableArchitecture/Effects/Cancellation.swift

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -196,27 +196,24 @@ public func withTaskCancellation<T: Sendable>(
196196
cancelInFlight: Bool = false,
197197
operation: @Sendable @escaping () async throws -> T
198198
) async rethrows -> T {
199-
let task = { () -> Task<T, Error> in
200-
cancellablesLock.lock()
201-
let id = CancelToken(id: id)
199+
let id = CancelToken(id: id)
200+
let (cancellable, task) = cancellablesLock.sync { () -> (AnyCancellable, Task<T, Error>) in
202201
if cancelInFlight {
203202
cancellationCancellables[id]?.forEach { $0.cancel() }
204203
}
205204
let task = Task { try await operation() }
206-
var cancellable: AnyCancellable!
207-
cancellable = AnyCancellable {
208-
task.cancel()
209-
cancellablesLock.sync {
210-
cancellationCancellables[id]?.remove(cancellable)
211-
if cancellationCancellables[id]?.isEmpty == .some(true) {
212-
cancellationCancellables[id] = nil
213-
}
205+
let cancellable = AnyCancellable { task.cancel() }
206+
cancellationCancellables[id, default: []].insert(cancellable)
207+
return (cancellable, task)
208+
}
209+
defer {
210+
cancellablesLock.sync {
211+
cancellationCancellables[id]?.remove(cancellable)
212+
if cancellationCancellables[id]?.isEmpty == .some(true) {
213+
cancellationCancellables[id] = nil
214214
}
215215
}
216-
cancellationCancellables[id, default: []].insert(cancellable)
217-
cancellablesLock.unlock()
218-
return task
219-
}()
216+
}
220217
do {
221218
return try await task.cancellableValue
222219
} catch {
@@ -252,10 +249,8 @@ extension Task where Success == Never, Failure == Never {
252249
/// Cancel any currently in-flight operation with the given identifier.
253250
///
254251
/// - Parameter id: An identifier.
255-
public static func cancel<ID: Hashable & Sendable>(id: ID) async {
256-
await MainActor.run {
257-
cancellablesLock.sync { cancellationCancellables[.init(id: id)]?.forEach { $0.cancel() } }
258-
}
252+
public static func cancel<ID: Hashable & Sendable>(id: ID) {
253+
cancellablesLock.sync { cancellationCancellables[.init(id: id)]?.forEach { $0.cancel() } }
259254
}
260255

261256
/// Cancel any currently in-flight operation with the given identifier.
@@ -264,8 +259,8 @@ extension Task where Success == Never, Failure == Never {
264259
/// identifier.
265260
///
266261
/// - Parameter id: A unique type identifying the operation.
267-
public static func cancel(id: Any.Type) async {
268-
await self.cancel(id: ObjectIdentifier(id))
262+
public static func cancel(id: Any.Type) {
263+
self.cancel(id: ObjectIdentifier(id))
269264
}
270265
}
271266

Tests/ComposableArchitectureTests/EffectRunTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ final class EffectRunTests: XCTestCase {
8080
switch action {
8181
case .tapped:
8282
return .run { send in
83-
await Task.cancel(id: CancelID.self)
83+
Task.cancel(id: CancelID.self)
8484
try Task.checkCancellation()
8585
await send(.response)
8686
}
@@ -101,7 +101,7 @@ final class EffectRunTests: XCTestCase {
101101
switch action {
102102
case .tapped:
103103
return .run { send in
104-
await Task.cancel(id: CancelID.self)
104+
Task.cancel(id: CancelID.self)
105105
try Task.checkCancellation()
106106
await send(.responseA)
107107
} catch: { @Sendable _, send in // NB: Explicit '@Sendable' required in 5.5.2

Tests/ComposableArchitectureTests/EffectTaskTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ final class EffectTaskTests: XCTestCase {
8080
switch action {
8181
case .tapped:
8282
return .task {
83-
await Task.cancel(id: CancelID.self)
83+
Task.cancel(id: CancelID.self)
8484
try Task.checkCancellation()
8585
return .response
8686
}
@@ -101,7 +101,7 @@ final class EffectTaskTests: XCTestCase {
101101
switch action {
102102
case .tapped:
103103
return .task {
104-
await Task.cancel(id: CancelID.self)
104+
Task.cancel(id: CancelID.self)
105105
try Task.checkCancellation()
106106
return .responseA
107107
} catch: { @Sendable _ in // NB: Explicit '@Sendable' required in 5.5.2

Tests/ComposableArchitectureTests/TaskCancellationTests.swift

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ import XCTest
55

66
final class TaskCancellationTests: XCTestCase {
77
func testCancellation() async throws {
8-
cancellationCancellables.removeAll()
8+
cancellablesLock.sync {
9+
cancellationCancellables.removeAll()
10+
}
911
enum ID {}
1012
let (stream, continuation) = AsyncStream<Void>.streamWithContinuation()
1113
let task = Task {
@@ -16,12 +18,28 @@ final class TaskCancellationTests: XCTestCase {
1618
}
1719
}
1820
await stream.first(where: { true })
19-
await Task.cancel(id: ID.self)
20-
XCTAssertEqual(cancellationCancellables, [:])
21+
Task.cancel(id: ID.self)
22+
await Task.megaYield(count: 20)
23+
XCTAssertEqual(cancellablesLock.sync { cancellationCancellables }, [:])
2124
do {
2225
try await task.cancellableValue
2326
XCTFail()
2427
} catch {
2528
}
2629
}
30+
31+
func testWithTaskCancellationCleansUpTask() async throws {
32+
let task = Task {
33+
try await withTaskCancellation(id: 0) {
34+
try await Task.sleep(nanoseconds: NSEC_PER_SEC * 1000)
35+
}
36+
}
37+
38+
try await Task.sleep(nanoseconds: NSEC_PER_SEC / 3)
39+
XCTAssertEqual(cancellationCancellables.count, 1)
40+
41+
task.cancel()
42+
try await Task.sleep(nanoseconds: NSEC_PER_SEC / 3)
43+
XCTAssertEqual(cancellationCancellables.count, 0)
44+
}
2745
}

0 commit comments

Comments
 (0)