Skip to content

Commit 02d78ef

Browse files
ericastorcopybara-github
authored andcommitted
[opt] Unconditionally remove sends' data if their predicates are false
Previously, when we are unable to remove an unconditionally-disabled send, we did nothing; this left the sends marked as using their data, which could prevent other optimizations from kicking in. Now, if we can't remove a send, we replace its data input with a zero-valued literal. PiperOrigin-RevId: 809231139
1 parent 52e22eb commit 02d78ef

File tree

2 files changed

+72
-22
lines changed

2 files changed

+72
-22
lines changed

xls/passes/useless_io_removal_pass.cc

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,23 @@ absl::StatusOr<bool> UselessIORemovalPass::RunInternal(
8888
}
8989
XLS_ASSIGN_OR_RETURN(ChannelRef channel_ref, send->GetChannelRef());
9090
Node* predicate = send->predicate().value();
91-
if (query_engine.IsAllZeros(predicate) &&
92-
channel_maps.to_send.at(channel_ref).size() >= 2) {
93-
channel_maps.to_send.at(channel_ref).erase(send);
94-
replacement = send->token();
91+
if (query_engine.IsAllZeros(predicate)) {
92+
// We can remove the send if this is not the last send left on the
93+
// channel.
94+
if (channel_maps.to_send.at(channel_ref).size() > 1) {
95+
channel_maps.to_send.at(channel_ref).erase(send);
96+
replacement = send->token();
97+
} else if (!send->data()->Is<Literal>() ||
98+
!send->data()->As<Literal>()->value().IsAllZeros()) {
99+
// If we aren't removing the send, at least replace its data
100+
// input with zero, as it will never be used.
101+
XLS_ASSIGN_OR_RETURN(
102+
Literal * zero,
103+
proc->MakeNode<Literal>(send->loc(),
104+
ZeroOfType(send->data()->GetType())));
105+
send->ReplaceOperand(send->data(), zero);
106+
changed = true;
107+
}
95108
} else if (query_engine.IsAllOnes(predicate)) {
96109
XLS_ASSIGN_OR_RETURN(
97110
replacement,
@@ -108,6 +121,8 @@ absl::StatusOr<bool> UselessIORemovalPass::RunInternal(
108121
Node* predicate = receive->predicate().value();
109122
if (query_engine.IsAllZeros(predicate) &&
110123
channel_maps.to_receive.at(channel_ref).size() >= 2) {
124+
// We can remove the receive if this is not the last receive left on
125+
// the channel.
111126
XLS_ASSIGN_OR_RETURN(Channel * channel, GetChannelUsedByNode(node));
112127
channel_maps.to_receive.at(channel_ref).erase(receive);
113128
XLS_ASSIGN_OR_RETURN(Literal * zero,

xls/passes/useless_io_removal_pass_test.cc

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
#include "xls/ir/function_builder.h"
3333
#include "xls/ir/ir_matcher.h"
3434
#include "xls/ir/ir_test_base.h"
35+
#include "xls/ir/nodes.h"
3536
#include "xls/ir/package.h"
37+
#include "xls/ir/proc.h"
3638
#include "xls/ir/source_location.h"
3739
#include "xls/ir/value.h"
3840
#include "xls/passes/dce_pass.h"
@@ -77,27 +79,18 @@ TEST_F(UselessIORemovalPassTest, DontRemoveOnlySend) {
7779
p->CreateStreamingChannel("test_channel", ChannelOps::kSendOnly,
7880
p->GetBitsType(32)));
7981
ProcBuilder pb(TestName(), p.get());
82+
BValue token = pb.StateElement("tkn", Value::Token());
8083
pb.StateElement("state", Value(UBits(0, 0)));
81-
pb.SendIf(channel, pb.Literal(Value::Token()), pb.Literal(UBits(0, 1)),
82-
pb.Literal(UBits(1, 32)));
83-
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build({pb.Literal(UBits(0, 0))}));
84-
int64_t original_node_count = proc->node_count();
85-
EXPECT_THAT(Run(p.get()), IsOkAndHolds(false));
86-
EXPECT_EQ(proc->node_count(), original_node_count);
87-
}
88-
89-
TEST_F(UselessIORemovalPassTest, DontRemoveOnlySendNewStyle) {
90-
auto p = CreatePackage();
91-
TokenlessProcBuilder pb(NewStyleProc(), TestName(), "tkn", p.get());
92-
XLS_ASSERT_OK_AND_ASSIGN(
93-
SendChannelInterface * channel,
94-
pb.AddOutputChannel("test_channel", p->GetBitsType(32)));
95-
pb.StateElement("state", Value(UBits(0, 0)));
96-
pb.SendIf(channel, pb.Literal(UBits(0, 1)), pb.Literal(UBits(1, 32)));
97-
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build({pb.Literal(UBits(0, 0))}));
84+
token = pb.SendIf(channel, token, pb.Literal(UBits(0, 1)),
85+
pb.Literal(UBits(1, 32)), SourceInfo(), "my_send");
86+
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc,
87+
pb.Build({token, pb.Literal(UBits(0, 0))}));
9888
int64_t original_node_count = proc->node_count();
99-
EXPECT_THAT(Run(p.get()), IsOkAndHolds(false));
89+
EXPECT_THAT(Run(p.get()), IsOkAndHolds(true));
10090
EXPECT_EQ(proc->node_count(), original_node_count);
91+
XLS_ASSERT_OK_AND_ASSIGN(Node * send_node, proc->GetNode("my_send"));
92+
ASSERT_TRUE(send_node->Is<Send>());
93+
EXPECT_THAT(send_node->As<Send>()->data(), m::Literal(0));
10194
}
10295

10396
TEST_F(UselessIORemovalPassTest, RemoveSendIfLiteralFalse) {
@@ -243,6 +236,48 @@ TEST_F(UselessIORemovalPassTest, RemoveReceivePredIfLiteralTrue) {
243236
ElementsAre(m::Next(proc->GetStateRead(1), m::TupleIndex(tuple, 1))));
244237
}
245238

239+
TEST_F(UselessIORemovalPassTest, DontRemoveLastSendIfOnSendOnlyChannel) {
240+
auto p = CreatePackage();
241+
XLS_ASSERT_OK_AND_ASSIGN(
242+
StreamingChannel * channel,
243+
p->CreateStreamingChannel("test_channel", ChannelOps::kSendOnly,
244+
p->GetBitsType(32)));
245+
ProcBuilder pb(TestName(), p.get());
246+
BValue token = pb.StateElement("tkn", Value::Token());
247+
pb.StateElement("state", Value(UBits(0, 0)));
248+
token = pb.SendIf(channel, token, pb.Literal(UBits(0, 1)),
249+
pb.Literal(UBits(1, 32)), SourceInfo(), "my_send");
250+
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc,
251+
pb.Build({token, pb.Literal(UBits(0, 0))}));
252+
ASSERT_EQ(p->channels().size(), 1);
253+
int64_t original_node_count = proc->node_count();
254+
EXPECT_THAT(Run(p.get()), IsOkAndHolds(true));
255+
EXPECT_EQ(proc->node_count(), original_node_count);
256+
XLS_ASSERT_OK_AND_ASSIGN(Node * send_node, proc->GetNode("my_send"));
257+
ASSERT_TRUE(send_node->Is<Send>());
258+
EXPECT_THAT(send_node->As<Send>()->data(), m::Literal(0));
259+
}
260+
261+
TEST_F(UselessIORemovalPassTest, DontRemoveLastReceiveIfOnReceiveOnlyChannel) {
262+
auto p = CreatePackage();
263+
XLS_ASSERT_OK_AND_ASSIGN(
264+
StreamingChannel * channel,
265+
p->CreateStreamingChannel("test_channel", ChannelOps::kReceiveOnly,
266+
p->GetBitsType(32)));
267+
ProcBuilder pb(TestName(), p.get());
268+
BValue token = pb.StateElement("tkn", Value::Token());
269+
pb.StateElement("state", Value(UBits(0, 32)));
270+
BValue token_and_result =
271+
pb.ReceiveIf(channel, token, pb.Literal(UBits(0, 1)));
272+
token = pb.TupleIndex(token_and_result, 0);
273+
BValue result = pb.TupleIndex(token_and_result, 1);
274+
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build({token, result}));
275+
ASSERT_EQ(p->channels().size(), 1);
276+
int64_t original_node_count = proc->node_count();
277+
EXPECT_THAT(Run(p.get()), IsOkAndHolds(false));
278+
EXPECT_EQ(proc->node_count(), original_node_count);
279+
}
280+
246281
void IrFuzzUselessIORemoval(FuzzPackageWithArgs fuzz_package_with_args) {
247282
UselessIORemovalPass pass;
248283
OptimizationPassChangesOutputs(std::move(fuzz_package_with_args), pass);

0 commit comments

Comments
 (0)