@@ -80,6 +80,28 @@ class DeductionWorklist {
8080 AddAll (context_.inst_blocks ().Get (params), args, needs_substitution);
8181 }
8282
83+ auto AddAll (SemIR::StructTypeFieldsId params, SemIR::StructTypeFieldsId args,
84+ bool needs_substitution) -> void {
85+ const auto & param_fields = context_.struct_type_fields ().Get (params);
86+ const auto & arg_fields = context_.struct_type_fields ().Get (args);
87+ if (param_fields.size () != arg_fields.size ()) {
88+ // TODO: Decide whether to error on this or just treat the parameter list
89+ // as non-deduced. For now we treat it as non-deduced.
90+ return ;
91+ }
92+ // Don't do deduction unless the names match in order.
93+ // TODO: Support reordering of names.
94+ for (auto [param, arg] : llvm::zip_equal (param_fields, arg_fields)) {
95+ if (param.name_id != arg.name_id ) {
96+ return ;
97+ }
98+ }
99+ for (auto [param, arg] :
100+ llvm::reverse (llvm::zip_equal (param_fields, arg_fields))) {
101+ Add (param.type_id , arg.type_id , needs_substitution);
102+ }
103+ }
104+
83105 auto AddAll (SemIR::InstBlockId params, SemIR::InstBlockId args,
84106 bool needs_substitution) -> void {
85107 AddAll (context_.inst_blocks ().Get (params), context_.inst_blocks ().Get (args),
@@ -107,6 +129,10 @@ class DeductionWorklist {
107129 case SemIR::IdKind::For<SemIR::TypeId>:
108130 Add (SemIR::TypeId (param), SemIR::TypeId (arg), needs_substitution);
109131 break ;
132+ case SemIR::IdKind::For<SemIR::StructTypeFieldsId>:
133+ AddAll (SemIR::StructTypeFieldsId (param), SemIR::StructTypeFieldsId (arg),
134+ needs_substitution);
135+ break ;
110136 case SemIR::IdKind::For<SemIR::InstBlockId>:
111137 AddAll (SemIR::InstBlockId (param), SemIR::InstBlockId (arg),
112138 needs_substitution);
@@ -378,6 +404,7 @@ auto DeductionContext::Deduce() -> bool {
378404 case SemIR::InterfaceType::Kind:
379405 case SemIR::IntType::Kind:
380406 case SemIR::PointerType::Kind:
407+ case SemIR::StructType::Kind:
381408 case SemIR::TupleType::Kind:
382409 case SemIR::TupleValue::Kind: {
383410 auto arg_inst = context ().insts ().Get (arg_id);
@@ -392,7 +419,6 @@ auto DeductionContext::Deduce() -> bool {
392419 continue ;
393420 }
394421
395- case SemIR::StructType::Kind:
396422 case SemIR::StructValue::Kind:
397423 // TODO: Match field name order between param and arg.
398424 break ;
0 commit comments