Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ keywords = ["protobuf", "protoc"]
license = "MIT"
desc = "Julia protobuf implementation"
authors = ["Tomáš Drvoštěp <[email protected]>"]
version = "1.1.1"
version = "1.1.2"

[deps]
BufferedStreams = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d"
Expand Down
61 changes: 43 additions & 18 deletions src/codec/Codecs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,37 @@

using BufferedStreams: BufferedOutputStream, BufferedInputStream

macro _const(ex)
ex = esc(ex)
if VERSION < v"1.8.0-DEV.1148"
return ex

Check warning on line 8 in src/codec/Codecs.jl

View check run for this annotation

Codecov / codecov/patch

src/codec/Codecs.jl#L8

Added line #L8 was not covered by tests
else
return Expr(:const, ex)
end
end
const var"@const" = var"@_const"

@enum(WireType::UInt32, VARINT=0, FIXED64=1, LENGTH_DELIMITED=2, START_GROUP=3, END_GROUP=4, FIXED32=5)

abstract type AbstractProtoDecoder end
abstract type AbstractProtoEncoder end
struct ProtoDecoder{I<:IO,F<:Function} <: AbstractProtoDecoder
io::I
message_done::F
end
message_done(d::ProtoDecoder) = d.message_done(d.io)
ProtoDecoder(io::IO) = ProtoDecoder(io, eof)
get_stream(d::AbstractProtoDecoder) = d.io

struct LengthDelimitedProtoDecoder{I<:IO} <: AbstractProtoDecoder
io::I
endpos::Int
mutable struct ProtoDecoder{I<:IO,F<:Function} <: AbstractProtoDecoder
@const io::I
@const message_done::F
end
message_done(d::LengthDelimitedProtoDecoder) = d.endpos == position(d.io)

struct GroupProtoDecoder{I<:IO} <: AbstractProtoDecoder
io::I
end
function message_done(d::GroupProtoDecoder)
done = peek(d.io) == UInt8(END_GROUP)
done && skip(d.io, 1)
function message_done(d::AbstractProtoDecoder, endpos::Int, group::Bool)
io = get_stream(d)
if group
done = peek(io) == UInt8(END_GROUP)
done && skip(io, 1)
else
done = d.message_done(io) || (endpos > 0 && position(io) >= endpos)
end
return done
end
ProtoDecoder(io::IO) = ProtoDecoder(io, eof)

struct ProtoEncoder{I<:IO} <: AbstractProtoEncoder
io::I
Expand Down Expand Up @@ -61,4 +67,23 @@

export encode, decode

end # module

# Backwards compatibility with old older decode methods

message_done(d::ProtoDecoder) = d.message_done(d.io)
struct LengthDelimitedProtoDecoder{I<:IO} <: AbstractProtoDecoder
io::I
endpos::Int
end
message_done(d::LengthDelimitedProtoDecoder) = d.endpos == position(d.io)

struct GroupProtoDecoder{I<:IO} <: AbstractProtoDecoder
io::I
end
function message_done(d::GroupProtoDecoder)
done = peek(d.io) == UInt8(END_GROUP)
done && skip(d.io, 1)
return done

Check warning on line 86 in src/codec/Codecs.jl

View check run for this annotation

Codecov / codecov/patch

src/codec/Codecs.jl#L83-L86

Added lines #L83 - L86 were not covered by tests
end

end # module
139 changes: 83 additions & 56 deletions src/codec/decode.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function decode_tag(d::AbstractProtoDecoder)
b = vbyte_decode(d.io, UInt32)
b = vbyte_decode(get_stream(d), UInt32)
field_number = b >> 3
wire_type = WireType(b & 0x07)
return field_number, wire_type
Expand All @@ -9,94 +9,99 @@
const _ScalarTypesEnum = Union{_ScalarTypes,Enum}

