Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion xls/dslx/ir_convert/channel_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ class ChannelArray {
// For logging purposes.
std::string ToString() const { return base_channel_name_; }

std::vector<ChannelRef> channels() const {
std::vector<ChannelRef> channels;
channels.reserve(flattened_names_in_order_.size());
for (auto& name : flattened_names_in_order_) {
std::optional<ChannelRef> channel = FindChannel(name);
if (channel.has_value()) {
channels.push_back(*channel);
}
}
return channels;
}

private:
explicit ChannelArray(std::string_view base_channel_name,
bool subarray = false)
Expand All @@ -69,7 +81,7 @@ class ChannelArray {
flattened_name_to_channel_.emplace(flattened_name, channel);
}

std::optional<ChannelRef> FindChannel(std::string_view flattened_name) {
std::optional<ChannelRef> FindChannel(std::string_view flattened_name) const {
const auto it = flattened_name_to_channel_.find(flattened_name);
if (it == flattened_name_to_channel_.end()) {
return std::nullopt;
Expand Down
26 changes: 18 additions & 8 deletions xls/dslx/ir_convert/function_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3216,20 +3216,30 @@ absl::Status FunctionConverter::HandleSpawn(const Spawn* node) {

const Invocation* invocation = node->config();
XLS_RET_CHECK(package_data_.invocation_to_ir_proc.contains(invocation));
xls::Proc* ir_proc = package_data_.invocation_to_ir_proc[invocation];

// Lookup the channel interface for each actual arg and add it to the args
// vector.
std::vector<ChannelInterface*> channel_args;
channel_args.reserve(invocation->args().size());
for (Expr* arg : invocation->args()) {
XLS_RETURN_IF_ERROR(Visit(arg));
std::optional<IrValue> arg_value = GetNodeToIr(arg);
XLS_RET_CHECK(arg_value.has_value());
XLS_ASSIGN_OR_RETURN(ChannelInterface * channel_interface,
IrValueToChannelInterface(*arg_value));
channel_args.push_back(channel_interface);
std::optional<IrValue> arg_value_opt = GetNodeToIr(arg);
XLS_RET_CHECK(arg_value_opt.has_value());
IrValue arg_value = *arg_value_opt;
if (std::holds_alternative<ChannelArray*>(arg_value)) {
// Get all the channel interfaces of this array and
// add them to the channel_args
ChannelArray* channel_array = std::get<ChannelArray*>(arg_value);
for (ChannelRef channel : channel_array->channels()) {
XLS_RET_CHECK(std::holds_alternative<ChannelInterface*>(channel));
channel_args.push_back(std::get<ChannelInterface*>(channel));
}
} else {
XLS_ASSIGN_OR_RETURN(ChannelInterface * channel_interface,
IrValueToChannelInterface(arg_value));
channel_args.push_back(channel_interface);
}
}

xls::Proc* ir_proc = package_data_.invocation_to_ir_proc[invocation];
xls::Proc* current_proc = builder_ptr->proc();
XLS_RETURN_IF_ERROR(
current_proc
Expand Down
88 changes: 88 additions & 0 deletions xls/dslx/ir_convert/ir_converter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5033,6 +5033,50 @@ pub proc main {
}
}
)";

auto import_data = CreateImportDataForTest();
XLS_ASSERT_OK_AND_ASSIGN(
std::string converted,
ConvertOneFunctionForTest(kProgram, "main", import_data,
ConvertOptions{
.emit_positions = false,
.lower_to_proc_scoped_channels = true,
}));
ExpectIr(converted);
}

TEST_P(ProcScopedChannelsIrConverterTest, SpawnFromChannelArrayParams) {
if (GetParam() == TypeInferenceVersion::kVersion1) {
// v1 fails figuring out the send argument type for some reason.
return;
}

constexpr std::string_view kProgram = R"(
pub proc spawnee {
ins: chan<u32>[3] in;
outs: chan<u16>[2] out;

init { }
config(in_chans: chan<u32>[3] in, out_chans: chan<u16>[2] out) {
(in_chans, out_chans)
}
next(state: ()) {
send(token(), outs[0], u16:42);
recv(token(), ins[1]);
()
}
}

pub proc main {
init { }
config(in_chans: chan<u32>[3] in, out_chans: chan<u16>[2] out) {
spawn spawnee(in_chans, out_chans);
()
}
next(state: ()) { () }
}
)";

auto import_data = CreateImportDataForTest();
XLS_ASSERT_OK_AND_ASSIGN(
std::string converted,
Expand Down Expand Up @@ -5061,6 +5105,7 @@ pub proc main {
}
}
)";

auto import_data = CreateImportDataForTest();
XLS_ASSERT_OK_AND_ASSIGN(
std::string converted,
Expand All @@ -5072,6 +5117,49 @@ pub proc main {
ExpectIr(converted);
}

TEST_P(ProcScopedChannelsIrConverterTest,
SpawnFromChannelArrayAndChannelParams) {
if (GetParam() == TypeInferenceVersion::kVersion1) {
// v1 fails figuring out the send argument type for some reason.
return;
}

constexpr std::string_view kProgram = R"(
pub proc spawnee {
ins: chan<u32>[3] in;
outch: chan<u16> out;

init { }
config(in_chans: chan<u32>[3] in, out_chan: chan<u16> out) {
(in_chans, out_chan)
}
next(state: ()) {
send(token(), outch, u16:42);
recv(token(), ins[1]);
()
}
}

pub proc main {
init { }
config(in_chans: chan<u32>[3] in, out_chans: chan<u16>[2] out) {
spawn spawnee(in_chans, out_chans[1]);
()
}
next(state: ()) { () }
}
)";

auto import_data = CreateImportDataForTest();
XLS_ASSERT_OK_AND_ASSIGN(
std::string converted,
ConvertOneFunctionForTest(kProgram, "main", import_data,
ConvertOptions{
.emit_positions = false,
.lower_to_proc_scoped_channels = true,
}));
ExpectIr(converted);
}
TEST_P(ProcScopedChannelsIrConverterTest, LoopbackChannelMember) {
constexpr std::string_view kProgram = R"(
proc main {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package test_module

file_number 0 "test_module.x"

proc __test_module__spawnee_0_next<in_chans__0: bits[32] in, in_chans__1: bits[32] in, in_chans__2: bits[32] in, out_chan: bits[16] out>(__state: (), init={()}) {
chan_interface in_chans__0(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface in_chans__1(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface in_chans__2(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface out_chan(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
literal.9: token = literal(value=token, id=9)
literal.3: bits[1] = literal(value=1, id=3)
out_chan: bits[16] = send_channel_end(id=4)
literal.6: token = literal(value=token, id=6)
literal.7: bits[16] = literal(value=42, id=7)
receive.10: (token, bits[32]) = receive(literal.9, predicate=literal.3, channel=in_chans__1, id=10)
__state: () = state_read(state_element=__state, id=2)
tuple.12: () = tuple(id=12)
__token: token = literal(value=token, id=1)
tuple.5: (bits[16]) = tuple(out_chan, id=5)
send.8: token = send(literal.6, literal.7, predicate=literal.3, channel=out_chan, id=8)
tuple_index.11: token = tuple_index(receive.10, index=0, id=11)
next_value.13: () = next_value(param=__state, value=tuple.12, id=13)
}

top proc __test_module__main_0_next<in_chans__0: bits[32] in, in_chans__1: bits[32] in, in_chans__2: bits[32] in, out_chans__0: bits[16] out, out_chans__1: bits[16] out>(__state: (), init={()}) {
chan_interface in_chans__0(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface in_chans__1(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface in_chans__2(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface out_chans__0(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface out_chans__1(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
proc_instantiation __test_module__spawnee_0_next_inst(in_chans__0, in_chans__1, in_chans__2, out_chans__1, proc=__test_module__spawnee_0_next)
__state: () = state_read(state_element=__state, id=15)
tuple.18: () = tuple(id=18)
__token: token = literal(value=token, id=14)
literal.16: bits[1] = literal(value=1, id=16)
tuple.17: () = tuple(id=17)
next_value.19: () = next_value(param=__state, value=tuple.18, id=19)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package test_module

file_number 0 "test_module.x"

proc __test_module__spawnee_0_next<in_chans__0: bits[32] in, in_chans__1: bits[32] in, in_chans__2: bits[32] in, out_chans__0: bits[16] out, out_chans__1: bits[16] out>(__state: (), init={()}) {
chan_interface in_chans__0(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface in_chans__1(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface in_chans__2(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface out_chans__0(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface out_chans__1(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
literal.8: token = literal(value=token, id=8)
literal.3: bits[1] = literal(value=1, id=3)
literal.5: token = literal(value=token, id=5)
literal.6: bits[16] = literal(value=42, id=6)
receive.9: (token, bits[32]) = receive(literal.8, predicate=literal.3, channel=in_chans__1, id=9)
__state: () = state_read(state_element=__state, id=2)
tuple.11: () = tuple(id=11)
__token: token = literal(value=token, id=1)
tuple.4: () = tuple(id=4)
send.7: token = send(literal.5, literal.6, predicate=literal.3, channel=out_chans__0, id=7)
tuple_index.10: token = tuple_index(receive.9, index=0, id=10)
next_value.12: () = next_value(param=__state, value=tuple.11, id=12)
}

top proc __test_module__main_0_next<in_chans__0: bits[32] in, in_chans__1: bits[32] in, in_chans__2: bits[32] in, out_chans__0: bits[16] out, out_chans__1: bits[16] out>(__state: (), init={()}) {
chan_interface in_chans__0(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface in_chans__1(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface in_chans__2(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface out_chans__0(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface out_chans__1(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
proc_instantiation __test_module__spawnee_0_next_inst(in_chans__0, in_chans__1, in_chans__2, out_chans__0, out_chans__1, proc=__test_module__spawnee_0_next)
__state: () = state_read(state_element=__state, id=14)
tuple.17: () = tuple(id=17)
__token: token = literal(value=token, id=13)
literal.15: bits[1] = literal(value=1, id=15)
tuple.16: () = tuple(id=16)
next_value.18: () = next_value(param=__state, value=tuple.17, id=18)
}