Skip to content

Commit 0d83df2

Browse files
committed
Add target module and old_to_new mapping to replacer
1 parent 007e8f6 commit 0d83df2

File tree

7 files changed

+319
-82
lines changed

7 files changed

+319
-82
lines changed

xls/dslx/fmt/ast_fmt.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3167,7 +3167,10 @@ absl::StatusOr<std::string> AutoFmt(VirtualizableFilesystem& vfs,
31673167
FormatDisabler disabler(vfs, comments, *m.fs_path());
31683168
XLS_ASSIGN_OR_RETURN(
31693169
std::unique_ptr<Module> clone,
3170-
CloneModule(m, std::bind_front(&FormatDisabler::operator(), &disabler)));
3170+
CloneModule(m, [&](const AstNode* node, Module*,
3171+
const absl::flat_hash_map<const AstNode*, AstNode*>&) {
3172+
return disabler(node);
3173+
}));
31713174
return AutoFmt(*clone, comments, text_width);
31723175
}
31733176

@@ -3177,7 +3180,10 @@ absl::StatusOr<std::string> AutoFmt(VirtualizableFilesystem& vfs,
31773180
FormatDisabler disabler(vfs, comments, contents);
31783181
XLS_ASSIGN_OR_RETURN(
31793182
std::unique_ptr<Module> clone,
3180-
CloneModule(m, std::bind_front(&FormatDisabler::operator(), &disabler)));
3183+
CloneModule(m, [&](const AstNode* node, Module*,
3184+
const absl::flat_hash_map<const AstNode*, AstNode*>&) {
3185+
return disabler(node);
3186+
}));
31813187
return AutoFmt(*clone, comments, text_width);
31823188
}
31833189

