Skip to content

Commit 51bc99c

Browse files
dplassgitcopybara-github
authored andcommitted
[Proc-scoped channels] Spawn a proc with channel arrays. When figuring out the proc instantiation args, expand any channel arrays to the list of all the channel interfaces it contains.
PiperOrigin-RevId: 807775728
1 parent 0c8f13f commit 51bc99c

File tree

5 files changed

+195
-9
lines changed

5 files changed

+195
-9
lines changed

xls/dslx/ir_convert/channel_scope.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,18 @@ class ChannelArray {
5252
// For logging purposes.
5353
std::string ToString() const { return base_channel_name_; }
5454

55+
std::vector<ChannelRef> channels() const {
56+
std::vector<ChannelRef> channels;
57+
channels.reserve(flattened_names_in_order_.size());
58+
for (auto& name : flattened_names_in_order_) {
59+
std::optional<ChannelRef> channel = FindChannel(name);
60+
if (channel.has_value()) {
61+
channels.push_back(*channel);
62+
}
63+
}
64+
return channels;
65+
}
66+
5567
private:
5668
explicit ChannelArray(std::string_view base_channel_name,
5769
bool subarray = false)
@@ -69,7 +81,7 @@ class ChannelArray {
6981
flattened_name_to_channel_.emplace(flattened_name, channel);
7082
}
7183

72-
std::optional<ChannelRef> FindChannel(std::string_view flattened_name) {
84+
std::optional<ChannelRef> FindChannel(std::string_view flattened_name) const {
7385
const auto it = flattened_name_to_channel_.find(flattened_name);
7486
if (it == flattened_name_to_channel_.end()) {
7587
return std::nullopt;

xls/dslx/ir_convert/function_converter.cc

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3216,20 +3216,30 @@ absl::Status FunctionConverter::HandleSpawn(const Spawn* node) {
32163216

32173217
const Invocation* invocation = node->config();
32183218
XLS_RET_CHECK(package_data_.invocation_to_ir_proc.contains(invocation));
3219-
xls::Proc* ir_proc = package_data_.invocation_to_ir_proc[invocation];
32203219

3221-
// Lookup the channel interface for each actual arg and add it to the args
3222-
// vector.
32233220
std::vector<ChannelInterface*> channel_args;
3221+
channel_args.reserve(invocation->args().size());
32243222
for (Expr* arg : invocation->args()) {
32253223
XLS_RETURN_IF_ERROR(Visit(arg));
3226-
std::optional<IrValue> arg_value = GetNodeToIr(arg);
3227-
XLS_RET_CHECK(arg_value.has_value());
3228-
XLS_ASSIGN_OR_RETURN(ChannelInterface * channel_interface,
3229-
IrValueToChannelInterface(*arg_value));
3230-
channel_args.push_back(channel_interface);
3224+
std::optional<IrValue> arg_value_opt = GetNodeToIr(arg);
3225+
XLS_RET_CHECK(arg_value_opt.has_value());
3226+
IrValue arg_value = *arg_value_opt;
3227+
if (std::holds_alternative<ChannelArray*>(arg_value)) {
3228+
// Get all the channel interfaces of this array and
3229+
// add them to the channel_args
3230+
ChannelArray* channel_array = std::get<ChannelArray*>(arg_value);
3231+
for (ChannelRef channel : channel_array->channels()) {
3232+
XLS_RET_CHECK(std::holds_alternative<ChannelInterface*>(channel));
3233+
channel_args.push_back(std::get<ChannelInterface*>(channel));
3234+
}
3235+
} else {
3236+
XLS_ASSIGN_OR_RETURN(ChannelInterface * channel_interface,
3237+
IrValueToChannelInterface(arg_value));
3238+
channel_args.push_back(channel_interface);
3239+
}
32313240
}
32323241

3242+
xls::Proc* ir_proc = package_data_.invocation_to_ir_proc[invocation];
32333243
xls::Proc* current_proc = builder_ptr->proc();
32343244
XLS_RETURN_IF_ERROR(
32353245
current_proc

xls/dslx/ir_convert/ir_converter_test.cc

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5033,6 +5033,50 @@ pub proc main {
50335033
}
50345034
}
50355035
)";
5036+
5037+
auto import_data = CreateImportDataForTest();
5038+
XLS_ASSERT_OK_AND_ASSIGN(
5039+
std::string converted,
5040+
ConvertOneFunctionForTest(kProgram, "main", import_data,
5041+
ConvertOptions{
5042+
.emit_positions = false,
5043+
.lower_to_proc_scoped_channels = true,
5044+
}));
5045+
ExpectIr(converted);
5046+
}
5047+
5048+
TEST_P(ProcScopedChannelsIrConverterTest, SpawnFromChannelArrayParams) {
5049+
if (GetParam() == TypeInferenceVersion::kVersion1) {
5050+
// v1 fails figuring out the send argument type for some reason.
5051+
return;
5052+
}
5053+
5054+
constexpr std::string_view kProgram = R"(
5055+
pub proc spawnee {
5056+
ins: chan<u32>[3] in;
5057+
outs: chan<u16>[2] out;
5058+
5059+
init { }
5060+
config(in_chans: chan<u32>[3] in, out_chans: chan<u16>[2] out) {
5061+
(in_chans, out_chans)
5062+
}
5063+
next(state: ()) {
5064+
send(token(), outs[0], u16:42);
5065+
recv(token(), ins[1]);
5066+
()
5067+
}
5068+
}
5069+
5070+
pub proc main {
5071+
init { }
5072+
config(in_chans: chan<u32>[3] in, out_chans: chan<u16>[2] out) {
5073+
spawn spawnee(in_chans, out_chans);
5074+
()
5075+
}
5076+
next(state: ()) { () }
5077+
}
5078+
)";
5079+
50365080
auto import_data = CreateImportDataForTest();
50375081
XLS_ASSERT_OK_AND_ASSIGN(
50385082
std::string converted,
@@ -5061,6 +5105,7 @@ pub proc main {
50615105
}
50625106
}
50635107
)";
5108+
50645109
auto import_data = CreateImportDataForTest();
50655110
XLS_ASSERT_OK_AND_ASSIGN(
50665111
std::string converted,
@@ -5072,6 +5117,49 @@ pub proc main {
50725117
ExpectIr(converted);
50735118
}
50745119

5120+
TEST_P(ProcScopedChannelsIrConverterTest,
5121+
SpawnFromChannelArrayAndChannelParams) {
5122+
if (GetParam() == TypeInferenceVersion::kVersion1) {
5123+
// v1 fails figuring out the send argument type for some reason.
5124+
return;
5125+
}
5126+
5127+
constexpr std::string_view kProgram = R"(
5128+
pub proc spawnee {
5129+
ins: chan<u32>[3] in;
5130+
outch: chan<u16> out;
5131+
5132+
init { }
5133+
config(in_chans: chan<u32>[3] in, out_chan: chan<u16> out) {
5134+
(in_chans, out_chan)
5135+
}
5136+
next(state: ()) {
5137+
send(token(), outch, u16:42);
5138+
recv(token(), ins[1]);
5139+
()
5140+
}
5141+
}
5142+
5143+
pub proc main {
5144+
init { }
5145+
config(in_chans: chan<u32>[3] in, out_chans: chan<u16>[2] out) {
5146+
spawn spawnee(in_chans, out_chans[1]);
5147+
()
5148+
}
5149+
next(state: ()) { () }
5150+
}
5151+
)";
5152+
5153+
auto import_data = CreateImportDataForTest();
5154+
XLS_ASSERT_OK_AND_ASSIGN(
5155+
std::string converted,
5156+
ConvertOneFunctionForTest(kProgram, "main", import_data,
5157+
ConvertOptions{
5158+
.emit_positions = false,
5159+
.lower_to_proc_scoped_channels = true,
5160+
}));
5161+
ExpectIr(converted);
5162+
}
50755163
TEST_P(ProcScopedChannelsIrConverterTest, LoopbackChannelMember) {
50765164
constexpr std::string_view kProgram = R"(
50775165
proc main {
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package test_module
2+
3+
file_number 0 "test_module.x"
4+
5+
proc __test_module__spawnee_0_next<in_chans__0: bits[32] in, in_chans__1: bits[32] in, in_chans__2: bits[32] in, out_chan: bits[16] out>(__state: (), init={()}) {
6+
chan_interface in_chans__0(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
7+
chan_interface in_chans__1(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
8+
chan_interface in_chans__2(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
9+
chan_interface out_chan(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
10+
literal.9: token = literal(value=token, id=9)
11+
literal.3: bits[1] = literal(value=1, id=3)
12+
out_chan: bits[16] = send_channel_end(id=4)
13+
literal.6: token = literal(value=token, id=6)
14+
literal.7: bits[16] = literal(value=42, id=7)
15+
receive.10: (token, bits[32]) = receive(literal.9, predicate=literal.3, channel=in_chans__1, id=10)
16+
__state: () = state_read(state_element=__state, id=2)
17+
tuple.12: () = tuple(id=12)
18+
__token: token = literal(value=token, id=1)
19+
tuple.5: (bits[16]) = tuple(out_chan, id=5)
20+
send.8: token = send(literal.6, literal.7, predicate=literal.3, channel=out_chan, id=8)
21+
tuple_index.11: token = tuple_index(receive.10, index=0, id=11)
22+
next_value.13: () = next_value(param=__state, value=tuple.12, id=13)
23+
}
24+
25+
top proc __test_module__main_0_next<in_chans__0: bits[32] in, in_chans__1: bits[32] in, in_chans__2: bits[32] in, out_chans__0: bits[16] out, out_chans__1: bits[16] out>(__state: (), init={()}) {
26+
chan_interface in_chans__0(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
27+
chan_interface in_chans__1(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
28+
chan_interface in_chans__2(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
29+
chan_interface out_chans__0(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
30+
chan_interface out_chans__1(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
31+
proc_instantiation __test_module__spawnee_0_next_inst(in_chans__0, in_chans__1, in_chans__2, out_chans__1, proc=__test_module__spawnee_0_next)
32+
__state: () = state_read(state_element=__state, id=15)
33+
tuple.18: () = tuple(id=18)
34+
__token: token = literal(value=token, id=14)
35+
literal.16: bits[1] = literal(value=1, id=16)
36+
tuple.17: () = tuple(id=17)
37+
next_value.19: () = next_value(param=__state, value=tuple.18, id=19)
38+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package test_module
2+
3+
file_number 0 "test_module.x"
4+
5+
proc __test_module__spawnee_0_next<in_chans__0: bits[32] in, in_chans__1: bits[32] in, in_chans__2: bits[32] in, out_chans__0: bits[16] out, out_chans__1: bits[16] out>(__state: (), init={()}) {
6+
chan_interface in_chans__0(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
7+
chan_interface in_chans__1(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
8+
chan_interface in_chans__2(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
9+
chan_interface out_chans__0(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
10+
chan_interface out_chans__1(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
11+
literal.8: token = literal(value=token, id=8)
12+
literal.3: bits[1] = literal(value=1, id=3)
13+
literal.5: token = literal(value=token, id=5)
14+
literal.6: bits[16] = literal(value=42, id=6)
15+
receive.9: (token, bits[32]) = receive(literal.8, predicate=literal.3, channel=in_chans__1, id=9)
16+
__state: () = state_read(state_element=__state, id=2)
17+
tuple.11: () = tuple(id=11)
18+
__token: token = literal(value=token, id=1)
19+
tuple.4: () = tuple(id=4)
20+
send.7: token = send(literal.5, literal.6, predicate=literal.3, channel=out_chans__0, id=7)
21+
tuple_index.10: token = tuple_index(receive.9, index=0, id=10)
22+
next_value.12: () = next_value(param=__state, value=tuple.11, id=12)
23+
}
24+
25+
top proc __test_module__main_0_next<in_chans__0: bits[32] in, in_chans__1: bits[32] in, in_chans__2: bits[32] in, out_chans__0: bits[16] out, out_chans__1: bits[16] out>(__state: (), init={()}) {
26+
chan_interface in_chans__0(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
27+
chan_interface in_chans__1(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
28+
chan_interface in_chans__2(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
29+
chan_interface out_chans__0(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
30+
chan_interface out_chans__1(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
31+
proc_instantiation __test_module__spawnee_0_next_inst(in_chans__0, in_chans__1, in_chans__2, out_chans__0, out_chans__1, proc=__test_module__spawnee_0_next)
32+
__state: () = state_read(state_element=__state, id=14)
33+
tuple.17: () = tuple(id=17)
34+
__token: token = literal(value=token, id=13)
35+
literal.15: bits[1] = literal(value=1, id=15)
36+
tuple.16: () = tuple(id=16)
37+
next_value.18: () = next_value(param=__state, value=tuple.17, id=18)
38+
}

0 commit comments

Comments
 (0)