Skip to content

Commit caba03d

Browse files
authored
Support deduction of the types of struct fields (#4500)
Follow-on to #4492 . --------- Co-authored-by: Josh L <[email protected]>
1 parent 138ecf1 commit caba03d

File tree

2 files changed

+456
-13
lines changed

2 files changed

+456
-13
lines changed

toolchain/check/deduce.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,28 @@ class DeductionWorklist {
8080
AddAll(context_.inst_blocks().Get(params), args, needs_substitution);
8181
}
8282

83+
auto AddAll(SemIR::StructTypeFieldsId params, SemIR::StructTypeFieldsId args,
84+
bool needs_substitution) -> void {
85+
const auto& param_fields = context_.struct_type_fields().Get(params);
86+
const auto& arg_fields = context_.struct_type_fields().Get(args);
87+
if (param_fields.size() != arg_fields.size()) {
88+
// TODO: Decide whether to error on this or just treat the parameter list
89+
// as non-deduced. For now we treat it as non-deduced.
90+
return;
91+
}
92+
// Don't do deduction unless the names match in order.
93+
// TODO: Support reordering of names.
94+
for (auto [param, arg] : llvm::zip_equal(param_fields, arg_fields)) {
95+
if (param.name_id != arg.name_id) {
96+
return;
97+
}
98+
}
99+
for (auto [param, arg] :
100+
llvm::reverse(llvm::zip_equal(param_fields, arg_fields))) {
101+
Add(param.type_id, arg.type_id, needs_substitution);
102+
}
103+
}
104+
83105
auto AddAll(SemIR::InstBlockId params, SemIR::InstBlockId args,
84106
bool needs_substitution) -> void {
85107
AddAll(context_.inst_blocks().Get(params), context_.inst_blocks().Get(args),
@@ -107,6 +129,10 @@ class DeductionWorklist {
107129
case SemIR::IdKind::For<SemIR::TypeId>:
108130
Add(SemIR::TypeId(param), SemIR::TypeId(arg), needs_substitution);
109131
break;
132+
case SemIR::IdKind::For<SemIR::StructTypeFieldsId>:
133+
AddAll(SemIR::StructTypeFieldsId(param), SemIR::StructTypeFieldsId(arg),
134+
needs_substitution);
135+
break;
110136
case SemIR::IdKind::For<SemIR::InstBlockId>:
111137
AddAll(SemIR::InstBlockId(param), SemIR::InstBlockId(arg),
112138
needs_substitution);
@@ -378,6 +404,7 @@ auto DeductionContext::Deduce() -> bool {
378404
case SemIR::InterfaceType::Kind:
379405
case SemIR::IntType::Kind:
380406
case SemIR::PointerType::Kind:
407+
case SemIR::StructType::Kind:
381408
case SemIR::TupleType::Kind:
382409
case SemIR::TupleValue::Kind: {
383410
auto arg_inst = context().insts().Get(arg_id);
@@ -392,7 +419,6 @@ auto DeductionContext::Deduce() -> bool {
392419
continue;
393420
}
394421

395-
case SemIR::StructType::Kind:
396422
case SemIR::StructValue::Kind:
397423
// TODO: Match field name order between param and arg.
398424
break;

0 commit comments

Comments
 (0)