Skip to content

Commit 6d69896

Browse files
authored
chore: Add support for typed nested objects in autogen (#3801)
1 parent e25a01c commit 6d69896

26 files changed

+846
-129
lines changed

.golangci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ linters:
108108
- linters:
109109
- gocritic
110110
text: "^hugeParam: req is heavy"
111-
- path: schema\.go # exclude rules for schema files as it's auto-genereated from OpenAPI spec
111+
- path: schema\.go # exclude rules for schema files as it's auto-generated from OpenAPI spec
112112
text: var-naming|exceeds the maximum|regexpSimplify
113113
- path: (.+)\.go$
114114
text: declaration of ".*" shadows declaration at line .*
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
package customtype
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"reflect"
7+
8+
"github.com/hashicorp/terraform-plugin-framework/attr"
9+
"github.com/hashicorp/terraform-plugin-framework/diag"
10+
"github.com/hashicorp/terraform-plugin-framework/types/basetypes"
11+
"github.com/hashicorp/terraform-plugin-go/tftypes"
12+
)
13+
14+
/*
15+
Custom Object type used in auto-generated code to enable the generic marshal/unmarshal operations to access nested attribute struct tags during conversion.
16+
Custom types docs: https://developer.hashicorp.com/terraform/plugin/framework/handling-data/types/custom
17+
18+
Usage:
19+
- Schema definition:
20+
"sample_nested_object": schema.SingleNestedAttribute{
21+
...
22+
CustomType: customtype.NewObjectType[TFSampleNestedObjectModel](ctx),
23+
Attributes: map[string]schema.Attribute{
24+
"string_attribute": schema.StringAttribute{...},
25+
},
26+
}
27+
28+
- TF Models:
29+
type TFModel struct {
30+
SampleNestedObject customtype.ObjectValue[TFSampleNestedObjectModel] `tfsdk:"sample_nested_object"`
31+
...
32+
}
33+
34+
type TFSampleNestedObjectModel struct {
35+
StringAttribute types.String `tfsdk:"string_attribute"`
36+
...
37+
}
38+
*/
39+
40+
var (
41+
_ basetypes.ObjectTypable = ObjectType[struct{}]{}
42+
_ basetypes.ObjectValuable = ObjectValue[struct{}]{}
43+
_ ObjectValueInterface = ObjectValue[struct{}]{}
44+
)
45+
46+
type ObjectType[T any] struct {
47+
basetypes.ObjectType
48+
}
49+
50+
func NewObjectType[T any](ctx context.Context) ObjectType[T] {
51+
result := ObjectType[T]{}
52+
53+
attrTypes, diags := getAttributeTypes[T](ctx)
54+
if diags.HasError() {
55+
panic(fmt.Errorf("error creating ObjectType: %v", diags))
56+
}
57+
58+
result.ObjectType = basetypes.ObjectType{AttrTypes: attrTypes}
59+
return result
60+
}
61+
62+
func (t ObjectType[T]) Equal(o attr.Type) bool {
63+
other, ok := o.(ObjectType[T])
64+
if !ok {
65+
return false
66+
}
67+
return t.ObjectType.Equal(other.ObjectType)
68+
}
69+
70+
func (ObjectType[T]) String() string {
71+
var t T
72+
return fmt.Sprintf("ObjectType[%T]", t)
73+
}
74+
75+
func (t ObjectType[T]) ValueFromObject(ctx context.Context, in basetypes.ObjectValue) (basetypes.ObjectValuable, diag.Diagnostics) {
76+
if in.IsNull() {
77+
return NewObjectValueNull[T](ctx), nil
78+
}
79+
80+
if in.IsUnknown() {
81+
return NewObjectValueUnknown[T](ctx), nil
82+
}
83+
84+
attrTypes, diags := getAttributeTypes[T](ctx)
85+
if diags.HasError() {
86+
return nil, diags
87+
}
88+
89+
baseObjectValue, diags := basetypes.NewObjectValue(attrTypes, in.Attributes())
90+
if diags.HasError() {
91+
return nil, diags
92+
}
93+
94+
return ObjectValue[T]{ObjectValue: baseObjectValue}, nil
95+
}
96+
97+
func (t ObjectType[T]) ValueFromTerraform(ctx context.Context, in tftypes.Value) (attr.Value, error) {
98+
attrValue, err := t.ObjectType.ValueFromTerraform(ctx, in)
99+
100+
if err != nil {
101+
return nil, err
102+
}
103+
104+
objectValue, ok := attrValue.(basetypes.ObjectValue)
105+
if !ok {
106+
return nil, fmt.Errorf("unexpected value type of %T", attrValue)
107+
}
108+
109+
objectValuable, diags := t.ValueFromObject(ctx, objectValue)
110+
if diags.HasError() {
111+
return nil, fmt.Errorf("unexpected error converting ObjectValue to ObjectValuable: %v", diags)
112+
}
113+
114+
return objectValuable, nil
115+
}
116+
117+
func (t ObjectType[T]) ValueType(_ context.Context) attr.Value {
118+
return ObjectValue[T]{}
119+
}
120+
121+
type ObjectValue[T any] struct {
122+
basetypes.ObjectValue
123+
}
124+
125+
type ObjectValueInterface interface {
126+
basetypes.ObjectValuable
127+
ValuePtrAsAny(ctx context.Context) (any, diag.Diagnostics)
128+
NewObjectValue(ctx context.Context, value any) ObjectValueInterface
129+
NewObjectValueNull(ctx context.Context) ObjectValueInterface
130+
}
131+
132+
func (v ObjectValue[T]) NewObjectValue(ctx context.Context, value any) ObjectValueInterface {
133+
return NewObjectValue[T](ctx, value)
134+
}
135+
136+
func NewObjectValue[T any](ctx context.Context, value any) ObjectValue[T] {
137+
attrTypes, diags := getAttributeTypes[T](ctx)
138+
if diags.HasError() {
139+
panic(fmt.Errorf("error creating ObjectValue: %v", diags))
140+
}
141+
142+
newValue, diags := basetypes.NewObjectValueFrom(ctx, attrTypes, value)
143+
if diags.HasError() {
144+
return NewObjectValueUnknown[T](ctx)
145+
}
146+
147+
return ObjectValue[T]{ObjectValue: newValue}
148+
}
149+
150+
func (v ObjectValue[T]) NewObjectValueNull(ctx context.Context) ObjectValueInterface {
151+
return NewObjectValueNull[T](ctx)
152+
}
153+
154+
func NewObjectValueNull[T any](ctx context.Context) ObjectValue[T] {
155+
attrTypes, diags := getAttributeTypes[T](ctx)
156+
if diags.HasError() {
157+
panic(fmt.Errorf("error creating null ObjectValue: %v", diags))
158+
}
159+
return ObjectValue[T]{ObjectValue: basetypes.NewObjectNull(attrTypes)}
160+
}
161+
162+
func NewObjectValueUnknown[T any](ctx context.Context) ObjectValue[T] {
163+
attrTypes, diags := getAttributeTypes[T](ctx)
164+
if diags.HasError() {
165+
panic(fmt.Errorf("error creating unknown ObjectValue: %v", diags))
166+
}
167+
return ObjectValue[T]{ObjectValue: basetypes.NewObjectUnknown(attrTypes)}
168+
}
169+
170+
func (v ObjectValue[T]) Equal(o attr.Value) bool {
171+
other, ok := o.(ObjectValue[T])
172+
if !ok {
173+
return false
174+
}
175+
return v.ObjectValue.Equal(other.ObjectValue)
176+
}
177+
178+
func (v ObjectValue[T]) Type(ctx context.Context) attr.Type {
179+
return NewObjectType[T](ctx)
180+
}
181+
182+
func (v ObjectValue[T]) ValuePtrAsAny(ctx context.Context) (any, diag.Diagnostics) {
183+
valuePtr := new(T)
184+
185+
if v.IsNull() || v.IsUnknown() {
186+
return valuePtr, nil
187+
}
188+
189+
diags := v.As(ctx, valuePtr, basetypes.ObjectAsOptions{})
190+
if diags.HasError() {
191+
return nil, diags
192+
}
193+
194+
return valuePtr, diags
195+
}
196+
197+
func getAttributeTypes[T any](ctx context.Context) (map[string]attr.Type, diag.Diagnostics) {
198+
var t T
199+
return valueToAttributeTypes(ctx, reflect.ValueOf(t))
200+
}
201+
202+
func valueToAttributeTypes(ctx context.Context, value reflect.Value) (map[string]attr.Type, diag.Diagnostics) {
203+
valueType := value.Type()
204+
205+
if valueType.Kind() != reflect.Struct {
206+
return nil, diag.Diagnostics{diag.NewErrorDiagnostic(
207+
"Error getting value attribute types",
208+
fmt.Sprintf(`%T has usupported type: %s`, value.Interface(), valueType),
209+
)}
210+
}
211+
212+
attributeTypes := make(map[string]attr.Type)
213+
for i := range valueType.NumField() {
214+
typeField := valueType.Field(i)
215+
valueField := value.Field(i)
216+
217+
tfName := typeField.Tag.Get(`tfsdk`)
218+
if tfName == "" {
219+
return nil, diag.Diagnostics{diag.NewErrorDiagnostic(
220+
"Error getting value attribute types",
221+
fmt.Sprintf(`%T has no tfsdk tag on field %s`, value.Interface(), typeField.Name),
222+
)}
223+
}
224+
225+
attrValue, ok := valueField.Interface().(attr.Value)
226+
if !ok {
227+
return nil, diag.Diagnostics{diag.NewErrorDiagnostic(
228+
"Error getting value attribute types",
229+
fmt.Sprintf(`%T has unsupported type in field %s: %T`, value.Interface(), typeField.Name, valueField.Interface()),
230+
)}
231+
}
232+
233+
attributeTypes[tfName] = attrValue.Type(ctx)
234+
}
235+
236+
return attributeTypes, nil
237+
}

