Skip to content

Commit 23faafd

Browse files
committed
Add support for psubscribe
1 parent e1c66e1 commit 23faafd

File tree

8 files changed

+278
-47
lines changed

8 files changed

+278
-47
lines changed

Sources/Valkey/Connection/ValkeyChannelHandler.swift

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,17 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
176176
promise.fail(ValkeyClientError(.commandError, message: token.errorString.map { String(buffer: $0) }))
177177

178178
case .push:
179-
if (try? self.subscriptions.notify(token)) == true {
180-
guard let promise = commands.popFirst() else {
181-
preconditionFailure("Unexpected response")
179+
// If subscription notify throws an error then assume something has gone wrong
180+
// and close the channel with the error
181+
do {
182+
if try self.subscriptions.notify(token) == true {
183+
guard let promise = commands.popFirst() else {
184+
preconditionFailure("Unexpected response")
185+
}
186+
promise.succeed(Self.simpleOk)
182187
}
183-
promise.succeed(Self.simpleOk)
188+
} catch {
189+
self.handleError(context: context, error: error)
184190
}
185191

186192
default:

Sources/Valkey/Subscriptions/PushToken.swift

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ struct PushToken: RESPTokenRepresentable {
1717
case subscribe(subscriptionCount: Int)
1818
case unsubscribe(subscriptionCount: Int)
1919
case message(String)
20+
case psubscribe(subscriptionCount: Int)
21+
case punsubscribe(subscriptionCount: Int)
22+
case pmessage(channel: String, message: String)
2023
}
21-
let channel: String
24+
let value: String
2225
let type: TokenType
2326

2427
init(from token: RESPToken) throws {
@@ -28,20 +31,37 @@ struct PushToken: RESPTokenRepresentable {
2831
throw RESPParsingError(code: .invalidData, buffer: token.base)
2932
}
3033
var arrayIterator = respArray.makeIterator()
31-
switch try String(from: arrayIterator.next()!) {
32-
case "subscribe", "psubscribe":
33-
self.channel = try String(from: arrayIterator.next()!)
34+
let notification = try String(from: arrayIterator.next()!)
35+
switch notification {
36+
case "subscribe":
37+
self.value = try String(from: arrayIterator.next()!)
3438
self.type = try TokenType.subscribe(subscriptionCount: Int(from: arrayIterator.next()!))
3539

36-
case "unsubscribe", "punsubscribe":
37-
self.channel = try String(from: arrayIterator.next()!)
40+
case "unsubscribe":
41+
self.value = try String(from: arrayIterator.next()!)
3842
self.type = try TokenType.unsubscribe(subscriptionCount: Int(from: arrayIterator.next()!))
3943

4044
case "message":
41-
self.channel = try String(from: arrayIterator.next()!)
45+
self.value = try String(from: arrayIterator.next()!)
4246
self.type = try TokenType.message(String(from: arrayIterator.next()!))
4347

48+
case "psubscribe":
49+
self.value = try String(from: arrayIterator.next()!)
50+
self.type = try TokenType.psubscribe(subscriptionCount: Int(from: arrayIterator.next()!))
51+
52+
case "punsubscribe":
53+
self.value = try String(from: arrayIterator.next()!)
54+
self.type = try TokenType.punsubscribe(subscriptionCount: Int(from: arrayIterator.next()!))
55+
56+
case "pmessage":
57+
self.value = try String(from: arrayIterator.next()!)
58+
self.type = try TokenType.pmessage(
59+
channel: String(from: arrayIterator.next()!),
60+
message: String(from: arrayIterator.next()!)
61+
)
62+
4463
default:
64+
print("Unrecognised push notification \(notification)")
4565
throw RESPParsingError(code: .invalidData, buffer: token.base)
4666
}
4767
default:

Sources/Valkey/Subscriptions/ValkeyConnection+subscribe.swift

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,32 +25,30 @@ public struct ValkeySubscriptionMessage: Sendable, Equatable {
2525
}
2626

2727
extension ValkeyConnection {
28-
// Function used internally by subscribe
29-
@inlinable
30-
func _send<Command: RESPCommand>(command: Command) -> EventLoopFuture<RESPToken> {
31-
self.channel.eventLoop.assertInEventLoop()
32-
var encoder = RESPCommandEncoder()
33-
command.encode(into: &encoder)
34-
let buffer = encoder.buffer
35-
36-
let promise = channel.eventLoop.makePromise(of: RESPToken.self)
37-
self.channelHandler.value.write(request: ValkeyRequest.single(buffer: buffer, promise: .nio(promise)))
38-
return promise.futureResult
39-
}
40-
4128
public func subscribe(to channels: String...) async throws -> some AsyncSequence<ValkeySubscriptionMessage, Error> {
4229
try await self.subscribe(to: channels)
4330
}
4431

4532
public func subscribe(to channels: [String]) async throws -> some AsyncSequence<ValkeySubscriptionMessage, Error> {
46-
let (stream, streamContinuation) = ValkeySubscriptionAsyncStream.makeStream()
4733
let command = SUBSCRIBE(channel: channels)
34+
return try await subscribe(command: command, filter: .channels(Set(channels)))
35+
}
36+
37+
public func psubscribe(to pattern: String...) async throws -> some AsyncSequence<ValkeySubscriptionMessage, Error> {
38+
try await self.psubscribe(to: pattern)
39+
}
4840

49-
let channels = Set(channels)
41+
public func psubscribe(to pattern: [String]) async throws -> some AsyncSequence<ValkeySubscriptionMessage, Error> {
42+
let command = PSUBSCRIBE(pattern: pattern)
43+
return try await subscribe(command: command, filter: .patterns(Set(pattern)))
44+
}
45+
46+
func subscribe(command: some RESPCommand, filter: ValkeySubscriptionFilter) async throws -> some AsyncSequence<ValkeySubscriptionMessage, Error> {
47+
let (stream, streamContinuation) = ValkeySubscriptionAsyncStream.makeStream()
5048
if self.channel.eventLoop.inEventLoop {
5149
let subscriptionID = self.channelHandler.value.addSubscription(
5250
continuation: streamContinuation,
53-
filter: .channels(channels)
51+
filter: filter
5452
)
5553
_ = try await self._send(command: command)
5654
.flatMapErrorThrowing { error in
@@ -62,7 +60,7 @@ extension ValkeyConnection {
6260
_ = try await self.channel.eventLoop.flatSubmit {
6361
let subscriptionID = self.channelHandler.value.addSubscription(
6462
continuation: streamContinuation,
65-
filter: .channels(channels)
63+
filter: filter
6664
)
6765
return self._send(command: command)
6866
.flatMapErrorThrowing { error in
@@ -73,4 +71,17 @@ extension ValkeyConnection {
7371
}
7472
return stream
7573
}
74+
75+
// Function used internally by subscribe
76+
@inlinable
77+
func _send<Command: RESPCommand>(command: Command) -> EventLoopFuture<RESPToken> {
78+
self.channel.eventLoop.assertInEventLoop()
79+
var encoder = RESPCommandEncoder()
80+
command.encode(into: &encoder)
81+
let buffer = encoder.buffer
82+
83+
let promise = channel.eventLoop.makePromise(of: RESPToken.self)
84+
self.channelHandler.value.write(request: ValkeyRequest.single(buffer: buffer, promise: .nio(promise)))
85+
return promise.futureResult
86+
}
7687
}

Sources/Valkey/Subscriptions/ValkeySubscription+stateMachine.swift

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,27 @@ struct ValkeySubscriptionStateMachine {
2828
}
2929

3030
enum ReceivedTokenAction {
31-
case sendMessage(String)
31+
case sendMessage(ValkeySubscriptionMessage)
3232
case deleteSubscriptionAndReturnOk
3333
case returnOk
3434
case doNothing
3535
case fail(Error)
3636
}
3737
mutating func receivedToken(_ token: PushToken) -> ReceivedTokenAction {
38-
guard filter.filter(token.channel) else { return .doNothing }
38+
guard filter.filter(token.value) else { return .doNothing }
3939
switch token.type {
4040
case .subscribe:
41-
return receivedSubscribe(channel: token.channel)
41+
return receivedSubscribe(channel: token.value)
4242
case .unsubscribe:
43-
return receivedUnsubscribe(channel: token.channel)
43+
return receivedUnsubscribe(channel: token.value)
4444
case .message(let message):
45-
return receivedMessage(message)
45+
return receivedMessage(channel: token.value, message: message)
46+
case .psubscribe:
47+
return receivedPatternSubscribe(channel: token.value)
48+
case .punsubscribe:
49+
return receivedPatternUnsubscribe(channel: token.value)
50+
case .pmessage(let channel, let message):
51+
return receivedMessage(channel: channel, message: message)
4652
}
4753
}
4854

@@ -94,14 +100,62 @@ struct ValkeySubscriptionStateMachine {
94100
}
95101
}
96102

97-
mutating func receivedMessage(_ message: String) -> ReceivedTokenAction {
103+
mutating func receivedMessage(channel: String, message: String) -> ReceivedTokenAction {
98104
switch state {
99105
case .initialized, .starting:
100106
let error = ValkeyClientError(.subscriptionError, message: "Received message before in listening state")
101107
self.state = .failed(error)
102108
return .fail(error)
103109
case .listening:
104-
return .sendMessage(message)
110+
return .sendMessage(.init(channel: channel, message: message))
111+
case .failed(let error):
112+
return .fail(error)
113+
}
114+
}
115+
116+
mutating func receivedPatternSubscribe(channel: String) -> ReceivedTokenAction {
117+
switch state {
118+
case .initialized:
119+
let filter = ValkeySubscriptionFilter.patterns([channel])
120+
if self.filter == filter {
121+
self.state = .listening
122+
return .returnOk
123+
} else {
124+
self.state = .starting(filter: filter)
125+
return .doNothing
126+
}
127+
128+
case .starting(let filter):
129+
let filter = filter.addingPattern(channel)
130+
if self.filter == filter {
131+
self.state = .listening
132+
return .returnOk
133+
} else {
134+
self.state = .starting(filter: filter)
135+
return .doNothing
136+
}
137+
138+
case .listening:
139+
return .doNothing
140+
141+
case .failed(let error):
142+
return .fail(error)
143+
}
144+
}
145+
146+
mutating func receivedPatternUnsubscribe(channel: String) -> ReceivedTokenAction {
147+
switch state {
148+
case .initialized, .starting:
149+
let error = ValkeyClientError(.subscriptionError, message: "Received unsubscribe before in listening state")
150+
self.state = .failed(error)
151+
return .fail(error)
152+
case .listening:
153+
self.filter = self.filter.removingPattern(channel)
154+
if self.filter.isEmpty {
155+
return .deleteSubscriptionAndReturnOk
156+
} else {
157+
return .returnOk
158+
}
105159
case .failed(let error):
106160
return .fail(error)
107161
}

Sources/Valkey/Subscriptions/ValkeySubscriptionFilter.swift

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,17 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414

15-
enum ValkeySubscriptionFilter: Equatable {
15+
@usableFromInline
16+
enum ValkeySubscriptionFilter: Equatable, Sendable {
1617
case channels(Set<String>)
1718
case patterns(Set<String>)
1819

1920
func filter(_ value: String) -> Bool {
2021
switch self {
2122
case .channels(let channels):
2223
channels.contains(value)
23-
case .patterns(_):
24-
preconditionFailure("We don't support patterns yet")
24+
case .patterns(let patterns):
25+
patterns.contains(value)
2526
}
2627
}
2728

@@ -30,8 +31,8 @@ enum ValkeySubscriptionFilter: Equatable {
3031
case .channels(var channels):
3132
channels.insert(channel)
3233
return .channels(channels)
33-
case .patterns(_):
34-
preconditionFailure("We don't support patterns yet")
34+
case .patterns:
35+
preconditionFailure("Cannot add channel to pattern filter")
3536
}
3637
}
3738

@@ -40,17 +41,37 @@ enum ValkeySubscriptionFilter: Equatable {
4041
case .channels(var channels):
4142
channels.remove(channel)
4243
return .channels(channels)
43-
case .patterns(_):
44-
preconditionFailure("We don't support patterns yet")
44+
case .patterns:
45+
preconditionFailure("Cannot remove channel from pattern filter")
46+
}
47+
}
48+
49+
func addingPattern(_ channel: String) -> Self {
50+
switch self {
51+
case .channels:
52+
preconditionFailure("Cannot add pattern to channel filter")
53+
case .patterns(var patterns):
54+
patterns.insert(channel)
55+
return .patterns(patterns)
56+
}
57+
}
58+
59+
func removingPattern(_ channel: String) -> Self {
60+
switch self {
61+
case .channels:
62+
preconditionFailure("Cannot remove pattern from channel filter")
63+
case .patterns(var patterns):
64+
patterns.remove(channel)
65+
return .patterns(patterns)
4566
}
4667
}
4768

4869
var isEmpty: Bool {
4970
switch self {
5071
case .channels(let channels):
5172
return channels.isEmpty
52-
case .patterns(_):
53-
preconditionFailure("We don't support patterns yet")
73+
case .patterns(let patterns):
74+
return patterns.isEmpty
5475
}
5576
}
5677
}

Sources/Valkey/Subscriptions/ValkeySubscriptions.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,15 @@ struct ValkeySubscriptions {
3333
}
3434

3535
mutating func notify(_ token: RESPToken) throws -> Bool {
36-
// TODO: What should we do if push token doesnt decode
37-
guard let pushToken = try? PushToken(from: token) else { return false }
36+
let pushToken = try PushToken(from: token)
3837

3938
self.logger.trace("\(pushToken)")
4039

4140
var returnValue = false
4241
for index in subscriptions.indices {
4342
switch subscriptions[index].stateMachine.receivedToken(pushToken) {
4443
case .sendMessage(let message):
45-
subscriptions[index].subscription.sendMessage(.init(channel: pushToken.channel, message: message))
44+
subscriptions[index].subscription.sendMessage(message)
4645
case .fail(let error):
4746
subscriptions[index].subscription.sendError(error)
4847
subscriptions.remove(at: index)

0 commit comments

Comments
 (0)