Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4287,6 +4287,7 @@ cc_test(
"//xls/common/fuzzing:fuzztest",
"//xls/common/status:matchers",
"//xls/common/status:status_macros",
"//xls/estimators/delay_model:delay_estimator",
"//xls/fuzzer/ir_fuzzer:ir_fuzz_domain",
"//xls/fuzzer/ir_fuzzer:ir_fuzz_test_library",
"//xls/ir",
Expand Down
357 changes: 181 additions & 176 deletions xls/passes/select_lifting_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,141 @@ absl::StatusOr<std::optional<LiftedOpInfo>> CanLiftSelect(
return std::nullopt;
}

absl::StatusOr<Node*> MakeSelectNode(FunctionBase* func, Node* old_select,
const std::vector<Node*>& new_cases,
std::optional<Node*> new_default) {
if (old_select->Is<PrioritySelect>()) {
return func->MakeNode<PrioritySelect>(
SourceInfo(), old_select->As<PrioritySelect>()->selector(), new_cases,
*new_default);
} else {
return func->MakeNode<Select>(SourceInfo(),
old_select->As<Select>()->selector(),
new_cases, new_default);
}
}

absl::StatusOr<bool> CheckLatencyIncrease(
FunctionBase* func, Node* select_to_optimize, const LiftedOpInfo& info,
const OptimizationPassOptions& options, OptimizationContext& context) {
CriticalPathDelayAnalysis* analysis =
context.SharedNodeData<CriticalPathDelayAnalysis>(func, options);
if (analysis == nullptr) {
return absl::InternalError(absl::StrCat(
"Failed to get CriticalPathDelayAnalysis for delay model: ",
*options.delay_model));
}

// Check the (unscheduled) critical path through the select we're optimizing
int64_t t_before = *analysis->GetInfo(select_to_optimize);

// To make it easy to estimate the critical path after lifting the select, we
// add nodes to represent the post-optimization result.
Node* tmp_new_select = nullptr;
Node* tmp_lifted_op = nullptr;
absl::flat_hash_set<Node*> tmp_identity_literals;

absl::Cleanup cleanup = [&] {
if (tmp_lifted_op != nullptr) {
CHECK_OK(func->RemoveNode(tmp_lifted_op));
}
if (tmp_new_select != nullptr) {
CHECK_OK(func->RemoveNode(tmp_new_select));
}
for (Node* literal : tmp_identity_literals) {
CHECK_OK(func->RemoveNode(literal));
}
};

if (info.lifted_op == Op::kArrayIndex) {
XLS_ASSIGN_OR_RETURN(
tmp_new_select,
MakeSelectNode(func, select_to_optimize, info.other_operands,
info.default_other_operand));
XLS_ASSIGN_OR_RETURN(
tmp_lifted_op,
func->MakeNode<ArrayIndex>(SourceInfo(), info.shared_node,
absl::Span<Node* const>{tmp_new_select}));
} else {
Type* other_operand_type = nullptr;
if (!info.other_operands.empty()) {
other_operand_type = info.other_operands[0]->GetType();
} else {
other_operand_type = (*info.default_other_operand)->GetType();
}

std::vector<Node*> tmp_new_cases;
std::optional<Node*> tmp_new_default;
absl::Span<Node* const> original_cases = GetCases(select_to_optimize);
int64_t other_operand_idx = 0;
for (int64_t i = 0; i < original_cases.size(); ++i) {
if (info.identity_case_indices.contains(i)) {
XLS_ASSIGN_OR_RETURN(
Node * identity_literal,
GetIdentityLiteral(info.lifted_op, other_operand_type, func));
tmp_new_cases.push_back(identity_literal);
tmp_identity_literals.insert(identity_literal);
} else {
tmp_new_cases.push_back(info.other_operands[other_operand_idx++]);
}
}
std::optional<Node*> original_default = GetDefaultValue(select_to_optimize);
if (original_default.has_value()) {
if (info.default_is_identity) {
XLS_ASSIGN_OR_RETURN(
Node * identity_literal,
GetIdentityLiteral(info.lifted_op, other_operand_type, func));
tmp_new_default = identity_literal;
tmp_identity_literals.insert(identity_literal);
} else {
tmp_new_default = *info.default_other_operand;
}
}
XLS_ASSIGN_OR_RETURN(tmp_new_select,
MakeSelectNode(func, select_to_optimize, tmp_new_cases,
tmp_new_default));
Node* lhs = info.shared_is_lhs ? info.shared_node : tmp_new_select;
Node* rhs = info.shared_is_lhs ? tmp_new_select : info.shared_node;
switch (info.lifted_op) {
case Op::kAdd:
case Op::kSub:
case Op::kShll:
case Op::kShrl:
case Op::kShra: {
XLS_ASSIGN_OR_RETURN(
tmp_lifted_op,
func->MakeNode<BinOp>(SourceInfo(), lhs, rhs, info.lifted_op));
break;
}
case Op::kAnd:
case Op::kOr:
case Op::kXor: {
XLS_ASSIGN_OR_RETURN(
tmp_lifted_op,
func->MakeNode<NaryOp>(SourceInfo(), std::vector<Node*>{lhs, rhs},
info.lifted_op));
break;
}
case Op::kUMul:
case Op::kSMul: {
XLS_ASSIGN_OR_RETURN(
tmp_lifted_op, func->MakeNode<ArithOp>(
SourceInfo(), lhs, rhs,
select_to_optimize->GetType()->GetFlatBitCount(),
info.lifted_op));
break;
}
default:
return absl::InternalError(
absl::StrCat("Unsupported binary operation in latency check: ",
OpToString(info.lifted_op)));
}
}

int64_t t_after = *analysis->GetInfo(tmp_lifted_op);
return t_after > t_before;
}

absl::StatusOr<bool> ProfitabilityGuardForArrayIndex(FunctionBase* func,
Node* select_to_optimize,
Node* array_reference) {
Expand Down Expand Up @@ -554,23 +689,30 @@ absl::StatusOr<bool> ProfitabilityGuardForArrayIndex(FunctionBase* func,
return true;
}

absl::StatusOr<Node*> MakeSelectNode(FunctionBase* func, Node* old_select,
const std::vector<Node*>& new_cases,
std::optional<Node*> new_default) {
if (old_select->Is<PrioritySelect>()) {
return func->MakeNode<PrioritySelect>(
SourceInfo(), old_select->As<PrioritySelect>()->selector(), new_cases,
*new_default);
} else {
return func->MakeNode<Select>(SourceInfo(),
old_select->As<Select>()->selector(),
new_cases, new_default);
}
}

absl::StatusOr<bool> ProfitabilityGuardForBinaryOperation(
FunctionBase* func, Node* select_to_optimize, const LiftedOpInfo& info,
const OptimizationPassOptions& options, OptimizationContext& context) {
// If lifting a shift operation combines literal shift amounts into a select,
// this creates a variable shift from constant shifts, which is more
// expensive. If we don't have a delay model, avoid lifting in this case.
// If we do have a delay model, we trust it to determine if lifting is
// profitable.
bool all_other_operands_are_literals = absl::c_all_of(
info.other_operands, [](Node* n) { return n->Is<Literal>(); });
bool default_is_literal_or_identity_or_nonexistent =
!info.default_other_operand.has_value() || info.default_is_identity ||
(*info.default_other_operand)->Is<Literal>();
if (!options.delay_model.has_value() && info.shared_is_lhs &&
(info.lifted_op == Op::kShll || info.lifted_op == Op::kShrl ||
info.lifted_op == Op::kShra) &&
all_other_operands_are_literals &&
default_is_literal_or_identity_or_nonexistent) {
VLOG(3) << " Not lifting " << OpToString(info.lifted_op)
<< " because it would replace constant shifts with a variable "
"shift, and no delay model is provided.";
return false;
}

// Heuristically: If the selector depends on `shared_node`, lifting will
// likely serialize more operations & worsen the critical path.
Node* selector = select_to_optimize->Is<Select>()
Expand All @@ -582,134 +724,6 @@ absl::StatusOr<bool> ProfitabilityGuardForBinaryOperation(
return false;
}

// If latency-aware profitability is enabled, check if lifting increases
// latency.
if (options.delay_model.has_value()) {
CriticalPathDelayAnalysis* analysis =
context.SharedNodeData<CriticalPathDelayAnalysis>(func, options);
if (analysis == nullptr) {
return absl::InternalError(absl::StrCat(
"Failed to get CriticalPathDelayAnalysis for delay model: ",
*options.delay_model));
}

// Check the (unscheduled) critical path through the select we're optimizing
int64_t t_before = *analysis->GetInfo(select_to_optimize);

// To make it easy to estimate the critical path after lifting the select,
// we add nodes to represent the post-optimization result.
Type* other_operand_type = nullptr;
if (!info.other_operands.empty()) {
other_operand_type = info.other_operands[0]->GetType();
} else {
other_operand_type = (*info.default_other_operand)->GetType();
}

std::vector<Node*> tmp_new_cases;
std::optional<Node*> tmp_new_default;
absl::flat_hash_set<Node*> tmp_identity_literals;
Node* tmp_new_select = nullptr;
Node* tmp_lifted_op = nullptr;

absl::Cleanup cleanup = [&] {
if (tmp_lifted_op != nullptr) {
CHECK_OK(func->RemoveNode(tmp_lifted_op));
}
if (tmp_new_select != nullptr) {
CHECK_OK(func->RemoveNode(tmp_new_select));
}
for (Node* literal : tmp_identity_literals) {
CHECK_OK(func->RemoveNode(literal));
}
};

absl::Span<Node* const> original_cases = GetCases(select_to_optimize);
int64_t other_operand_idx = 0;
for (int64_t i = 0; i < original_cases.size(); ++i) {
if (info.identity_case_indices.contains(i)) {
XLS_ASSIGN_OR_RETURN(
Node * identity_literal,
GetIdentityLiteral(info.lifted_op, other_operand_type, func));
tmp_new_cases.push_back(identity_literal);
tmp_identity_literals.insert(identity_literal);
} else {
tmp_new_cases.push_back(info.other_operands[other_operand_idx++]);
}
}
std::optional<Node*> original_default = GetDefaultValue(select_to_optimize);
if (original_default.has_value()) {
if (info.default_is_identity) {
XLS_ASSIGN_OR_RETURN(
Node * identity_literal,
GetIdentityLiteral(info.lifted_op, other_operand_type, func));
tmp_new_default = identity_literal;
tmp_identity_literals.insert(identity_literal);
} else {
tmp_new_default = *info.default_other_operand;
}
}

XLS_ASSIGN_OR_RETURN(tmp_new_select,
MakeSelectNode(func, select_to_optimize, tmp_new_cases,
tmp_new_default));

Node* lhs = info.shared_is_lhs ? info.shared_node : tmp_new_select;
Node* rhs = info.shared_is_lhs ? tmp_new_select : info.shared_node;
switch (info.lifted_op) {
case Op::kAdd:
case Op::kSub:
case Op::kShll:
case Op::kShrl:
case Op::kShra: {
XLS_ASSIGN_OR_RETURN(
tmp_lifted_op,
func->MakeNode<BinOp>(SourceInfo(), lhs, rhs, info.lifted_op));
break;
}
case Op::kAnd:
case Op::kOr:
case Op::kXor: {
XLS_ASSIGN_OR_RETURN(
tmp_lifted_op,
func->MakeNode<NaryOp>(SourceInfo(), std::vector<Node*>{lhs, rhs},
info.lifted_op));
break;
}
case Op::kUMul:
case Op::kSMul: {
XLS_ASSIGN_OR_RETURN(
tmp_lifted_op, func->MakeNode<ArithOp>(
SourceInfo(), lhs, rhs,
select_to_optimize->GetType()->GetFlatBitCount(),
info.lifted_op));
break;
}
default:
return absl::InternalError(
absl::StrCat("Unsupported binary operation in latency check: ",
OpToString(info.lifted_op)));
}

int64_t t_after = *analysis->GetInfo(tmp_lifted_op);

if (t_after > t_before) {
VLOG(3) << " Not lifting " << OpToString(info.lifted_op)
<< " because latency increases from " << t_before << " to "
<< t_after << ".";
return false;
}
} else {
// Fallback heuristic: if identity cases are present, don't lift
// mul/smul.
if ((!info.identity_case_indices.empty() || info.default_is_identity) &&
(info.lifted_op == Op::kUMul || info.lifted_op == Op::kSMul)) {
VLOG(3) << " Not lifting high-latency op "
<< OpToString(info.lifted_op)
<< " with identity cases because no delay model is provided.";
return false;
}
}

// Calculate Cost Before:
// Sum of bitwidths of the original select and any single-use non-identity
// case nodes.
Expand Down Expand Up @@ -762,42 +776,33 @@ absl::StatusOr<bool> ShouldLiftSelect(FunctionBase* func,
OptimizationContext& context) {
VLOG(3) << " Checking the profitability guard";

// Check if the transformation is profitable.
//
// Only "select" nodes with specific properties should be optimized.
// Such properties depend on the inputs of the "select" node.
//
// The next code checks to see if the "select" node given as input should be
// transformed.
switch (info.lifted_op) {
case Op::kArrayIndex:
// TODO(epastor): This case needs to be handled by the new logic.
return ProfitabilityGuardForArrayIndex(func, select_to_optimize,
info.shared_node);

// The applicability guard checked the operation has exactly 2 operands.
case Op::kAnd:
case Op::kOr:
case Op::kNand:
case Op::kNor:
case Op::kXor:
case Op::kAdd:
case Op::kSub:
case Op::kUMul:
case Op::kSMul:
case Op::kUDiv:
case Op::kSDiv:
case Op::kShll:
case Op::kShrl:
case Op::kShra:
return ProfitabilityGuardForBinaryOperation(func, select_to_optimize,
info, options, context);
break;

default:
VLOG(3) << " The current input of the select is not handled";
if (options.delay_model.has_value()) {
// If delay model is provided, check for latency increase.
XLS_ASSIGN_OR_RETURN(
bool latency_increases,
CheckLatencyIncrease(func, select_to_optimize, info, options, context));
if (latency_increases) {
VLOG(3) << " Not lifting " << OpToString(info.lifted_op)
<< " because latency increases.";
return false;
}
} else if ((!info.identity_case_indices.empty() ||
info.default_is_identity) &&
(info.lifted_op == Op::kUMul || info.lifted_op == Op::kSMul)) {
// If no delay model, apply fallback heuristic for binary ops.
VLOG(3) << " Not lifting high-latency op " << OpToString(info.lifted_op)
<< " with identity cases because no delay model is provided.";
return false;
}

// If we pass latency checks or they don't apply, proceed to
// op-specific profitability guards for bitwidth/cost checks.
if (info.lifted_op == Op::kArrayIndex) {
return ProfitabilityGuardForArrayIndex(func, select_to_optimize,
info.shared_node);
}
return ProfitabilityGuardForBinaryOperation(func, select_to_optimize, info,
options, context);
}

absl::StatusOr<TransformationResult> LiftSelectForArrayIndex(
Expand Down
Loading