xls/dslx/frontend/ast_cloner.cc

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,7 +1274,8 @@ class AstCloner : public AstNodeVisitor {
12741274
if (node == nullptr) {
12751275
return absl::OkStatus();
12761276
}
1277-
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> replacement, replacer_(node));
1277+
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> replacement,
1278+
replacer_(node, module(node), old_to_new_));
12781279
if (replacement.has_value()) {
12791280
old_to_new_[node] = *replacement;
12801281
return absl::OkStatus();
@@ -1307,17 +1308,20 @@ class AstCloner : public AstNodeVisitor {
13071308

13081309
} // namespace
13091310

1310-
std::optional<AstNode*> PreserveTypeDefinitionsReplacer(const AstNode* node) {
1311+
std::optional<AstNode*> PreserveTypeDefinitionsReplacer(
1312+
const AstNode* node, Module* module,
1313+
const absl::flat_hash_map<const AstNode*, AstNode*>&) {
13111314
if (node->kind() == AstNodeKind::kTypeRef) {
13121315
const auto* type_ref = down_cast<const TypeRef*>(node);
1313-
return node->owner()->Make<TypeRef>(type_ref->span(),
1314-
type_ref->type_definition());
1316+
return module->Make<TypeRef>(type_ref->span(), type_ref->type_definition());
13151317
}
13161318
return std::nullopt;
13171319
}
13181320

13191321
CloneReplacer NameRefReplacer(const NameDef* def, Expr* replacement) {
1320-
return [=](const AstNode* node) -> std::optional<AstNode*> {
1322+
return [=](const AstNode* node, Module* new_module,
1323+
const absl::flat_hash_map<const AstNode*, AstNode*>&)
1324+
-> std::optional<AstNode*> {
13211325
if (node->kind() == AstNodeKind::kNameRef) {
13221326
const auto* name_ref = down_cast<const NameRef*>(node);
13231327
if (std::holds_alternative<const NameDef*>(name_ref->name_def()) &&
@@ -1331,14 +1335,16 @@ CloneReplacer NameRefReplacer(const NameDef* def, Expr* replacement) {
13311335

13321336
CloneReplacer NameRefReplacer(
13331337
const absl::flat_hash_map<const NameDef*, NameDef*>* replacement_defs) {
1334-
return [=](const AstNode* original_node) -> std::optional<AstNode*> {
1338+
return [=](const AstNode* original_node, Module* new_module,
1339+
const absl::flat_hash_map<const AstNode*, AstNode*>&)
1340+
-> std::optional<AstNode*> {
13351341
if (original_node->kind() == AstNodeKind::kNameRef) {
13361342
const auto* original_ref = down_cast<const NameRef*>(original_node);
13371343
const AstNode* def = ToAstNode(original_ref->name_def());
13381344
if (def->kind() == AstNodeKind::kNameDef) {
13391345
const auto it = replacement_defs->find(down_cast<const NameDef*>(def));
13401346
if (it != replacement_defs->end()) {
1341-
return original_node->owner()->Make<NameRef>(
1347+
return new_module->Make<NameRef>(
13421348
original_ref->span(), original_ref->identifier(), it->second);
13431349
}
13441350
}
@@ -1354,8 +1360,11 @@ CloneAstAndGetAllPairs(const AstNode* root,
13541360
if (root->kind() == AstNodeKind::kModule) {
13551361
return absl::InvalidArgumentError("Clone a module via 'CloneModule'.");
13561362
}
1363+
Module* new_module =
1364+
target_module.has_value() ? *target_module : root->owner();
1365+
absl::flat_hash_map<const AstNode*, AstNode*> empty_old_to_new;
13571366
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> root_replacement,
1358-
replacer(root));
1367+
replacer(root, new_module, empty_old_to_new));
13591368
if (root_replacement.has_value()) {
13601369
return absl::flat_hash_map<const AstNode*, AstNode*>{
13611370
{root, *root_replacement}};
@@ -1386,15 +1395,19 @@ absl::StatusOr<std::unique_ptr<Module>> CloneModule(const Module& module,
13861395
}
13871396

13881397
CloneReplacer ChainCloneReplacers(CloneReplacer first, CloneReplacer second) {
1389-
return [first = std::move(first),
1390-
second = std::move(second)](const AstNode* node) mutable
1391-
-> absl::StatusOr<std::optional<AstNode*>> {
1392-
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> first_result, first(node));
1393-
XLS_ASSIGN_OR_RETURN(
1394-
std::optional<AstNode*> second_result,
1395-
second(first_result.has_value() ? *first_result : node));
1396-
return second_result.has_value() ? second_result : first_result;
1397-
};
1398+
return
1399+
[first = std::move(first), second = std::move(second)](
1400+
const AstNode* node, Module* module,
1401+
const absl::flat_hash_map<const AstNode*, AstNode*>&
1402+
old_to_new) mutable -> absl::StatusOr<std::optional<AstNode*>> {
1403+
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> first_result,
1404+
first(node, module, old_to_new));
1405+
XLS_ASSIGN_OR_RETURN(
1406+
std::optional<AstNode*> second_result,
1407+
second(first_result.has_value() ? *first_result : node, module,
1408+
old_to_new));
1409+
return second_result.has_value() ? second_result : first_result;
1410+
};
13981411
}
13991412

14001413
// Verifies that `node` consists solely of "new" AST nodes and none that are

xls/dslx/frontend/ast_cloner.h

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,21 @@ namespace xls::dslx {
3434
// nodes during a `CloneAst` operations. A replacer can be used to replace
3535
// targeted nodes with something else entirely, or it can just "clone" those
3636
// nodes differently than the default logic.
37+
//
38+
// The replacer is invoked with:
39+
// - the original AST node under consideration
40+
// - the target `Module*` where any new nodes should be created
41+
// - a pointer to the current old->new mapping accumulated so far during clone
3742
using CloneReplacer =
38-
absl::AnyInvocable<absl::StatusOr<std::optional<AstNode*>>(const AstNode*)>;
43+
absl::AnyInvocable<absl::StatusOr<std::optional<AstNode*>>(
44+
const AstNode*, Module*,
45+
const absl::flat_hash_map<const AstNode*, AstNode*>&)>;
3946

4047
// This function is directly usable as the `replacer` argument for `CloneAst`
4148
// when a direct clone with no replacements is desired.
42-
inline std::optional<AstNode*> NoopCloneReplacer(const AstNode* original_node) {
49+
inline std::optional<AstNode*> NoopCloneReplacer(
50+
const AstNode* original_node, Module*,
51+
const absl::flat_hash_map<const AstNode*, AstNode*>&) {
4352
return std::nullopt;
4453
}
4554

@@ -50,8 +59,11 @@ class ObservableCloneReplacer {
5059
explicit ObservableCloneReplacer(bool* flag, CloneReplacer replacer)
5160
: flag_(flag), replacer_(std::move(replacer)) {}
5261

53-
absl::StatusOr<std::optional<AstNode*>> operator()(const AstNode* node) {
54-
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> result, replacer_(node));
62+
absl::StatusOr<std::optional<AstNode*>> operator()(
63+
const AstNode* node, Module* module,
64+
const absl::flat_hash_map<const AstNode*, AstNode*>& old_to_new) {
65+
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> result,
66+
replacer_(node, module, old_to_new));
5567
*flag_ |= result.has_value();
5668
return result;
5769
}
@@ -66,7 +78,8 @@ class ObservableCloneReplacer {
6678
// cloning return types without recursing into cloned definitions which would
6779
// change nominal types.
6880
std::optional<AstNode*> PreserveTypeDefinitionsReplacer(
69-
const AstNode* original_node);
81+
const AstNode* original_node, Module* module,
82+
const absl::flat_hash_map<const AstNode*, AstNode*>& old_to_new);
7083

7184
// Creates a `CloneReplacer` that replaces references to the given `def` with
7285
// the given `replacement`.

0 commit comments

Comments
 (0)