# uint32, uint64
decode(d::AbstractProtoDecoder, ::Type{T}) where {T <: Union{UInt32,UInt64}} = vbyte_decode(d.io, T)
decode(d::AbstractProtoDecoder, ::Type{T}) where {T <: Union{UInt32,UInt64}} = vbyte_decode(get_stream(d), T)
# int32: Negative int32 are encoded in 10 bytes...
# TODO: add check the int is negative if larger than typemax UInt32
decode(d::AbstractProtoDecoder, ::Type{Int32}) = reinterpret(Int32, UInt32(vbyte_decode(d.io, UInt64) % UInt32))
decode(d::AbstractProtoDecoder, ::Type{Int32}) = reinterpret(Int32, UInt32(vbyte_decode(get_stream(d), UInt64) % UInt32))
# int64
decode(d::AbstractProtoDecoder, ::Type{Int64}) = reinterpret(Int64, vbyte_decode(d.io, UInt64))
decode(d::AbstractProtoDecoder, ::Type{Int64}) = reinterpret(Int64, vbyte_decode(get_stream(d), UInt64))
# sfixed32, sfixed64, # fixed32, fixed64
decode(d::AbstractProtoDecoder, ::Type{T}, ::Type{Val{:fixed}}) where {T <: Union{Int32,Int64,UInt32,UInt64}} = read(d.io, T)
decode(d::AbstractProtoDecoder, ::Type{T}, ::Type{Val{:fixed}}) where {T <: Union{Int32,Int64,UInt32,UInt64}} = read(get_stream(d), T)
# sint32, sint64
function decode(d::AbstractProtoDecoder, ::Type{T}, ::Type{Val{:zigzag}}) where {T <: Union{Int32,Int64}}
v = vbyte_decode(d.io, unsigned(T))
v = vbyte_decode(get_stream(d), unsigned(T))
z = zigzag_decode(v)
return reinterpret(T, z)
end
decode(d::AbstractProtoDecoder, ::Type{Bool}) = Bool(read(d.io, UInt8))
decode(d::AbstractProtoDecoder, ::Type{Bool}) = Bool(read(get_stream(d), UInt8))
function decode(d::AbstractProtoDecoder, ::Type{T}) where {T <: Union{Enum{Int32},Enum{UInt32}}}
val = vbyte_decode(d.io, UInt32)
val = vbyte_decode(get_stream(d), UInt32)
return Core.bitcast(T, reinterpret(Int32, val))
end
decode(d::AbstractProtoDecoder, ::Type{T}) where {T <: Union{Float64,Float32}} = read(d.io, T)
decode(d::AbstractProtoDecoder, ::Type{T}) where {T <: Union{Float64,Float32}} = read(get_stream(d), T)
function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}) where {K,V<:_ScalarTypesEnum}
pair_len = vbyte_decode(d.io, UInt32)
pair_end_pos = position(d.io) + pair_len
io = get_stream(d)
pair_len = vbyte_decode(io, UInt32)
pair_end_pos = position(io) + pair_len
field_number, wire_type = decode_tag(d)
key = decode(d, K)
field_number, wire_type = decode_tag(d)
val = decode(d, V)
@assert position(d.io) == pair_end_pos
@assert position(io) == pair_end_pos
buffer[key] = val
nothing
end

function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}) where {K,V}
pair_len = vbyte_decode(d.io, UInt32)
pair_end_pos = position(d.io) + pair_len
io = get_stream(d)
pair_len = vbyte_decode(io, UInt32)
pair_end_pos = position(io) + pair_len
field_number, wire_type = decode_tag(d)
key = decode(d, K)
field_number, wire_type = decode_tag(d)
val = decode(d, Ref{V})
@assert position(d.io) == pair_end_pos
@assert position(io) == pair_end_pos
buffer[key] = val
nothing
end

for T in (:(:fixed), :(:zigzag))
@eval function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}, ::Type{Val{Tuple{Nothing,$(T)}}}) where {K,V}
pair_len = vbyte_decode(d.io, UInt32)
pair_end_pos = position(d.io) + pair_len
io = get_stream(d)
pair_len = vbyte_decode(io, UInt32)
pair_end_pos = position(io) + pair_len
field_number, wire_type = decode_tag(d)
key = decode(d, K)
field_number, wire_type = decode_tag(d)
val = decode(d, V, Val{$(T)})
@assert position(d.io) == pair_end_pos
@assert position(io) == pair_end_pos
buffer[key] = val
nothing
end

@eval function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}, ::Type{Val{Tuple{$(T),Nothing}}}) where {K,V}
pair_len = vbyte_decode(d.io, UInt32)
pair_end_pos = position(d.io) + pair_len
io = get_stream(d)
pair_len = vbyte_decode(io, UInt32)
pair_end_pos = position(io) + pair_len
field_number, wire_type = decode_tag(d)
key = decode(d, K, Val{$(T)})
field_number, wire_type = decode_tag(d)
val = decode(d, V)
@assert position(d.io) == pair_end_pos
@assert position(io) == pair_end_pos
buffer[key] = val
nothing
end
end

