@@ -1274,7 +1274,8 @@ class AstCloner : public AstNodeVisitor {
1274
1274
if (node == nullptr ) {
1275
1275
return absl::OkStatus ();
1276
1276
}
1277
- XLS_ASSIGN_OR_RETURN (std::optional<AstNode*> replacement, replacer_ (node));
1277
+ XLS_ASSIGN_OR_RETURN (std::optional<AstNode*> replacement,
1278
+ replacer_ (node, module (node), old_to_new_));
1278
1279
if (replacement.has_value ()) {
1279
1280
old_to_new_[node] = *replacement;
1280
1281
return absl::OkStatus ();
@@ -1307,17 +1308,20 @@ class AstCloner : public AstNodeVisitor {
1307
1308
1308
1309
} // namespace
1309
1310
1310
- std::optional<AstNode*> PreserveTypeDefinitionsReplacer (const AstNode* node) {
1311
+ std::optional<AstNode*> PreserveTypeDefinitionsReplacer (
1312
+ const AstNode* node, Module* module ,
1313
+ const absl::flat_hash_map<const AstNode*, AstNode*>&) {
1311
1314
if (node->kind () == AstNodeKind::kTypeRef ) {
1312
1315
const auto * type_ref = down_cast<const TypeRef*>(node);
1313
- return node->owner ()->Make <TypeRef>(type_ref->span (),
1314
- type_ref->type_definition ());
1316
+ return module ->Make <TypeRef>(type_ref->span (), type_ref->type_definition ());
1315
1317
}
1316
1318
return std::nullopt ;
1317
1319
}
1318
1320
1319
1321
CloneReplacer NameRefReplacer (const NameDef* def, Expr* replacement) {
1320
- return [=](const AstNode* node) -> std::optional<AstNode*> {
1322
+ return [=](const AstNode* node, Module* new_module,
1323
+ const absl::flat_hash_map<const AstNode*, AstNode*>&)
1324
+ -> std::optional<AstNode*> {
1321
1325
if (node->kind () == AstNodeKind::kNameRef ) {
1322
1326
const auto * name_ref = down_cast<const NameRef*>(node);
1323
1327
if (std::holds_alternative<const NameDef*>(name_ref->name_def ()) &&
@@ -1331,14 +1335,16 @@ CloneReplacer NameRefReplacer(const NameDef* def, Expr* replacement) {
1331
1335
1332
1336
CloneReplacer NameRefReplacer (
1333
1337
const absl::flat_hash_map<const NameDef*, NameDef*>* replacement_defs) {
1334
- return [=](const AstNode* original_node) -> std::optional<AstNode*> {
1338
+ return [=](const AstNode* original_node, Module* new_module,
1339
+ const absl::flat_hash_map<const AstNode*, AstNode*>&)
1340
+ -> std::optional<AstNode*> {
1335
1341
if (original_node->kind () == AstNodeKind::kNameRef ) {
1336
1342
const auto * original_ref = down_cast<const NameRef*>(original_node);
1337
1343
const AstNode* def = ToAstNode (original_ref->name_def ());
1338
1344
if (def->kind () == AstNodeKind::kNameDef ) {
1339
1345
const auto it = replacement_defs->find (down_cast<const NameDef*>(def));
1340
1346
if (it != replacement_defs->end ()) {
1341
- return original_node-> owner () ->Make <NameRef>(
1347
+ return new_module ->Make <NameRef>(
1342
1348
original_ref->span (), original_ref->identifier (), it->second );
1343
1349
}
1344
1350
}
@@ -1354,8 +1360,11 @@ CloneAstAndGetAllPairs(const AstNode* root,
1354
1360
if (root->kind () == AstNodeKind::kModule ) {
1355
1361
return absl::InvalidArgumentError (" Clone a module via 'CloneModule'." );
1356
1362
}
1363
+ Module* new_module =
1364
+ target_module.has_value () ? *target_module : root->owner ();
1365
+ absl::flat_hash_map<const AstNode*, AstNode*> empty_old_to_new;
1357
1366
XLS_ASSIGN_OR_RETURN (std::optional<AstNode*> root_replacement,
1358
- replacer (root));
1367
+ replacer (root, new_module, empty_old_to_new ));
1359
1368
if (root_replacement.has_value ()) {
1360
1369
return absl::flat_hash_map<const AstNode*, AstNode*>{
1361
1370
{root, *root_replacement}};
@@ -1386,15 +1395,19 @@ absl::StatusOr<std::unique_ptr<Module>> CloneModule(const Module& module,
1386
1395
}
1387
1396
1388
1397
CloneReplacer ChainCloneReplacers (CloneReplacer first, CloneReplacer second) {
1389
- return [first = std::move (first),
1390
- second = std::move (second)](const AstNode* node) mutable
1391
- -> absl::StatusOr<std::optional<AstNode*>> {
1392
- XLS_ASSIGN_OR_RETURN (std::optional<AstNode*> first_result, first (node));
1393
- XLS_ASSIGN_OR_RETURN (
1394
- std::optional<AstNode*> second_result,
1395
- second (first_result.has_value () ? *first_result : node));
1396
- return second_result.has_value () ? second_result : first_result;
1397
- };
1398
+ return
1399
+ [first = std::move (first), second = std::move (second)](
1400
+ const AstNode* node, Module* module ,
1401
+ const absl::flat_hash_map<const AstNode*, AstNode*>&
1402
+ old_to_new) mutable -> absl::StatusOr<std::optional<AstNode*>> {
1403
+ XLS_ASSIGN_OR_RETURN (std::optional<AstNode*> first_result,
1404
+ first (node, module , old_to_new));
1405
+ XLS_ASSIGN_OR_RETURN (
1406
+ std::optional<AstNode*> second_result,
1407
+ second (first_result.has_value () ? *first_result : node, module ,
1408
+ old_to_new));
1409
+ return second_result.has_value () ? second_result : first_result;
1410
+ };
1398
1411
}
1399
1412
1400
1413
// Verifies that `node` consists solely of "new" AST nodes and none that are
0 commit comments