Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions xls/dslx/frontend/ast_cloner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -949,16 +949,19 @@ class AstCloner : public AstNodeVisitor {
XLS_RETURN_IF_ERROR(VisitChildren(n));

XLS_RETURN_IF_ERROR(ReplaceOrVisit(&n->fn()));
old_to_new_[n] = module(n)->Make<TestFunction>(
n->span(), *down_cast<Function*>(old_to_new_.at(&n->fn())));
XLS_ASSIGN_OR_RETURN(Function * new_fn, CastIfNotVerbatim<Function*>(
old_to_new_.at(&n->fn())));
old_to_new_[n] = module(n)->Make<TestFunction>(n->span(), *new_fn);
return absl::OkStatus();
}

absl::Status HandleTestProc(const TestProc* n) override {
XLS_RETURN_IF_ERROR(VisitChildren(n));

old_to_new_[n] = module(n)->Make<TestProc>(
down_cast<Proc*>(old_to_new_.at(n->proc())), n->expected_fail_label());
XLS_ASSIGN_OR_RETURN(Proc * new_proc,
CastIfNotVerbatim<Proc*>(old_to_new_.at(n->proc())));
old_to_new_[n] =
module(n)->Make<TestProc>(new_proc, n->expected_fail_label());
return absl::OkStatus();
}

Expand Down Expand Up @@ -1190,9 +1193,7 @@ class AstCloner : public AstNodeVisitor {
// already been processed.
absl::Status VisitChildren(const AstNode* node) {
for (const auto& child : node->GetChildren(/*want_types=*/true)) {
if (!old_to_new_.contains(child)) {
XLS_RETURN_IF_ERROR(ReplaceOrVisit(child));
}
XLS_RETURN_IF_ERROR(ReplaceOrVisit(child));
}
return absl::OkStatus();
}
Expand All @@ -1201,6 +1202,9 @@ class AstCloner : public AstNodeVisitor {
if (node == nullptr) {
return absl::OkStatus();
}
if (old_to_new_.contains(node)) {
return absl::OkStatus();
}
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> replacement, replacer_(node));
if (replacement.has_value()) {
old_to_new_[node] = *replacement;
Expand Down
124 changes: 122 additions & 2 deletions xls/dslx/frontend/ast_cloner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
#include <variant>
#include <vector>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "xls/common/casts.h"
#include "xls/common/status/matchers.h"
#include "xls/common/status/status_macros.h"
Expand Down Expand Up @@ -1634,6 +1634,126 @@ fn bar() -> u32{
EXPECT_EQ(orig_ref->name_def(), new_ref->name_def());
}

TEST(AstClonerTest, DoesntCloneTypeAliasTwice) {
constexpr std::string_view kProgram = R"(
type MyU32 = u32;

fn foo(x: MyU32) -> MyU32 {
x
}
)";

FileTable file_table;
XLS_ASSERT_OK_AND_ASSIGN(auto module, ParseModule(kProgram, "fake_path.x",
"the_module", file_table));
XLS_ASSERT_OK_AND_ASSIGN(Function * foo,
module->GetMemberOrError<Function>("foo"));
XLS_ASSERT_OK(module->GetMemberOrError<TypeAlias>("MyU32"));

// Clone just the function into the same target module.
XLS_ASSERT_OK_AND_ASSIGN(AstNode * clone, CloneAst(foo));
Function* cloned_foo = down_cast<Function*>(clone);

// Param type and the return type should reference the same alias.
auto* param_trta = down_cast<TypeRefTypeAnnotation*>(
cloned_foo->params()[0]->type_annotation());
TypeRef* param_tr = param_trta->type_ref();
ASSERT_TRUE(std::holds_alternative<TypeAlias*>(param_tr->type_definition()));
EXPECT_EQ(param_tr->owner(), clone->owner());
TypeAlias* param_alias = std::get<TypeAlias*>(param_tr->type_definition());
EXPECT_EQ(param_alias->owner(), clone->owner());
auto* ret_trta = down_cast<TypeRefTypeAnnotation*>(cloned_foo->return_type());
TypeRef* ret_tr = ret_trta->type_ref();
ASSERT_TRUE(std::holds_alternative<TypeAlias*>(ret_tr->type_definition()));
EXPECT_EQ(std::get<TypeAlias*>(param_tr->type_definition()),
std::get<TypeAlias*>(ret_tr->type_definition()));
}

TEST(AstClonerTest, DoesntCloneEnumTwice) {
constexpr std::string_view kProgram = R"(
enum MyEnum : u32 {
A = 0,
B = 1,
}

fn id(x: MyEnum) -> MyEnum { x }
)";

FileTable file_table;
XLS_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Module> module,
ParseModule(kProgram, "fake_path.x", "the_module", file_table));

