Skip to content

Commit bf310d1

Browse files
dplassgitcopybara-github
authored andcommitted
[Proc-scoped channels] Spawn a proc with channel arrays. When figuring out the proc instantiation args, expand any channel arrays to the list of all the channel interfaces it contains.
PiperOrigin-RevId: 807775728
1 parent 65aa30d commit bf310d1

8 files changed

+357
-40
lines changed

xls/dslx/ir_convert/channel_scope.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,18 @@ class ChannelArray {
5252
// For logging purposes.
5353
std::string ToString() const { return base_channel_name_; }
5454

55+
std::vector<ChannelRef> channels() const {
56+
std::vector<ChannelRef> channels;
57+
channels.reserve(flattened_names_in_order_.size());
58+
for (auto& name : flattened_names_in_order_) {
59+
std::optional<ChannelRef> channel = FindChannel(name);
60+
if (channel.has_value()) {
61+
channels.push_back(*channel);
62+
}
63+
}
64+
return channels;
65+
}
66+
5567
private:
5668
explicit ChannelArray(std::string_view base_channel_name,
5769
bool subarray = false)
@@ -69,7 +81,7 @@ class ChannelArray {
6981
flattened_name_to_channel_.emplace(flattened_name, channel);
7082
}
7183

72-
std::optional<ChannelRef> FindChannel(std::string_view flattened_name) {
84+
std::optional<ChannelRef> FindChannel(std::string_view flattened_name) const {
7385
const auto it = flattened_name_to_channel_.find(flattened_name);
7486
if (it == flattened_name_to_channel_.end()) {
7587
return std::nullopt;

xls/dslx/ir_convert/function_converter.cc

Lines changed: 75 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ absl::Status FunctionConverter::Visit(const AstNode* node) {
480480
ChannelDirectionToString(ci->direction()),
481481
ci);
482482
},
483+
[](ChannelArray* ca) { return absl::StrFormat("%p", ca); },
483484
},
484485
value);
485486
}
@@ -745,6 +746,23 @@ absl::Status FunctionConverter::HandleTupleIndex(const TupleIndex* node) {
745746

746747
absl::Status FunctionConverter::HandleXlsTuple(const XlsTuple* node) {
747748
VLOG(5) << "FunctionConverter::HandleXlsTuple: " << node->ToString();
749+
if (current_fn_tag_ == FunctionTag::kProcConfig) {
750+
std::vector<IrValue> ir_operands;
751+
std::vector<BValue> b_operands;
752+
for (Expr* operand : node->members()) {
753+
std::optional<IrValue> v = GetNodeToIr(operand);
754+
XLS_RET_CHECK(v.has_value());
755+
if (std::holds_alternative<BValue>(*v)) {
756+
b_operands.push_back(std::get<BValue>(*v));
757+
}
758+
ir_operands.push_back(*v);
759+
}
760+
last_tuple_ = ir_operands;
761+
Def(node, [this, &b_operands](const SourceInfo& loc) {
762+
return function_builder_->Tuple(b_operands, loc);
763+
});
764+
return absl::OkStatus();
765+
}
748766
std::vector<BValue> operands;
749767
for (Expr* o : node->members()) {
750768
XLS_ASSIGN_OR_RETURN(BValue v, Use(o));
@@ -753,9 +771,6 @@ absl::Status FunctionConverter::HandleXlsTuple(const XlsTuple* node) {
753771
Def(node, [this, &operands](const SourceInfo& loc) {
754772
return function_builder_->Tuple(operands, loc);
755773
});
756-
if (current_fn_tag_ == FunctionTag::kProcConfig) {
757-
last_tuple_ = operands;
758-
}
759774
return absl::OkStatus();
760775
}
761776

@@ -850,9 +865,7 @@ absl::Status FunctionConverter::HandleNameRef(const NameRef* node) {
850865
} else if (std::holds_alternative<ChannelArray*>(v)) {
851866
VLOG(4) << "Reference to Proc member: " << k << " : Chan array : "
852867
<< std::get<ChannelArray*>(v)->ToString();
853-
// There is no IR equivalent for a channel array. The Index nodes
854-
// that refer to it have to be lowered to refer to specific channels.
855-
return absl::OkStatus();
868+
SetNodeToIr(from, std::get<ChannelArray*>(v));
856869
} else if (std::holds_alternative<Channel*>(v)) {
857870
VLOG(4) << "Reference to Proc member: " << k
858871
<< " : Chan : " << std::get<Channel*>(v)->ToString();
@@ -2635,6 +2648,10 @@ absl::StatusOr<ChanRef> FunctionConverter::IrValueToChannelRef(
26352648
return absl::InvalidArgumentError(
26362649
"Unexpected ChannelInterface in IrValue.");
26372650
},
2651+
[](ChannelArray* chan) -> absl::StatusOr<ChanRef> {
2652+
return absl::InvalidArgumentError(
2653+
"Unexpected ChannelArray in IrValue.");
2654+
},
26382655
},
26392656
ir_value);
26402657
}
@@ -3198,20 +3215,30 @@ absl::Status FunctionConverter::HandleSpawn(const Spawn* node) {
31983215

31993216
const Invocation* invocation = node->config();
32003217
XLS_RET_CHECK(package_data_.invocation_to_ir_proc.contains(invocation));
3201-
xls::Proc* ir_proc = package_data_.invocation_to_ir_proc[invocation];
32023218

3203-
// Lookup the channel interface for each actual arg and add it to the args
3204-
// vector.
32053219
std::vector<ChannelInterface*> channel_args;
3220+
channel_args.reserve(invocation->args().size());
32063221
for (Expr* arg : invocation->args()) {
32073222
XLS_RETURN_IF_ERROR(Visit(arg));
3208-
std::optional<IrValue> arg_value = GetNodeToIr(arg);
3209-
XLS_RET_CHECK(arg_value.has_value());
3210-
XLS_ASSIGN_OR_RETURN(ChannelInterface * channel_interface,
3211-
IrValueToChannelInterface(*arg_value));
3212-
channel_args.push_back(channel_interface);
3223+
std::optional<IrValue> arg_value_opt = GetNodeToIr(arg);
3224+
XLS_RET_CHECK(arg_value_opt.has_value());
3225+
IrValue arg_value = *arg_value_opt;
3226+
if (std::holds_alternative<ChannelArray*>(arg_value)) {
3227+
// Get all the channel interfaces of this array and
3228+
// add them to the channel_args
3229+
ChannelArray* channel_array = std::get<ChannelArray*>(arg_value);
3230+
for (ChannelRef channel : channel_array->channels()) {
3231+
XLS_RET_CHECK(std::holds_alternative<ChannelInterface*>(channel));
3232+
channel_args.push_back(std::get<ChannelInterface*>(channel));
3233+
}
3234+
} else {
3235+
XLS_ASSIGN_OR_RETURN(ChannelInterface * channel_interface,
3236+
IrValueToChannelInterface(arg_value));
3237+
channel_args.push_back(channel_interface);
3238+
}
32133239
}
32143240

3241+
xls::Proc* ir_proc = package_data_.invocation_to_ir_proc[invocation];
32153242
xls::Proc* current_proc = builder_ptr->proc();
32163243
XLS_RETURN_IF_ERROR(
32173244
current_proc
@@ -3449,8 +3476,8 @@ absl::Status FunctionConverter::HandleProcNextFunction(
34493476
proc_scoped_channel_scope_->DefineBoundaryChannelOrArray(
34503477
param, type_info));
34513478
XLS_RET_CHECK(std::holds_alternative<ChannelArray*>(channel_or_array));
3452-
// TODO: davidplass - assign the param name to an appropriate BValue
3453-
// for this channel or array, for the Def call.
3479+
SetNodeToIr(param->name_def(),
3480+
std::get<ChannelArray*>(channel_or_array));
34543481
} else {
34553482
ChannelType* channel_type = dynamic_cast<ChannelType*>(type);
34563483
XLS_RET_CHECK_NE(channel_type, nullptr)
@@ -3509,28 +3536,39 @@ absl::Status FunctionConverter::HandleProcNextFunction(
35093536
XLS_RET_CHECK_EQ(last_tuple_.size(), proc->members().size());
35103537
int i = 0;
35113538
for (const ProcMember* member : proc->members()) {
3512-
BValue tuple_entry = last_tuple_[i++];
3513-
Def(member, [tuple_entry](const SourceInfo& loc) { return tuple_entry; });
3514-
3515-
// Store the ChannelInterface for this entry in the id_to_members map.
3516-
Node* node = tuple_entry.node();
3517-
switch (node->op()) {
3518-
case Op::kSendChannelEnd: {
3519-
SendChannelEnd* sce = node->As<SendChannelEnd>();
3520-
proc_data_->id_to_members.at(proc_id)[member->identifier()] =
3521-
sce->channel_interface();
3522-
break;
3523-
}
3524-
case Op::kRecvChannelEnd: {
3525-
RecvChannelEnd* rce = node->As<RecvChannelEnd>();
3526-
proc_data_->id_to_members.at(proc_id)[member->identifier()] =
3527-
rce->channel_interface();
3528-
break;
3539+
IrValue tuple_entry = last_tuple_[i++];
3540+
if (std::holds_alternative<ChannelArray*>(tuple_entry)) {
3541+
ChannelArray* channel_array = std::get<ChannelArray*>(tuple_entry);
3542+
SetNodeToIr(member->name_def(), tuple_entry);
3543+
proc_data_->id_to_members.at(proc_id)[member->identifier()] =
3544+
channel_array;
3545+
XLS_RETURN_IF_ERROR(channel_scope_->AssociateWithExistingChannelOrArray(
3546+
*proc_id_, member->name_def(), channel_array));
3547+
} else {
3548+
BValue bvalue = std::get<BValue>(tuple_entry);
3549+
Def(member->name_def(),
3550+
[bvalue](const SourceInfo& loc) { return bvalue; });
3551+
3552+
// Store the ChannelInterface for this entry in the id_to_members map.
3553+
Node* node = bvalue.node();
3554+
switch (node->op()) {
3555+
case Op::kSendChannelEnd: {
3556+
SendChannelEnd* sce = node->As<SendChannelEnd>();
3557+
proc_data_->id_to_members.at(proc_id)[member->identifier()] =
3558+
sce->channel_interface();
3559+
break;
3560+
}
3561+
case Op::kRecvChannelEnd: {
3562+
RecvChannelEnd* rce = node->As<RecvChannelEnd>();
3563+
proc_data_->id_to_members.at(proc_id)[member->identifier()] =
3564+
rce->channel_interface();
3565+
break;
3566+
}
3567+
default:
3568+
return absl::InternalError(absl::StrFormat(
3569+
"Cannot process config return tuple element %d of type %s",
3570+
i - 1, OpToString(node->op())));
35293571
}
3530-
default:
3531-
return absl::InternalError(absl::StrFormat(
3532-
"Cannot process config return tuple element %d of type %s", i - 1,
3533-
OpToString(node->op())));
35343572
}
35353573
}
35363574
}

xls/dslx/ir_convert/function_converter.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,8 @@ class FunctionConverter {
225225
// Every AST node has an "IR value" that is either a function builder value
226226
// (BValue) or its IR-conversion-time-constant-decorated cousin (CValue), or
227227
// an inter-proc Channel.
228-
using IrValue = std::variant<BValue, CValue, Channel*, ChannelInterface*>;
228+
using IrValue =
229+
std::variant<BValue, CValue, Channel*, ChannelInterface*, ChannelArray*>;
229230

230231
// Helper for converting an IR value to its BValue pointer for use in
231232
// debugging.
@@ -629,7 +630,7 @@ class FunctionConverter {
629630

630631
// The last tuple converted. Used for mapping the return tuple of a proc
631632
// `config` method to actual proc members.
632-
std::vector<BValue> last_tuple_;
633+
std::vector<IrValue> last_tuple_;
633634

634635
std::unique_ptr<ProcScopedChannelScope> proc_scoped_channel_scope_;
635636
};

xls/dslx/ir_convert/ir_converter_test.cc

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5016,6 +5016,150 @@ pub proc main {
50165016
HasSubstr("Cannot have non-channel parameters")));
50175017
}
50185018

5019+
TEST_P(ProcScopedChannelsIrConverterTest, ChannelArrayMembers) {
5020+
constexpr std::string_view kProgram = R"(
5021+
pub proc main {
5022+
ins: chan<u32>[3] in;
5023+
outs: chan<u16>[2] out;
5024+
5025+
init { }
5026+
config(in_chans: chan<u32>[3] in, out_chans: chan<u16>[2] out) {
5027+
(in_chans, out_chans)
5028+
}
5029+
next(state: ()) {
5030+
send(token(), outs[0], u16:42);
5031+
recv(token(), ins[1]);
5032+
()
5033+
}
5034+
}
5035+
)";
5036+
5037+
auto import_data = CreateImportDataForTest();
5038+
XLS_ASSERT_OK_AND_ASSIGN(
5039+
std::string converted,
5040+
ConvertOneFunctionForTest(kProgram, "main", import_data,
5041+
ConvertOptions{
5042+
.emit_positions = false,
5043+
.lower_to_proc_scoped_channels = true,
5044+
}));
5045+
ExpectIr(converted);
5046+
}
5047+
5048+
TEST_P(ProcScopedChannelsIrConverterTest, SpawnFromChannelArrayParams) {
5049+
if (GetParam() == TypeInferenceVersion::kVersion1) {
5050+
// v1 fails figuring out the send argument type for some reason.
5051+
return;
5052+
}
5053+
5054+
constexpr std::string_view kProgram = R"(
5055+
pub proc spawnee {
5056+
ins: chan<u32>[3] in;
5057+
outs: chan<u16>[2] out;
5058+
5059+
init { }
5060+
config(in_chans: chan<u32>[3] in, out_chans: chan<u16>[2] out) {
5061+
(in_chans, out_chans)
5062+
}
5063+
next(state: ()) {
5064+
send(token(), outs[0], u16:42);
5065+
recv(token(), ins[1]);
5066+
()
5067+
}
5068+
}
5069+
5070+
pub proc main {
5071+
init { }
5072+
config(in_chans: chan<u32>[3] in, out_chans: chan<u16>[2] out) {
5073+
spawn spawnee(in_chans, out_chans);
5074+
()
5075+
}
5076+
next(state: ()) { () }
5077+
}
5078+
)";
5079+
5080+
auto import_data = CreateImportDataForTest();
5081+
XLS_ASSERT_OK_AND_ASSIGN(
5082+
std::string converted,
5083+
ConvertOneFunctionForTest(kProgram, "main", import_data,
5084+
ConvertOptions{
5085+
.emit_positions = false,
5086+
.lower_to_proc_scoped_channels = true,
5087+
}));
5088+
ExpectIr(converted);
5089+
}
5090+
5091+
TEST_P(ProcScopedChannelsIrConverterTest, ChannelArrayAndChannelMembers) {
5092+
constexpr std::string_view kProgram = R"(
5093+
pub proc main {
5094+
ins: chan<u32>[3] in;
5095+
outs: chan<u16> out;
5096+
5097+
init { }
5098+
config(in_chans: chan<u32>[3] in, out_chan: chan<u16> out) {
5099+
(in_chans, out_chan)
5100+
}
5101+
next(state: ()) {
5102+
send(token(), outs, u16:42);
5103+
recv(token(), ins[1]);
5104+
()
5105+
}
5106+
}
5107+
)";
5108+
5109+
auto import_data = CreateImportDataForTest();
5110+
XLS_ASSERT_OK_AND_ASSIGN(
5111+
std::string converted,
5112+
ConvertOneFunctionForTest(kProgram, "main", import_data,
5113+
ConvertOptions{
5114+
.emit_positions = false,
5115+
.lower_to_proc_scoped_channels = true,
5116+
}));
5117+
ExpectIr(converted);
5118+
}
5119+
5120+
TEST_P(ProcScopedChannelsIrConverterTest,
5121+
SpawnFromChannelArrayAndChannelParams) {
5122+
if (GetParam() == TypeInferenceVersion::kVersion1) {
5123+
// v1 fails figuring out the send argument type for some reason.
5124+
return;
5125+
}
5126+
5127+
constexpr std::string_view kProgram = R"(
5128+
pub proc spawnee {
5129+
ins: chan<u32>[3] in;
5130+
outch: chan<u16> out;
5131+
5132+
init { }
5133+
config(in_chans: chan<u32>[3] in, out_chan: chan<u16> out) {
5134+
(in_chans, out_chan)
5135+
}
5136+
next(state: ()) {
5137+
send(token(), outch, u16:42);
5138+
recv(token(), ins[1]);
5139+
()
5140+
}
5141+
}
5142+
5143+
pub proc main {
5144+
init { }
5145+
config(in_chans: chan<u32>[3] in, out_chans: chan<u16>[2] out) {
5146+
spawn spawnee(in_chans, out_chans[1]);
5147+
()
5148+
}
5149+
next(state: ()) { () }
5150+
}
5151+
)";
5152+
5153+
auto import_data = CreateImportDataForTest();
5154+
XLS_ASSERT_OK_AND_ASSIGN(
5155+
std::string converted,
5156+
ConvertOneFunctionForTest(kProgram, "main", import_data,
5157+
ConvertOptions{
5158+
.emit_positions = false,
5159+
.lower_to_proc_scoped_channels = true,
5160+
}));
5161+
ExpectIr(converted);
5162+
}
50195163
TEST_P(ProcScopedChannelsIrConverterTest, LoopbackChannelMember) {
50205164
constexpr std::string_view kProgram = R"(
50215165
proc main {
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package test_module
2+
3+
file_number 0 "test_module.x"
4+
5+
top proc __test_module__main_0_next<in_chans__0: bits[32] in, in_chans__1: bits[32] in, in_chans__2: bits[32] in, out_chan: bits[16] out>(__state: (), init={()}) {
6+
chan_interface in_chans__0(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
7+
chan_interface in_chans__1(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
8+
chan_interface in_chans__2(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
9+
chan_interface out_chan(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
10+
literal.9: token = literal(value=token, id=9)
11+
literal.3: bits[1] = literal(value=1, id=3)
12+
out_chan: bits[16] = send_channel_end(id=4)
13+
literal.6: token = literal(value=token, id=6)
14+
literal.7: bits[16] = literal(value=42, id=7)
15+
receive.10: (token, bits[32]) = receive(literal.9, predicate=literal.3, channel=in_chans__1, id=10)
16+
__state: () = state_read(state_element=__state, id=2)
17+
tuple.12: () = tuple(id=12)
18+
__token: token = literal(value=token, id=1)
19+
tuple.5: (bits[16]) = tuple(out_chan, id=5)
20+
send.8: token = send(literal.6, literal.7, predicate=literal.3, channel=out_chan, id=8)
21+
tuple_index.11: token = tuple_index(receive.10, index=0, id=11)
22+
next_value.13: () = next_value(param=__state, value=tuple.12, id=13)
23+
}

0 commit comments

Comments
 (0)