Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions xls/dslx/ir_convert/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,27 @@ cc_library(
],
)

cc_library(
name = "proc_scoped_channel_scope",
srcs = ["proc_scoped_channel_scope.cc"],
hdrs = ["proc_scoped_channel_scope.h"],
deps = [
":channel_scope",
":conversion_info",
":convert_options",
"//xls/common/status:status_macros",
"//xls/dslx:import_data",
"//xls/dslx/frontend:ast",
"//xls/dslx/type_system:type_info",
"//xls/ir:channel",
"//xls/ir:channel_ops",
"//xls/ir:function_builder",
"//xls/ir:type",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)

cc_test(
name = "channel_scope_test",
srcs = ["channel_scope_test.cc"],
Expand Down Expand Up @@ -288,6 +309,7 @@ cc_library(
":convert_options",
":ir_conversion_utils",
":proc_config_ir_converter",
":proc_scoped_channel_scope",
"//xls/common:casts",
"//xls/common:visitor",
"//xls/common/status:ret_check",
Expand Down
92 changes: 51 additions & 41 deletions xls/dslx/ir_convert/channel_scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ constexpr std::string_view kBetweenDimsSeparator = "_";

} // namespace

ChannelScope::ChannelScope(PackageConversionData* conversion_info,
ImportData* import_data,
const ConvertOptions& convert_options,
std::optional<FifoConfig> default_fifo_config)
GlobalChannelScope::GlobalChannelScope(
PackageConversionData* conversion_info, ImportData* import_data,
const ConvertOptions& convert_options,
std::optional<FifoConfig> default_fifo_config)
: conversion_info_(conversion_info),
import_data_(import_data),
convert_options_(convert_options),
Expand All @@ -77,9 +77,9 @@ ChannelScope::ChannelScope(PackageConversionData* conversion_info,
}
}

absl::StatusOr<ChannelOrArray> ChannelScope::DefineChannelOrArray(
absl::StatusOr<ChannelOrArray> GlobalChannelScope::DefineChannelOrArray(
const ChannelDecl* decl) {
VLOG(4) << "ChannelScope::DefineChannelOrArray: " << decl->ToString();
VLOG(4) << "GlobalChannelScope::DefineChannelOrArray: " << decl->ToString();
XLS_RET_CHECK(function_context_.has_value());
XLS_ASSIGN_OR_RETURN(
InterpValue name_interp_value,
Expand All @@ -95,15 +95,15 @@ absl::StatusOr<ChannelOrArray> ChannelScope::DefineChannelOrArray(
XLS_ASSIGN_OR_RETURN(
ChannelOrArray channel_or_array,
DefineChannelOrArrayInternal(short_name, ChannelOps::kSendReceive, type,
channel_config, decl->dims()));
channel_config, decl->dims(), false));
decl_to_channel_or_array_[decl] = channel_or_array;
return channel_or_array;
}

absl::StatusOr<ChannelOrArray> ChannelScope::DefineChannelOrArrayInternal(
absl::StatusOr<ChannelOrArray> GlobalChannelScope::DefineChannelOrArrayInternal(
std::string_view short_name, ChannelOps ops, xls::Type* type,
std::optional<ChannelConfig> channel_config,
const std::optional<std::vector<Expr*>>& dims) {
const std::optional<std::vector<Expr*>>& dims, bool interface_channel) {
std::string base_channel_name;
if (convert_options_.lower_to_proc_scoped_channels) {
// When using proc scoped channels the channel names do not have to be
Expand All @@ -115,10 +115,14 @@ absl::StatusOr<ChannelOrArray> ChannelScope::DefineChannelOrArrayInternal(
}
std::vector<std::string> channel_names;
if (!dims.has_value()) {
XLS_ASSIGN_OR_RETURN(
Channel * channel,
CreateChannel(base_channel_name, ops, type, channel_config));
return channel;
if (interface_channel) {
return absl::InternalError("Cannot deal with interface channels yet");
} else {
XLS_ASSIGN_OR_RETURN(
Channel * channel,
CreateChannel(base_channel_name, ops, type, channel_config));
return channel;
}
}
ChannelArray* array = &arrays_.emplace_back(ChannelArray(base_channel_name));
XLS_RET_CHECK(channel_arrays_.has_value());
Expand All @@ -127,17 +131,22 @@ absl::StatusOr<ChannelOrArray> ChannelScope::DefineChannelOrArrayInternal(
for (const std::string& suffix : suffixes) {
std::string channel_name =
absl::StrCat(base_channel_name, kNameAndDimsSeparator, suffix);
XLS_ASSIGN_OR_RETURN(
Channel * channel,
CreateChannel(channel_name, ops, type, channel_config));
array->AddChannel(channel_name, channel);
if (interface_channel) {
return absl::InternalError(
"Cannot deal with interface channels arrays yet");
} else {
XLS_ASSIGN_OR_RETURN(
Channel * channel,
CreateChannel(channel_name, ops, type, channel_config));
array->AddChannel(channel_name, channel);
}
}
return array;
}

absl::StatusOr<ChannelOrArray> ChannelScope::DefineBoundaryChannelOrArray(
const Param* param, TypeInfo* type_info) {
VLOG(4) << "ChannelScope::DefineBoundaryChannelOrArray: "
absl::StatusOr<ChannelOrArray> GlobalChannelScope::DefineBoundaryChannelOrArray(
const Param* param, TypeInfo* type_info, bool interface_channel) {
VLOG(4) << "GlobalChannelScope::DefineBoundaryChannelOrArray: "
<< param->ToString();
auto* type_annot =
dynamic_cast<ChannelTypeAnnotation*>(param->type_annotation());
Expand All @@ -154,7 +163,7 @@ absl::StatusOr<ChannelOrArray> ChannelScope::DefineBoundaryChannelOrArray(
ChannelOrArray channel_or_array,
DefineChannelOrArrayInternal(param->identifier(), op, ir_type,
/*channel_config=*/std::nullopt,
type_annot->dims()));
type_annot->dims(), interface_channel));
XLS_RETURN_IF_ERROR(DefineProtoChannelOrArray(channel_or_array, type_annot,
ir_type, type_info));
return channel_or_array;
Expand All @@ -170,7 +179,7 @@ std::string_view GetChannelName(ChannelOrArray channel_or_array) {
channel_or_array);
}

absl::Status ChannelScope::DefineProtoChannelOrArray(
absl::Status GlobalChannelScope::DefineProtoChannelOrArray(
ChannelOrArray channel_or_array, dslx::ChannelTypeAnnotation* type_annot,
xls::Type* ir_type, TypeInfo* type_info) {
if (std::holds_alternative<ChannelArray*>(channel_or_array)) {
Expand Down Expand Up @@ -202,10 +211,9 @@ absl::Status ChannelScope::DefineProtoChannelOrArray(
}

absl::StatusOr<ChannelOrArray>
ChannelScope::AssociateWithExistingChannelOrArray(const ProcId& proc_id,
const NameDef* name_def,
const ChannelDecl* decl) {
VLOG(4) << "ChannelScope::AssociateWithExistingChannelOrArray : "
GlobalChannelScope::AssociateWithExistingChannelOrArray(
const ProcId& proc_id, const NameDef* name_def, const ChannelDecl* decl) {
VLOG(4) << "GlobalChannelScope::AssociateWithExistingChannelOrArray : "
<< name_def->ToString() << " -> " << decl->ToString();
if (!decl_to_channel_or_array_.contains(decl)) {
return absl::NotFoundError(absl::StrCat(
Expand All @@ -217,10 +225,10 @@ ChannelScope::AssociateWithExistingChannelOrArray(const ProcId& proc_id,
return channel_or_array;
}

absl::Status ChannelScope::AssociateWithExistingChannelOrArray(
absl::Status GlobalChannelScope::AssociateWithExistingChannelOrArray(
const ProcId& proc_id, const NameDef* name_def,
ChannelOrArray channel_or_array) {
VLOG(4) << "ChannelScope::AssociateWithExistingChannelOrArray : "
VLOG(4) << "GlobalChannelScope::AssociateWithExistingChannelOrArray : "
<< name_def->ToString() << " -> "
<< GetBaseNameForChannelOrArray(channel_or_array) << " (array: "
<< std::holds_alternative<ChannelArray*>(channel_or_array) << ")";
Expand All @@ -229,7 +237,7 @@ absl::Status ChannelScope::AssociateWithExistingChannelOrArray(
return absl::OkStatus();
}

absl::StatusOr<Channel*> ChannelScope::GetChannelForArrayIndex(
absl::StatusOr<Channel*> GlobalChannelScope::GetChannelForArrayIndex(
const ProcId& proc_id, const Index* index) {
XLS_ASSIGN_OR_RETURN(
ChannelOrArray result,
Expand All @@ -238,14 +246,16 @@ absl::StatusOr<Channel*> ChannelScope::GetChannelForArrayIndex(
return std::get<Channel*>(result);
}

absl::StatusOr<ChannelOrArray> ChannelScope::GetChannelOrArrayForArrayIndex(
const ProcId& proc_id, const Index* index) {
absl::StatusOr<ChannelOrArray>
GlobalChannelScope::GetChannelOrArrayForArrayIndex(const ProcId& proc_id,
const Index* index) {
return EvaluateIndex(proc_id, index, /*allow_subarray_reference=*/true);
}

absl::StatusOr<ChannelOrArray> ChannelScope::EvaluateIndex(
absl::StatusOr<ChannelOrArray> GlobalChannelScope::EvaluateIndex(
const ProcId& proc_id, const Index* index, bool allow_subarray_reference) {
VLOG(4) << "ChannelScope::GetChannelForArrayIndex : " << index->ToString();
VLOG(4) << "GlobalChannelScope::GetChannelForArrayIndex : "
<< index->ToString();
XLS_RET_CHECK(function_context_.has_value());
std::string suffix;
for (;;) {
Expand Down Expand Up @@ -279,7 +289,7 @@ absl::StatusOr<ChannelOrArray> ChannelScope::EvaluateIndex(
}
}

std::string_view ChannelScope::GetBaseNameForChannelOrArray(
std::string_view GlobalChannelScope::GetBaseNameForChannelOrArray(
ChannelOrArray channel_or_array) {
return absl::visit(
Visitor{[](Channel* channel) { return channel->name(); },
Expand All @@ -290,13 +300,13 @@ std::string_view ChannelScope::GetBaseNameForChannelOrArray(
channel_or_array);
}

absl::StatusOr<std::string> ChannelScope::CreateBaseChannelName(
absl::StatusOr<std::string> GlobalChannelScope::CreateBaseChannelName(
std::string_view short_name) {
return channel_name_uniquer_.GetSanitizedUniqueName(absl::StrCat(
conversion_info_->package->name(), kNameAndDimsSeparator, short_name));
}

absl::StatusOr<xls::Type*> ChannelScope::GetChannelType(
absl::StatusOr<xls::Type*> GlobalChannelScope::GetChannelType(
const ChannelDecl* decl) const {
std::optional<Type*> type =
function_context_->type_info->GetItem(decl->type());
Expand All @@ -305,8 +315,8 @@ absl::StatusOr<xls::Type*> ChannelScope::GetChannelType(
function_context_->bindings);
}

absl::StatusOr<std::optional<ChannelConfig>> ChannelScope::CreateChannelConfig(
const ChannelDecl* decl) const {
absl::StatusOr<std::optional<ChannelConfig>>
GlobalChannelScope::CreateChannelConfig(const ChannelDecl* decl) const {
if (decl->channel_config().has_value()) {
XLS_RET_CHECK(!decl->fifo_depth().has_value())
<< "Cannot specify both fifo_depth and channel_config.";
Expand Down Expand Up @@ -346,7 +356,7 @@ absl::StatusOr<std::optional<ChannelConfig>> ChannelScope::CreateChannelConfig(
return ChannelConfig().WithFifoConfig(default_fifo_config_);
}

absl::StatusOr<Channel*> ChannelScope::CreateChannel(
absl::StatusOr<Channel*> GlobalChannelScope::CreateChannel(
std::string_view name, ChannelOps ops, xls::Type* type,
std::optional<ChannelConfig> channel_config) {
if (channel_config.has_value()) {
Expand All @@ -358,7 +368,7 @@ absl::StatusOr<Channel*> ChannelScope::CreateChannel(
return conversion_info_->package->CreateStreamingChannel(name, ops, type);
}

absl::StatusOr<ChannelOrArray> ChannelScope::GetChannelArrayElement(
absl::StatusOr<ChannelOrArray> GlobalChannelScope::GetChannelArrayElement(
const ProcId& proc_id, const NameRef* name_ref,
std::string_view flattened_name_suffix, bool allow_subarray_reference) {
const auto* name_def = std::get<const NameDef*>(name_ref->name_def());
Expand Down Expand Up @@ -392,7 +402,7 @@ absl::StatusOr<ChannelOrArray> ChannelScope::GetChannelArrayElement(
"No array element with flattened name: ", flattened_channel_name));
}

absl::StatusOr<ChannelArray*> ChannelScope::GetOrDefineSubarray(
absl::StatusOr<ChannelArray*> GlobalChannelScope::GetOrDefineSubarray(
ChannelArray* array, std::string_view subarray_name) {
const auto it = subarrays_.find(subarray_name);
if (it != subarrays_.end()) {
Expand Down
Loading