Skip to content

Commit 1a0ffce

Browse files
committed
Add subscription removal on cancellation
1 parent 192d503 commit 1a0ffce

File tree

4 files changed

+64
-7
lines changed

4 files changed

+64
-7
lines changed

Sources/Valkey/Connection/ValkeyChannelHandler.swift

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,30 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
100100
}
101101
}
102102

103+
func addSubscription(
104+
continuation: ValkeySubscriptionAsyncStream.Continuation,
105+
filter: ValkeySubscriptionFilter
106+
) -> Int {
107+
self.eventLoop.assertInEventLoop()
108+
let id = ValkeySubscriptions.getSubscriptionID()
109+
let loopBoundHandler = NIOLoopBound(self, eventLoop: self.eventLoop)
110+
continuation.onTermination = { [eventLoop] termination in
111+
switch termination {
112+
case .cancelled:
113+
eventLoop.execute {
114+
loopBoundHandler.value.subscriptions.removeSubscription(id: id)
115+
}
116+
case .finished:
117+
break
118+
119+
@unknown default:
120+
break
121+
}
122+
}
123+
self.subscriptions.addSubscription(id: id, continuation: continuation, filter: filter)
124+
return id
125+
}
126+
103127
@usableFromInline
104128
func handlerAdded(context: ChannelHandlerContext) {
105129
self.context = context

Sources/Valkey/Subscriptions/ValkeyConnection+subscribe.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ extension ValkeyConnection {
4848

4949
let channels = Set(channels)
5050
if self.channel.eventLoop.inEventLoop {
51-
let subscriptionID = self.channelHandler.value.subscriptions.addSubscription(
51+
let subscriptionID = self.channelHandler.value.addSubscription(
5252
continuation: streamContinuation,
5353
filter: .channels(channels)
5454
)
@@ -60,7 +60,7 @@ extension ValkeyConnection {
6060
.get()
6161
} else {
6262
_ = try await self.channel.eventLoop.flatSubmit {
63-
let subscriptionID = self.channelHandler.value.subscriptions.addSubscription(
63+
let subscriptionID = self.channelHandler.value.addSubscription(
6464
continuation: streamContinuation,
6565
filter: .channels(channels)
6666
)

Sources/Valkey/Subscriptions/ValkeySubscriptions.swift

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,18 @@ struct ValkeySubscriptions {
6666
self.subscriptions = []
6767
}
6868

69-
mutating func addSubscription(continuation: ValkeySubscriptionAsyncStream.Continuation, filter: ValkeySubscriptionFilter) -> Int {
70-
let id = Self.globalSubscriptionId.wrappingAdd(1, ordering: .relaxed)
69+
static func getSubscriptionID() -> Int {
70+
Self.globalSubscriptionId.wrappingAdd(1, ordering: .relaxed).newValue
71+
}
72+
73+
mutating func addSubscription(id: Int, continuation: ValkeySubscriptionAsyncStream.Continuation, filter: ValkeySubscriptionFilter) {
7174
subscriptions.append(
7275
.init(
73-
id: id.newValue,
76+
id: id,
7477
subscription: .init(continuation: continuation, filter: filter, logger: self.logger),
7578
stateMachine: .init(filter: filter)
7679
)
7780
)
78-
return id.newValue
7981
}
8082

8183
mutating func removeSubscription(id: Int) {

Tests/ValkeyTests/ValkeySubscriptionTests.swift

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ import Logging
1616
import NIOCore
1717
import NIOEmbedded
1818
import Testing
19-
import Valkey
19+
20+
@testable import Valkey
2021

2122
@Suite
2223
struct SubscriptionTests {
@@ -225,4 +226,34 @@ struct SubscriptionTests {
225226
try await group.waitForAll()
226227
}
227228
}
229+
230+
@Test
231+
func testRemoveSubscriptionOnCancellation() async throws {
232+
let channel = NIOAsyncTestingChannel()
233+
let logger = Logger(label: "test")
234+
let connection = try await ValkeyConnection.setupChannel(channel, configuration: .init(), logger: logger)
235+
236+
try await withThrowingTaskGroup(of: Void.self) { group in
237+
group.addTask {
238+
let subscription = try await connection.subscribe(to: "test")
239+
for try await message in subscription {
240+
#expect(message == .init(channel: "test", message: "Testing!"))
241+
}
242+
}
243+
group.addTask {
244+
let outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
245+
// expect SUBSCRIBE command
246+
#expect(String(buffer: outbound) == "*2\r\n$9\r\nSUBSCRIBE\r\n$4\r\ntest\r\n")
247+
// push SUBSCRIBE channel
248+
try await channel.writeInbound(ByteBuffer(string: ">3\r\n$9\r\nsubscribe\r\n$4\r\ntest\r\n:1\r\n"))
249+
// push SUBSCRIBE message
250+
try await channel.writeInbound(ByteBuffer(string: ">3\r\n$7\r\nmessage\r\n$4\r\ntest\r\n$8\r\nTesting!\r\n"))
251+
}
252+
try await group.next()
253+
group.cancelAll()
254+
}
255+
try await connection.channel.eventLoop.submit {
256+
#expect(connection.channelHandler.value.subscriptions.subscriptions.count == 0)
257+
}.get()
258+
}
228259
}

0 commit comments

Comments
 (0)