diff --git a/Package.swift b/Package.swift index aeded741..983d6a30 100644 --- a/Package.swift +++ b/Package.swift @@ -5,11 +5,12 @@ import PackageDescription let package = Package( name: "swift-valkey", - platforms: [.macOS(.v13)], + platforms: [.macOS(.v15)], products: [ .library(name: "Valkey", targets: ["Valkey"]) ], dependencies: [ + .package(url: "https://github.com/apple/swift-collections.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-log.git", from: "1.0.0"), .package(url: "https://github.com/apple/swift-nio.git", from: "2.79.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.29.0"), @@ -19,6 +20,7 @@ let package = Package( .target( name: "Valkey", dependencies: [ + .product(name: "DequeModule", package: "swift-collections"), .product(name: "Logging", package: "swift-log"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOPosix", package: "swift-nio"), diff --git a/Sources/Valkey/Connection/ValkeyChannelHandler.swift b/Sources/Valkey/Connection/ValkeyChannelHandler.swift new file mode 100644 index 00000000..b197a389 --- /dev/null +++ b/Sources/Valkey/Connection/ValkeyChannelHandler.swift @@ -0,0 +1,143 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the swift-valkey project +// +// Copyright (c) 2025 the swift-valkey authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See swift-valkey/CONTRIBUTORS.txt for the list of swift-valkey authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import DequeModule +import Logging +import NIOCore + +@usableFromInline +enum ValkeyRequest: Sendable { + case single(buffer: ByteBuffer, promise: EventLoopPromise) + case multiple(buffer: ByteBuffer, promises: [EventLoopPromise]) +} + +@usableFromInline +final class ValkeyChannelHandler: ChannelDuplexHandler { + @usableFromInline + typealias OutboundIn = ValkeyRequest + @usableFromInline + typealias OutboundOut = ByteBuffer + @usableFromInline + typealias InboundIn = ByteBuffer + + @usableFromInline + let eventLoop: EventLoop + private var commands: Deque> + private var decoder: NIOSingleStepByteToMessageProcessor + private var context: ChannelHandlerContext? + private let logger: Logger + + init(channel: Channel, logger: Logger) { + self.eventLoop = channel.eventLoop + self.commands = .init() + self.decoder = NIOSingleStepByteToMessageProcessor(RESPTokenDecoder()) + self.context = nil + self.logger = logger + } + + /// Write valkey command/commands to channel + /// - Parameters: + /// - request: Valkey command request + /// - promise: Promise to fulfill when command is complete + @inlinable + func write(request: ValkeyRequest) { + if self.eventLoop.inEventLoop { + self._write(request: request) + } else { + eventLoop.execute { + self._write(request: request) + } + } + } + + @usableFromInline + func _write(request: ValkeyRequest) { + guard let context = self.context else { + preconditionFailure("Trying to use valkey connection before it is setup") + } + switch request { + case .single(let buffer, let tokenPromise): + self.commands.append(tokenPromise) + context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil) + + case .multiple(let buffer, let tokenPromises): + for tokenPromise in tokenPromises { + self.commands.append(tokenPromise) + } + context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil) + } + } + + @usableFromInline + func handlerAdded(context: ChannelHandlerContext) { + self.context = context + } + + @usableFromInline + func handlerRemoved(context: ChannelHandlerContext) { + self.context = nil + while let promise = commands.popFirst() { + promise.fail(ValkeyClientError.init(.connectionClosed)) + } + } + + @usableFromInline + func channelInactive(context: ChannelHandlerContext) { + do { + try self.decoder.finishProcessing(seenEOF: true) { token in + self.handleToken(context: context, token: token) + } + } catch let error as RESPParsingError { + self.handleError(context: context, error: error) + } catch { + preconditionFailure("Expected to only get RESPParsingError from the RESPTokenDecoder.") + } + + self.logger.trace("Channel inactive.") + } + + @usableFromInline + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let buffer = self.unwrapInboundIn(data) + + do { + try self.decoder.process(buffer: buffer) { token in + self.handleToken(context: context, token: token) + } + } catch let error as RESPParsingError { + self.handleError(context: context, error: error) + } catch { + preconditionFailure("Expected to only get RESPParsingError from the RESPTokenDecoder.") + } + } + + func handleToken(context: ChannelHandlerContext, token: RESPToken) { + guard let promise = commands.popFirst() else { + preconditionFailure("Unexpected response") + } + promise.succeed(token) + } + + func handleError(context: ChannelHandlerContext, error: Error) { + self.logger.debug("ValkeyCommandHandler: ERROR \(error)") + guard let promise = commands.popFirst() else { + preconditionFailure("Unexpected response") + } + promise.fail(error) + } +} + +// The ValkeyChannelHandler needs to be Sendable so the ValkeyConnection can pass it +// around at initialisation +extension ValkeyChannelHandler: @unchecked Sendable {} diff --git a/Sources/Valkey/Connection/ValkeyConnection.swift b/Sources/Valkey/Connection/ValkeyConnection.swift index e96b9c13..c4419d47 100644 --- a/Sources/Valkey/Connection/ValkeyConnection.swift +++ b/Sources/Valkey/Connection/ValkeyConnection.swift @@ -16,6 +16,7 @@ import Logging import NIOCore import NIOPosix import NIOSSL +import Synchronization #if canImport(Network) import Network @@ -41,198 +42,118 @@ public struct ServerAddress: Sendable, Equatable { } /// Single connection to a Valkey database -public struct ValkeyConnection: Sendable { - enum Request { - case command(ByteBuffer) - case pipelinedCommands(ByteBuffer, Int) - } - enum Response { - case token(RESPToken) - case pipelinedResponse([Result]) - } - typealias RequestStreamElement = (Request, CheckedContinuation) +public final class ValkeyConnection: Sendable { /// Logger used by Server let logger: Logger - let eventLoopGroup: EventLoopGroup + @usableFromInline + let channel: Channel + @usableFromInline + let channelHandler: ValkeyChannelHandler let configuration: ValkeyClientConfiguration - let address: ServerAddress - #if canImport(Network) - let tlsOptions: NWProtocolTLS.Options? - #endif - - let requestStream: AsyncStream - let requestContinuation: AsyncStream.Continuation + let isClosed: Atomic - /// Initialize Client - public init( - address: ServerAddress, + /// Initialize connection + private init( + channel: Channel, + channelHandler: ValkeyChannelHandler, configuration: ValkeyClientConfiguration, - eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, logger: Logger ) { - self.address = address + self.channel = channel + self.channelHandler = channelHandler self.configuration = configuration - self.eventLoopGroup = eventLoopGroup self.logger = logger - #if canImport(Network) - self.tlsOptions = nil - #endif - (self.requestStream, self.requestContinuation) = AsyncStream.makeStream(of: RequestStreamElement.self) + self.isClosed = .init(false) } - public func run() async throws { - let asyncChannel = try await self.makeClient( - address: self.address + /// Connect to Valkey database and return connection + /// + /// - Parameters: + /// - address: Internet address of database + /// - configuration: Configuration of Valkey connection + /// - eventLoopGroup: EventLoopGroup to use + /// - logger: Logger for connection + /// - Returns: ValkeyConnection + public static func connect( + address: ServerAddress, + configuration: ValkeyClientConfiguration, + eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, + logger: Logger + ) async throws -> ValkeyConnection { + let (channel, channelHandler) = try await makeClient( + address: address, + eventLoopGroup: eventLoopGroup, + configuration: configuration, + logger: logger ) - do { - try await withTaskCancellationHandler { - try await asyncChannel.executeThenClose { inbound, outbound in - var inboundIterator = inbound.makeAsyncIterator() - if self.configuration.respVersion == .v3 { - try await resp3Upgrade(outbound: outbound, inboundIterator: &inboundIterator) - } - for await (request, continuation) in requestStream { - do { - switch request { - case .command(let command): - try await outbound.write(command) - let response = try await inboundIterator.next() - if let response { - continuation.resume(returning: .token(response)) - } else { - requestContinuation.finish() - continuation.resume( - throwing: ValkeyClientError( - .connectionClosed, - message: "The connection to the database was unexpectedly closed." - ) - ) - } - case .pipelinedCommands(let commands, let count): - try await outbound.write(commands) - var responses: [Result] = .init() - for _ in 0.. EventLoopFuture { + guard self.isClosed.compareExchange(expected: false, desired: true, successOrdering: .relaxed, failureOrdering: .relaxed).exchanged else { + return channel.eventLoop.makeSucceededVoidFuture() } + self.channel.close(mode: .all, promise: nil) + return self.channel.closeFuture } - @discardableResult public func send(command: Command) async throws -> Command.Response { + /// Send RESP command to Valkey connection + /// - Parameter command: RESPCommand structure + /// - Returns: The command response as defined in the RESPCommand + + @inlinable + public func send(command: Command) async throws -> Command.Response { var encoder = RESPCommandEncoder() command.encode(into: &encoder) - let response: Response = try await withCheckedThrowingContinuation { continuation in - switch requestContinuation.yield((.command(encoder.buffer), continuation)) { - case .enqueued: - break - case .dropped, .terminated: - continuation.resume( - throwing: ValkeyClientError( - .connectionClosed, - message: "Unable to enqueue request due to the connection being shutdown." - ) - ) - default: - break - } - } - guard case .token(let token) = response else { preconditionFailure("Expected a single response") } - return try .init(from: token) + + let promise = channel.eventLoop.makePromise(of: RESPToken.self) + channelHandler.write(request: ValkeyRequest.single(buffer: encoder.buffer, promise: promise)) + return try await .init(from: promise.futureResult.get()) } - @discardableResult public func pipeline( + /// Pipeline a series of commands to Valkey connection + /// + /// This function will only return once it has the results of all the commands sent + /// - Parameter commands: Parameter pack of RESPCommands + /// - Returns: Parameter pack holding the responses of all the commands + @inlinable + public func pipeline( _ commands: repeat each Command ) async throws -> (repeat (each Command).Response) { - var count = 0 + // this currently allocates a promise for every command. We could collpase this down to one promise + var promises: [EventLoopPromise] = [] var encoder = RESPCommandEncoder() for command in repeat each commands { command.encode(into: &encoder) - count += 1 - } - - let response: Response = try await withCheckedThrowingContinuation { continuation in - switch requestContinuation.yield((.pipelinedCommands(encoder.buffer, count), continuation)) { - case .enqueued: - break - case .dropped, .terminated: - continuation.resume( - throwing: ValkeyClientError( - .connectionClosed, - message: "Unable to enqueue request due to the connection being shutdown." - ) - ) - default: - break - } + promises.append(channel.eventLoop.makePromise(of: RESPToken.self)) } - guard case .pipelinedResponse(let tokens) = response else { preconditionFailure("Expected a single response") } - + // write directly to channel handler + channelHandler.write(request: ValkeyRequest.multiple(buffer: encoder.buffer, promises: promises)) + // get response from channel handler var index = AutoIncrementingInteger() - return try (repeat (each Command).Response(from: tokens[index.next()].get())) + return try await (repeat (each Command).Response(from: promises[index.next()].futureResult.get())) } /// Try to upgrade to RESP3 - private func resp3Upgrade( - outbound: NIOAsyncChannelOutboundWriter, - inboundIterator: inout NIOAsyncChannelInboundStream.AsyncIterator - ) async throws { - var encoder = RESPCommandEncoder() - encoder.encodeArray("HELLO", 3) - try await outbound.write(encoder.buffer) - let response = try await inboundIterator.next() - guard let response else { - throw ValkeyClientError(.connectionClosed, message: "The connection to the database was unexpectedly closed.") - } - // if returned value is an error then throw that error - if let value = response.errorString { - throw ValkeyClientError(.commandError, message: String(buffer: value)) - } + private func resp3Upgrade() async throws { + _ = try await send(command: HELLO(arguments: .init(protover: 3, auth: nil, clientname: nil))) } - /// Connect to server - private func makeClient(address: ServerAddress) async throws -> NIOAsyncChannel { + /// Create Valkey connection and return channel connection is running on and the Valkey channel handler + private static func makeClient( + address: ServerAddress, + eventLoopGroup: EventLoopGroup, + configuration: ValkeyClientConfiguration, + logger: Logger + ) async throws -> (Channel, ValkeyChannelHandler) { // get bootstrap let bootstrap: ClientBootstrapProtocol #if canImport(Network) - if let tsBootstrap = self.createTSBootstrap() { + if let tsBootstrap = createTSBootstrap(eventLoopGroup: eventLoopGroup, tlsOptions: nil) { bootstrap = tsBootstrap } else { #if os(iOS) || os(tvOS) @@ -240,61 +161,64 @@ public struct ValkeyConnection: Sendable { "Running BSD sockets on iOS or tvOS is not recommended. Please use NIOTSEventLoopGroup, to run with the Network framework" ) #endif - bootstrap = self.createSocketsBootstrap() + bootstrap = self.createSocketsBootstrap(eventLoopGroup: eventLoopGroup) } #else - bootstrap = self.createSocketsBootstrap() + bootstrap = self.createSocketsBootstrap(eventLoopGroup: eventLoopGroup) #endif // connect - let result: NIOAsyncChannel + let channel: Channel + let channelHandler: ValkeyChannelHandler do { switch address.value { case .hostname(let host, let port): - result = + (channel, channelHandler) = try await bootstrap .connect(host: host, port: port) { channel in - setupChannel(channel) + setupChannel(channel, configuration: configuration, logger: logger) } - self.logger.debug("Client connnected to \(host):\(port)") + logger.debug("Client connnected to \(host):\(port)") case .unixDomainSocket(let path): - result = + (channel, channelHandler) = try await bootstrap .connect(unixDomainSocketPath: path) { channel in - setupChannel(channel) + setupChannel(channel, configuration: configuration, logger: logger) } - self.logger.debug("Client connnected to socket path \(path)") + logger.debug("Client connnected to socket path \(path)") } - return result + return (channel, channelHandler) } catch { throw error } } - private func setupChannel(_ channel: Channel) -> EventLoopFuture> { + private static func setupChannel( + _ channel: Channel, + configuration: ValkeyClientConfiguration, + logger: Logger + ) -> EventLoopFuture<(Channel, ValkeyChannelHandler)> { channel.eventLoop.makeCompletedFuture { - if case .enable(let sslContext, let tlsServerName) = self.configuration.tls.base { + if case .enable(let sslContext, let tlsServerName) = configuration.tls.base { try channel.pipeline.syncOperations.addHandler(NIOSSLClientHandler(context: sslContext, serverHostname: tlsServerName)) } - try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(RESPTokenDecoder())) - return try NIOAsyncChannel( - wrappingChannelSynchronously: channel, - configuration: .init() - ) + let valkeyChannelHandler = ValkeyChannelHandler(channel: channel, logger: logger) + try channel.pipeline.syncOperations.addHandler(valkeyChannelHandler) + return (channel, valkeyChannelHandler) } } /// create a BSD sockets based bootstrap - private func createSocketsBootstrap() -> ClientBootstrap { - ClientBootstrap(group: self.eventLoopGroup) + private static func createSocketsBootstrap(eventLoopGroup: EventLoopGroup) -> ClientBootstrap { + ClientBootstrap(group: eventLoopGroup) .channelOption(ChannelOptions.allowRemoteHalfClosure, value: true) } #if canImport(Network) /// create a NIOTransportServices bootstrap using Network.framework - private func createTSBootstrap() -> NIOTSConnectionBootstrap? { + private static func createTSBootstrap(eventLoopGroup: EventLoopGroup, tlsOptions: NWProtocolTLS.Options?) -> NIOTSConnectionBootstrap? { guard - let bootstrap = NIOTSConnectionBootstrap(validatingGroup: self.eventLoopGroup)? + let bootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoopGroup)? .channelOption(ChannelOptions.allowRemoteHalfClosure, value: true) else { return nil @@ -325,8 +249,18 @@ extension ClientBootstrap: ClientBootstrapProtocol {} extension NIOTSConnectionBootstrap: ClientBootstrapProtocol {} #endif -private struct AutoIncrementingInteger { +// Used in ValkeyConnection.pipeline +@usableFromInline +struct AutoIncrementingInteger { + @usableFromInline var value: Int = 0 + + @inlinable + init() { + self.value = 0 + } + + @inlinable mutating func next() -> Int { value += 1 return value - 1 diff --git a/Sources/Valkey/RESP/RESPKey.swift b/Sources/Valkey/RESP/RESPKey.swift index f20f6d60..785ac692 100644 --- a/Sources/Valkey/RESP/RESPKey.swift +++ b/Sources/Valkey/RESP/RESPKey.swift @@ -15,7 +15,7 @@ import NIOCore /// Type representing a RESPKey -public struct RESPKey: RawRepresentable { +public struct RESPKey: RawRepresentable, Sendable { public var rawValue: String public init(rawValue: String) { diff --git a/Sources/Valkey/ValkeyClient.swift b/Sources/Valkey/ValkeyClient.swift index fe363734..cb36c9b5 100644 --- a/Sources/Valkey/ValkeyClient.swift +++ b/Sources/Valkey/ValkeyClient.swift @@ -64,19 +64,20 @@ extension ValkeyClient { logger: Logger, operation: @escaping @Sendable (ValkeyConnection) async throws -> Value ) async throws -> Value { - let valkeyConnection = ValkeyConnection( + let valkeyConnection = try await ValkeyConnection.connect( address: self.serverAddress, configuration: self.configuration, eventLoopGroup: self.eventLoopGroup, logger: logger ) - return try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - try await valkeyConnection.run() - } - let value: Value = try await operation(valkeyConnection) - group.cancelAll() - return value + let value: Value + do { + value = try await operation(valkeyConnection) + } catch { + try? await valkeyConnection.close().get() + throw error } + try await valkeyConnection.close().get() + return value } } diff --git a/Tests/ValkeyTests/ValkeyTests.swift b/Tests/ValkeyTests/ValkeyTests.swift index e0e451e4..6854dd56 100644 --- a/Tests/ValkeyTests/ValkeyTests.swift +++ b/Tests/ValkeyTests/ValkeyTests.swift @@ -171,6 +171,50 @@ struct GeneratedCommands { } } + @Test + func testMultiplexing() async throws { + var logger = Logger(label: "Valkey") + logger.logLevel = .debug + try await ValkeyClient(.hostname(valkeyHostname, port: 6379), logger: logger).withConnection(logger: logger) { connection in + try await withThrowingTaskGroup(of: Void.self) { group in + for _ in 0..<100 { + group.addTask { + try await withKey(connection: connection) { key in + _ = try await connection.set(key: key, value: key.rawValue) + let response = try await connection.get(key: key) + #expect(response == key.rawValue) + } + } + } + try await group.waitForAll() + } + } + } + + @Test + func testMultiplexingPipelinedRequests() async throws { + var logger = Logger(label: "Valkey") + logger.logLevel = .debug + try await ValkeyClient(.hostname(valkeyHostname, port: 6379), logger: logger).withConnection(logger: logger) { connection in + try await withThrowingTaskGroup(of: Void.self) { group in + try await withKey(connection: connection) { key in + // Add 100 requests get and setting the same key + for _ in 0..<100 { + group.addTask { + let value = UUID().uuidString + let responses = try await connection.pipeline( + SET(key: key, value: value), + GET(key: key) + ) + #expect(responses.1 == value) + } + } + } + try await group.waitForAll() + } + } + } + /* @Test func testSubscriptions() async throws {