diff --git a/Sources/Valkey/Connection/ValkeyChannelHandler.swift b/Sources/Valkey/Connection/ValkeyChannelHandler.swift index ff3698d0..9207227f 100644 --- a/Sources/Valkey/Connection/ValkeyChannelHandler.swift +++ b/Sources/Valkey/Connection/ValkeyChannelHandler.swift @@ -140,10 +140,18 @@ final class ValkeyChannelHandler: ChannelInboundHandler { } func handleToken(context: ChannelHandlerContext, token: RESPToken) { - guard let promise = commands.popFirst() else { - preconditionFailure("Unexpected response") + switch token.identifier { + case .simpleError, .bulkError: + guard let promise = commands.popFirst() else { + preconditionFailure("Unexpected response") + } + promise.fail(ValkeyClientError(.commandError, message: token.errorString.map { String(buffer: $0) })) + default: + guard let promise = commands.popFirst() else { + preconditionFailure("Unexpected response") + } + promise.succeed(token) } - promise.succeed(token) } func handleError(context: ChannelHandlerContext, error: Error) { diff --git a/Sources/Valkey/RESP/RESPToken.swift b/Sources/Valkey/RESP/RESPToken.swift index 94b382f3..de2344c0 100644 --- a/Sources/Valkey/RESP/RESPToken.swift +++ b/Sources/Valkey/RESP/RESPToken.swift @@ -117,7 +117,7 @@ public struct RESPToken: Hashable, Sendable { return .bulkString(local.readSlice(length: length)!) - case .blobError: + case .bulkError: var lengthSlice = try! local.readCRLFTerminatedSlice2()! let lengthString = lengthSlice.readString(length: lengthSlice.readableBytes)! let length = Int(lengthString)! @@ -197,7 +197,7 @@ public struct RESPToken: Hashable, Sendable { case .simpleError: let slice = try! local.readCRLFTerminatedSlice2()! return slice - case .blobError: + case .bulkError: var lengthSlice = try! local.readCRLFTerminatedSlice2()! let lengthString = lengthSlice.readString(length: lengthSlice.readableBytes)! let length = Int(lengthString)! @@ -207,6 +207,10 @@ public struct RESPToken: Hashable, Sendable { } } + public var identifier: RESPTypeIdentifier { + self.base.getValidatedRESP3TypeIdentifier() + } + public init?(consuming buffer: inout ByteBuffer) throws { try self.init(consuming: &buffer, depth: 0) } @@ -223,7 +227,7 @@ public struct RESPToken: Hashable, Sendable { case .some(.bulkString), .some(.verbatimString), - .some(.blobError): + .some(.bulkError): validated = try buffer.readRESPBlobStringSlice() case .some(.simpleString), @@ -260,7 +264,7 @@ public struct RESPToken: Hashable, Sendable { } extension ByteBuffer { - fileprivate mutating func getRESP3TypeIdentifier(at index: Int) throws -> RESPTypeIdentifier? { + fileprivate func getRESP3TypeIdentifier(at index: Int) throws -> RESPTypeIdentifier? { guard let int = self.getInteger(at: index, as: UInt8.self) else { return nil } @@ -272,6 +276,11 @@ extension ByteBuffer { return id } + fileprivate func getValidatedRESP3TypeIdentifier() -> RESPTypeIdentifier { + let int = self.getInteger(at: self.readerIndex, as: UInt8.self)! + return RESPTypeIdentifier(rawValue: int)! + } + fileprivate mutating func readValidatedRESP3TypeIdentifier() -> RESPTypeIdentifier { let int = self.readInteger(as: UInt8.self)! return RESPTypeIdentifier(rawValue: int)! @@ -311,7 +320,7 @@ extension ByteBuffer { fileprivate mutating func readRESPBlobStringSlice() throws -> ByteBuffer? { let marker = try self.getRESP3TypeIdentifier(at: self.readerIndex)! - precondition(marker == .bulkString || marker == .verbatimString || marker == .blobError) + precondition(marker == .bulkString || marker == .verbatimString || marker == .bulkError) guard var lengthSlice = try self.getCRLFTerminatedSlice(at: self.readerIndex + 1) else { return nil } diff --git a/Sources/Valkey/RESP/RESPTypeIdentifier.swift b/Sources/Valkey/RESP/RESPTypeIdentifier.swift index 0b45ea22..485d3590 100644 --- a/Sources/Valkey/RESP/RESPTypeIdentifier.swift +++ b/Sources/Valkey/RESP/RESPTypeIdentifier.swift @@ -18,7 +18,7 @@ public enum RESPTypeIdentifier: UInt8 { case simpleString = 43 // UInt8.plus case simpleError = 45 // UInt8.min case bulkString = 36 // UInt8.dollar - case blobError = 33 // UInt8.exclamationMark + case bulkError = 33 // UInt8.exclamationMark case verbatimString = 61 // UInt8.equals case boolean = 35 // UInt8.pound case null = 95 // UInt8.underscore diff --git a/Sources/Valkey/ValkeyClientError.swift b/Sources/Valkey/ValkeyClientError.swift index 4ca6a74a..953b1115 100644 --- a/Sources/Valkey/ValkeyClientError.swift +++ b/Sources/Valkey/ValkeyClientError.swift @@ -13,7 +13,7 @@ //===----------------------------------------------------------------------===// /// Errors returned by ``ValkeyClient`` -public struct ValkeyClientError: Error, CustomStringConvertible { +public struct ValkeyClientError: Error, CustomStringConvertible, Equatable { public struct ErrorCode: Equatable, Sendable { fileprivate enum _Internal: Equatable, Sendable { case connectionClosed diff --git a/Tests/IntegrationTests/ValkeyTests.swift b/Tests/IntegrationTests/ValkeyTests.swift index 6854dd56..12bdf506 100644 --- a/Tests/IntegrationTests/ValkeyTests.swift +++ b/Tests/IntegrationTests/ValkeyTests.swift @@ -171,6 +171,18 @@ struct GeneratedCommands { } } + @Test("Test command error is thrown") + func testCommandError() async throws { + var logger = Logger(label: "Valkey") + logger.logLevel = .debug + try await ValkeyClient(.hostname(valkeyHostname, port: 6379), logger: logger).withConnection(logger: logger) { connection in + try await withKey(connection: connection) { key in + _ = try await connection.set(key: key, value: "Hello") + await #expect(throws: ValkeyClientError.self) { _ = try await connection.rpop(key: key) } + } + } + } + @Test func testMultiplexing() async throws { var logger = Logger(label: "Valkey") diff --git a/Tests/ValkeyTests/ValkeyConnectionTests.swift b/Tests/ValkeyTests/ValkeyConnectionTests.swift index 64b10e73..63065d41 100644 --- a/Tests/ValkeyTests/ValkeyConnectionTests.swift +++ b/Tests/ValkeyTests/ValkeyConnectionTests.swift @@ -35,4 +35,42 @@ struct ConnectionTests { try await channel.writeInbound(ByteBuffer(string: "$3\r\nBar\r\n")) #expect(try await fooResult == "Bar") } + + @Test + func testSimpleError() async throws { + let channel = NIOAsyncTestingChannel() + let logger = Logger(label: "test") + let connection = try await ValkeyConnection.setupChannel(channel, configuration: .init(), logger: logger) + + async let fooResult = connection.get(key: "foo") + _ = try await channel.waitForOutboundWrite(as: ByteBuffer.self) + + try await channel.writeInbound(ByteBuffer(string: "-Error!\r\n")) + do { + _ = try await fooResult + Issue.record() + } catch let error as ValkeyClientError { + #expect(error.errorCode == .commandError) + #expect(error.message == "Error!") + } + } + + @Test + func testBulkError() async throws { + let channel = NIOAsyncTestingChannel() + let logger = Logger(label: "test") + let connection = try await ValkeyConnection.setupChannel(channel, configuration: .init(), logger: logger) + + async let fooResult = connection.get(key: "foo") + _ = try await channel.waitForOutboundWrite(as: ByteBuffer.self) + + try await channel.writeInbound(ByteBuffer(string: "!10\r\nBulkError!\r\n")) + do { + _ = try await fooResult + Issue.record() + } catch let error as ValkeyClientError { + #expect(error.errorCode == .commandError) + #expect(error.message == "BulkError!") + } + } }