internal/common/autogen/marshal.go

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
package autogen
22

33
import (
4+
"context"
45
"encoding/json"
56
"fmt"
67
"reflect"
78

89
"github.com/hashicorp/terraform-plugin-framework-jsontypes/jsontypes"
910
"github.com/hashicorp/terraform-plugin-framework/attr"
1011
"github.com/hashicorp/terraform-plugin-framework/types"
12+
"github.com/mongodb/terraform-provider-mongodbatlas/internal/common/autogen/customtype"
1113
"github.com/mongodb/terraform-provider-mongodbatlas/internal/common/autogen/stringcase"
1214
)
1315

@@ -43,7 +45,7 @@ func Marshal(model any, isUpdate bool) ([]byte, error) {
4345

4446
func marshalAttrs(valModel reflect.Value, isUpdate bool) (map[string]any, error) {
4547
objJSON := make(map[string]any)
46-
for i := 0; i < valModel.NumField(); i++ {
48+
for i := range valModel.NumField() {
4749
attrTypeModel := valModel.Type().Field(i)
4850
tag := attrTypeModel.Tag.Get(tagKey)
4951
if tag == tagValOmitJSON {
@@ -68,7 +70,7 @@ func marshalAttr(attrNameModel string, attrValModel reflect.Value, objJSON map[s
6870
if !ok {
6971
panic("marshal expects only Terraform types in the model")
7072
}
71-
val, err := getModelAttr(obj)
73+
val, err := getModelAttr(obj, isUpdate)
7274
if err != nil {
7375
return err
7476
}
@@ -86,7 +88,7 @@ func marshalAttr(attrNameModel string, attrValModel reflect.Value, objJSON map[s
8688
return nil
8789
}
8890

89-
func getModelAttr(val attr.Value) (any, error) {
91+
func getModelAttr(val attr.Value, isUpdate bool) (any, error) {
9092
if val.IsNull() || val.IsUnknown() {
9193
return nil, nil // skip null or unknown values
9294
}
@@ -100,28 +102,36 @@ func getModelAttr(val attr.Value) (any, error) {
100102
case types.Float64:
101103
return v.ValueFloat64(), nil
102104
case types.Object:
103-
return getMapAttr(v.Attributes(), false)
105+
return getMapAttr(v.Attributes(), false, isUpdate)
104106
case types.Map:
105-
return getMapAttr(v.Elements(), true)
107+
return getMapAttr(v.Elements(), true, isUpdate)
106108
case types.List:
107-
return getListAttr(v.Elements())
109+
return getListAttr(v.Elements(), isUpdate)
108110
case types.Set:
109-
return getListAttr(v.Elements())
111+
return getListAttr(v.Elements(), isUpdate)
110112
case jsontypes.Normalized:
111113
var valueJSON any
112114
if err := json.Unmarshal([]byte(v.ValueString()), &valueJSON); err != nil {
113115
return nil, fmt.Errorf("marshal failed for JSON custom type: %v", err)
114116
}
115117
return valueJSON, nil
118+
case customtype.ObjectValueInterface:
119+
valuePtr, diags := v.ValuePtrAsAny(context.Background())
120+
if diags.HasError() {
121+
return nil, fmt.Errorf("marshal failed for type: %v", diags)
122+
}
123+
124+
result, err := marshalAttrs(reflect.ValueOf(valuePtr).Elem(), isUpdate)
125+
return result, err
116126
default:
117127
return nil, fmt.Errorf("marshal not supported yet for type %T", v)
118128
}
119129
}
120130

121-
func getListAttr(elms []attr.Value) (any, error) {
131+
func getListAttr(elms []attr.Value, isUpdate bool) (any, error) {
122132
arr := make([]any, 0)
123133
for _, attr := range elms {
124-
valChild, err := getModelAttr(attr)
134+
valChild, err := getModelAttr(attr, isUpdate)
125135
if err != nil {
126136
return nil, err
127137
}
@@ -134,10 +144,10 @@ func getListAttr(elms []attr.Value) (any, error) {
134144

135145
// getMapAttr gets a map of attributes and returns a map of JSON attributes.
136146
// keepKeyCase is used for types.Map to keep key case. However, we want to use JSON key case for types.Object
137-
func getMapAttr(elms map[string]attr.Value, keepKeyCase bool) (any, error) {
147+
func getMapAttr(elms map[string]attr.Value, keepKeyCase, isUpdate bool) (any, error) {
138148
objJSON := make(map[string]any)
139149
for name, attr := range elms {
140-
valChild, err := getModelAttr(attr)
150+
valChild, err := getModelAttr(attr, isUpdate)
141151
if err != nil {
142152
return nil, err
143153
}

0 commit comments

Comments
 (0)