Skip to content

Commit 8c72089

Browse files
committed
Add target module and old_to_new mapping to replacer
1 parent a08c440 commit 8c72089

File tree

7 files changed

+309
-78
lines changed

7 files changed

+309
-78
lines changed

xls/dslx/fmt/ast_fmt.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3167,7 +3167,10 @@ absl::StatusOr<std::string> AutoFmt(VirtualizableFilesystem& vfs,
31673167
FormatDisabler disabler(vfs, comments, *m.fs_path());
31683168
XLS_ASSIGN_OR_RETURN(
31693169
std::unique_ptr<Module> clone,
3170-
CloneModule(m, std::bind_front(&FormatDisabler::operator(), &disabler)));
3170+
CloneModule(m, [&](const AstNode* node, Module*,
3171+
const absl::flat_hash_map<const AstNode*, AstNode*>&) {
3172+
return disabler(node);
3173+
}));
31713174
return AutoFmt(*clone, comments, text_width);
31723175
}
31733176

@@ -3177,7 +3180,10 @@ absl::StatusOr<std::string> AutoFmt(VirtualizableFilesystem& vfs,
31773180
FormatDisabler disabler(vfs, comments, contents);
31783181
XLS_ASSIGN_OR_RETURN(
31793182
std::unique_ptr<Module> clone,
3180-
CloneModule(m, std::bind_front(&FormatDisabler::operator(), &disabler)));
3183+
CloneModule(m, [&](const AstNode* node, Module*,
3184+
const absl::flat_hash_map<const AstNode*, AstNode*>&) {
3185+
return disabler(node);
3186+
}));
31813187
return AutoFmt(*clone, comments, text_width);
31823188
}
31833189

