@@ -48,7 +48,7 @@ public final class ValkeyConnection: Sendable {
4848 @usableFromInline
4949 let channel : Channel
5050 @usableFromInline
51- let channelHandler : ValkeyChannelHandler
51+ let channelHandler : NIOLoopBound < ValkeyChannelHandler >
5252 let configuration : ValkeyClientConfiguration
5353 let isClosed : Atomic < Bool >
5454
@@ -60,7 +60,7 @@ public final class ValkeyConnection: Sendable {
6060 logger: Logger
6161 ) {
6262 self . channel = channel
63- self . channelHandler = channelHandler
63+ self . channelHandler = . init ( channelHandler, eventLoop : channel . eventLoop )
6464 self . configuration = configuration
6565 self . logger = logger
6666 self . isClosed = . init( false )
@@ -77,16 +77,17 @@ public final class ValkeyConnection: Sendable {
7777 public static func connect(
7878 address: ServerAddress ,
7979 configuration: ValkeyClientConfiguration ,
80- eventLoopGroup : EventLoopGroup = MultiThreadedEventLoopGroup . singleton,
80+ eventLoop : EventLoop = MultiThreadedEventLoopGroup . singleton. any ( ) ,
8181 logger: Logger
8282 ) async throws -> ValkeyConnection {
83- let ( channel, channelHandler) = try await makeClient (
84- address: address,
85- eventLoopGroup: eventLoopGroup,
86- configuration: configuration,
87- logger: logger
88- )
89- let connection = ValkeyConnection ( channel: channel, channelHandler: channelHandler, configuration: configuration, logger: logger)
83+ let future = if eventLoop. inEventLoop {
84+ self . _makeClient ( address: address, eventLoop: eventLoop, configuration: configuration, logger: logger)
85+ } else {
86+ eventLoop. flatSubmit {
87+ self . _makeClient ( address: address, eventLoop: eventLoop, configuration: configuration, logger: logger)
88+ }
89+ }
90+ let connection = try await future. get ( )
9091 if configuration. respVersion == . v3 {
9192 try await connection. resp3Upgrade ( )
9293 }
@@ -111,9 +112,16 @@ public final class ValkeyConnection: Sendable {
111112 public func send< Command: RESPCommand > ( command: Command ) async throws -> Command . Response {
112113 var encoder = RESPCommandEncoder ( )
113114 command. encode ( into: & encoder)
115+ let result = encoder. buffer
114116
115117 let promise = channel. eventLoop. makePromise ( of: RESPToken . self)
116- channelHandler. write ( request: ValkeyRequest . single ( buffer: encoder. buffer, promise: promise) )
118+ if self . channel. eventLoop. inEventLoop {
119+ self . channelHandler. value. write ( request: ValkeyRequest . single ( buffer: result, promise: promise) )
120+ } else {
121+ self . channel. eventLoop. execute {
122+ self . channelHandler. value. write ( request: ValkeyRequest . single ( buffer: result, promise: promise) )
123+ }
124+ }
117125 return try await . init( from: promise. futureResult. get ( ) )
118126 }
119127
@@ -127,14 +135,23 @@ public final class ValkeyConnection: Sendable {
127135 _ commands: repeat each Command
128136 ) async throws -> ( repeat ( each Command ) . Response) {
129137 // this currently allocates a promise for every command. We could collpase this down to one promise
130- var promises : [ EventLoopPromise < RESPToken > ] = [ ]
138+ var mpromises : [ EventLoopPromise < RESPToken > ] = [ ]
131139 var encoder = RESPCommandEncoder ( )
132140 for command in repeat each commands {
133141 command. encode ( into: & encoder)
134- promises . append ( channel. eventLoop. makePromise ( of: RESPToken . self) )
142+ mpromises . append ( channel. eventLoop. makePromise ( of: RESPToken . self) )
135143 }
144+ let outBuffer = encoder. buffer
145+ let promises = mpromises
136146 // write directly to channel handler
137- channelHandler. write ( request: ValkeyRequest . multiple ( buffer: encoder. buffer, promises: promises) )
147+ if self . channel. eventLoop. inEventLoop {
148+ self . channelHandler. value. write ( request: ValkeyRequest . multiple ( buffer: outBuffer, promises: promises) )
149+ } else {
150+ self . channel. eventLoop. execute {
151+ self . channelHandler. value. write ( request: ValkeyRequest . multiple ( buffer: outBuffer, promises: promises) )
152+ }
153+ }
154+
138155 // get response from channel handler
139156 var index = AutoIncrementingInteger ( )
140157 return try await ( repeat ( each Command) . Response ( from: promises [ index. next ( ) ] . futureResult. get ( ) ) )
@@ -146,70 +163,64 @@ public final class ValkeyConnection: Sendable {
146163 }
147164
148165 /// Create Valkey connection and return channel connection is running on and the Valkey channel handler
149- private static func makeClient (
166+ private static func _makeClient (
150167 address: ServerAddress ,
151- eventLoopGroup : EventLoopGroup ,
168+ eventLoop : EventLoop ,
152169 configuration: ValkeyClientConfiguration ,
153170 logger: Logger
154- ) async throws -> ( Channel , ValkeyChannelHandler ) {
155- // get bootstrap
156- let bootstrap : ClientBootstrapProtocol
171+ ) -> EventLoopFuture < ValkeyConnection > {
172+ eventLoop. assertInEventLoop ( )
173+
174+ let bootstrap : NIOClientTCPBootstrapProtocol
157175 #if canImport(Network)
158- if let tsBootstrap = createTSBootstrap ( eventLoopGroup: eventLoopGroup , tlsOptions: nil ) {
176+ if let tsBootstrap = createTSBootstrap ( eventLoopGroup: eventLoop , tlsOptions: nil ) {
159177 bootstrap = tsBootstrap
160178 } else {
161179 #if os(iOS) || os(tvOS)
162180 self . logger. warning (
163181 " Running BSD sockets on iOS or tvOS is not recommended. Please use NIOTSEventLoopGroup, to run with the Network framework "
164182 )
165183 #endif
166- bootstrap = self . createSocketsBootstrap ( eventLoopGroup: eventLoopGroup )
184+ bootstrap = self . createSocketsBootstrap ( eventLoopGroup: eventLoop )
167185 }
168186 #else
169- bootstrap = self . createSocketsBootstrap ( eventLoopGroup: eventLoopGroup )
187+ bootstrap = self . createSocketsBootstrap ( eventLoopGroup: eventLoop )
170188 #endif
171189
172- // connect
173- let channel : Channel
174- let channelHandler : ValkeyChannelHandler
175- do {
176- switch address. value {
177- case . hostname( let host, let port) :
178- ( channel, channelHandler) =
179- try await bootstrap
180- . connect ( host: host, port: port) { channel in
181- setupChannel ( channel, configuration: configuration, logger: logger)
182- }
190+ let connect = bootstrap. channelInitializer { channel in
191+ do {
192+ let sync = channel. pipeline. syncOperations
193+ if case . enable( let sslContext, let tlsServerName) = configuration. tls. base {
194+ try sync. addHandler ( NIOSSLClientHandler ( context: sslContext, serverHostname: tlsServerName) )
195+ }
196+ let valkeyChannelHandler = ValkeyChannelHandler (
197+ eventLoop: channel. eventLoop,
198+ logger: logger
199+ )
200+ try sync. addHandler ( valkeyChannelHandler)
201+ return eventLoop. makeSucceededVoidFuture ( )
202+ } catch {
203+ return eventLoop. makeFailedFuture ( error)
204+ }
205+ }
206+
207+ let future : EventLoopFuture < Channel >
208+ switch address. value {
209+ case . hostname( let host, let port) :
210+ future = connect. connect ( host: host, port: port)
211+ future. whenSuccess { _ in
183212 logger. debug ( " Client connnected to \( host) : \( port) " )
184- case . unixDomainSocket( let path) :
185- ( channel, channelHandler) =
186- try await bootstrap
187- . connect ( unixDomainSocketPath: path) { channel in
188- setupChannel ( channel, configuration: configuration, logger: logger)
189- }
213+ }
214+ case . unixDomainSocket( let path) :
215+ future = connect. connect ( unixDomainSocketPath: path)
216+ future. whenSuccess { _ in
190217 logger. debug ( " Client connnected to socket path \( path) " )
191218 }
192- return ( channel, channelHandler)
193- } catch {
194- throw error
195219 }
196- }
197220
198- private static func setupChannel(
199- _ channel: Channel ,
200- configuration: ValkeyClientConfiguration ,
201- logger: Logger
202- ) -> EventLoopFuture < ( Channel , ValkeyChannelHandler ) > {
203- channel. eventLoop. makeCompletedFuture {
204- if case . enable( let sslContext, let tlsServerName) = configuration. tls. base {
205- try channel. pipeline. syncOperations. addHandler ( NIOSSLClientHandler ( context: sslContext, serverHostname: tlsServerName) )
206- }
207- let valkeyChannelHandler = ValkeyChannelHandler (
208- channel: channel,
209- logger: logger
210- )
211- try channel. pipeline. syncOperations. addHandler ( valkeyChannelHandler)
212- return ( channel, valkeyChannelHandler)
221+ return future. flatMapThrowing { channel in
222+ let handler = try channel. pipeline. syncOperations. handler ( type: ValkeyChannelHandler . self)
223+ return ValkeyConnection ( channel: channel, channelHandler: handler, configuration: configuration, logger: logger)
213224 }
214225 }
215226
@@ -236,24 +247,6 @@ public final class ValkeyConnection: Sendable {
236247 #endif
237248}
238249
239- protocol ClientBootstrapProtocol {
240- func connect< Output: Sendable > (
241- host: String ,
242- port: Int ,
243- channelInitializer: @escaping @Sendable ( Channel ) -> EventLoopFuture < Output >
244- ) async throws -> Output
245-
246- func connect< Output: Sendable > (
247- unixDomainSocketPath: String ,
248- channelInitializer: @escaping @Sendable ( Channel ) -> EventLoopFuture < Output >
249- ) async throws -> Output
250- }
251-
252- extension ClientBootstrap : ClientBootstrapProtocol { }
253- #if canImport(Network)
254- extension NIOTSConnectionBootstrap : ClientBootstrapProtocol { }
255- #endif
256-
257250// Used in ValkeyConnection.pipeline
258251@usableFromInline
259252struct AutoIncrementingInteger {
0 commit comments