@@ -39,12 +39,10 @@ func RegisterDiscriminatedUnion[T any](key string, mappings map[string]reflect.T
3939
4040func (d * decoderBuilder ) newStructUnionDecoder (t reflect.Type ) decoderFunc {
4141 type variantDecoder struct {
42- decoder decoderFunc
43- field reflect.StructField
44- discriminatorValue any
42+ decoder decoderFunc
43+ field reflect.StructField
4544 }
46-
47- variants := []variantDecoder {}
45+ decoders := []variantDecoder {}
4846 for i := 0 ; i < t .NumField (); i ++ {
4947 field := t .Field (i )
5048
@@ -53,18 +51,26 @@ func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc {
5351 }
5452
5553 decoder := d .typeDecoder (field .Type )
56- variants = append (variants , variantDecoder {
54+ decoders = append (decoders , variantDecoder {
5755 decoder : decoder ,
5856 field : field ,
5957 })
6058 }
6159
60+ type discriminatedDecoder struct {
61+ variantDecoder
62+ discriminator any
63+ }
64+ discriminatedDecoders := []discriminatedDecoder {}
6265 unionEntry , discriminated := unionRegistry [t ]
63- for _ , unionVariant := range unionEntry .variants {
64- for i := 0 ; i < len (variants ); i ++ {
65- variant := & variants [i ]
66- if variant .field .Type .Elem () == unionVariant .Type {
67- variant .discriminatorValue = unionVariant .DiscriminatorValue
66+ for _ , variant := range unionEntry .variants {
67+ // For each union variant, find a matching decoder and save it
68+ for _ , decoder := range decoders {
69+ if decoder .field .Type .Elem () == variant .Type {
70+ discriminatedDecoders = append (discriminatedDecoders , discriminatedDecoder {
71+ decoder ,
72+ variant .DiscriminatorValue ,
73+ })
6874 break
6975 }
7076 }
@@ -73,10 +79,10 @@ func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc {
7379 return func (n gjson.Result , v reflect.Value , state * decoderState ) error {
7480 if discriminated && n .Type == gjson .JSON && len (unionEntry .discriminatorKey ) != 0 {
7581 discriminator := n .Get (unionEntry .discriminatorKey ).Value ()
76- for _ , variant := range variants {
77- if discriminator == variant . discriminatorValue {
78- inner := v .FieldByIndex (variant .field .Index )
79- return variant .decoder (n , inner , state )
82+ for _ , decoder := range discriminatedDecoders {
83+ if discriminator == decoder . discriminator {
84+ inner := v .FieldByIndex (decoder .field .Index )
85+ return decoder .decoder (n , inner , state )
8086 }
8187 }
8288 return errors .New ("apijson: was not able to find discriminated union variant" )
@@ -85,15 +91,15 @@ func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc {
8591 // Set bestExactness to worse than loose
8692 bestExactness := loose - 1
8793 bestVariant := - 1
88- for i , variant := range variants {
94+ for i , decoder := range decoders {
8995 // Pointers are used to discern JSON object variants from value variants
90- if n .Type != gjson .JSON && variant .field .Type .Kind () == reflect .Ptr {
96+ if n .Type != gjson .JSON && decoder .field .Type .Kind () == reflect .Ptr {
9197 continue
9298 }
9399
94100 sub := decoderState {strict : state .strict , exactness : exact }
95- inner := v .FieldByIndex (variant .field .Index )
96- err := variant .decoder (n , inner , & sub )
101+ inner := v .FieldByIndex (decoder .field .Index )
102+ err := decoder .decoder (n , inner , & sub )
97103 if err != nil {
98104 continue
99105 }
@@ -116,11 +122,11 @@ func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc {
116122 return errors .New ("apijson: was not able to coerce type as union strictly" )
117123 }
118124
119- for i := 0 ; i < len (variants ); i ++ {
125+ for i := 0 ; i < len (decoders ); i ++ {
120126 if i == bestVariant {
121127 continue
122128 }
123- v .FieldByIndex (variants [i ].field .Index ).SetZero ()
129+ v .FieldByIndex (decoders [i ].field .Index ).SetZero ()
124130 }
125131
126132 return nil
0 commit comments