xls/dslx/frontend/ast_cloner.cc

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,8 @@ class AstCloner : public AstNodeVisitor {
12051205
if (old_to_new_.contains(node)) {
12061206
return absl::OkStatus();
12071207
}
1208-
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> replacement, replacer_(node));
1208+
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> replacement,
1209+
replacer_(node, module(node), old_to_new_));
12091210
if (replacement.has_value()) {
12101211
old_to_new_[node] = *replacement;
12111212
return absl::OkStatus();
@@ -1238,17 +1239,20 @@ class AstCloner : public AstNodeVisitor {
12381239

12391240
} // namespace
12401241

1241-
std::optional<AstNode*> PreserveTypeDefinitionsReplacer(const AstNode* node) {
1242+
std::optional<AstNode*> PreserveTypeDefinitionsReplacer(
1243+
const AstNode* node, Module* module,
1244+
const absl::flat_hash_map<const AstNode*, AstNode*>&) {
12421245
if (node->kind() == AstNodeKind::kTypeRef) {
12431246
const auto* type_ref = down_cast<const TypeRef*>(node);
1244-
return node->owner()->Make<TypeRef>(type_ref->span(),
1245-
type_ref->type_definition());
1247+
return module->Make<TypeRef>(type_ref->span(), type_ref->type_definition());
12461248
}
12471249
return std::nullopt;
12481250
}
12491251

12501252
CloneReplacer NameRefReplacer(const NameDef* def, Expr* replacement) {
1251-
return [=](const AstNode* node) -> std::optional<AstNode*> {
1253+
return [=](const AstNode* node, Module* new_module,
1254+
const absl::flat_hash_map<const AstNode*, AstNode*>&)
1255+
-> std::optional<AstNode*> {
12521256
if (node->kind() == AstNodeKind::kNameRef) {
12531257
const auto* name_ref = down_cast<const NameRef*>(node);
12541258
if (std::holds_alternative<const NameDef*>(name_ref->name_def()) &&
@@ -1262,14 +1266,16 @@ CloneReplacer NameRefReplacer(const NameDef* def, Expr* replacement) {
12621266

12631267
CloneReplacer NameRefReplacer(
12641268
const absl::flat_hash_map<const NameDef*, NameDef*>* replacement_defs) {
1265-
return [=](const AstNode* original_node) -> std::optional<AstNode*> {
1269+
return [=](const AstNode* original_node, Module* new_module,
1270+
const absl::flat_hash_map<const AstNode*, AstNode*>&)
1271+
-> std::optional<AstNode*> {
12661272
if (original_node->kind() == AstNodeKind::kNameRef) {
12671273
const auto* original_ref = down_cast<const NameRef*>(original_node);
12681274
const AstNode* def = ToAstNode(original_ref->name_def());
12691275
if (def->kind() == AstNodeKind::kNameDef) {
12701276
const auto it = replacement_defs->find(down_cast<const NameDef*>(def));
12711277
if (it != replacement_defs->end()) {
1272-
return original_node->owner()->Make<NameRef>(
1278+
return new_module->Make<NameRef>(
12731279
original_ref->span(), original_ref->identifier(), it->second);
12741280
}
12751281
}
@@ -1285,8 +1291,11 @@ CloneAstAndGetAllPairs(const AstNode* root,
12851291
if (root->kind() == AstNodeKind::kModule) {
12861292
return absl::InvalidArgumentError("Clone a module via 'CloneModule'.");
12871293
}
1294+
Module* new_module =
1295+
target_module.has_value() ? *target_module : root->owner();
1296+
absl::flat_hash_map<const AstNode*, AstNode*> empty_old_to_new;
12881297
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> root_replacement,
1289-
replacer(root));
1298+
replacer(root, new_module, empty_old_to_new));
12901299
if (root_replacement.has_value()) {
12911300
return absl::flat_hash_map<const AstNode*, AstNode*>{
12921301
{root, *root_replacement}};
@@ -1317,15 +1326,19 @@ absl::StatusOr<std::unique_ptr<Module>> CloneModule(const Module& module,
13171326
}
13181327

13191328
CloneReplacer ChainCloneReplacers(CloneReplacer first, CloneReplacer second) {
1320-
return [first = std::move(first),
1321-
second = std::move(second)](const AstNode* node) mutable
1322-
-> absl::StatusOr<std::optional<AstNode*>> {
1323-
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> first_result, first(node));
1324-
XLS_ASSIGN_OR_RETURN(
1325-
std::optional<AstNode*> second_result,
1326-
second(first_result.has_value() ? *first_result : node));
1327-
return second_result.has_value() ? second_result : first_result;
1328-
};
1329+
return
1330+
[first = std::move(first), second = std::move(second)](
1331+
const AstNode* node, Module* module,
1332+
const absl::flat_hash_map<const AstNode*, AstNode*>&
1333+
old_to_new) mutable -> absl::StatusOr<std::optional<AstNode*>> {
1334+
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> first_result,
1335+
first(node, module, old_to_new));
1336+
XLS_ASSIGN_OR_RETURN(
1337+
std::optional<AstNode*> second_result,
1338+
second(first_result.has_value() ? *first_result : node, module,
1339+
old_to_new));
1340+
return second_result.has_value() ? second_result : first_result;
1341+
};
13291342
}
13301343

13311344
// Verifies that `node` consists solely of "new" AST nodes and none that are

xls/dslx/frontend/ast_cloner.h

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,21 @@ namespace xls::dslx {
3434
// nodes during a `CloneAst` operations. A replacer can be used to replace
3535
// targeted nodes with something else entirely, or it can just "clone" those
3636
// nodes differently than the default logic.
37+
//
38+
// The replacer is invoked with:
39+
// - the original AST node under consideration
40+
// - the target `Module*` where any new nodes should be created
41+
// - a pointer to the current old->new mapping accumulated so far during clone
3742
using CloneReplacer =
38-
absl::AnyInvocable<absl::StatusOr<std::optional<AstNode*>>(const AstNode*)>;
43+
absl::AnyInvocable<absl::StatusOr<std::optional<AstNode*>>(
44+
const AstNode*, Module*,
45+
const absl::flat_hash_map<const AstNode*, AstNode*>&)>;
3946

4047
// This function is directly usable as the `replacer` argument for `CloneAst`
4148
// when a direct clone with no replacements is desired.
42-
inline std::optional<AstNode*> NoopCloneReplacer(const AstNode* original_node) {
49+
inline std::optional<AstNode*> NoopCloneReplacer(
50+
const AstNode* original_node, Module*,
51+
const absl::flat_hash_map<const AstNode*, AstNode*>&) {
4352
return std::nullopt;
4453
}
4554

@@ -50,8 +59,11 @@ class ObservableCloneReplacer {
5059
explicit ObservableCloneReplacer(bool* flag, CloneReplacer replacer)
5160
: flag_(flag), replacer_(std::move(replacer)) {}
5261

53-
absl::StatusOr<std::optional<AstNode*>> operator()(const AstNode* node) {
54-
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> result, replacer_(node));
62+
absl::StatusOr<std::optional<AstNode*>> operator()(
63+
const AstNode* node, Module* module,
64+
const absl::flat_hash_map<const AstNode*, AstNode*>& old_to_new) {
65+
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> result,
66+
replacer_(node, module, old_to_new));
5567
*flag_ |= result.has_value();
5668
return result;
5769
}
@@ -66,7 +78,8 @@ class ObservableCloneReplacer {
6678
// cloning return types without recursing into cloned definitions which would
6779
// change nominal types.
6880
std::optional<AstNode*> PreserveTypeDefinitionsReplacer(
69-
const AstNode* original_node);
81+
const AstNode* original_node, Module* module,
82+
const absl::flat_hash_map<const AstNode*, AstNode*>& old_to_new);
7083

7184
// Creates a `CloneReplacer` that replaces references to the given `def` with
7285
// the given `replacement`.

0 commit comments

Comments
 (0)