Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions Sources/Valkey/Connection/ValkeyChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ final class ValkeyChannelHandler: ChannelDuplexHandler {
private var context: ChannelHandlerContext?
private let logger: Logger

init(channel: Channel, logger: Logger) {
self.eventLoop = channel.eventLoop
init(eventLoop: EventLoop, logger: Logger) {
self.eventLoop = eventLoop
self.commands = .init()
self.decoder = NIOSingleStepByteToMessageProcessor(RESPTokenDecoder())
self.context = nil
Expand All @@ -52,17 +52,12 @@ final class ValkeyChannelHandler: ChannelDuplexHandler {
/// - promise: Promise to fulfill when command is complete
@inlinable
func write(request: ValkeyRequest) {
if self.eventLoop.inEventLoop {
self._write(request: request)
} else {
eventLoop.execute {
self._write(request: request)
}
}
self._write(request: request)
}

@usableFromInline
func _write(request: ValkeyRequest) {
self.eventLoop.assertInEventLoop()
guard let context = self.context else {
preconditionFailure("Trying to use valkey connection before it is setup")
}
Expand Down Expand Up @@ -137,7 +132,3 @@ final class ValkeyChannelHandler: ChannelDuplexHandler {
promise.fail(error)
}
}

// The ValkeyChannelHandler needs to be Sendable so the ValkeyConnection can pass it
// around at initialisation
extension ValkeyChannelHandler: @unchecked Sendable {}
145 changes: 69 additions & 76 deletions Sources/Valkey/Connection/ValkeyConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public final class ValkeyConnection: Sendable {
@usableFromInline
let channel: Channel
@usableFromInline
let channelHandler: ValkeyChannelHandler
let channelHandler: NIOLoopBound<ValkeyChannelHandler>
let configuration: ValkeyClientConfiguration
let isClosed: Atomic<Bool>

Expand All @@ -60,7 +60,7 @@ public final class ValkeyConnection: Sendable {
logger: Logger
) {
self.channel = channel
self.channelHandler = channelHandler
self.channelHandler = .init(channelHandler, eventLoop: channel.eventLoop)
self.configuration = configuration
self.logger = logger
self.isClosed = .init(false)
Expand All @@ -77,16 +77,17 @@ public final class ValkeyConnection: Sendable {
public static func connect(
address: ServerAddress,
configuration: ValkeyClientConfiguration,
eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton,
eventLoop: EventLoop = MultiThreadedEventLoopGroup.singleton.any(),
logger: Logger
) async throws -> ValkeyConnection {
let (channel, channelHandler) = try await makeClient(
address: address,
eventLoopGroup: eventLoopGroup,
configuration: configuration,
logger: logger
)
let connection = ValkeyConnection(channel: channel, channelHandler: channelHandler, configuration: configuration, logger: logger)
let future = if eventLoop.inEventLoop {
self._makeClient(address: address, eventLoop: eventLoop, configuration: configuration, logger: logger)
} else {
eventLoop.flatSubmit {
self._makeClient(address: address, eventLoop: eventLoop, configuration: configuration, logger: logger)
}
}
let connection = try await future.get()
if configuration.respVersion == .v3 {
try await connection.resp3Upgrade()
}
Expand All @@ -111,9 +112,16 @@ public final class ValkeyConnection: Sendable {
public func send<Command: RESPCommand>(command: Command) async throws -> Command.Response {
var encoder = RESPCommandEncoder()
command.encode(into: &encoder)
let result = encoder.buffer

let promise = channel.eventLoop.makePromise(of: RESPToken.self)
channelHandler.write(request: ValkeyRequest.single(buffer: encoder.buffer, promise: promise))
if self.channel.eventLoop.inEventLoop {
self.channelHandler.value.write(request: ValkeyRequest.single(buffer: result, promise: promise))
} else {
self.channel.eventLoop.execute {
self.channelHandler.value.write(request: ValkeyRequest.single(buffer: result, promise: promise))
}
}
return try await .init(from: promise.futureResult.get())
}

Expand All @@ -127,14 +135,23 @@ public final class ValkeyConnection: Sendable {
_ commands: repeat each Command
) async throws -> (repeat (each Command).Response) {
// this currently allocates a promise for every command. We could collpase this down to one promise
var promises: [EventLoopPromise<RESPToken>] = []
var mpromises: [EventLoopPromise<RESPToken>] = []
var encoder = RESPCommandEncoder()
for command in repeat each commands {
command.encode(into: &encoder)
promises.append(channel.eventLoop.makePromise(of: RESPToken.self))
mpromises.append(channel.eventLoop.makePromise(of: RESPToken.self))
}
let outBuffer = encoder.buffer
let promises = mpromises
// write directly to channel handler
channelHandler.write(request: ValkeyRequest.multiple(buffer: encoder.buffer, promises: promises))
if self.channel.eventLoop.inEventLoop {
self.channelHandler.value.write(request: ValkeyRequest.multiple(buffer: outBuffer, promises: promises))
} else {
self.channel.eventLoop.execute {
self.channelHandler.value.write(request: ValkeyRequest.multiple(buffer: outBuffer, promises: promises))
}
}

// get response from channel handler
var index = AutoIncrementingInteger()
return try await (repeat (each Command).Response(from: promises[index.next()].futureResult.get()))
Expand All @@ -146,70 +163,64 @@ public final class ValkeyConnection: Sendable {
}

/// Create Valkey connection and return channel connection is running on and the Valkey channel handler
private static func makeClient(
private static func _makeClient(
address: ServerAddress,
eventLoopGroup: EventLoopGroup,
eventLoop: EventLoop,
configuration: ValkeyClientConfiguration,
logger: Logger
) async throws -> (Channel, ValkeyChannelHandler) {
// get bootstrap
let bootstrap: ClientBootstrapProtocol
) -> EventLoopFuture<ValkeyConnection> {
eventLoop.assertInEventLoop()

let bootstrap: NIOClientTCPBootstrapProtocol
#if canImport(Network)
if let tsBootstrap = createTSBootstrap(eventLoopGroup: eventLoopGroup, tlsOptions: nil) {
if let tsBootstrap = createTSBootstrap(eventLoopGroup: eventLoop, tlsOptions: nil) {
bootstrap = tsBootstrap
} else {
#if os(iOS) || os(tvOS)
self.logger.warning(
"Running BSD sockets on iOS or tvOS is not recommended. Please use NIOTSEventLoopGroup, to run with the Network framework"
)
#endif
bootstrap = self.createSocketsBootstrap(eventLoopGroup: eventLoopGroup)
bootstrap = self.createSocketsBootstrap(eventLoopGroup: eventLoop)
}
#else
bootstrap = self.createSocketsBootstrap(eventLoopGroup: eventLoopGroup)
bootstrap = self.createSocketsBootstrap(eventLoopGroup: eventLoop)
#endif

// connect
let channel: Channel
let channelHandler: ValkeyChannelHandler
do {
switch address.value {
case .hostname(let host, let port):
(channel, channelHandler) =
try await bootstrap
.connect(host: host, port: port) { channel in
setupChannel(channel, configuration: configuration, logger: logger)
}
let connect = bootstrap.channelInitializer { channel in
do {
let sync = channel.pipeline.syncOperations
if case .enable(let sslContext, let tlsServerName) = configuration.tls.base {
try sync.addHandler(NIOSSLClientHandler(context: sslContext, serverHostname: tlsServerName))
}
let valkeyChannelHandler = ValkeyChannelHandler(
eventLoop: channel.eventLoop,
logger: logger
)
try sync.addHandler(valkeyChannelHandler)
return eventLoop.makeSucceededVoidFuture()
} catch {
return eventLoop.makeFailedFuture(error)
}
}

let future: EventLoopFuture<Channel>
switch address.value {
case .hostname(let host, let port):
future = connect.connect(host: host, port: port)
future.whenSuccess { _ in
logger.debug("Client connnected to \(host):\(port)")
case .unixDomainSocket(let path):
(channel, channelHandler) =
try await bootstrap
.connect(unixDomainSocketPath: path) { channel in
setupChannel(channel, configuration: configuration, logger: logger)
}
}
case .unixDomainSocket(let path):
future = connect.connect(unixDomainSocketPath: path)
future.whenSuccess { _ in
logger.debug("Client connnected to socket path \(path)")
}
return (channel, channelHandler)
} catch {
throw error
}
}

private static func setupChannel(
_ channel: Channel,
configuration: ValkeyClientConfiguration,
logger: Logger
) -> EventLoopFuture<(Channel, ValkeyChannelHandler)> {
channel.eventLoop.makeCompletedFuture {
if case .enable(let sslContext, let tlsServerName) = configuration.tls.base {
try channel.pipeline.syncOperations.addHandler(NIOSSLClientHandler(context: sslContext, serverHostname: tlsServerName))
}
let valkeyChannelHandler = ValkeyChannelHandler(
channel: channel,
logger: logger
)
try channel.pipeline.syncOperations.addHandler(valkeyChannelHandler)
return (channel, valkeyChannelHandler)
return future.flatMapThrowing { channel in
let handler = try channel.pipeline.syncOperations.handler(type: ValkeyChannelHandler.self)
return ValkeyConnection(channel: channel, channelHandler: handler, configuration: configuration, logger: logger)
}
}

Expand All @@ -236,24 +247,6 @@ public final class ValkeyConnection: Sendable {
#endif
}

protocol ClientBootstrapProtocol {
func connect<Output: Sendable>(
host: String,
port: Int,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
) async throws -> Output

func connect<Output: Sendable>(
unixDomainSocketPath: String,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
) async throws -> Output
}

extension ClientBootstrap: ClientBootstrapProtocol {}
#if canImport(Network)
extension NIOTSConnectionBootstrap: ClientBootstrapProtocol {}
#endif

// Used in ValkeyConnection.pipeline
@usableFromInline
struct AutoIncrementingInteger {
Expand Down
2 changes: 1 addition & 1 deletion Sources/Valkey/ValkeyClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ extension ValkeyClient {
let valkeyConnection = try await ValkeyConnection.connect(
address: self.serverAddress,
configuration: self.configuration,
eventLoopGroup: self.eventLoopGroup,
eventLoop: self.eventLoopGroup.any(),
logger: logger
)
let value: Value
Expand Down