Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions clang/include/clang/CIR/Interfaces/CIRLinkerInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,13 @@ class CIRSymbolLinkerInterface

static bool isComdat(Operation *op);

static std::optional<mlir::link::ComdatSelector>
getComdatSelector(Operation *op);
static bool hasComdat(Operation *op);

static const link::Comdat *getComdatResolution(Operation *op);

static bool selectedByComdat(Operation *op);

static void updateNoDeduplicate(Operation *op);

static Visibility getVisibility(Operation *op);

Expand Down
19 changes: 15 additions & 4 deletions clang/lib/CIR/Dialect/IR/CIRLinkerInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,23 @@ bool CIRSymbolLinkerInterface::isComdat(Operation *op) {
return false;
}

std::optional<link::ComdatSelector>
CIRSymbolLinkerInterface::getComdatSelector(Operation *op) {
// TODO(frabert): Extracting comdat info from CIR is not implemented yet
return std::nullopt;
bool CIRSymbolLinkerInterface::hasComdat(Operation *op) {
// TODO: Extracting comdat info from CIR is not implemented yet
return false;
}

const link::Comdat *
CIRSymbolLinkerInterface::getComdatResolution(Operation *op) {
return nullptr;
}

bool CIRSymbolLinkerInterface::selectedByComdat(Operation *op) {
// TODO: Extracting comdat info from CIR is not implemented yet
llvm_unreachable("comdat resolution not implemented for CIR");
}

void CIRSymbolLinkerInterface::updateNoDeduplicate(Operation *op) {}

Visibility CIRSymbolLinkerInterface::getVisibility(Operation *op) {
if (auto gv = dyn_cast<GlobalOp>(op))
return toLLVMVisibility(gv.getGlobalVisibility());
Expand Down
30 changes: 25 additions & 5 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMLinkerInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ class LLVMSymbolLinkerInterface
static Visibility getVisibility(Operation *op);
static void setVisibility(Operation *op, Visibility visibility);
static bool isComdat(Operation *op);
static std::optional<link::ComdatSelector> getComdatSelector(Operation *op);
static bool hasComdat(Operation *op);
static SymbolRefAttr getComdatSymbol(Operation *op);
static LLVM::comdat::Comdat getComdatSelectionKind(Operation *op);
static bool isDeclaration(Operation *op);
static unsigned getBitWidth(Operation *op);
static UnnamedAddr getUnnamedAddr(Operation *op);
Expand All @@ -34,7 +36,19 @@ class LLVMSymbolLinkerInterface
dependencies(Operation *op, SymbolTableCollection &collection) const override;
LogicalResult initialize(ModuleOp src) override;
LogicalResult finalize(ModuleOp dst) const override;
LogicalResult moduleOpSummary(ModuleOp src,
SymbolTableCollection &collection) override;
Operation *appendGlobals(llvm::StringRef glob, link::LinkState &state);
Operation *appendComdatOps(ArrayRef<Operation *> globs, LLVM::ComdatOp comdat,
link::LinkState &state);
link::ComdatResolution
computeComdatResolution(Operation *, SymbolTableCollection &, link::Comdat *);
LogicalResult resolveComdats(ModuleOp srcMod,
SymbolTableCollection &collection);
const link::Comdat *getComdatResolution(Operation *op) const;
bool selectedByComdat(Operation *op) const;
void dropReplacedComdat(Operation *op) const;
static void updateNoDeduplicate(Operation *op);

LogicalResult link(link::LinkState &state) override;

Expand Down Expand Up @@ -90,22 +104,27 @@ class LLVMSymbolLinkerInterface
ArrayRef<Attribute> priorities = structor.getPriorities().getValue();
ArrayRef<Attribute> data = structor.getData().getValue();

for (auto [idx, dataAttr] : llvm::enumerate(data)) {
for (auto [idx, structor] : llvm::enumerate(structorList)) {
auto structorSymbol = cast<FlatSymbolRefAttr>(structor);
// Skip constructors not included based on COMDAT
if (!summary.contains(structorSymbol.getValue()))
continue;

auto dataAttr = data[idx];
// data value is either #llvm.zero or symbol ref
// if it is zero, we always have to include the value
// if it is a symbol ref, we have to check if the symbol
// from the same module is being used
//
if (auto globalSymbol = dyn_cast<FlatSymbolRefAttr>(dataAttr)) {
auto globalOp = summary.lookup(globalSymbol.getValue());
Operation *globalOp = summary.lookup(globalSymbol.getValue());
assert(globalOp && "structor referenced global not in summary?");
// globals are definde at module level
if (globalOp->getParentOp() != op->getParentOp())
continue;
}

newData.push_back(dataAttr);
newStructorList.push_back(structorList[idx]);
newStructorList.push_back(structor);
newPriorities.push_back(priorities[idx]);
}
}
Expand Down Expand Up @@ -141,6 +160,7 @@ class LLVMSymbolLinkerInterface
/// When a function is defined with different signatures in different modules,
/// we track both types here so we can later fix call sites.
mutable llvm::StringMap<std::pair<Type, Type>> mismatchedFunctions;
mutable llvm::StringMap<link::Comdat> comdatResolution;
};

} // namespace LLVM
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1533,7 +1533,8 @@ def LLVM_AliasOp : LLVM_Op<"mlir.alias",
UnitAttr:$dso_local,
UnitAttr:$thread_local_,
OptionalAttr<UnnamedAddr>:$unnamed_addr,
DefaultValuedAttr<Visibility, "mlir::LLVM::Visibility::Default">:$visibility_
DefaultValuedAttr<Visibility, "mlir::LLVM::Visibility::Default">:$visibility_,
OptionalAttr<SymbolRefAttr>:$comdat
);
let summary = "LLVM dialect alias.";
let description = [{
Expand Down
71 changes: 38 additions & 33 deletions mlir/include/mlir/Linker/LLVMLinkerMixin.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,17 @@ static UnnamedAddr getMinUnnamedAddr(UnnamedAddr lhs, UnnamedAddr rhs) {

using ComdatKind = LLVM::comdat::Comdat;

struct ComdatSelector {
StringRef name;
struct Comdat {
ComdatKind kind;
Operation *selectorOp;
llvm::SmallPtrSet<Operation *, 2> users;
};

enum class ComdatResolution {
LinkFromSrc,
LinkFromDst,
LinkFromBoth,
Failure,
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -176,6 +184,20 @@ class LLVMLinkerMixin {
if (derived.isComdat(pair.src))
return true;

// Thrown away symbol can affect the visibility
if (pair.dst) {
Visibility srcVisibility = derived.getVisibility(pair.src);
Visibility dstVisibility = derived.getVisibility(pair.dst);
Visibility visibility = getMinVisibility(srcVisibility, dstVisibility);

derived.setVisibility(pair.src, visibility);
derived.setVisibility(pair.dst, visibility);
}
if (derived.hasComdat(pair.src)) {
// operations with COMDAT are selected as a group
return derived.selectedByComdat(pair.src);
}

Linkage srcLinkage = derived.getLinkage(pair.src);

// Always import variables with appending linkage.
Expand Down Expand Up @@ -222,6 +244,10 @@ class LLVMLinkerMixin {
return pair.src->emitError(error) << " dst: " << pair.dst->getLoc();
};

if (derived.isComdat(pair.src) != derived.isComdat(pair.dst)) {
return linkError("Linking ComdatOp with non-comdat op");
}

Linkage srcLinkage = derived.getLinkage(pair.src);
Linkage dstLinkage = derived.getLinkage(pair.dst);

Expand Down Expand Up @@ -349,6 +375,11 @@ class LLVMLinkerMixin {
if (isWeakForLinker(srcLinkage)) {
assert(!isExternalWeakLinkage(dstLinkage));
assert(!isAvailableExternallyLinkage(dstLinkage));
const Comdat *comdat = derived.getComdatResolution(pair.src);
if (comdat && comdat->kind == ComdatKind::NoDeduplicate) {
derived.updateNoDeduplicate(pair.src);
return ConflictResolution::LinkFromBothAndRenameSrc;
}
if (isLinkOnceLinkage(dstLinkage) && isWeakLinkage(srcLinkage)) {
return ConflictResolution::LinkFromSrc;
}
Expand All @@ -358,38 +389,12 @@ class LLVMLinkerMixin {

if (isWeakForLinker(dstLinkage)) {
assert(isExternalLinkage(srcLinkage));
return ConflictResolution::LinkFromSrc;
}

std::optional<ComdatSelector> srcComdatSel =
derived.getComdatSelector(pair.src);
std::optional<ComdatSelector> dstComdatSel =
derived.getComdatSelector(pair.dst);
if (srcComdatSel.has_value() && dstComdatSel.has_value()) {
auto srcComdatName = srcComdatSel->name;
auto dstComdatName = dstComdatSel->name;
auto srcComdat = srcComdatSel->kind;
auto dstComdat = dstComdatSel->kind;
if (srcComdatName != dstComdatName) {
llvm_unreachable("Comdat selector names don't match");
}
if (srcComdat != dstComdat) {
llvm_unreachable("Comdat selector kinds don't match");
}

if (srcComdat == mlir::LLVM::comdat::Comdat::Any) {
return ConflictResolution::LinkFromDst;
const Comdat *comdat = derived.getComdatResolution(pair.dst);
if (comdat && comdat->kind == ComdatKind::NoDeduplicate) {
derived.updateNoDeduplicate(pair.dst);
return ConflictResolution::LinkFromBothAndRenameDst;
}
if (srcComdat == mlir::LLVM::comdat::Comdat::NoDeduplicate) {
return ConflictResolution::Failure;
}
if (srcComdat == mlir::LLVM::comdat::Comdat::ExactMatch) {
return ConflictResolution::LinkFromDst;
}
if (srcComdat == mlir::LLVM::comdat::Comdat::SameSize) {
return ConflictResolution::LinkFromDst;
}
llvm_unreachable("unimplemented comdat kind");
return ConflictResolution::LinkFromSrc;
}

// If we reach here, we have two external definitions that can't be resolved
Expand Down
18 changes: 18 additions & 0 deletions mlir/include/mlir/Linker/LinkerInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Threading.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Error.h"
#include <memory>
Expand Down Expand Up @@ -62,6 +63,8 @@ class LinkState {
return builder.create<Op>(location, std::forward<Args>(args)...);
}

OpBuilder &getBuilder() { return builder; };

private:
// Private constructor used by nest()
LinkState(ModuleOp dst, std::shared_ptr<IRMapping> mapping,
Expand Down Expand Up @@ -149,6 +152,13 @@ class SymbolLinkerInterface : public LinkerInterface<SymbolLinkerInterface> {
return state.clone(src);
}

/// Perform tasks that need to be computed on whole-module basis before actual summary.
/// E.g. Pre-compute COMDAT resolution before actually linking the modules.
virtual LogicalResult moduleOpSummary(ModuleOp module,
SymbolTableCollection &collection) {
return success();
}

/// Dependencies of the given operation required to be linked.
virtual SmallVector<Operation *>
dependencies(Operation *op, SymbolTableCollection &collection) const = 0;
Expand Down Expand Up @@ -286,6 +296,14 @@ class SymbolLinkerInterfaces {
return Conflict::noConflict(src);
}

LogicalResult moduleOpSummary(ModuleOp src,
SymbolTableCollection &collection) {
return failableParallelForEach(src.getContext(), interfaces,
[&](SymbolLinkerInterface *linker) {
return linker->moduleOpSummary(src, collection);
});
}

private:
SetVector<SymbolLinkerInterface *> interfaces;
};
Expand Down
Loading
Loading