@@ -1205,7 +1205,8 @@ class AstCloner : public AstNodeVisitor {
1205
1205
if (old_to_new_.contains (node)) {
1206
1206
return absl::OkStatus ();
1207
1207
}
1208
- XLS_ASSIGN_OR_RETURN (std::optional<AstNode*> replacement, replacer_ (node));
1208
+ XLS_ASSIGN_OR_RETURN (std::optional<AstNode*> replacement,
1209
+ replacer_ (node, module (node), old_to_new_));
1209
1210
if (replacement.has_value ()) {
1210
1211
old_to_new_[node] = *replacement;
1211
1212
return absl::OkStatus ();
@@ -1238,17 +1239,20 @@ class AstCloner : public AstNodeVisitor {
1238
1239
1239
1240
} // namespace
1240
1241
1241
- std::optional<AstNode*> PreserveTypeDefinitionsReplacer (const AstNode* node) {
1242
+ std::optional<AstNode*> PreserveTypeDefinitionsReplacer (
1243
+ const AstNode* node, Module* module ,
1244
+ const absl::flat_hash_map<const AstNode*, AstNode*>&) {
1242
1245
if (node->kind () == AstNodeKind::kTypeRef ) {
1243
1246
const auto * type_ref = down_cast<const TypeRef*>(node);
1244
- return node->owner ()->Make <TypeRef>(type_ref->span (),
1245
- type_ref->type_definition ());
1247
+ return module ->Make <TypeRef>(type_ref->span (), type_ref->type_definition ());
1246
1248
}
1247
1249
return std::nullopt ;
1248
1250
}
1249
1251
1250
1252
CloneReplacer NameRefReplacer (const NameDef* def, Expr* replacement) {
1251
- return [=](const AstNode* node) -> std::optional<AstNode*> {
1253
+ return [=](const AstNode* node, Module* new_module,
1254
+ const absl::flat_hash_map<const AstNode*, AstNode*>&)
1255
+ -> std::optional<AstNode*> {
1252
1256
if (node->kind () == AstNodeKind::kNameRef ) {
1253
1257
const auto * name_ref = down_cast<const NameRef*>(node);
1254
1258
if (std::holds_alternative<const NameDef*>(name_ref->name_def ()) &&
@@ -1262,14 +1266,16 @@ CloneReplacer NameRefReplacer(const NameDef* def, Expr* replacement) {
1262
1266
1263
1267
CloneReplacer NameRefReplacer (
1264
1268
const absl::flat_hash_map<const NameDef*, NameDef*>* replacement_defs) {
1265
- return [=](const AstNode* original_node) -> std::optional<AstNode*> {
1269
+ return [=](const AstNode* original_node, Module* new_module,
1270
+ const absl::flat_hash_map<const AstNode*, AstNode*>&)
1271
+ -> std::optional<AstNode*> {
1266
1272
if (original_node->kind () == AstNodeKind::kNameRef ) {
1267
1273
const auto * original_ref = down_cast<const NameRef*>(original_node);
1268
1274
const AstNode* def = ToAstNode (original_ref->name_def ());
1269
1275
if (def->kind () == AstNodeKind::kNameDef ) {
1270
1276
const auto it = replacement_defs->find (down_cast<const NameDef*>(def));
1271
1277
if (it != replacement_defs->end ()) {
1272
- return original_node-> owner () ->Make <NameRef>(
1278
+ return new_module ->Make <NameRef>(
1273
1279
original_ref->span (), original_ref->identifier (), it->second );
1274
1280
}
1275
1281
}
@@ -1285,8 +1291,11 @@ CloneAstAndGetAllPairs(const AstNode* root,
1285
1291
if (root->kind () == AstNodeKind::kModule ) {
1286
1292
return absl::InvalidArgumentError (" Clone a module via 'CloneModule'." );
1287
1293
}
1294
+ Module* new_module =
1295
+ target_module.has_value () ? *target_module : root->owner ();
1296
+ absl::flat_hash_map<const AstNode*, AstNode*> empty_old_to_new;
1288
1297
XLS_ASSIGN_OR_RETURN (std::optional<AstNode*> root_replacement,
1289
- replacer (root));
1298
+ replacer (root, new_module, empty_old_to_new ));
1290
1299
if (root_replacement.has_value ()) {
1291
1300
return absl::flat_hash_map<const AstNode*, AstNode*>{
1292
1301
{root, *root_replacement}};
@@ -1317,15 +1326,19 @@ absl::StatusOr<std::unique_ptr<Module>> CloneModule(const Module& module,
1317
1326
}
1318
1327
1319
1328
CloneReplacer ChainCloneReplacers (CloneReplacer first, CloneReplacer second) {
1320
- return [first = std::move (first),
1321
- second = std::move (second)](const AstNode* node) mutable
1322
- -> absl::StatusOr<std::optional<AstNode*>> {
1323
- XLS_ASSIGN_OR_RETURN (std::optional<AstNode*> first_result, first (node));
1324
- XLS_ASSIGN_OR_RETURN (
1325
- std::optional<AstNode*> second_result,
1326
- second (first_result.has_value () ? *first_result : node));
1327
- return second_result.has_value () ? second_result : first_result;
1328
- };
1329
+ return
1330
+ [first = std::move (first), second = std::move (second)](
1331
+ const AstNode* node, Module* module ,
1332
+ const absl::flat_hash_map<const AstNode*, AstNode*>&
1333
+ old_to_new) mutable -> absl::StatusOr<std::optional<AstNode*>> {
1334
+ XLS_ASSIGN_OR_RETURN (std::optional<AstNode*> first_result,
1335
+ first (node, module , old_to_new));
1336
+ XLS_ASSIGN_OR_RETURN (
1337
+ std::optional<AstNode*> second_result,
1338
+ second (first_result.has_value () ? *first_result : node, module ,
1339
+ old_to_new));
1340
+ return second_result.has_value () ? second_result : first_result;
1341
+ };
1329
1342
}
1330
1343
1331
1344
// Verifies that `node` consists solely of "new" AST nodes and none that are
0 commit comments