for T in (:(:fixed), :(:zigzag)), S in (:(:fixed), :(:zigzag))
@eval function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}, ::Type{Val{Tuple{$(T),$(S)}}}) where {K,V}
pair_len = vbyte_decode(d.io, UInt32)
pair_end_pos = position(d.io) + pair_len
io = get_stream(d)
pair_len = vbyte_decode(io, UInt32)
pair_end_pos = position(io) + pair_len
field_number, wire_type = decode_tag(d)
key = decode(d, K, Val{$(T)})
field_number, wire_type = decode_tag(d)
val = decode(d, V, Val{$(S)})
@assert position(d.io) == pair_end_pos
@assert position(io) == pair_end_pos
buffer[key] = val
nothing
end
end

function decode(d::AbstractProtoDecoder, ::Type{String})
bytelen = vbyte_decode(d.io, UInt32)
bytelen = vbyte_decode(get_stream(d), UInt32)
str = Base._string_n(bytelen)
Base.unsafe_read(d.io, pointer(str), bytelen)
Base.unsafe_read(get_stream(d), pointer(str), bytelen)
return str
end
function decode!(d::AbstractProtoDecoder, buffer::BufferedVector{String})
Expand All @@ -105,12 +110,12 @@
end

function decode(d::AbstractProtoDecoder, ::Type{Vector{UInt8}})
bytelen = vbyte_decode(d.io, UInt32)
return read(d.io, bytelen)
bytelen = vbyte_decode(get_stream(d), UInt32)
return read(get_stream(d), bytelen)
end
function decode(d::AbstractProtoDecoder, ::Type{Base.CodeUnits{UInt8, String}})
bytelen = vbyte_decode(d.io, UInt32)
return read(d.io, bytelen)
bytelen = vbyte_decode(get_stream(d), UInt32)
return read(get_stream(d), bytelen)
end
function decode!(d::AbstractProtoDecoder, buffer::BufferedVector{Vector{UInt8}})
buffer[] = decode(d, Vector{UInt8})
Expand All @@ -119,12 +124,13 @@

function decode!(d::AbstractProtoDecoder, w::WireType, buffer::BufferedVector{T}) where {T <: Union{Int32,Int64,UInt32,UInt64,Enum{Int32},Enum{UInt32}}}
if w == LENGTH_DELIMITED
bytelen = vbyte_decode(d.io, UInt32)
endpos = bytelen + position(d.io)
while position(d.io) < endpos
io = get_stream(d)
bytelen = vbyte_decode(io, UInt32)
endpos = bytelen + position(io)
while position(io) < endpos
buffer[] = decode(d, T)
end
@assert position(d.io) == endpos
@assert position(io) == endpos
else
buffer[] = decode(d, T)
end
Expand All @@ -133,12 +139,13 @@

function decode!(d::AbstractProtoDecoder, w::WireType, buffer::BufferedVector{T}, ::Type{Val{:zigzag}}) where {T <: Union{Int32,Int64}}
if w == LENGTH_DELIMITED
bytelen = vbyte_decode(d.io, UInt32)
endpos = bytelen + position(d.io)
while position(d.io) < endpos
io = get_stream(d)
bytelen = vbyte_decode(io, UInt32)
endpos = bytelen + position(io)
while position(io) < endpos
buffer[] = decode(d, T, Val{:zigzag})
end
@assert position(d.io) == endpos
@assert position(io) == endpos
else
buffer[] = decode(d, T, Val{:zigzag})
end
Expand All @@ -147,16 +154,17 @@

function decode!(d::AbstractProtoDecoder, w::WireType, buffer::BufferedVector{T}, ::Type{Val{:fixed}}) where {T <: Union{Int32,Int64,UInt32,UInt64}}
if w == LENGTH_DELIMITED
bytelen = vbyte_decode(d.io, UInt32)
io = get_stream(d)
bytelen = vbyte_decode(io, UInt32)
n_incoming = div(bytelen, sizeof(T))
n_current = length(buffer.elements)
resize!(buffer.elements, n_current + n_incoming)
endpos = bytelen + position(d.io)
endpos = bytelen + position(io)
for i in (n_current+1):(n_current + n_incoming)
buffer.occupied += 1
@inbounds buffer.elements[i] = decode(d, T, Val{:fixed})
end
@assert position(d.io) == endpos
@assert position(io) == endpos
else
buffer[] = decode(d, T, Val{:fixed})
end
Expand All @@ -165,16 +173,17 @@

