Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions Sources/Valkey/Connection/ValkeyChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,18 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
}

func handleToken(context: ChannelHandlerContext, token: RESPToken) {
guard let promise = commands.popFirst() else {
preconditionFailure("Unexpected response")
switch token.identifier {
case .simpleError, .bulkError:
guard let promise = commands.popFirst() else {
preconditionFailure("Unexpected response")
}
promise.fail(ValkeyClientError(.commandError, message: token.errorString.map { String(buffer: $0) }))
default:
guard let promise = commands.popFirst() else {
preconditionFailure("Unexpected response")
}
promise.succeed(token)
}
promise.succeed(token)
}

func handleError(context: ChannelHandlerContext, error: Error) {
Expand Down
19 changes: 14 additions & 5 deletions Sources/Valkey/RESP/RESPToken.swift
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public struct RESPToken: Hashable, Sendable {

return .bulkString(local.readSlice(length: length)!)

case .blobError:
case .bulkError:
var lengthSlice = try! local.readCRLFTerminatedSlice2()!
let lengthString = lengthSlice.readString(length: lengthSlice.readableBytes)!
let length = Int(lengthString)!
Expand Down Expand Up @@ -197,7 +197,7 @@ public struct RESPToken: Hashable, Sendable {
case .simpleError:
let slice = try! local.readCRLFTerminatedSlice2()!
return slice
case .blobError:
case .bulkError:
var lengthSlice = try! local.readCRLFTerminatedSlice2()!
let lengthString = lengthSlice.readString(length: lengthSlice.readableBytes)!
let length = Int(lengthString)!
Expand All @@ -207,6 +207,10 @@ public struct RESPToken: Hashable, Sendable {
}
}

public var identifier: RESPTypeIdentifier {
self.base.getValidatedRESP3TypeIdentifier()
}

public init?(consuming buffer: inout ByteBuffer) throws {
try self.init(consuming: &buffer, depth: 0)
}
Expand All @@ -223,7 +227,7 @@ public struct RESPToken: Hashable, Sendable {

case .some(.bulkString),
.some(.verbatimString),
.some(.blobError):
.some(.bulkError):
validated = try buffer.readRESPBlobStringSlice()

case .some(.simpleString),
Expand Down Expand Up @@ -260,7 +264,7 @@ public struct RESPToken: Hashable, Sendable {
}

extension ByteBuffer {
fileprivate mutating func getRESP3TypeIdentifier(at index: Int) throws -> RESPTypeIdentifier? {
fileprivate func getRESP3TypeIdentifier(at index: Int) throws -> RESPTypeIdentifier? {
guard let int = self.getInteger(at: index, as: UInt8.self) else {
return nil
}
Expand All @@ -272,6 +276,11 @@ extension ByteBuffer {
return id
}

fileprivate func getValidatedRESP3TypeIdentifier() -> RESPTypeIdentifier {
let int = self.getInteger(at: self.readerIndex, as: UInt8.self)!
return RESPTypeIdentifier(rawValue: int)!
}

fileprivate mutating func readValidatedRESP3TypeIdentifier() -> RESPTypeIdentifier {
let int = self.readInteger(as: UInt8.self)!
return RESPTypeIdentifier(rawValue: int)!
Expand Down Expand Up @@ -311,7 +320,7 @@ extension ByteBuffer {

fileprivate mutating func readRESPBlobStringSlice() throws -> ByteBuffer? {
let marker = try self.getRESP3TypeIdentifier(at: self.readerIndex)!
precondition(marker == .bulkString || marker == .verbatimString || marker == .blobError)
precondition(marker == .bulkString || marker == .verbatimString || marker == .bulkError)
guard var lengthSlice = try self.getCRLFTerminatedSlice(at: self.readerIndex + 1) else {
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/Valkey/RESP/RESPTypeIdentifier.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public enum RESPTypeIdentifier: UInt8 {
case simpleString = 43 // UInt8.plus
case simpleError = 45 // UInt8.min
case bulkString = 36 // UInt8.dollar
case blobError = 33 // UInt8.exclamationMark
case bulkError = 33 // UInt8.exclamationMark
case verbatimString = 61 // UInt8.equals
case boolean = 35 // UInt8.pound
case null = 95 // UInt8.underscore
Expand Down
2 changes: 1 addition & 1 deletion Sources/Valkey/ValkeyClientError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
//===----------------------------------------------------------------------===//

/// Errors returned by ``ValkeyClient``
public struct ValkeyClientError: Error, CustomStringConvertible {
public struct ValkeyClientError: Error, CustomStringConvertible, Equatable {
public struct ErrorCode: Equatable, Sendable {
fileprivate enum _Internal: Equatable, Sendable {
case connectionClosed
Expand Down
12 changes: 12 additions & 0 deletions Tests/IntegrationTests/ValkeyTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@ struct GeneratedCommands {
}
}

@Test("Test command error is thrown")
func testCommandError() async throws {
var logger = Logger(label: "Valkey")
logger.logLevel = .debug
try await ValkeyClient(.hostname(valkeyHostname, port: 6379), logger: logger).withConnection(logger: logger) { connection in
try await withKey(connection: connection) { key in
_ = try await connection.set(key: key, value: "Hello")
await #expect(throws: ValkeyClientError.self) { _ = try await connection.rpop(key: key) }
}
}
}
Comment on lines +175 to +184
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you maybe want to add this as a unit test for both RESP variants?


@Test
func testMultiplexing() async throws {
var logger = Logger(label: "Valkey")
Expand Down
38 changes: 38 additions & 0 deletions Tests/ValkeyTests/ValkeyConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,42 @@ struct ConnectionTests {
try await channel.writeInbound(ByteBuffer(string: "$3\r\nBar\r\n"))
#expect(try await fooResult == "Bar")
}

@Test
func testSimpleError() async throws {
let channel = NIOAsyncTestingChannel()
let logger = Logger(label: "test")
let connection = try await ValkeyConnection.setupChannel(channel, configuration: .init(), logger: logger)

async let fooResult = connection.get(key: "foo")
_ = try await channel.waitForOutboundWrite(as: ByteBuffer.self)

try await channel.writeInbound(ByteBuffer(string: "-Error!\r\n"))
do {
_ = try await fooResult
Issue.record()
} catch let error as ValkeyClientError {
#expect(error.errorCode == .commandError)
#expect(error.message == "Error!")
}
}

@Test
func testBulkError() async throws {
let channel = NIOAsyncTestingChannel()
let logger = Logger(label: "test")
let connection = try await ValkeyConnection.setupChannel(channel, configuration: .init(), logger: logger)

async let fooResult = connection.get(key: "foo")
_ = try await channel.waitForOutboundWrite(as: ByteBuffer.self)

try await channel.writeInbound(ByteBuffer(string: "!10\r\nBulkError!\r\n"))
do {
_ = try await fooResult
Issue.record()
} catch let error as ValkeyClientError {
#expect(error.errorCode == .commandError)
#expect(error.message == "BulkError!")
}
}
}