From 01a9be7b72b4b33896e0b766de70c83dd3440959 Mon Sep 17 00:00:00 2001 From: David Plass Date: Thu, 18 Sep 2025 12:55:09 -0700 Subject: [PATCH] [Proc-scoped channels] Spawn a proc with channel arrays. When figuring out the proc instantiation args, expand any channel arrays to the list of all the channel interfaces it contains. PiperOrigin-RevId: 808699596 --- xls/dslx/ir_convert/channel_scope.h | 14 ++- xls/dslx/ir_convert/function_converter.cc | 26 ++++-- xls/dslx/ir_convert/ir_converter_test.cc | 88 +++++++++++++++++++ ...t_SpawnFromChannelArrayAndChannelParams.ir | 38 ++++++++ ...verter_test_SpawnFromChannelArrayParams.ir | 38 ++++++++ 5 files changed, 195 insertions(+), 9 deletions(-) create mode 100644 xls/dslx/ir_convert/testdata/ir_converter_test_SpawnFromChannelArrayAndChannelParams.ir create mode 100644 xls/dslx/ir_convert/testdata/ir_converter_test_SpawnFromChannelArrayParams.ir diff --git a/xls/dslx/ir_convert/channel_scope.h b/xls/dslx/ir_convert/channel_scope.h index 10fa26d1b8..08bbb83147 100644 --- a/xls/dslx/ir_convert/channel_scope.h +++ b/xls/dslx/ir_convert/channel_scope.h @@ -52,6 +52,18 @@ class ChannelArray { // For logging purposes. std::string ToString() const { return base_channel_name_; } + std::vector channels() const { + std::vector channels; + channels.reserve(flattened_names_in_order_.size()); + for (auto& name : flattened_names_in_order_) { + std::optional channel = FindChannel(name); + if (channel.has_value()) { + channels.push_back(*channel); + } + } + return channels; + } + private: explicit ChannelArray(std::string_view base_channel_name, bool subarray = false) @@ -69,7 +81,7 @@ class ChannelArray { flattened_name_to_channel_.emplace(flattened_name, channel); } - std::optional FindChannel(std::string_view flattened_name) { + std::optional FindChannel(std::string_view flattened_name) const { const auto it = flattened_name_to_channel_.find(flattened_name); if (it == flattened_name_to_channel_.end()) { return std::nullopt; diff --git a/xls/dslx/ir_convert/function_converter.cc b/xls/dslx/ir_convert/function_converter.cc index ccaf649a95..1e87d2e9bc 100644 --- a/xls/dslx/ir_convert/function_converter.cc +++ b/xls/dslx/ir_convert/function_converter.cc @@ -3216,20 +3216,30 @@ absl::Status FunctionConverter::HandleSpawn(const Spawn* node) { const Invocation* invocation = node->config(); XLS_RET_CHECK(package_data_.invocation_to_ir_proc.contains(invocation)); - xls::Proc* ir_proc = package_data_.invocation_to_ir_proc[invocation]; - // Lookup the channel interface for each actual arg and add it to the args - // vector. std::vector channel_args; + channel_args.reserve(invocation->args().size()); for (Expr* arg : invocation->args()) { XLS_RETURN_IF_ERROR(Visit(arg)); - std::optional arg_value = GetNodeToIr(arg); - XLS_RET_CHECK(arg_value.has_value()); - XLS_ASSIGN_OR_RETURN(ChannelInterface * channel_interface, - IrValueToChannelInterface(*arg_value)); - channel_args.push_back(channel_interface); + std::optional arg_value_opt = GetNodeToIr(arg); + XLS_RET_CHECK(arg_value_opt.has_value()); + IrValue arg_value = *arg_value_opt; + if (std::holds_alternative(arg_value)) { + // Get all the channel interfaces of this array and + // add them to the channel_args + ChannelArray* channel_array = std::get(arg_value); + for (ChannelRef channel : channel_array->channels()) { + XLS_RET_CHECK(std::holds_alternative(channel)); + channel_args.push_back(std::get(channel)); + } + } else { + XLS_ASSIGN_OR_RETURN(ChannelInterface * channel_interface, + IrValueToChannelInterface(arg_value)); + channel_args.push_back(channel_interface); + } } + xls::Proc* ir_proc = package_data_.invocation_to_ir_proc[invocation]; xls::Proc* current_proc = builder_ptr->proc(); XLS_RETURN_IF_ERROR( current_proc diff --git a/xls/dslx/ir_convert/ir_converter_test.cc b/xls/dslx/ir_convert/ir_converter_test.cc index 5541ee5d00..10034089f6 100644 --- a/xls/dslx/ir_convert/ir_converter_test.cc +++ b/xls/dslx/ir_convert/ir_converter_test.cc @@ -5033,6 +5033,50 @@ pub proc main { } } )"; + + auto import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + std::string converted, + ConvertOneFunctionForTest(kProgram, "main", import_data, + ConvertOptions{ + .emit_positions = false, + .lower_to_proc_scoped_channels = true, + })); + ExpectIr(converted); +} + +TEST_P(ProcScopedChannelsIrConverterTest, SpawnFromChannelArrayParams) { + if (GetParam() == TypeInferenceVersion::kVersion1) { + // v1 fails figuring out the send argument type for some reason. + return; + } + + constexpr std::string_view kProgram = R"( +pub proc spawnee { + ins: chan[3] in; + outs: chan[2] out; + + init { } + config(in_chans: chan[3] in, out_chans: chan[2] out) { + (in_chans, out_chans) + } + next(state: ()) { + send(token(), outs[0], u16:42); + recv(token(), ins[1]); + () + } +} + +pub proc main { + init { } + config(in_chans: chan[3] in, out_chans: chan[2] out) { + spawn spawnee(in_chans, out_chans); + () + } + next(state: ()) { () } +} +)"; + auto import_data = CreateImportDataForTest(); XLS_ASSERT_OK_AND_ASSIGN( std::string converted, @@ -5061,6 +5105,7 @@ pub proc main { } } )"; + auto import_data = CreateImportDataForTest(); XLS_ASSERT_OK_AND_ASSIGN( std::string converted, @@ -5072,6 +5117,49 @@ pub proc main { ExpectIr(converted); } +TEST_P(ProcScopedChannelsIrConverterTest, + SpawnFromChannelArrayAndChannelParams) { + if (GetParam() == TypeInferenceVersion::kVersion1) { + // v1 fails figuring out the send argument type for some reason. + return; + } + + constexpr std::string_view kProgram = R"( +pub proc spawnee { + ins: chan[3] in; + outch: chan out; + + init { } + config(in_chans: chan[3] in, out_chan: chan out) { + (in_chans, out_chan) + } + next(state: ()) { + send(token(), outch, u16:42); + recv(token(), ins[1]); + () + } +} + +pub proc main { + init { } + config(in_chans: chan[3] in, out_chans: chan[2] out) { + spawn spawnee(in_chans, out_chans[1]); + () + } + next(state: ()) { () } +} +)"; + + auto import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + std::string converted, + ConvertOneFunctionForTest(kProgram, "main", import_data, + ConvertOptions{ + .emit_positions = false, + .lower_to_proc_scoped_channels = true, + })); + ExpectIr(converted); +} TEST_P(ProcScopedChannelsIrConverterTest, LoopbackChannelMember) { constexpr std::string_view kProgram = R"( proc main { diff --git a/xls/dslx/ir_convert/testdata/ir_converter_test_SpawnFromChannelArrayAndChannelParams.ir b/xls/dslx/ir_convert/testdata/ir_converter_test_SpawnFromChannelArrayAndChannelParams.ir new file mode 100644 index 0000000000..19a589b5b6 --- /dev/null +++ b/xls/dslx/ir_convert/testdata/ir_converter_test_SpawnFromChannelArrayAndChannelParams.ir @@ -0,0 +1,38 @@ +package test_module + +file_number 0 "test_module.x" + +proc __test_module__spawnee_0_next(__state: (), init={()}) { + chan_interface in_chans__0(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface in_chans__1(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface in_chans__2(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface out_chan(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + literal.9: token = literal(value=token, id=9) + literal.3: bits[1] = literal(value=1, id=3) + out_chan: bits[16] = send_channel_end(id=4) + literal.6: token = literal(value=token, id=6) + literal.7: bits[16] = literal(value=42, id=7) + receive.10: (token, bits[32]) = receive(literal.9, predicate=literal.3, channel=in_chans__1, id=10) + __state: () = state_read(state_element=__state, id=2) + tuple.12: () = tuple(id=12) + __token: token = literal(value=token, id=1) + tuple.5: (bits[16]) = tuple(out_chan, id=5) + send.8: token = send(literal.6, literal.7, predicate=literal.3, channel=out_chan, id=8) + tuple_index.11: token = tuple_index(receive.10, index=0, id=11) + next_value.13: () = next_value(param=__state, value=tuple.12, id=13) +} + +top proc __test_module__main_0_next(__state: (), init={()}) { + chan_interface in_chans__0(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface in_chans__1(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface in_chans__2(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface out_chans__0(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface out_chans__1(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + proc_instantiation __test_module__spawnee_0_next_inst(in_chans__0, in_chans__1, in_chans__2, out_chans__1, proc=__test_module__spawnee_0_next) + __state: () = state_read(state_element=__state, id=15) + tuple.18: () = tuple(id=18) + __token: token = literal(value=token, id=14) + literal.16: bits[1] = literal(value=1, id=16) + tuple.17: () = tuple(id=17) + next_value.19: () = next_value(param=__state, value=tuple.18, id=19) +} diff --git a/xls/dslx/ir_convert/testdata/ir_converter_test_SpawnFromChannelArrayParams.ir b/xls/dslx/ir_convert/testdata/ir_converter_test_SpawnFromChannelArrayParams.ir new file mode 100644 index 0000000000..72fbe6afaa --- /dev/null +++ b/xls/dslx/ir_convert/testdata/ir_converter_test_SpawnFromChannelArrayParams.ir @@ -0,0 +1,38 @@ +package test_module + +file_number 0 "test_module.x" + +proc __test_module__spawnee_0_next(__state: (), init={()}) { + chan_interface in_chans__0(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface in_chans__1(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface in_chans__2(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface out_chans__0(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface out_chans__1(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + literal.8: token = literal(value=token, id=8) + literal.3: bits[1] = literal(value=1, id=3) + literal.5: token = literal(value=token, id=5) + literal.6: bits[16] = literal(value=42, id=6) + receive.9: (token, bits[32]) = receive(literal.8, predicate=literal.3, channel=in_chans__1, id=9) + __state: () = state_read(state_element=__state, id=2) + tuple.11: () = tuple(id=11) + __token: token = literal(value=token, id=1) + tuple.4: () = tuple(id=4) + send.7: token = send(literal.5, literal.6, predicate=literal.3, channel=out_chans__0, id=7) + tuple_index.10: token = tuple_index(receive.9, index=0, id=10) + next_value.12: () = next_value(param=__state, value=tuple.11, id=12) +} + +top proc __test_module__main_0_next(__state: (), init={()}) { + chan_interface in_chans__0(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface in_chans__1(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface in_chans__2(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface out_chans__0(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface out_chans__1(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + proc_instantiation __test_module__spawnee_0_next_inst(in_chans__0, in_chans__1, in_chans__2, out_chans__0, out_chans__1, proc=__test_module__spawnee_0_next) + __state: () = state_read(state_element=__state, id=14) + tuple.17: () = tuple(id=17) + __token: token = literal(value=token, id=13) + literal.15: bits[1] = literal(value=1, id=15) + tuple.16: () = tuple(id=16) + next_value.18: () = next_value(param=__state, value=tuple.17, id=18) +}