diff --git a/avro/src/error.rs b/avro/src/error.rs index bdb20552..b8192011 100644 --- a/avro/src/error.rs +++ b/avro/src/error.rs @@ -274,6 +274,14 @@ pub enum Details { #[error("Could not find matching type in {schema:?} for {value:?}")] FindUnionVariant { schema: UnionSchema, value: Value }, + #[error("Union index {index} out of bounds: {num_variants} in {schema:?} for {value:?}")] + UnionIndexOutOfBounds { + schema: UnionSchema, + value: Value, + index: usize, + num_variants: usize, + }, + #[error("Union type should not be empty")] EmptyUnion, diff --git a/avro/src/serde/ser.rs b/avro/src/serde/ser.rs index d78f5017..3335e326 100644 --- a/avro/src/serde/ser.rs +++ b/avro/src/serde/ser.rs @@ -31,9 +31,8 @@ pub struct SeqSerializer { items: Vec, } -pub struct SeqVariantSerializer<'a> { +pub struct SeqVariantSerializer { index: u32, - variant: &'a str, items: Vec, } @@ -46,9 +45,8 @@ pub struct StructSerializer { fields: Vec<(String, Value)>, } -pub struct StructVariantSerializer<'a> { +pub struct StructVariantSerializer { index: u32, - variant: &'a str, fields: Vec<(String, Value)>, } @@ -63,17 +61,13 @@ impl SeqSerializer { } } -impl<'a> SeqVariantSerializer<'a> { - pub fn new(index: u32, variant: &'a str, len: Option) -> SeqVariantSerializer<'a> { +impl SeqVariantSerializer { + pub fn new(index: u32, len: Option) -> SeqVariantSerializer { let items = match len { Some(len) => Vec::with_capacity(len), None => Vec::new(), }; - SeqVariantSerializer { - index, - variant, - items, - } + SeqVariantSerializer { index, items } } } @@ -96,26 +90,25 @@ impl StructSerializer { } } -impl<'a> StructVariantSerializer<'a> { - pub fn new(index: u32, variant: &'a str, len: usize) -> StructVariantSerializer<'a> { +impl StructVariantSerializer { + pub fn new(index: u32, len: usize) -> StructVariantSerializer { StructVariantSerializer { index, - variant, fields: Vec::with_capacity(len), } } } -impl<'b> ser::Serializer for &'b mut Serializer { +impl ser::Serializer for &mut Serializer { type Ok = Value; type Error = Error; type SerializeSeq = SeqSerializer; type SerializeTuple = SeqSerializer; type SerializeTupleStruct = SeqSerializer; - type SerializeTupleVariant = SeqVariantSerializer<'b>; + type SerializeTupleVariant = SeqVariantSerializer; type SerializeMap = MapSerializer; type SerializeStruct = StructSerializer; - type SerializeStructVariant = StructVariantSerializer<'b>; + type SerializeStructVariant = StructVariantSerializer; fn serialize_bool(self, v: bool) -> Result { Ok(Value::Boolean(v)) @@ -226,21 +219,15 @@ impl<'b> ser::Serializer for &'b mut Serializer { fn serialize_newtype_variant( self, - _: &'static str, + _name: &'static str, index: u32, - variant: &'static str, + _variant: &'static str, value: &T, ) -> Result where T: Serialize + ?Sized, { - Ok(Value::Record(vec![ - ("type".to_owned(), Value::Enum(index, variant.to_owned())), - ( - "value".to_owned(), - Value::Union(index, Box::new(value.serialize(self)?)), - ), - ])) + Ok(Value::Union(index, Box::new(value.serialize(self)?))) } fn serialize_seq(self, len: Option) -> Result { @@ -261,12 +248,12 @@ impl<'b> ser::Serializer for &'b mut Serializer { fn serialize_tuple_variant( self, - _: &'static str, + _name: &'static str, index: u32, - variant: &'static str, + _variant: &'static str, len: usize, ) -> Result { - Ok(SeqVariantSerializer::new(index, variant, Some(len))) + Ok(SeqVariantSerializer::new(index, Some(len))) } fn serialize_map(self, len: Option) -> Result { @@ -283,12 +270,12 @@ impl<'b> ser::Serializer for &'b mut Serializer { fn serialize_struct_variant( self, - _: &'static str, + _name: &'static str, index: u32, - variant: &'static str, + _variant: &'static str, len: usize, ) -> Result { - Ok(StructVariantSerializer::new(index, variant, len)) + Ok(StructVariantSerializer::new(index, len)) } fn is_human_readable(&self) -> bool { @@ -346,11 +333,11 @@ impl ser::SerializeTupleStruct for SeqSerializer { } } -impl ser::SerializeSeq for SeqVariantSerializer<'_> { +impl ser::SerializeTupleVariant for SeqVariantSerializer { type Ok = Value; type Error = Error; - fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> where T: Serialize + ?Sized, { @@ -362,29 +349,7 @@ impl ser::SerializeSeq for SeqVariantSerializer<'_> { } fn end(self) -> Result { - Ok(Value::Record(vec![ - ( - "type".to_owned(), - Value::Enum(self.index, self.variant.to_owned()), - ), - ("value".to_owned(), Value::Array(self.items)), - ])) - } -} - -impl ser::SerializeTupleVariant for SeqVariantSerializer<'_> { - type Ok = Value; - type Error = Error; - - fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> - where - T: Serialize + ?Sized, - { - ser::SerializeSeq::serialize_element(self, value) - } - - fn end(self) -> Result { - ser::SerializeSeq::end(self) + Ok(Value::Union(self.index, Box::new(Value::Array(self.items)))) } } @@ -447,7 +412,7 @@ impl ser::SerializeStruct for StructSerializer { } } -impl ser::SerializeStructVariant for StructVariantSerializer<'_> { +impl ser::SerializeStructVariant for StructVariantSerializer { type Ok = Value; type Error = Error; @@ -463,16 +428,10 @@ impl ser::SerializeStructVariant for StructVariantSerializer<'_> { } fn end(self) -> Result { - Ok(Value::Record(vec![ - ( - "type".to_owned(), - Value::Enum(self.index, self.variant.to_owned()), - ), - ( - "value".to_owned(), - Value::Union(self.index, Box::new(Value::Record(self.fields))), - ), - ])) + Ok(Value::Union( + self.index, + Box::new(Value::Record(self.fields)), + )) } } @@ -789,13 +748,7 @@ mod tests { let expected = Value::Record(vec![( "a".to_owned(), - Value::Record(vec![ - ("type".to_owned(), Value::Enum(0, "Double".to_owned())), - ( - "value".to_owned(), - Value::Union(0, Box::new(Value::Double(64.0))), - ), - ]), + Value::Union(0, Box::new(Value::Double(64.0))), )]); assert_eq!( @@ -851,19 +804,13 @@ mod tests { }; let expected = Value::Record(vec![( "a".to_owned(), - Value::Record(vec![ - ("type".to_owned(), Value::Enum(0, "Val1".to_owned())), - ( - "value".to_owned(), - Value::Union( - 0, - Box::new(Value::Record(vec![ - ("x".to_owned(), Value::Float(1.0)), - ("y".to_owned(), Value::Float(2.0)), - ])), - ), - ), - ]), + Value::Union( + 0, + Box::new(Value::Record(vec![ + ("x".to_owned(), Value::Float(1.0)), + ("y".to_owned(), Value::Float(2.0)), + ])), + ), )]); assert_eq!( @@ -965,17 +912,14 @@ mod tests { let expected = Value::Record(vec![( "a".to_owned(), - Value::Record(vec![ - ("type".to_owned(), Value::Enum(1, "Val2".to_owned())), - ( - "value".to_owned(), - Value::Array(vec![ - Value::Union(1, Box::new(Value::Float(1.0))), - Value::Union(1, Box::new(Value::Float(2.0))), - Value::Union(1, Box::new(Value::Float(3.0))), - ]), - ), - ]), + Value::Union( + 1, + Box::new(Value::Array(vec![ + Value::Union(1, Box::new(Value::Float(1.0))), + Value::Union(1, Box::new(Value::Float(2.0))), + Value::Union(1, Box::new(Value::Float(3.0))), + ])), + ), )]); assert_eq!( diff --git a/avro/src/types.rs b/avro/src/types.rs index 5a54c3f2..78833664 100644 --- a/avro/src/types.rs +++ b/avro/src/types.rs @@ -1024,18 +1024,34 @@ impl Value { enclosing_namespace: &Namespace, field_default: &Option, ) -> Result { - let v = match self { - // Both are unions case. - Value::Union(_i, v) => *v, - // Reader is a union, but writer is not. - v => v, - }; - let (i, inner) = schema - .find_schema_with_known_schemata(&v, Some(names), enclosing_namespace) - .ok_or_else(|| Details::FindUnionVariant { - schema: schema.clone(), - value: v.clone(), - })?; + let (i, inner, v) = + match self { + // Both are unions case. + Value::Union(i, v) => { + let index = i as usize; + let inner = schema.schemas.get(index).ok_or_else(|| { + Details::UnionIndexOutOfBounds { + schema: schema.clone(), + value: *v.clone(), + index, + num_variants: schema.schemas.len(), + } + })?; + + (index, inner, *v) + } + // Reader is a union, but writer is not. + v => { + let (i, inner) = schema + .find_schema_with_known_schemata(&v, Some(names), enclosing_namespace) + .ok_or_else(|| Details::FindUnionVariant { + schema: schema.clone(), + value: v.clone(), + })?; + + (i, inner, v) + } + }; Ok(Value::Union( i as u32, diff --git a/avro/tests/union_serialization.rs b/avro/tests/union_serialization.rs new file mode 100644 index 00000000..48cd1287 --- /dev/null +++ b/avro/tests/union_serialization.rs @@ -0,0 +1,202 @@ +use apache_avro::{AvroResult, Schema}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] +struct Root { + field_union: Enum, + field_f: String, +} + +#[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] +enum Enum { + A {}, + B {}, + C { + field_a: i64, + field_b: Option, + }, + D { + field_a: f32, + field_b: i32, + }, +} + +const SCHEMA_STR: &str = r#"{ + "name": "Root", + "type": "record", + "fields": [ + {"name": "field_union", "type": [ + { + "name": "A", + "type": "record", + "fields": [] + }, + { + "name": "B", + "type": "record", + "fields": [] + }, + { + "name": "C", + "type": "record", + "fields": [ + {"name": "field_a", "type": "long"}, + {"name": "field_b", "type": ["null", "string"]} + ] + }, + { + "name": "D", + "type": "record", + "fields": [ + {"name": "field_a", "type": "float"}, + {"name": "field_b", "type": "int"} + ] + } + ]}, + {"name": "field_f", "type": "string"} + ] +}"#; + +#[test] +fn test_union_variants_serialization() -> AvroResult<()> { + let schema = Schema::parse_str(SCHEMA_STR)?; + + // Test variant 0 + { + let input = Root { + field_union: Enum::A {}, + field_f: "test1".to_owned(), + }; + + #[rustfmt::skip] + let expected_bytes: [u8; 7] = [ + // Root { + // field_union: + 0x00, // variant 0 (Enum::A) { + // } + // field_f: + 0x0A, // string length = 5 + 0x74, 0x65, 0x73, 0x74, 0x31, // UTF-8 string "test1" + // } + ]; + + let value = apache_avro::to_value(&input)?.resolve(&schema)?; + let encoded = apache_avro::to_avro_datum(&schema, value)?; + + assert_eq!(encoded, expected_bytes); + + let value = apache_avro::from_avro_datum(&schema, &mut encoded.as_slice(), None)?; + let output: Root = apache_avro::from_value(&value)?; + + assert_eq!(input, output); + } + + // Test variant 1 + { + let input = Root { + field_union: Enum::B {}, + field_f: "test2".to_owned(), + }; + + #[rustfmt::skip] + let expected_bytes: [u8; 7] = [ + // Root { + // field_union: + 0x02, // variant 1 (Enum::B) { + // } + // field_f: + 0x0A, // string length = 5 + 0x74, 0x65, 0x73, 0x74, 0x32, // UTF-8 string "test2" + // } + ]; + + let value = apache_avro::to_value(&input)?.resolve(&schema)?; + let encoded = apache_avro::to_avro_datum(&schema, value)?; + + assert_eq!(encoded, expected_bytes); + + let value = apache_avro::from_avro_datum(&schema, &mut encoded.as_slice(), None)?; + let output: Root = apache_avro::from_value(&value)?; + + assert_eq!(input, output); + } + + // Test variant 2 + { + let input = Root { + field_union: Enum::C { + field_a: 3, + field_b: Some("test3".to_owned()), + }, + field_f: "test4".to_owned(), + }; + + #[rustfmt::skip] + let expected_bytes: [u8; 15] = [ + // Root { + // field_union: + 0x04, // variant 2 (Enum::C) { + // field_a: + 0x06, // 3 + // field_b: + 0x02, // variant 1 (Some) { + 0x0A, // string length = 5 + 0x74, 0x65, 0x73, 0x74, 0x33, // UTF-8 string "test3" + // } + // } + // field_f: + 0x0A, // string length = 5 + 0x74, 0x65, 0x73, 0x74, 0x34, // UTF-8 string "test4" + // } + ]; + + let value = apache_avro::to_value(&input)?.resolve(&schema)?; + let encoded = apache_avro::to_avro_datum(&schema, value)?; + + assert_eq!(encoded, expected_bytes); + + let value = apache_avro::from_avro_datum(&schema, &mut encoded.as_slice(), None)?; + let output: Root = apache_avro::from_value(&value)?; + + assert_eq!(input, output); + } + + // Test variant 3 + { + let input = Root { + field_union: Enum::D { + field_a: 0.0, + field_b: 4, + }, + field_f: "test5".to_owned(), + }; + + #[rustfmt::skip] + let expected_bytes: [u8; 12] = [ + // Root { + // field_union: + 0x06, // variant 3 (Enum::D) { + // field_a: + 0x00, 0x00, 0x00, 0x00, // 0.0 + // field_b: + 0x08, // 4 + // } + // field_f: + 0x0A, // string length = 5 + 0x74, 0x65, 0x73, 0x74, 0x35, // UTF-8 string "test5" + // } + ]; + + let value = apache_avro::to_value(&input)?.resolve(&schema)?; + let encoded = apache_avro::to_avro_datum(&schema, value)?; + + assert_eq!(encoded, expected_bytes); + + let value = apache_avro::from_avro_datum(&schema, &mut encoded.as_slice(), None)?; + let output: Root = apache_avro::from_value(&value)?; + + assert_eq!(input, output); + } + + Ok(()) +}