Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xls/build_rules/xls_ir_rules.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def _convert_to_ir(ctx, src):
"default_fifo_config",
"proc_scoped_channels",
"lower_to_proc_scoped_channels",
"force_implicit_token_calling_convention",
)

# With runs outside a monorepo, the execution root for the workspace of
Expand Down
4 changes: 4 additions & 0 deletions xls/dslx/ir_convert/convert_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ struct ConvertOptions {
// intermediate step? See https://github.com/google/xls/issues/2078
bool lower_to_proc_scoped_channels = false;

// Force every DSLX function to use the implicit-token calling convention,
// regardless of what type inference determined.
bool force_implicit_token_calling_convention = false;

// Configured values to override for use in IR conversion.
std::vector<std::string> configured_values;
};
Expand Down
3 changes: 3 additions & 0 deletions xls/dslx/ir_convert/function_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ absl::StatusOr<xls::Function*> EmitImplicitTokenEntryWrapper(

bool GetRequiresImplicitToken(const dslx::Function& f, ImportData* import_data,
const ConvertOptions& options) {
if (options.force_implicit_token_calling_convention) {
return true;
}
std::optional<bool> requires_opt =
import_data->GetRootTypeInfo(f.owner()).value()->GetRequiresImplicitToken(
f);
Expand Down
3 changes: 2 additions & 1 deletion xls/dslx/ir_convert/function_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ absl::StatusOr<std::vector<ConstantDef*>> GetConstantDepFreevars(
AstNode* node, TypeInfo& type_info);

// Wrapper around the type information query for whether DSL function "f"
// requires an implicit token calling convention.
// requires an implicit token calling convention, including overrides provided
// via ConvertOptions.
bool GetRequiresImplicitToken(const dslx::Function& f, ImportData* import_data,
const ConvertOptions& options);

Expand Down
4 changes: 4 additions & 0 deletions xls/dslx/ir_convert/ir_converter_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ absl::Status RealMain(absl::Span<const std::string_view> paths) {
bool warnings_as_errors = ir_converter_options.warnings_as_errors();
bool proc_scoped_channels = ir_converter_options.proc_scoped_channels();
bool type_inference_v2 = ir_converter_options.type_inference_v2();
bool force_implicit_token_calling_convention =
ir_converter_options.force_implicit_token_calling_convention();

// Start with the default set, then enable the to-enable and then disable the
// to-disable.
Expand All @@ -124,6 +126,8 @@ absl::Status RealMain(absl::Span<const std::string_view> paths) {
.default_fifo_config = default_fifo_config,
.proc_scoped_channels = proc_scoped_channels,
.type_inference_v2 = type_inference_v2,
.force_implicit_token_calling_convention =
force_implicit_token_calling_convention,
.configured_values = configured_values,
};

Expand Down
23 changes: 23 additions & 0 deletions xls/dslx/ir_convert/ir_converter_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,29 @@ def test_b_dot_x(self) -> None:
"""),
)

def test_force_implicit_token_calling_convention(self) -> None:
program = "fn f(x: u32) -> u32 { x }"
result = self._ir_convert(
{"a.x": program},
top="f",
extra_flags=["--force_implicit_token_calling_convention"],
)
implicit_funcs = [
f
for f in result.interface.functions
if f.base.name == "__itok__a__f"
]
self.assertLen(implicit_funcs, 1)
self.assertTrue(implicit_funcs[0].base.top)
self.assertGreaterEqual(len(implicit_funcs[0].parameters), 2)
self.assertEqual(implicit_funcs[0].parameters[0].type, _TOKEN)
self.assertEqual(implicit_funcs[0].parameters[1].type, _1BITS)
self.assertTrue(
any(f.base.name == "__a__f" for f in result.interface.functions)
)
self.assertIn("fn __itok__a__f(", result.ir)
self.assertIn("fn __a__f(", result.ir)

def test_multi_file(self) -> None:
self.assertEqual(
self._ir_convert(
Expand Down
4 changes: 4 additions & 0 deletions xls/dslx/ir_convert/ir_converter_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ ABSL_FLAG(bool, lower_to_proc_scoped_channels, false,
"false, generates global channels. This is a temporary flag that "
"will not be used after the full implementation is complete. Cannot "
"be combined with proc_scoped_channels");
ABSL_FLAG(bool, force_implicit_token_calling_convention, false,
"Force every DSLX function to use the implicit-token calling "
"convention during IR conversion.");
ABSL_FLAG(
std::optional<std::vector<std::string>>, configured_values, std::nullopt,
"Dictionary of overrides to use for overridable constants "
Expand Down Expand Up @@ -140,6 +143,7 @@ absl::StatusOr<bool> SetOptionsFromFlags(IrConverterOptionsFlagsProto& proto) {
POPULATE_OPTIONAL_FLAG(interface_textproto_file);
POPULATE_FLAG(type_inference_v2);
POPULATE_FLAG(lower_to_proc_scoped_channels);
POPULATE_FLAG(force_implicit_token_calling_convention);
POPULATE_REPEATED_FLAG(configured_values);

#undef POPULATE_FLAG
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/ir_convert/ir_converter_options_flags.proto
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@ message IrConverterOptionsFlagsProto {
optional bool type_inference_v2 = 16;
optional bool lower_to_proc_scoped_channels = 17;
repeated string configured_values = 18;
optional bool force_implicit_token_calling_convention = 19;
}
18 changes: 17 additions & 1 deletion xls/dslx/ir_convert/ir_converter_options_flags_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@
#include <optional>
#include <string>

#include "gtest/gtest.h"
#include "absl/cleanup/cleanup.h"
#include "absl/flags/declare.h"
#include "absl/flags/flag.h"
#include "gtest/gtest.h"
#include "xls/common/status/matchers.h"
#include "xls/dslx/ir_convert/ir_converter_options_flags.pb.h"

ABSL_DECLARE_FLAG(std::optional<std::string>, enable_warnings);
ABSL_DECLARE_FLAG(std::optional<std::string>, disable_warnings);
ABSL_DECLARE_FLAG(bool, warnings_as_errors);
ABSL_DECLARE_FLAG(bool, force_implicit_token_calling_convention);

namespace xls {

Expand All @@ -37,6 +38,8 @@ TEST(IrConverterOptionsFlagsTest, WarningOptionsNoFlagSettings) {
EXPECT_TRUE(options.disable_warnings().empty());
EXPECT_TRUE(options.has_warnings_as_errors());
EXPECT_EQ(options.warnings_as_errors(), true);
EXPECT_TRUE(options.has_force_implicit_token_calling_convention());
EXPECT_FALSE(options.force_implicit_token_calling_convention());
}

TEST(IrConverterOptionsFlagsTest,
Expand Down Expand Up @@ -66,4 +69,17 @@ TEST(IrConverterOptionsFlagsTest,
EXPECT_EQ(options.warnings_as_errors(), false);
}

TEST(IrConverterOptionsFlagsTest, ForceImplicitTokenFlagSet) {
ASSERT_FALSE(absl::GetFlag(FLAGS_force_implicit_token_calling_convention));
absl::SetFlag(&FLAGS_force_implicit_token_calling_convention, true);
absl::Cleanup reset_flag([] {
absl::SetFlag(&FLAGS_force_implicit_token_calling_convention, false);
});

XLS_ASSERT_OK_AND_ASSIGN(IrConverterOptionsFlagsProto options,
GetIrConverterOptionsFlagsProto());
EXPECT_TRUE(options.has_force_implicit_token_calling_convention());
EXPECT_TRUE(options.force_implicit_token_calling_convention());
}

} // namespace xls
1 change: 1 addition & 0 deletions xls/dslx/run_routines/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ cc_library(
"//xls/dslx/frontend:pos",
"//xls/dslx/ir_convert:conversion_info",
"//xls/dslx/ir_convert:convert_options",
"//xls/dslx/ir_convert:function_converter",
"//xls/dslx/ir_convert:ir_converter",
"//xls/dslx/type_system:parametric_env",
"//xls/dslx/type_system:type",
Expand Down
12 changes: 5 additions & 7 deletions xls/dslx/run_routines/run_routines.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "re2/re2.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/data_structures/inline_bitmap.h"
Expand All @@ -65,6 +66,7 @@
#include "xls/dslx/interp_value_utils.h"
#include "xls/dslx/ir_convert/conversion_info.h"
#include "xls/dslx/ir_convert/convert_options.h"
#include "xls/dslx/ir_convert/function_converter.h"
#include "xls/dslx/ir_convert/ir_converter.h"
#include "xls/dslx/mangle.h"
#include "xls/dslx/parse_and_typecheck.h"
Expand All @@ -88,7 +90,6 @@
#include "xls/passes/optimization_pass_pipeline.h"
#include "xls/passes/pass_base.h"
#include "xls/solvers/z3_ir_translator.h"
#include "re2/re2.h"

namespace xls::dslx {
namespace {
Expand Down Expand Up @@ -973,13 +974,10 @@ absl::StatusOr<TestResultData> AbstractTestRunner::ParseAndTest(
const ParametricEnv* parametric_env,
const InterpValue& got) -> absl::Status {
XLS_RET_CHECK(f != nullptr);
std::optional<bool> requires_implicit_token =
import_data.GetRootTypeInfoForNode(f)
.value()
->GetRequiresImplicitToken(*f);
XLS_RET_CHECK(requires_implicit_token.has_value());
bool requires_implicit_token =
GetRequiresImplicitToken(*f, &import_data, options.convert_options);
return options.run_comparator->RunComparison(ir_package.get(),
*requires_implicit_token, f,
requires_implicit_token, f,
args, parametric_env, got);
};
}
Expand Down
15 changes: 11 additions & 4 deletions xls/public/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ bool xls_convert_dslx_to_ir_with_warnings(
size_t additional_search_paths_count, const char* enable_warnings[],
size_t enable_warnings_count, const char* disable_warnings[],
size_t disable_warnings_count, bool warnings_as_errors,
char*** warnings_out, size_t* warnings_out_count, char** error_out,
char** ir_out) {
bool force_implicit_token_calling_convention, char*** warnings_out,
size_t* warnings_out_count, char** error_out, char** ir_out) {
CHECK(dslx != nullptr);
CHECK(path != nullptr);
CHECK(dslx_stdlib_path != nullptr);
Expand All @@ -96,6 +96,8 @@ bool xls_convert_dslx_to_ir_with_warnings(
.disable_warnings = disable_warnings_cpp,
.warnings_as_errors = warnings_as_errors,
.warnings_out = &warnings_out_cpp,
.force_implicit_token_calling_convention =
force_implicit_token_calling_convention,
};

absl::StatusOr<std::string> result =
Expand All @@ -122,6 +124,7 @@ bool xls_convert_dslx_to_ir(const char* dslx, const char* path,
dslx, path, module_name, dslx_stdlib_path, additional_search_paths,
additional_search_paths_count, enable_warnings, 0, disable_warnings, 0,
/*warnings_as_errors=*/false,
/*force_implicit_token_calling_convention=*/false,
/*warnings_out=*/nullptr,
/*warnings_out_count=*/nullptr, error_out, ir_out);
}
Expand All @@ -131,8 +134,9 @@ bool xls_convert_dslx_path_to_ir_with_warnings(
const char* additional_search_paths[], size_t additional_search_paths_count,
const char* enable_warnings[], size_t enable_warnings_count,
const char* disable_warnings[], size_t disable_warnings_count,
bool warnings_as_errors, char*** warnings_out, size_t* warnings_out_count,
char** error_out, char** ir_out) {
bool warnings_as_errors, bool force_implicit_token_calling_convention,
char*** warnings_out, size_t* warnings_out_count, char** error_out,
char** ir_out) {
CHECK(path != nullptr);
CHECK(dslx_stdlib_path != nullptr);
CHECK(error_out != nullptr);
Expand All @@ -156,6 +160,8 @@ bool xls_convert_dslx_path_to_ir_with_warnings(
.disable_warnings = disable_warnings_cpp,
.warnings_as_errors = warnings_as_errors,
.warnings_out = &warnings_out_cpp,
.force_implicit_token_calling_convention =
force_implicit_token_calling_convention,
};
absl::StatusOr<std::string> result = xls::ConvertDslxPathToIr(path, options);

Expand All @@ -178,6 +184,7 @@ bool xls_convert_dslx_path_to_ir(const char* path, const char* dslx_stdlib_path,
path, dslx_stdlib_path, additional_search_paths,
additional_search_paths_count, enable_warnings, 0, disable_warnings, 0,
/*warnings_as_errors=*/false,
/*force_implicit_token_calling_convention=*/false,
/*warnings_out=*/nullptr,
/*warnings_out_count=*/nullptr, error_out, ir_out);
}
Expand Down
9 changes: 5 additions & 4 deletions xls/public/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ bool xls_convert_dslx_to_ir_with_warnings(
size_t additional_search_paths_count, const char* enable_warnings[],
size_t enable_warnings_count, const char* disable_warnings[],
size_t disable_warnings_count, bool warnings_as_errors,
char*** warnings_out, size_t* warnings_out_count, char** error_out,
char** ir_out);
bool force_implicit_token_calling_convention, char*** warnings_out,
size_t* warnings_out_count, char** error_out, char** ir_out);

bool xls_convert_dslx_path_to_ir(const char* path, const char* dslx_stdlib_path,
const char* additional_search_paths[],
Expand All @@ -104,8 +104,9 @@ bool xls_convert_dslx_path_to_ir_with_warnings(
const char* additional_search_paths[], size_t additional_search_paths_count,
const char* enable_warnings[], size_t enable_warnings_count,
const char* disable_warnings[], size_t disable_warnings_count,
bool warnings_as_errors, char*** warnings_out, size_t* warnings_out_count,
char** error_out, char** ir_out);
bool warnings_as_errors, bool force_implicit_token_calling_convention,
char*** warnings_out, size_t* warnings_out_count, char** error_out,
char** ir_out);

bool xls_optimize_ir(const char* ir, const char* top, char** error_out,
char** ir_out);
Expand Down
37 changes: 33 additions & 4 deletions xls/public/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,9 @@ fn id() { let x = u32:1; })";
/*dslx_stdlib_path=*/dslx_stdlib_path.c_str(), additional_search_paths,
0,
/*enable_warnings=*/nullptr, 0, /*disable_warnings=*/nullptr, 0,
/*warnings_as_errors=*/true, &warnings, &warnings_count, &error_out,
&ir_out);
/*warnings_as_errors=*/true,
/*force_implicit_token_calling_convention=*/false, &warnings,
&warnings_count, &error_out, &ir_out);

// Check we got the warning data even though the return code is non-ok.
ASSERT_EQ(warnings_count, 1);
Expand Down Expand Up @@ -232,6 +233,7 @@ fn id() { let x = u32:1; })";
kProgram.c_str(), "my_module.x", "my_module",
/*dslx_stdlib_path=*/dslx_stdlib_path.c_str(), additional_search_paths,
0, enable_warnings, 0, disable_warnings, 1, /*warnings_as_errors=*/true,
/*force_implicit_token_calling_convention=*/false,
/*warnings_out=*/nullptr, /*warnings_out_count=*/nullptr, &error_out,
&ir_out);
ASSERT_TRUE(ok);
Expand Down Expand Up @@ -261,8 +263,9 @@ TEST(XlsCApiTest, ConvertWithNoWarnings) {
kProgram.c_str(), "my_module.x", "my_module",
/*dslx_stdlib_path=*/dslx_stdlib_path.c_str(), additional_search_paths, 0,
/*enable_warnings=*/nullptr, 0, /*disable_warnings=*/nullptr, 0,
/*warnings_as_errors=*/true, &warnings, &warnings_count, &error_out,
&ir_out);
/*warnings_as_errors=*/true,
/*force_implicit_token_calling_convention=*/false, &warnings,
&warnings_count, &error_out, &ir_out);
ASSERT_TRUE(ok);
ASSERT_EQ(error_out, nullptr);
ASSERT_NE(ir_out, nullptr);
Expand All @@ -273,6 +276,32 @@ TEST(XlsCApiTest, ConvertWithNoWarnings) {
ASSERT_EQ(warnings, nullptr);
}

TEST(XlsCApiTest, ConvertWithForcedImplicitToken) {
const std::string kProgram = "fn id(x: u32) -> u32 { x }";
const std::string dslx_stdlib_path = std::string(xls::kDefaultDslxStdlibPath);
const char* additional_search_paths[] = {};
char* error_out = nullptr;
char* ir_out = nullptr;

absl::Cleanup free_cstrs([&] {
xls_c_str_free(error_out);
xls_c_str_free(ir_out);
});

bool ok = xls_convert_dslx_to_ir_with_warnings(
kProgram.c_str(), "my_module.x", "my_module",
/*dslx_stdlib_path=*/dslx_stdlib_path.c_str(), additional_search_paths, 0,
/*enable_warnings=*/nullptr, 0, /*disable_warnings=*/nullptr, 0,
/*warnings_as_errors=*/false,
/*force_implicit_token_calling_convention=*/true,
/*warnings_out=*/nullptr, /*warnings_out_count=*/nullptr, &error_out,
&ir_out);
ASSERT_TRUE(ok);
ASSERT_EQ(error_out, nullptr);
ASSERT_NE(ir_out, nullptr);
EXPECT_THAT(ir_out, HasSubstr("fn __itok__my_module__id"));
}

TEST(XlsCApiTest, ConvertDslxToIrError) {
const std::string kInvalidProgram = "@!";
const char* additional_search_paths[] = {};
Expand Down
2 changes: 2 additions & 0 deletions xls/public/runtime_build_actions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ absl::StatusOr<std::string> ConvertDslxToIr(
typechecked.module, &import_data,
dslx::ConvertOptions{
.warnings_as_errors = options.warnings_as_errors,
.force_implicit_token_calling_convention =
options.force_implicit_token_calling_convention,
});
}

Expand Down
1 change: 1 addition & 0 deletions xls/public/runtime_build_actions.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ struct ConvertDslxToIrOptions {
absl::Span<const std::string_view> disable_warnings;
bool warnings_as_errors = true;
std::vector<std::string>* warnings_out = nullptr;
bool force_implicit_token_calling_convention = false;
};

// Converts the specified DSLX text into XLS IR text.
Expand Down
Loading