diff --git a/xls/passes/BUILD b/xls/passes/BUILD index b7b01c1cc8..4d13941b12 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -4287,6 +4287,7 @@ cc_test( "//xls/common/fuzzing:fuzztest", "//xls/common/status:matchers", "//xls/common/status:status_macros", + "//xls/estimators/delay_model:delay_estimator", "//xls/fuzzer/ir_fuzzer:ir_fuzz_domain", "//xls/fuzzer/ir_fuzzer:ir_fuzz_test_library", "//xls/ir", diff --git a/xls/passes/select_lifting_pass.cc b/xls/passes/select_lifting_pass.cc index 7ca37aa893..529755de03 100644 --- a/xls/passes/select_lifting_pass.cc +++ b/xls/passes/select_lifting_pass.cc @@ -447,6 +447,141 @@ absl::StatusOr> CanLiftSelect( return std::nullopt; } +absl::StatusOr MakeSelectNode(FunctionBase* func, Node* old_select, + const std::vector& new_cases, + std::optional new_default) { + if (old_select->Is()) { + return func->MakeNode( + SourceInfo(), old_select->As()->selector(), new_cases, + *new_default); + } else { + return func->MakeNode()->selector(), + new_cases, new_default); + } +} + +absl::StatusOr CheckLatencyIncrease( + FunctionBase* func, Node* select_to_optimize, const LiftedOpInfo& info, + const OptimizationPassOptions& options, OptimizationContext& context) { + CriticalPathDelayAnalysis* analysis = + context.SharedNodeData(func, options); + if (analysis == nullptr) { + return absl::InternalError(absl::StrCat( + "Failed to get CriticalPathDelayAnalysis for delay model: ", + *options.delay_model)); + } + + // Check the (unscheduled) critical path through the select we're optimizing + int64_t t_before = *analysis->GetInfo(select_to_optimize); + + // To make it easy to estimate the critical path after lifting the select, we + // add nodes to represent the post-optimization result. + Node* tmp_new_select = nullptr; + Node* tmp_lifted_op = nullptr; + absl::flat_hash_set tmp_identity_literals; + + absl::Cleanup cleanup = [&] { + if (tmp_lifted_op != nullptr) { + CHECK_OK(func->RemoveNode(tmp_lifted_op)); + } + if (tmp_new_select != nullptr) { + CHECK_OK(func->RemoveNode(tmp_new_select)); + } + for (Node* literal : tmp_identity_literals) { + CHECK_OK(func->RemoveNode(literal)); + } + }; + + if (info.lifted_op == Op::kArrayIndex) { + XLS_ASSIGN_OR_RETURN( + tmp_new_select, + MakeSelectNode(func, select_to_optimize, info.other_operands, + info.default_other_operand)); + XLS_ASSIGN_OR_RETURN( + tmp_lifted_op, + func->MakeNode(SourceInfo(), info.shared_node, + absl::Span{tmp_new_select})); + } else { + Type* other_operand_type = nullptr; + if (!info.other_operands.empty()) { + other_operand_type = info.other_operands[0]->GetType(); + } else { + other_operand_type = (*info.default_other_operand)->GetType(); + } + + std::vector tmp_new_cases; + std::optional tmp_new_default; + absl::Span original_cases = GetCases(select_to_optimize); + int64_t other_operand_idx = 0; + for (int64_t i = 0; i < original_cases.size(); ++i) { + if (info.identity_case_indices.contains(i)) { + XLS_ASSIGN_OR_RETURN( + Node * identity_literal, + GetIdentityLiteral(info.lifted_op, other_operand_type, func)); + tmp_new_cases.push_back(identity_literal); + tmp_identity_literals.insert(identity_literal); + } else { + tmp_new_cases.push_back(info.other_operands[other_operand_idx++]); + } + } + std::optional original_default = GetDefaultValue(select_to_optimize); + if (original_default.has_value()) { + if (info.default_is_identity) { + XLS_ASSIGN_OR_RETURN( + Node * identity_literal, + GetIdentityLiteral(info.lifted_op, other_operand_type, func)); + tmp_new_default = identity_literal; + tmp_identity_literals.insert(identity_literal); + } else { + tmp_new_default = *info.default_other_operand; + } + } + XLS_ASSIGN_OR_RETURN(tmp_new_select, + MakeSelectNode(func, select_to_optimize, tmp_new_cases, + tmp_new_default)); + Node* lhs = info.shared_is_lhs ? info.shared_node : tmp_new_select; + Node* rhs = info.shared_is_lhs ? tmp_new_select : info.shared_node; + switch (info.lifted_op) { + case Op::kAdd: + case Op::kSub: + case Op::kShll: + case Op::kShrl: + case Op::kShra: { + XLS_ASSIGN_OR_RETURN( + tmp_lifted_op, + func->MakeNode(SourceInfo(), lhs, rhs, info.lifted_op)); + break; + } + case Op::kAnd: + case Op::kOr: + case Op::kXor: { + XLS_ASSIGN_OR_RETURN( + tmp_lifted_op, + func->MakeNode(SourceInfo(), std::vector{lhs, rhs}, + info.lifted_op)); + break; + } + case Op::kUMul: + case Op::kSMul: { + XLS_ASSIGN_OR_RETURN( + tmp_lifted_op, func->MakeNode( + SourceInfo(), lhs, rhs, + select_to_optimize->GetType()->GetFlatBitCount(), + info.lifted_op)); + break; + } + default: + return absl::InternalError( + absl::StrCat("Unsupported binary operation in latency check: ", + OpToString(info.lifted_op))); + } + } + + int64_t t_after = *analysis->GetInfo(tmp_lifted_op); + return t_after > t_before; +} + absl::StatusOr ProfitabilityGuardForArrayIndex(FunctionBase* func, Node* select_to_optimize, Node* array_reference) { @@ -554,23 +689,30 @@ absl::StatusOr ProfitabilityGuardForArrayIndex(FunctionBase* func, return true; } -absl::StatusOr MakeSelectNode(FunctionBase* func, Node* old_select, - const std::vector& new_cases, - std::optional new_default) { - if (old_select->Is()) { - return func->MakeNode( - SourceInfo(), old_select->As()->selector(), new_cases, - *new_default); - } else { - return func->MakeNode()->selector(), - new_cases, new_default); - } -} - absl::StatusOr ProfitabilityGuardForBinaryOperation( FunctionBase* func, Node* select_to_optimize, const LiftedOpInfo& info, const OptimizationPassOptions& options, OptimizationContext& context) { + // If lifting a shift operation combines literal shift amounts into a select, + // this creates a variable shift from constant shifts, which is more + // expensive. If we don't have a delay model, avoid lifting in this case. + // If we do have a delay model, we trust it to determine if lifting is + // profitable. + bool all_other_operands_are_literals = absl::c_all_of( + info.other_operands, [](Node* n) { return n->Is(); }); + bool default_is_literal_or_identity_or_nonexistent = + !info.default_other_operand.has_value() || info.default_is_identity || + (*info.default_other_operand)->Is(); + if (!options.delay_model.has_value() && info.shared_is_lhs && + (info.lifted_op == Op::kShll || info.lifted_op == Op::kShrl || + info.lifted_op == Op::kShra) && + all_other_operands_are_literals && + default_is_literal_or_identity_or_nonexistent) { + VLOG(3) << " Not lifting " << OpToString(info.lifted_op) + << " because it would replace constant shifts with a variable " + "shift, and no delay model is provided."; + return false; + } + // Heuristically: If the selector depends on `shared_node`, lifting will // likely serialize more operations & worsen the critical path. Node* selector = select_to_optimize->Is