// Clone the whole module.
XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Module> clone,
CloneModule(*module.get()));
XLS_ASSERT_OK(VerifyClone(module.get(), clone.get(), file_table));

// Get the cloned top-level enum and function.
XLS_ASSERT_OK_AND_ASSIGN(EnumDef * enum_def,
clone->GetMemberOrError<EnumDef>("MyEnum"));
XLS_ASSERT_OK_AND_ASSIGN(Function * id,
clone->GetMemberOrError<Function>("id"));

// Param type should reference the cloned top-level enum, not a detached copy.
auto* param_trta =
down_cast<TypeRefTypeAnnotation*>(id->params()[0]->type_annotation());
TypeRef* param_tr = param_trta->type_ref();
ASSERT_TRUE(std::holds_alternative<EnumDef*>(param_tr->type_definition()));
EXPECT_EQ(std::get<EnumDef*>(param_tr->type_definition()), enum_def);

// Return type should also reference the same cloned top-level enum.
auto* ret_trta = down_cast<TypeRefTypeAnnotation*>(id->return_type());
TypeRef* ret_tr = ret_trta->type_ref();
ASSERT_TRUE(std::holds_alternative<EnumDef*>(ret_tr->type_definition()));
EXPECT_EQ(std::get<EnumDef*>(ret_tr->type_definition()), enum_def);
}

TEST(AstClonerTest, DoesntCloneStructDefTwice) {
constexpr std::string_view kProgram = R"(
struct MyStruct { a: u32 }

fn id(x: MyStruct) -> MyStruct { x }
)";

constexpr std::string_view kExpectedFunction =
R"(fn id(x: MyStruct) -> MyStruct {
x
})";

FileTable file_table;
XLS_ASSERT_OK_AND_ASSIGN(auto module, ParseModule(kProgram, "fake_path.x",
"the_module", file_table));

XLS_ASSERT_OK_AND_ASSIGN(StructDef * struct_def,
module->GetMemberOrError<StructDef>("MyStruct"));
XLS_ASSERT_OK_AND_ASSIGN(Function * id,
module->GetMemberOrError<Function>("id"));

// Param and return types should be backed by the StructDef.
auto* param_trta =
down_cast<TypeRefTypeAnnotation*>(id->params()[0]->type_annotation());
TypeRef* param_tr = param_trta->type_ref();
ASSERT_TRUE(std::holds_alternative<StructDef*>(param_tr->type_definition()));
EXPECT_EQ(std::get<StructDef*>(param_tr->type_definition()), struct_def);

auto* ret_trta = down_cast<TypeRefTypeAnnotation*>(id->return_type());
TypeRef* ret_tr = ret_trta->type_ref();
ASSERT_TRUE(std::holds_alternative<StructDef*>(ret_tr->type_definition()));
EXPECT_EQ(std::get<StructDef*>(ret_tr->type_definition()), struct_def);

// Clone and ensure ToString and variant preservation.
XLS_ASSERT_OK_AND_ASSIGN(AstNode * clone, CloneAst(id));
EXPECT_EQ(kExpectedFunction, clone->ToString());
XLS_ASSERT_OK(VerifyClone(id, clone, *module->file_table()));

auto* cloned_id = down_cast<Function*>(clone);
auto* cloned_param_trta = down_cast<TypeRefTypeAnnotation*>(
cloned_id->params()[0]->type_annotation());
ASSERT_TRUE(std::holds_alternative<StructDef*>(
cloned_param_trta->type_ref()->type_definition()));
}

TEST(AstClonerTest, CloneAstClonesVerbatimNode) {
constexpr std::string_view kProgram = "const FOO = u32:42;";
FileTable file_table;
Expand Down