diff --git a/Project.toml b/Project.toml index ad8f622..c7bb4f9 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,7 @@ keywords = ["protobuf", "protoc"] license = "MIT" desc = "Julia protobuf implementation" authors = ["Tomáš Drvoštěp "] -version = "1.1.1" +version = "1.1.2" [deps] BufferedStreams = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" diff --git a/src/codec/Codecs.jl b/src/codec/Codecs.jl index e4c01af..622646f 100644 --- a/src/codec/Codecs.jl +++ b/src/codec/Codecs.jl @@ -2,31 +2,37 @@ module Codecs using BufferedStreams: BufferedOutputStream, BufferedInputStream +macro _const(ex) + ex = esc(ex) + if VERSION < v"1.8.0-DEV.1148" + return ex + else + return Expr(:const, ex) + end +end +const var"@const" = var"@_const" + @enum(WireType::UInt32, VARINT=0, FIXED64=1, LENGTH_DELIMITED=2, START_GROUP=3, END_GROUP=4, FIXED32=5) abstract type AbstractProtoDecoder end abstract type AbstractProtoEncoder end -struct ProtoDecoder{I<:IO,F<:Function} <: AbstractProtoDecoder - io::I - message_done::F -end -message_done(d::ProtoDecoder) = d.message_done(d.io) -ProtoDecoder(io::IO) = ProtoDecoder(io, eof) +get_stream(d::AbstractProtoDecoder) = d.io -struct LengthDelimitedProtoDecoder{I<:IO} <: AbstractProtoDecoder - io::I - endpos::Int +mutable struct ProtoDecoder{I<:IO,F<:Function} <: AbstractProtoDecoder + @const io::I + @const message_done::F end -message_done(d::LengthDelimitedProtoDecoder) = d.endpos == position(d.io) - -struct GroupProtoDecoder{I<:IO} <: AbstractProtoDecoder - io::I -end -function message_done(d::GroupProtoDecoder) - done = peek(d.io) == UInt8(END_GROUP) - done && skip(d.io, 1) +function message_done(d::AbstractProtoDecoder, endpos::Int, group::Bool) + io = get_stream(d) + if group + done = peek(io) == UInt8(END_GROUP) + done && skip(io, 1) + else + done = d.message_done(io) || (endpos > 0 && position(io) >= endpos) + end return done end +ProtoDecoder(io::IO) = ProtoDecoder(io, eof) struct ProtoEncoder{I<:IO} <: AbstractProtoEncoder io::I @@ -61,4 +67,23 @@ include("encode.jl") export encode, decode -end # module \ No newline at end of file + +# Backwards compatibility with old older decode methods + +message_done(d::ProtoDecoder) = d.message_done(d.io) +struct LengthDelimitedProtoDecoder{I<:IO} <: AbstractProtoDecoder + io::I + endpos::Int +end +message_done(d::LengthDelimitedProtoDecoder) = d.endpos == position(d.io) + +struct GroupProtoDecoder{I<:IO} <: AbstractProtoDecoder + io::I +end +function message_done(d::GroupProtoDecoder) + done = peek(d.io) == UInt8(END_GROUP) + done && skip(d.io, 1) + return done +end + +end # module diff --git a/src/codec/decode.jl b/src/codec/decode.jl index 3f32bb8..146aff9 100644 --- a/src/codec/decode.jl +++ b/src/codec/decode.jl @@ -1,5 +1,5 @@ function decode_tag(d::AbstractProtoDecoder) - b = vbyte_decode(d.io, UInt32) + b = vbyte_decode(get_stream(d), UInt32) field_number = b >> 3 wire_type = WireType(b & 0x07) return field_number, wire_type @@ -9,71 +9,75 @@ const _ScalarTypes = Union{Float64,Float32,Int32,Int64,UInt64,UInt32,Bool,String const _ScalarTypesEnum = Union{_ScalarTypes,Enum} # uint32, uint64 -decode(d::AbstractProtoDecoder, ::Type{T}) where {T <: Union{UInt32,UInt64}} = vbyte_decode(d.io, T) +decode(d::AbstractProtoDecoder, ::Type{T}) where {T <: Union{UInt32,UInt64}} = vbyte_decode(get_stream(d), T) # int32: Negative int32 are encoded in 10 bytes... # TODO: add check the int is negative if larger than typemax UInt32 -decode(d::AbstractProtoDecoder, ::Type{Int32}) = reinterpret(Int32, UInt32(vbyte_decode(d.io, UInt64) % UInt32)) +decode(d::AbstractProtoDecoder, ::Type{Int32}) = reinterpret(Int32, UInt32(vbyte_decode(get_stream(d), UInt64) % UInt32)) # int64 -decode(d::AbstractProtoDecoder, ::Type{Int64}) = reinterpret(Int64, vbyte_decode(d.io, UInt64)) +decode(d::AbstractProtoDecoder, ::Type{Int64}) = reinterpret(Int64, vbyte_decode(get_stream(d), UInt64)) # sfixed32, sfixed64, # fixed32, fixed64 -decode(d::AbstractProtoDecoder, ::Type{T}, ::Type{Val{:fixed}}) where {T <: Union{Int32,Int64,UInt32,UInt64}} = read(d.io, T) +decode(d::AbstractProtoDecoder, ::Type{T}, ::Type{Val{:fixed}}) where {T <: Union{Int32,Int64,UInt32,UInt64}} = read(get_stream(d), T) # sint32, sint64 function decode(d::AbstractProtoDecoder, ::Type{T}, ::Type{Val{:zigzag}}) where {T <: Union{Int32,Int64}} - v = vbyte_decode(d.io, unsigned(T)) + v = vbyte_decode(get_stream(d), unsigned(T)) z = zigzag_decode(v) return reinterpret(T, z) end -decode(d::AbstractProtoDecoder, ::Type{Bool}) = Bool(read(d.io, UInt8)) +decode(d::AbstractProtoDecoder, ::Type{Bool}) = Bool(read(get_stream(d), UInt8)) function decode(d::AbstractProtoDecoder, ::Type{T}) where {T <: Union{Enum{Int32},Enum{UInt32}}} - val = vbyte_decode(d.io, UInt32) + val = vbyte_decode(get_stream(d), UInt32) return Core.bitcast(T, reinterpret(Int32, val)) end -decode(d::AbstractProtoDecoder, ::Type{T}) where {T <: Union{Float64,Float32}} = read(d.io, T) +decode(d::AbstractProtoDecoder, ::Type{T}) where {T <: Union{Float64,Float32}} = read(get_stream(d), T) function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}) where {K,V<:_ScalarTypesEnum} - pair_len = vbyte_decode(d.io, UInt32) - pair_end_pos = position(d.io) + pair_len + io = get_stream(d) + pair_len = vbyte_decode(io, UInt32) + pair_end_pos = position(io) + pair_len field_number, wire_type = decode_tag(d) key = decode(d, K) field_number, wire_type = decode_tag(d) val = decode(d, V) - @assert position(d.io) == pair_end_pos + @assert position(io) == pair_end_pos buffer[key] = val nothing end function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}) where {K,V} - pair_len = vbyte_decode(d.io, UInt32) - pair_end_pos = position(d.io) + pair_len + io = get_stream(d) + pair_len = vbyte_decode(io, UInt32) + pair_end_pos = position(io) + pair_len field_number, wire_type = decode_tag(d) key = decode(d, K) field_number, wire_type = decode_tag(d) val = decode(d, Ref{V}) - @assert position(d.io) == pair_end_pos + @assert position(io) == pair_end_pos buffer[key] = val nothing end for T in (:(:fixed), :(:zigzag)) @eval function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}, ::Type{Val{Tuple{Nothing,$(T)}}}) where {K,V} - pair_len = vbyte_decode(d.io, UInt32) - pair_end_pos = position(d.io) + pair_len + io = get_stream(d) + pair_len = vbyte_decode(io, UInt32) + pair_end_pos = position(io) + pair_len field_number, wire_type = decode_tag(d) key = decode(d, K) field_number, wire_type = decode_tag(d) val = decode(d, V, Val{$(T)}) - @assert position(d.io) == pair_end_pos + @assert position(io) == pair_end_pos buffer[key] = val nothing end @eval function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}, ::Type{Val{Tuple{$(T),Nothing}}}) where {K,V} - pair_len = vbyte_decode(d.io, UInt32) - pair_end_pos = position(d.io) + pair_len + io = get_stream(d) + pair_len = vbyte_decode(io, UInt32) + pair_end_pos = position(io) + pair_len field_number, wire_type = decode_tag(d) key = decode(d, K, Val{$(T)}) field_number, wire_type = decode_tag(d) val = decode(d, V) - @assert position(d.io) == pair_end_pos + @assert position(io) == pair_end_pos buffer[key] = val nothing end @@ -81,22 +85,23 @@ end for T in (:(:fixed), :(:zigzag)), S in (:(:fixed), :(:zigzag)) @eval function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}, ::Type{Val{Tuple{$(T),$(S)}}}) where {K,V} - pair_len = vbyte_decode(d.io, UInt32) - pair_end_pos = position(d.io) + pair_len + io = get_stream(d) + pair_len = vbyte_decode(io, UInt32) + pair_end_pos = position(io) + pair_len field_number, wire_type = decode_tag(d) key = decode(d, K, Val{$(T)}) field_number, wire_type = decode_tag(d) val = decode(d, V, Val{$(S)}) - @assert position(d.io) == pair_end_pos + @assert position(io) == pair_end_pos buffer[key] = val nothing end end function decode(d::AbstractProtoDecoder, ::Type{String}) - bytelen = vbyte_decode(d.io, UInt32) + bytelen = vbyte_decode(get_stream(d), UInt32) str = Base._string_n(bytelen) - Base.unsafe_read(d.io, pointer(str), bytelen) + Base.unsafe_read(get_stream(d), pointer(str), bytelen) return str end function decode!(d::AbstractProtoDecoder, buffer::BufferedVector{String}) @@ -105,12 +110,12 @@ function decode!(d::AbstractProtoDecoder, buffer::BufferedVector{String}) end function decode(d::AbstractProtoDecoder, ::Type{Vector{UInt8}}) - bytelen = vbyte_decode(d.io, UInt32) - return read(d.io, bytelen) + bytelen = vbyte_decode(get_stream(d), UInt32) + return read(get_stream(d), bytelen) end function decode(d::AbstractProtoDecoder, ::Type{Base.CodeUnits{UInt8, String}}) - bytelen = vbyte_decode(d.io, UInt32) - return read(d.io, bytelen) + bytelen = vbyte_decode(get_stream(d), UInt32) + return read(get_stream(d), bytelen) end function decode!(d::AbstractProtoDecoder, buffer::BufferedVector{Vector{UInt8}}) buffer[] = decode(d, Vector{UInt8}) @@ -119,12 +124,13 @@ end function decode!(d::AbstractProtoDecoder, w::WireType, buffer::BufferedVector{T}) where {T <: Union{Int32,Int64,UInt32,UInt64,Enum{Int32},Enum{UInt32}}} if w == LENGTH_DELIMITED - bytelen = vbyte_decode(d.io, UInt32) - endpos = bytelen + position(d.io) - while position(d.io) < endpos + io = get_stream(d) + bytelen = vbyte_decode(io, UInt32) + endpos = bytelen + position(io) + while position(io) < endpos buffer[] = decode(d, T) end - @assert position(d.io) == endpos + @assert position(io) == endpos else buffer[] = decode(d, T) end @@ -133,12 +139,13 @@ end function decode!(d::AbstractProtoDecoder, w::WireType, buffer::BufferedVector{T}, ::Type{Val{:zigzag}}) where {T <: Union{Int32,Int64}} if w == LENGTH_DELIMITED - bytelen = vbyte_decode(d.io, UInt32) - endpos = bytelen + position(d.io) - while position(d.io) < endpos + io = get_stream(d) + bytelen = vbyte_decode(io, UInt32) + endpos = bytelen + position(io) + while position(io) < endpos buffer[] = decode(d, T, Val{:zigzag}) end - @assert position(d.io) == endpos + @assert position(io) == endpos else buffer[] = decode(d, T, Val{:zigzag}) end @@ -147,16 +154,17 @@ end function decode!(d::AbstractProtoDecoder, w::WireType, buffer::BufferedVector{T}, ::Type{Val{:fixed}}) where {T <: Union{Int32,Int64,UInt32,UInt64}} if w == LENGTH_DELIMITED - bytelen = vbyte_decode(d.io, UInt32) + io = get_stream(d) + bytelen = vbyte_decode(io, UInt32) n_incoming = div(bytelen, sizeof(T)) n_current = length(buffer.elements) resize!(buffer.elements, n_current + n_incoming) - endpos = bytelen + position(d.io) + endpos = bytelen + position(io) for i in (n_current+1):(n_current + n_incoming) buffer.occupied += 1 @inbounds buffer.elements[i] = decode(d, T, Val{:fixed}) end - @assert position(d.io) == endpos + @assert position(io) == endpos else buffer[] = decode(d, T, Val{:fixed}) end @@ -165,16 +173,17 @@ end function decode!(d::AbstractProtoDecoder, w::WireType, buffer::BufferedVector{T}) where {T <: Union{Bool,Float32,Float64}} if w == LENGTH_DELIMITED - bytelen = vbyte_decode(d.io, UInt32) + io = get_stream(d) + bytelen = vbyte_decode(io, UInt32) n_incoming = div(bytelen, sizeof(T)) n_current = length(buffer.elements) resize!(buffer.elements, n_current + n_incoming) - endpos = bytelen + position(d.io) + endpos = bytelen + position(io) for i in (n_current+1):(n_current + n_incoming) buffer.occupied += 1 @inbounds buffer.elements[i] = decode(d, T) end - @assert position(d.io) == endpos + @assert position(io) == endpos else buffer[] = decode(d, T) end @@ -187,10 +196,19 @@ end # We don't reuse the decode!(d::AbstractProtoDecoder, buffer::Base.RefValue{T}) method above # as with OneOf fields, we can't be sure that the previous OneOf value was also T. function decode(d::AbstractProtoDecoder, ::Type{Ref{T}}) where {T} - bytelen = vbyte_decode(d.io, UInt32) - endpos = bytelen + position(d.io) - out = decode(LengthDelimitedProtoDecoder(d.io, endpos), T) - @assert position(d.io) == endpos + io = get_stream(d) + bytelen = vbyte_decode(io, UInt32) + endpos = bytelen + position(io) + if hasmethod(decode, Tuple{AbstractProtoDecoder, Type{T}, Int, Bool}) + out = decode(d, T, endpos, false) + else + @warn "You are using code generated by an older version of ProtoBuf.jl, which \ + was deprecated. Please regenerate your protobuf definitions with the current version of \ + ProtoBuf.jl. The new version will allow for defining custom AbstractProtoDecoder variants. \ + This warning is only printed once per session." maxlog=1 T=T + out = decode(LengthDelimitedProtoDecoder(get_stream(d), endpos), T) + end + @assert position(io) == endpos "$(T) decode: expected position $(endpos), got $(position(io))" return out end @@ -200,7 +218,15 @@ function decode!(d::AbstractProtoDecoder, buffer::BufferedVector{T}) where {T} end function decode(d::AbstractProtoDecoder, ::Type{Ref{T}}, ::Type{Val{:group}}) where {T} - out = decode(GroupProtoDecoder(d.io), T) + if hasmethod(decode, Tuple{AbstractProtoDecoder, Type{T}, Int, Bool}) + out = decode(d, T, 0, true) + else + @warn "You are using code generated by an older version of ProtoBuf.jl, which \ + was deprecated. Please regenerate your protobuf definitions with the current version of \ + ProtoBuf.jl. The new version will allow for defining custom AbstractProtoDecoder variants. \ + This warning is only printed once per session." maxlog=1 T=T + out = decode(GroupProtoDecoder(get_stream(d)), T) + end return out end @@ -291,20 +317,21 @@ end end @inline function Base.skip(d::AbstractProtoDecoder, wire_type::WireType) + io = get_stream(d) if wire_type == VARINT - while read(d.io, UInt8) >= 0x80 end + while read(io, UInt8) >= 0x80 end elseif wire_type == FIXED64 - skip(d.io, 8) + skip(io, 8) elseif wire_type == LENGTH_DELIMITED - bytelen = vbyte_decode(d.io, UInt32) - skip(d.io, bytelen) + bytelen = vbyte_decode(io, UInt32) + skip(io, bytelen) elseif wire_type == START_GROUP - while peek(d.io) != UInt8(END_GROUP) + while peek(io) != UInt8(END_GROUP) skip(d, decode_tag(d)[2]) end - skip(d.io, 1) + skip(io, 1) elseif wire_type == FIXED32 - skip(d.io, 4) + skip(io, 4) else wire_type == END_GROUP error("Encountered END_GROUP wiretype while skipping") end diff --git a/src/codegen/decode_methods.jl b/src/codegen/decode_methods.jl index 6dee86d..c82ce9e 100644 --- a/src/codegen/decode_methods.jl +++ b/src/codegen/decode_methods.jl @@ -88,13 +88,13 @@ jl_fieldname_deref(f::GroupType, ::Context) = "$(jl_fieldname(f))[]" function generate_decode_method(io, t::MessageType, ctx::Context) - println(io, "function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:$(safename(t))})") + println(io, "function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:$(safename(t))}, _endpos::Int=0, _group::Bool=false)") n = length(t.fields) has_fields = n > 0 for field in t.fields println(io, " ", jl_fieldname(field)::String, " = ", jl_init_value(field, ctx)::String) end - println(io, " while !PB.message_done(d)") + println(io, " while !PB.message_done(d, _endpos, _group)") println(io, " field_number, wire_type = PB.decode_tag(d)") for (i, field) in enumerate(t.fields) field_decode_expr(io, field, i, ctx) diff --git a/test/test_decode.jl b/test/test_decode.jl index ed3fab4..8795339 100644 --- a/test/test_decode.jl +++ b/test/test_decode.jl @@ -322,5 +322,15 @@ end @test decode(d, TestInner) == TestInner(0) end + + @testset "backwards compat" begin + OlderVersion = include(joinpath(test_dir, "test_protos", "test_decode_backwards_compat.jl")) + msg = OlderVersion.TestInner(0, OlderVersion.TestInner(1, OlderVersion.TestInner(2, nothing))) + io = IOBuffer() + e = PB.ProtoEncoder(io) + PB.encode(e, msg) + roundtripped = PB.decode(PB.ProtoDecoder(seekstart(io)), OlderVersion.TestInner) + @test roundtripped == msg + end end end # module diff --git a/test/test_protos/test_decode_backwards_compat.jl b/test/test_protos/test_decode_backwards_compat.jl new file mode 100644 index 0000000..f895817 --- /dev/null +++ b/test/test_protos/test_decode_backwards_compat.jl @@ -0,0 +1,47 @@ +module OlderVersion + # Autogenerated using ProtoBuf.jl v1.1.1 on 2025-07-08T10:40:02.200 + # original file: /Users/tdrvostep/.julia/dev/ProtoBuf/test/test_protos/test_messages_for_codec.proto (proto3 syntax) + + import ProtoBuf as PB + using ProtoBuf: OneOf + using ProtoBuf.EnumX: @enumx + + abstract type var"##Abstract#TestStruct" end + + + struct TestInner + x::Int64 + r::Union{Nothing,TestInner} + end + PB.default_values(::Type{TestInner}) = (;x = zero(Int64), r = nothing) + PB.field_numbers(::Type{TestInner}) = (;x = 1, r = 2) + + function PB.decode(d::PB.AbstractProtoDecoder, ::Type{<:TestInner}) + x = zero(Int64) + r = Ref{Union{Nothing,TestInner}}(nothing) + while !PB.message_done(d) + field_number, wire_type = PB.decode_tag(d) + if field_number == 1 + x = PB.decode(d, Int64) + elseif field_number == 2 + PB.decode!(d, r) + else + Base.skip(d, wire_type) + end + end + return TestInner(x, r[]) + end + + function PB.encode(e::PB.AbstractProtoEncoder, x::TestInner) + initpos = position(e.io) + x.x != zero(Int64) && PB.encode(e, 1, x.x) + !isnothing(x.r) && PB.encode(e, 2, x.r) + return position(e.io) - initpos + end + function PB._encoded_size(x::TestInner) + encoded_size = 0 + x.x != zero(Int64) && (encoded_size += PB._encoded_size(x.x, 1)) + !isnothing(x.r) && (encoded_size += PB._encoded_size(x.r, 2)) + return encoded_size + end +end