Skip to content

Commit dc7f04e

Browse files
authored
Remove unchecked Sendable on ValkeyChannelHander (#15)
* Remove unchecked Sendable on ValkeyChannelHander * Remove now unnecessary eventLoop check.
1 parent 08181b2 commit dc7f04e

File tree

3 files changed

+74
-90
lines changed

3 files changed

+74
-90
lines changed

Sources/Valkey/Connection/ValkeyChannelHandler.swift

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
3636
private var context: ChannelHandlerContext?
3737
private let logger: Logger
3838

39-
init(channel: Channel, logger: Logger) {
40-
self.eventLoop = channel.eventLoop
39+
init(eventLoop: EventLoop, logger: Logger) {
40+
self.eventLoop = eventLoop
4141
self.commands = .init()
4242
self.decoder = NIOSingleStepByteToMessageProcessor(RESPTokenDecoder())
4343
self.context = nil
@@ -50,17 +50,12 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
5050
/// - promise: Promise to fulfill when command is complete
5151
@inlinable
5252
func write(request: ValkeyRequest) {
53-
if self.eventLoop.inEventLoop {
54-
self._write(request: request)
55-
} else {
56-
eventLoop.execute {
57-
self._write(request: request)
58-
}
59-
}
53+
self._write(request: request)
6054
}
6155

6256
@usableFromInline
6357
func _write(request: ValkeyRequest) {
58+
self.eventLoop.assertInEventLoop()
6459
guard let context = self.context else {
6560
preconditionFailure("Trying to use valkey connection before it is setup")
6661
}
@@ -135,7 +130,3 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
135130
promise.fail(error)
136131
}
137132
}
138-
139-
// The ValkeyChannelHandler needs to be Sendable so the ValkeyConnection can pass it
140-
// around at initialisation
141-
extension ValkeyChannelHandler: @unchecked Sendable {}

Sources/Valkey/Connection/ValkeyConnection.swift

Lines changed: 69 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public final class ValkeyConnection: Sendable {
4848
@usableFromInline
4949
let channel: Channel
5050
@usableFromInline
51-
let channelHandler: ValkeyChannelHandler
51+
let channelHandler: NIOLoopBound<ValkeyChannelHandler>
5252
let configuration: ValkeyClientConfiguration
5353
let isClosed: Atomic<Bool>
5454

@@ -60,7 +60,7 @@ public final class ValkeyConnection: Sendable {
6060
logger: Logger
6161
) {
6262
self.channel = channel
63-
self.channelHandler = channelHandler
63+
self.channelHandler = .init(channelHandler, eventLoop: channel.eventLoop)
6464
self.configuration = configuration
6565
self.logger = logger
6666
self.isClosed = .init(false)
@@ -77,16 +77,17 @@ public final class ValkeyConnection: Sendable {
7777
public static func connect(
7878
address: ServerAddress,
7979
configuration: ValkeyClientConfiguration,
80-
eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton,
80+
eventLoop: EventLoop = MultiThreadedEventLoopGroup.singleton.any(),
8181
logger: Logger
8282
) async throws -> ValkeyConnection {
83-
let (channel, channelHandler) = try await makeClient(
84-
address: address,
85-
eventLoopGroup: eventLoopGroup,
86-
configuration: configuration,
87-
logger: logger
88-
)
89-
let connection = ValkeyConnection(channel: channel, channelHandler: channelHandler, configuration: configuration, logger: logger)
83+
let future = if eventLoop.inEventLoop {
84+
self._makeClient(address: address, eventLoop: eventLoop, configuration: configuration, logger: logger)
85+
} else {
86+
eventLoop.flatSubmit {
87+
self._makeClient(address: address, eventLoop: eventLoop, configuration: configuration, logger: logger)
88+
}
89+
}
90+
let connection = try await future.get()
9091
if configuration.respVersion == .v3 {
9192
try await connection.resp3Upgrade()
9293
}
@@ -111,9 +112,16 @@ public final class ValkeyConnection: Sendable {
111112
public func send<Command: RESPCommand>(command: Command) async throws -> Command.Response {
112113
var encoder = RESPCommandEncoder()
113114
command.encode(into: &encoder)
115+
let result = encoder.buffer
114116

115117
let promise = channel.eventLoop.makePromise(of: RESPToken.self)
116-
channelHandler.write(request: ValkeyRequest.single(buffer: encoder.buffer, promise: promise))
118+
if self.channel.eventLoop.inEventLoop {
119+
self.channelHandler.value.write(request: ValkeyRequest.single(buffer: result, promise: promise))
120+
} else {
121+
self.channel.eventLoop.execute {
122+
self.channelHandler.value.write(request: ValkeyRequest.single(buffer: result, promise: promise))
123+
}
124+
}
117125
return try await .init(from: promise.futureResult.get())
118126
}
119127

@@ -127,14 +135,23 @@ public final class ValkeyConnection: Sendable {
127135
_ commands: repeat each Command
128136
) async throws -> (repeat (each Command).Response) {
129137
// this currently allocates a promise for every command. We could collpase this down to one promise
130-
var promises: [EventLoopPromise<RESPToken>] = []
138+
var mpromises: [EventLoopPromise<RESPToken>] = []
131139
var encoder = RESPCommandEncoder()
132140
for command in repeat each commands {
133141
command.encode(into: &encoder)
134-
promises.append(channel.eventLoop.makePromise(of: RESPToken.self))
142+
mpromises.append(channel.eventLoop.makePromise(of: RESPToken.self))
135143
}
144+
let outBuffer = encoder.buffer
145+
let promises = mpromises
136146
// write directly to channel handler
137-
channelHandler.write(request: ValkeyRequest.multiple(buffer: encoder.buffer, promises: promises))
147+
if self.channel.eventLoop.inEventLoop {
148+
self.channelHandler.value.write(request: ValkeyRequest.multiple(buffer: outBuffer, promises: promises))
149+
} else {
150+
self.channel.eventLoop.execute {
151+
self.channelHandler.value.write(request: ValkeyRequest.multiple(buffer: outBuffer, promises: promises))
152+
}
153+
}
154+
138155
// get response from channel handler
139156
var index = AutoIncrementingInteger()
140157
return try await (repeat (each Command).Response(from: promises[index.next()].futureResult.get()))
@@ -146,70 +163,64 @@ public final class ValkeyConnection: Sendable {
146163
}
147164

148165
/// Create Valkey connection and return channel connection is running on and the Valkey channel handler
149-
private static func makeClient(
166+
private static func _makeClient(
150167
address: ServerAddress,
151-
eventLoopGroup: EventLoopGroup,
168+
eventLoop: EventLoop,
152169
configuration: ValkeyClientConfiguration,
153170
logger: Logger
154-
) async throws -> (Channel, ValkeyChannelHandler) {
155-
// get bootstrap
156-
let bootstrap: ClientBootstrapProtocol
171+
) -> EventLoopFuture<ValkeyConnection> {
172+
eventLoop.assertInEventLoop()
173+
174+
let bootstrap: NIOClientTCPBootstrapProtocol
157175
#if canImport(Network)
158-
if let tsBootstrap = createTSBootstrap(eventLoopGroup: eventLoopGroup, tlsOptions: nil) {
176+
if let tsBootstrap = createTSBootstrap(eventLoopGroup: eventLoop, tlsOptions: nil) {
159177
bootstrap = tsBootstrap
160178
} else {
161179
#if os(iOS) || os(tvOS)
162180
self.logger.warning(
163181
"Running BSD sockets on iOS or tvOS is not recommended. Please use NIOTSEventLoopGroup, to run with the Network framework"
164182
)
165183
#endif
166-
bootstrap = self.createSocketsBootstrap(eventLoopGroup: eventLoopGroup)
184+
bootstrap = self.createSocketsBootstrap(eventLoopGroup: eventLoop)
167185
}
168186
#else
169-
bootstrap = self.createSocketsBootstrap(eventLoopGroup: eventLoopGroup)
187+
bootstrap = self.createSocketsBootstrap(eventLoopGroup: eventLoop)
170188
#endif
171189

172-
// connect
173-
let channel: Channel
174-
let channelHandler: ValkeyChannelHandler
175-
do {
176-
switch address.value {
177-
case .hostname(let host, let port):
178-
(channel, channelHandler) =
179-
try await bootstrap
180-
.connect(host: host, port: port) { channel in
181-
setupChannel(channel, configuration: configuration, logger: logger)
182-
}
190+
let connect = bootstrap.channelInitializer { channel in
191+
do {
192+
let sync = channel.pipeline.syncOperations
193+
if case .enable(let sslContext, let tlsServerName) = configuration.tls.base {
194+
try sync.addHandler(NIOSSLClientHandler(context: sslContext, serverHostname: tlsServerName))
195+
}
196+
let valkeyChannelHandler = ValkeyChannelHandler(
197+
eventLoop: channel.eventLoop,
198+
logger: logger
199+
)
200+
try sync.addHandler(valkeyChannelHandler)
201+
return eventLoop.makeSucceededVoidFuture()
202+
} catch {
203+
return eventLoop.makeFailedFuture(error)
204+
}
205+
}
206+
207+
let future: EventLoopFuture<Channel>
208+
switch address.value {
209+
case .hostname(let host, let port):
210+
future = connect.connect(host: host, port: port)
211+
future.whenSuccess { _ in
183212
logger.debug("Client connnected to \(host):\(port)")
184-
case .unixDomainSocket(let path):
185-
(channel, channelHandler) =
186-
try await bootstrap
187-
.connect(unixDomainSocketPath: path) { channel in
188-
setupChannel(channel, configuration: configuration, logger: logger)
189-
}
213+
}
214+
case .unixDomainSocket(let path):
215+
future = connect.connect(unixDomainSocketPath: path)
216+
future.whenSuccess { _ in
190217
logger.debug("Client connnected to socket path \(path)")
191218
}
192-
return (channel, channelHandler)
193-
} catch {
194-
throw error
195219
}
196-
}
197220

198-
private static func setupChannel(
199-
_ channel: Channel,
200-
configuration: ValkeyClientConfiguration,
201-
logger: Logger
202-
) -> EventLoopFuture<(Channel, ValkeyChannelHandler)> {
203-
channel.eventLoop.makeCompletedFuture {
204-
if case .enable(let sslContext, let tlsServerName) = configuration.tls.base {
205-
try channel.pipeline.syncOperations.addHandler(NIOSSLClientHandler(context: sslContext, serverHostname: tlsServerName))
206-
}
207-
let valkeyChannelHandler = ValkeyChannelHandler(
208-
channel: channel,
209-
logger: logger
210-
)
211-
try channel.pipeline.syncOperations.addHandler(valkeyChannelHandler)
212-
return (channel, valkeyChannelHandler)
221+
return future.flatMapThrowing { channel in
222+
let handler = try channel.pipeline.syncOperations.handler(type: ValkeyChannelHandler.self)
223+
return ValkeyConnection(channel: channel, channelHandler: handler, configuration: configuration, logger: logger)
213224
}
214225
}
215226

@@ -236,24 +247,6 @@ public final class ValkeyConnection: Sendable {
236247
#endif
237248
}
238249

239-
protocol ClientBootstrapProtocol {
240-
func connect<Output: Sendable>(
241-
host: String,
242-
port: Int,
243-
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
244-
) async throws -> Output
245-
246-
func connect<Output: Sendable>(
247-
unixDomainSocketPath: String,
248-
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
249-
) async throws -> Output
250-
}
251-
252-
extension ClientBootstrap: ClientBootstrapProtocol {}
253-
#if canImport(Network)
254-
extension NIOTSConnectionBootstrap: ClientBootstrapProtocol {}
255-
#endif
256-
257250
// Used in ValkeyConnection.pipeline
258251
@usableFromInline
259252
struct AutoIncrementingInteger {

Sources/Valkey/ValkeyClient.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ extension ValkeyClient {
6767
let valkeyConnection = try await ValkeyConnection.connect(
6868
address: self.serverAddress,
6969
configuration: self.configuration,
70-
eventLoopGroup: self.eventLoopGroup,
70+
eventLoop: self.eventLoopGroup.any(),
7171
logger: logger
7272
)
7373
let value: Value

0 commit comments

Comments
 (0)