function decode!(d::AbstractProtoDecoder, w::WireType, buffer::BufferedVector{T}) where {T <: Union{Bool,Float32,Float64}}
if w == LENGTH_DELIMITED
bytelen = vbyte_decode(d.io, UInt32)
io = get_stream(d)
bytelen = vbyte_decode(io, UInt32)
n_incoming = div(bytelen, sizeof(T))
n_current = length(buffer.elements)
resize!(buffer.elements, n_current + n_incoming)
endpos = bytelen + position(d.io)
endpos = bytelen + position(io)
for i in (n_current+1):(n_current + n_incoming)
buffer.occupied += 1
@inbounds buffer.elements[i] = decode(d, T)
end
@assert position(d.io) == endpos
@assert position(io) == endpos
else
buffer[] = decode(d, T)
end
Expand All @@ -187,10 +196,19 @@
# We don't reuse the decode!(d::AbstractProtoDecoder, buffer::Base.RefValue{T}) method above
# as with OneOf fields, we can't be sure that the previous OneOf value was also T.
function decode(d::AbstractProtoDecoder, ::Type{Ref{T}}) where {T}
bytelen = vbyte_decode(d.io, UInt32)
endpos = bytelen + position(d.io)
out = decode(LengthDelimitedProtoDecoder(d.io, endpos), T)
@assert position(d.io) == endpos
io = get_stream(d)
bytelen = vbyte_decode(io, UInt32)
endpos = bytelen + position(io)
if hasmethod(decode, Tuple{AbstractProtoDecoder, Type{T}, Int, Bool})
out = decode(d, T, endpos, false)
else
@warn "You are using code generated by an older version of ProtoBuf.jl, which \
was deprecated. Please regenerate your protobuf definitions with the current version of \
ProtoBuf.jl. The new version will allow for defining custom AbstractProtoDecoder variants. \
This warning is only printed once per session." maxlog=1 T=T
out = decode(LengthDelimitedProtoDecoder(get_stream(d), endpos), T)
end
@assert position(io) == endpos "$(T) decode: expected position $(endpos), got $(position(io))"
return out
end

Expand All @@ -200,7 +218,15 @@
end

function decode(d::AbstractProtoDecoder, ::Type{Ref{T}}, ::Type{Val{:group}}) where {T}
out = decode(GroupProtoDecoder(d.io), T)
if hasmethod(decode, Tuple{AbstractProtoDecoder, Type{T}, Int, Bool})
out = decode(d, T, 0, true)
else
@warn "You are using code generated by an older version of ProtoBuf.jl, which \

Check warning on line 224 in src/codec/decode.jl

View check run for this annotation

Codecov / codecov/patch

src/codec/decode.jl#L224

Added line #L224 was not covered by tests
was deprecated. Please regenerate your protobuf definitions with the current version of \
ProtoBuf.jl. The new version will allow for defining custom AbstractProtoDecoder variants. \
This warning is only printed once per session." maxlog=1 T=T
out = decode(GroupProtoDecoder(get_stream(d)), T)

Check warning on line 228 in src/codec/decode.jl

View check run for this annotation

Codecov / codecov/patch

src/codec/decode.jl#L228

Added line #L228 was not covered by tests
end
return out
end

Expand Down Expand Up @@ -291,20 +317,21 @@
end

@inline function Base.skip(d::AbstractProtoDecoder, wire_type::WireType)
io = get_stream(d)
if wire_type == VARINT
while read(d.io, UInt8) >= 0x80 end
while read(io, UInt8) >= 0x80 end
elseif wire_type == FIXED64
skip(d.io, 8)
skip(io, 8)
elseif wire_type == LENGTH_DELIMITED
bytelen = vbyte_decode(d.io, UInt32)
skip(d.io, bytelen)
bytelen = vbyte_decode(io, UInt32)
skip(io, bytelen)
elseif wire_type == START_GROUP
while peek(d.io) != UInt8(END_GROUP)
while peek(io) != UInt8(END_GROUP)
skip(d, decode_tag(d)[2])
end
skip(d.io, 1)
skip(io, 1)
elseif wire_type == FIXED32
skip(d.io, 4)
skip(io, 4)
else wire_type == END_GROUP
error("Encountered END_GROUP wiretype while skipping")
end
Expand Down
Loading
Loading