diff --git a/clang/include/clang/CIR/Interfaces/CIRLinkerInterface.h b/clang/include/clang/CIR/Interfaces/CIRLinkerInterface.h index 6295940dfd9d6..3da51862a7c89 100644 --- a/clang/include/clang/CIR/Interfaces/CIRLinkerInterface.h +++ b/clang/include/clang/CIR/Interfaces/CIRLinkerInterface.h @@ -45,8 +45,13 @@ class CIRSymbolLinkerInterface static bool isComdat(Operation *op); - static std::optional - getComdatSelector(Operation *op); + static bool hasComdat(Operation *op); + + static const link::Comdat *getComdatResolution(Operation *op); + + static bool selectedByComdat(Operation *op); + + static void updateNoDeduplicate(Operation *op); static Visibility getVisibility(Operation *op); diff --git a/clang/lib/CIR/Dialect/IR/CIRLinkerInterface.cpp b/clang/lib/CIR/Dialect/IR/CIRLinkerInterface.cpp index 55463fddc6776..85640ce08bba4 100644 --- a/clang/lib/CIR/Dialect/IR/CIRLinkerInterface.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRLinkerInterface.cpp @@ -99,12 +99,23 @@ bool CIRSymbolLinkerInterface::isComdat(Operation *op) { return false; } -std::optional -CIRSymbolLinkerInterface::getComdatSelector(Operation *op) { - // TODO(frabert): Extracting comdat info from CIR is not implemented yet - return std::nullopt; +bool CIRSymbolLinkerInterface::hasComdat(Operation *op) { + // TODO: Extracting comdat info from CIR is not implemented yet + return false; +} + +const link::Comdat * +CIRSymbolLinkerInterface::getComdatResolution(Operation *op) { + return nullptr; } +bool CIRSymbolLinkerInterface::selectedByComdat(Operation *op) { + // TODO: Extracting comdat info from CIR is not implemented yet + llvm_unreachable("comdat resolution not implemented for CIR"); +} + +void CIRSymbolLinkerInterface::updateNoDeduplicate(Operation *op) {} + Visibility CIRSymbolLinkerInterface::getVisibility(Operation *op) { if (auto gv = dyn_cast(op)) return toLLVMVisibility(gv.getGlobalVisibility()); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMLinkerInterface.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMLinkerInterface.h index 152f80aa6c479..3c3ab5d26d78e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMLinkerInterface.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMLinkerInterface.h @@ -16,7 +16,9 @@ class LLVMSymbolLinkerInterface static Visibility getVisibility(Operation *op); static void setVisibility(Operation *op, Visibility visibility); static bool isComdat(Operation *op); - static std::optional getComdatSelector(Operation *op); + static bool hasComdat(Operation *op); + static SymbolRefAttr getComdatSymbol(Operation *op); + static LLVM::comdat::Comdat getComdatSelectionKind(Operation *op); static bool isDeclaration(Operation *op); static unsigned getBitWidth(Operation *op); static UnnamedAddr getUnnamedAddr(Operation *op); @@ -34,7 +36,19 @@ class LLVMSymbolLinkerInterface dependencies(Operation *op, SymbolTableCollection &collection) const override; LogicalResult initialize(ModuleOp src) override; LogicalResult finalize(ModuleOp dst) const override; + LogicalResult moduleOpSummary(ModuleOp src, + SymbolTableCollection &collection) override; Operation *appendGlobals(llvm::StringRef glob, link::LinkState &state); + Operation *appendComdatOps(ArrayRef globs, LLVM::ComdatOp comdat, + link::LinkState &state); + link::ComdatResolution + computeComdatResolution(Operation *, SymbolTableCollection &, link::Comdat *); + LogicalResult resolveComdats(ModuleOp srcMod, + SymbolTableCollection &collection); + const link::Comdat *getComdatResolution(Operation *op) const; + bool selectedByComdat(Operation *op) const; + void dropReplacedComdat(Operation *op) const; + static void updateNoDeduplicate(Operation *op); LogicalResult link(link::LinkState &state) override; @@ -90,14 +104,19 @@ class LLVMSymbolLinkerInterface ArrayRef priorities = structor.getPriorities().getValue(); ArrayRef data = structor.getData().getValue(); - for (auto [idx, dataAttr] : llvm::enumerate(data)) { + for (auto [idx, structor] : llvm::enumerate(structorList)) { + auto structorSymbol = cast(structor); + // Skip constructors not included based on COMDAT + if (!summary.contains(structorSymbol.getValue())) + continue; + + auto dataAttr = data[idx]; // data value is either #llvm.zero or symbol ref // if it is zero, we always have to include the value // if it is a symbol ref, we have to check if the symbol // from the same module is being used - // if (auto globalSymbol = dyn_cast(dataAttr)) { - auto globalOp = summary.lookup(globalSymbol.getValue()); + Operation *globalOp = summary.lookup(globalSymbol.getValue()); assert(globalOp && "structor referenced global not in summary?"); // globals are definde at module level if (globalOp->getParentOp() != op->getParentOp()) @@ -105,7 +124,7 @@ class LLVMSymbolLinkerInterface } newData.push_back(dataAttr); - newStructorList.push_back(structorList[idx]); + newStructorList.push_back(structor); newPriorities.push_back(priorities[idx]); } } @@ -141,6 +160,7 @@ class LLVMSymbolLinkerInterface /// When a function is defined with different signatures in different modules, /// we track both types here so we can later fix call sites. mutable llvm::StringMap> mismatchedFunctions; + mutable llvm::StringMap comdatResolution; }; } // namespace LLVM diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 2f3b26439cbd4..4a97bcb50368e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1533,7 +1533,8 @@ def LLVM_AliasOp : LLVM_Op<"mlir.alias", UnitAttr:$dso_local, UnitAttr:$thread_local_, OptionalAttr:$unnamed_addr, - DefaultValuedAttr:$visibility_ + DefaultValuedAttr:$visibility_, + OptionalAttr:$comdat ); let summary = "LLVM dialect alias."; let description = [{ diff --git a/mlir/include/mlir/Linker/LLVMLinkerMixin.h b/mlir/include/mlir/Linker/LLVMLinkerMixin.h index 6c85d446261c0..df0073479122f 100644 --- a/mlir/include/mlir/Linker/LLVMLinkerMixin.h +++ b/mlir/include/mlir/Linker/LLVMLinkerMixin.h @@ -139,9 +139,17 @@ static UnnamedAddr getMinUnnamedAddr(UnnamedAddr lhs, UnnamedAddr rhs) { using ComdatKind = LLVM::comdat::Comdat; -struct ComdatSelector { - StringRef name; +struct Comdat { ComdatKind kind; + Operation *selectorOp; + llvm::SmallPtrSet users; +}; + +enum class ComdatResolution { + LinkFromSrc, + LinkFromDst, + LinkFromBoth, + Failure, }; //===----------------------------------------------------------------------===// @@ -176,6 +184,20 @@ class LLVMLinkerMixin { if (derived.isComdat(pair.src)) return true; + // Thrown away symbol can affect the visibility + if (pair.dst) { + Visibility srcVisibility = derived.getVisibility(pair.src); + Visibility dstVisibility = derived.getVisibility(pair.dst); + Visibility visibility = getMinVisibility(srcVisibility, dstVisibility); + + derived.setVisibility(pair.src, visibility); + derived.setVisibility(pair.dst, visibility); + } + if (derived.hasComdat(pair.src)) { + // operations with COMDAT are selected as a group + return derived.selectedByComdat(pair.src); + } + Linkage srcLinkage = derived.getLinkage(pair.src); // Always import variables with appending linkage. @@ -222,6 +244,10 @@ class LLVMLinkerMixin { return pair.src->emitError(error) << " dst: " << pair.dst->getLoc(); }; + if (derived.isComdat(pair.src) != derived.isComdat(pair.dst)) { + return linkError("Linking ComdatOp with non-comdat op"); + } + Linkage srcLinkage = derived.getLinkage(pair.src); Linkage dstLinkage = derived.getLinkage(pair.dst); @@ -349,6 +375,11 @@ class LLVMLinkerMixin { if (isWeakForLinker(srcLinkage)) { assert(!isExternalWeakLinkage(dstLinkage)); assert(!isAvailableExternallyLinkage(dstLinkage)); + const Comdat *comdat = derived.getComdatResolution(pair.src); + if (comdat && comdat->kind == ComdatKind::NoDeduplicate) { + derived.updateNoDeduplicate(pair.src); + return ConflictResolution::LinkFromBothAndRenameSrc; + } if (isLinkOnceLinkage(dstLinkage) && isWeakLinkage(srcLinkage)) { return ConflictResolution::LinkFromSrc; } @@ -358,38 +389,12 @@ class LLVMLinkerMixin { if (isWeakForLinker(dstLinkage)) { assert(isExternalLinkage(srcLinkage)); - return ConflictResolution::LinkFromSrc; - } - - std::optional srcComdatSel = - derived.getComdatSelector(pair.src); - std::optional dstComdatSel = - derived.getComdatSelector(pair.dst); - if (srcComdatSel.has_value() && dstComdatSel.has_value()) { - auto srcComdatName = srcComdatSel->name; - auto dstComdatName = dstComdatSel->name; - auto srcComdat = srcComdatSel->kind; - auto dstComdat = dstComdatSel->kind; - if (srcComdatName != dstComdatName) { - llvm_unreachable("Comdat selector names don't match"); - } - if (srcComdat != dstComdat) { - llvm_unreachable("Comdat selector kinds don't match"); - } - - if (srcComdat == mlir::LLVM::comdat::Comdat::Any) { - return ConflictResolution::LinkFromDst; + const Comdat *comdat = derived.getComdatResolution(pair.dst); + if (comdat && comdat->kind == ComdatKind::NoDeduplicate) { + derived.updateNoDeduplicate(pair.dst); + return ConflictResolution::LinkFromBothAndRenameDst; } - if (srcComdat == mlir::LLVM::comdat::Comdat::NoDeduplicate) { - return ConflictResolution::Failure; - } - if (srcComdat == mlir::LLVM::comdat::Comdat::ExactMatch) { - return ConflictResolution::LinkFromDst; - } - if (srcComdat == mlir::LLVM::comdat::Comdat::SameSize) { - return ConflictResolution::LinkFromDst; - } - llvm_unreachable("unimplemented comdat kind"); + return ConflictResolution::LinkFromSrc; } // If we reach here, we have two external definitions that can't be resolved diff --git a/mlir/include/mlir/Linker/LinkerInterface.h b/mlir/include/mlir/Linker/LinkerInterface.h index 67b617f171bd3..78bde572abacb 100644 --- a/mlir/include/mlir/Linker/LinkerInterface.h +++ b/mlir/include/mlir/Linker/LinkerInterface.h @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectInterface.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Threading.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Error.h" #include @@ -62,6 +63,8 @@ class LinkState { return builder.create(location, std::forward(args)...); } + OpBuilder &getBuilder() { return builder; }; + private: // Private constructor used by nest() LinkState(ModuleOp dst, std::shared_ptr mapping, @@ -149,6 +152,13 @@ class SymbolLinkerInterface : public LinkerInterface { return state.clone(src); } + /// Perform tasks that need to be computed on whole-module basis before actual summary. + /// E.g. Pre-compute COMDAT resolution before actually linking the modules. + virtual LogicalResult moduleOpSummary(ModuleOp module, + SymbolTableCollection &collection) { + return success(); + } + /// Dependencies of the given operation required to be linked. virtual SmallVector dependencies(Operation *op, SymbolTableCollection &collection) const = 0; @@ -286,6 +296,14 @@ class SymbolLinkerInterfaces { return Conflict::noConflict(src); } + LogicalResult moduleOpSummary(ModuleOp src, + SymbolTableCollection &collection) { + return failableParallelForEach(src.getContext(), interfaces, + [&](SymbolLinkerInterface *linker) { + return linker->moduleOpSummary(src, collection); + }); + } + private: SetVector interfaces; }; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp index f364e614e713b..87b74aa80837f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp @@ -27,6 +27,18 @@ LLVM::LLVMSymbolLinkerInterface::LLVMSymbolLinkerInterface(Dialect *dialect) : SymbolAttrLLVMLinkerInterface(dialect) {} bool LLVM::LLVMSymbolLinkerInterface::canBeLinked(Operation *op) const { + // Only link operations that are direct children of a top-level ModuleOp. + // This filters out: + // 1. Operations nested inside regions (e.g., inside an AliasOp's initializer) + // 2. Operations inside nested named modules + Operation *parent = op->getParentOp(); + if (!isa_and_nonnull(parent)) + return false; + + // Check that the parent module is a top-level module (has no parent module) + if (parent->getParentOfType()) + return false; + return isa(op); } @@ -75,20 +87,26 @@ void LLVM::LLVMSymbolLinkerInterface::setVisibility(Operation *op, llvm_unreachable("unexpected operation"); } -static bool hasComdat(Operation *op) { +bool LLVM::LLVMSymbolLinkerInterface::hasComdat(Operation *op) { if (auto gv = dyn_cast(op)) return gv.getComdat().has_value(); if (auto fn = dyn_cast(op)) return fn.getComdat().has_value(); + if (auto alias = dyn_cast(op)) + return alias.getComdat().has_value(); + if (isa(op)) + return false; llvm_unreachable("unexpected operation"); } -static SymbolRefAttr getComdatSymbol(Operation *op) { +SymbolRefAttr LLVM::LLVMSymbolLinkerInterface::getComdatSymbol(Operation *op) { assert(hasComdat(op) && "Operation with Comdat expected"); if (auto gv = dyn_cast(op)) return gv.getComdat().value(); if (auto fn = dyn_cast(op)) return fn.getComdat().value(); + if (auto alias = dyn_cast(op)) + return alias.getComdat().value(); llvm_unreachable("unexpected operation"); } @@ -96,16 +114,11 @@ bool LLVM::LLVMSymbolLinkerInterface::isComdat(Operation *op) { return isa(op); } -std::optional -LLVM::LLVMSymbolLinkerInterface::getComdatSelector(Operation *op) { - if (!hasComdat(op)) - return std::nullopt; - - auto symbol = getComdatSymbol(op); - auto *symTabOp = SymbolTable::getNearestSymbolTable(op); - auto comdatSelector = cast( - SymbolTable::lookupSymbolIn(symTabOp, symbol)); - return {{comdatSelector.getSymName(), comdatSelector.getComdat()}}; +LLVM::comdat::Comdat +LLVM::LLVMSymbolLinkerInterface::getComdatSelectionKind(Operation *op) { + if (auto selector = dyn_cast(op)) + return selector.getComdat(); + llvm_unreachable("expected selector op"); } // Return true if the primary definition of this global value is outside of @@ -248,9 +261,7 @@ Operation * LLVM::LLVMSymbolLinkerInterface::materialize(Operation *src, LinkState &state) { auto &derived = LinkerMixin::getDerived(); - // empty append means that we either have single module or that something went - // wrong - if (isAppendingLinkage(derived.getLinkage(src)) && !append.empty()) { + if (isAppendingLinkage(derived.getLinkage(src))) { return derived.appendGlobals(derived.getSymbol(src), state); } return SymbolAttrLinkerInterface::materialize(src, state); @@ -284,6 +295,11 @@ Conflict LLVM::LLVMSymbolLinkerInterface::findConflict( SmallVector LLVM::LLVMSymbolLinkerInterface::dependencies( Operation *op, SymbolTableCollection &collection) const { Operation *module = op->getParentOfType(); + // If this operation is nested inside a region (e.g., inside an AliasOp's + // initializer) or inside a nested module, it won't have dependencies that + // need linking. + if (!module || module->getParentOfType()) + return {}; SymbolTable &st = collection.getSymbolTable(module); SmallVector result; @@ -298,16 +314,23 @@ SmallVector LLVM::LLVMSymbolLinkerInterface::dependencies( // walk and reference the symbols in an attribute. We have to intercept on // these operations. ArrayAttr structors = {}; + ArrayAttr data = {}; if (auto ctor = dyn_cast(op)) { structors = ctor.getCtors(); + data = ctor.getData(); } if (auto dtor = dyn_cast(op)) { structors = dtor.getDtors(); + data = dtor.getData(); } if (structors) { for (auto structor : structors) insertDepIfExists(cast(structor)); + for (auto dataAttr : data) + if (auto symbolRef = dyn_cast(dataAttr)) + insertDepIfExists(symbolRef); + return result; } @@ -500,6 +523,11 @@ LogicalResult LLVM::LLVMSymbolLinkerInterface::finalize(ModuleOp dst) const { return success(); } +LogicalResult LLVM::LLVMSymbolLinkerInterface::moduleOpSummary( + ModuleOp src, SymbolTableCollection &collection) { + return resolveComdats(src, collection); +} + static std::pair getAppendedArrayAttr(llvm::ArrayRef globs, LinkState &state) { @@ -679,28 +707,17 @@ static Operation *appendGlobalOps(ArrayRef globs, llvm_unreachable("unknown value attribute type"); } -static Operation *appendComdatOps(ArrayRef globs, - LLVM::ComdatOp comdat, LinkState &state) { - auto result = cast(state.clone(comdat)); - llvm::StringMap selectors; +Operation *LLVM::LLVMSymbolLinkerInterface::appendComdatOps( + ArrayRef globs, LLVM::ComdatOp comdat, LinkState &state) { + auto result = + state.create(comdat.getLoc(), comdat.getSymName()); - for (auto selector : result.getOps()) { - selectors[selector.getSymName()] = selector; - } + auto guard = OpBuilder::InsertionGuard(state.getBuilder()); + state.getBuilder().setInsertionPointToStart(&result.getBody().front()); + + for (auto &&[name, comdatResPair] : comdatResolution) + state.clone(comdatResPair.selectorOp); - for (auto *glob : globs) { - comdat = dyn_cast(glob); - for (auto &op : comdat.getBody().getOps()) { - auto selector = cast(op); - auto selectorName = selector.getSymName(); - if (selectors.contains(selectorName)) { - continue; - } - auto *cloned = state.clone(selector); - cloned->moveBefore(&result.getBody().front().back()); - selectors[selectorName] = cloned; - } - } return result; } @@ -724,6 +741,200 @@ Operation *LLVM::LLVMSymbolLinkerInterface::appendGlobals(llvm::StringRef glob, llvm_unreachable("unexpected operation"); } +ComdatResolution LLVM::LLVMSymbolLinkerInterface::computeComdatResolution( + Operation *srcSelector, SymbolTableCollection &collection, + Comdat *dstComdat) { + // For reference check llvm/lib/Linker/LinkModules.cpp + // computeResultingSelectionKind + ComdatKind srcKind = getComdatSelectionKind(srcSelector); + ComdatKind dstKind = dstComdat->kind; + bool dstAnyOrLargest = + dstKind == ComdatKind::Any || dstKind == ComdatKind::Largest; + bool srcAnyOrLargest = + srcKind == ComdatKind::Any || srcKind == ComdatKind::Largest; + + ComdatKind resolutionKind; + if (dstAnyOrLargest && srcAnyOrLargest) { + if (dstKind == ComdatKind::Largest || srcKind == ComdatKind::Largest) { + resolutionKind = ComdatKind::Largest; + } else { + resolutionKind = ComdatKind::Any; + } + } else if (srcKind == dstKind) { + resolutionKind = dstKind; + } else { + return ComdatResolution::Failure; + } + + auto computeSize = [&](GlobalOp op) -> llvm::TypeSize { + auto dataLayout = DataLayout(op->getParentOfType()); + return dataLayout.getTypeSize(op.getType()); + }; + + auto getComdatLeader = [&](Operation *selector) -> GlobalOp { + SymbolTable &st = + collection.getSymbolTable(selector->getParentOfType()); + Operation *leader = st.lookup(getSymbol(selector)); + + while (auto alias = dyn_cast_if_present(leader)) + for (AddressOfOp addrOf : alias.getInitializer().getOps()) + leader = st.lookup(addrOf.getGlobalName()); + + if (hasComdat(leader) && + getComdatSymbol(leader).getLeafReference() == getSymbol(leader)) + return mlir::dyn_cast(leader); + + return {}; + }; + + switch (resolutionKind) { + case ComdatKind::Any: + return ComdatResolution::LinkFromDst; + case ComdatKind::NoDeduplicate: + return ComdatResolution::LinkFromBoth; + case ComdatKind::ExactMatch: { + GlobalOp srcLeader = getComdatLeader(srcSelector); + GlobalOp dstLeader = getComdatLeader(dstComdat->selectorOp); + assert(srcLeader && dstLeader && "Couldn't find comdat leader"); + return OperationEquivalence::isEquivalentTo( + dstLeader, srcLeader, + OperationEquivalence::Flags::IgnoreLocations) + ? ComdatResolution::Failure + : ComdatResolution::LinkFromDst; + } + case ComdatKind::Largest: { + GlobalOp srcLeader = getComdatLeader(srcSelector); + GlobalOp dstLeader = getComdatLeader(dstComdat->selectorOp); + assert(srcLeader && dstLeader && "Size based comdat without valid leader"); + return computeSize(srcLeader) > computeSize(dstLeader) + ? ComdatResolution::LinkFromSrc + : ComdatResolution::LinkFromDst; + } + case ComdatKind::SameSize: + GlobalOp srcLeader = getComdatLeader(srcSelector); + GlobalOp dstLeader = getComdatLeader(dstComdat->selectorOp); + assert(srcLeader && dstLeader && "Size based comdat without valid leader"); + return computeSize(srcLeader) == computeSize(dstLeader) + ? ComdatResolution::Failure + : ComdatResolution::LinkFromDst; + } +} + +void LLVM::LLVMSymbolLinkerInterface::dropReplacedComdat(Operation *op) const { + if (auto global = mlir::dyn_cast(op)) { + global.removeValueAttr(); + + Region &initializer = global.getInitializer(); + initializer.dropAllReferences(); + initializer.getBlocks().clear(); + + global.removeComdatAttr(); + global.setLinkage(Linkage::AvailableExternally); + } + + if (auto func = mlir::dyn_cast(op)) { + Region &body = func.getBody(); + body.dropAllReferences(); + body.getBlocks().clear(); + + func.removeComdatAttr(); + func.setLinkage(Linkage::External); + } + + // Handle aliases: just erase them. + // Aliases inherit COMDAT from their aliasee, so if the aliasee's COMDAT + // is dropped, the alias should be dropped too. + if (auto alias = mlir::dyn_cast(op)) { + alias.erase(); + } +} + +void LLVM::LLVMSymbolLinkerInterface::updateNoDeduplicate(Operation *op) { + if (auto global = mlir::dyn_cast(op)) { + global.setVisibility_(LLVM::Visibility::Default); + global.setDsoLocal(true); + global.setLinkage(LLVM::Linkage::Private); + } else { + llvm_unreachable("Only globals should have NoDeduplicate comdat"); + } +} + +LogicalResult LLVM::LLVMSymbolLinkerInterface::resolveComdats( + ModuleOp srcMod, SymbolTableCollection &collection) { + LLVM::ComdatOp srcComdatOp = + collection.getSymbolTable(srcMod).lookup( + "__llvm_global_comdat"); + + // Nothing to do + if (!srcComdatOp) + return success(); + + // TODO: Figure out how to share this map with the rest of the linker + SymbolUserMap srcSymbolUsers(collection, + srcComdatOp->getParentOfType()); + + for (Operation &op : srcComdatOp.getBody().front()) { + auto srcSelector = cast(op); + auto dstComdatIt = comdatResolution.find(getSymbol(&op)); + link::Comdat *dstComdat = + dstComdatIt == comdatResolution.end() ? nullptr : &dstComdatIt->second; + + // If no conflict choose src + ComdatResolution res = + dstComdat ? computeComdatResolution(srcSelector, collection, dstComdat) + : ComdatResolution::LinkFromSrc; + + switch (res) { + case ComdatResolution::LinkFromSrc: { + // COMDAT group is used or dropped as a whole, remove all users of dropped + // COMDAT if present + // Drop all users before replacing the value + if (dstComdat) + for (Operation *dstUser : dstComdat->users) + dropReplacedComdat(dstUser); + ArrayRef users = srcSymbolUsers.getUsers(&op); + comdatResolution[getSymbol(srcSelector)] = + link::Comdat{getComdatSelectionKind(srcSelector), + srcSelector, + {users.begin(), users.end()}}; + break; + } + case ComdatResolution::LinkFromDst: + continue; + case ComdatResolution::LinkFromBoth: { + ArrayRef users = srcSymbolUsers.getUsers(&op); + dstComdat->users.insert(users.begin(), users.end()); + break; + } + case ComdatResolution::Failure: + return failure(); + } + } + return success(); +} + +const link::Comdat * +LLVM::LLVMSymbolLinkerInterface::getComdatResolution(Operation *op) const { + if (hasComdat(op)) { + auto resolutionIt = comdatResolution.find( + getComdatSymbol(op).getLeafReference().getValue()); + return resolutionIt != comdatResolution.end() ? &resolutionIt->second + : nullptr; + } + return nullptr; +} + +bool LLVM::LLVMSymbolLinkerInterface::selectedByComdat(Operation *op) const { + assert(hasComdat(op) && "expected operation with comdat"); + + if (auto comdatIt = comdatResolution.find( + getComdatSymbol(op).getLeafReference().getValue()); + comdatIt != comdatResolution.end()) { + return comdatIt->second.users.contains(op); + } + return false; +} + //===----------------------------------------------------------------------===// // registerLinkerInterface //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinLinkerInterface.cpp b/mlir/lib/IR/BuiltinLinkerInterface.cpp index bd61830b95676..2be3121f8089b 100644 --- a/mlir/lib/IR/BuiltinLinkerInterface.cpp +++ b/mlir/lib/IR/BuiltinLinkerInterface.cpp @@ -37,6 +37,8 @@ class BuiltinLinkerInterface : public ModuleLinkerInterface { LogicalResult summarize(ModuleOp src, unsigned flags, SymbolTableCollection &collection) override { + if (symbolLinkers.moduleOpSummary(src, symbolTableCollection).failed()) + return failure(); // Collect all operations to process in parallel SmallVector ops; src.walk([&](Operation *op) { diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index d3d4582c2046d..e402d478a9c1f 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1376,6 +1376,15 @@ LogicalResult ModuleImport::convertAlias(llvm::GlobalAlias *alias) { aliasOp.setUnnamedAddr(convertUnnamedAddrFromLLVM(alias->getUnnamedAddr())); aliasOp.setVisibility_(convertVisibilityFromLLVM(alias->getVisibility())); + // Aliases inherit COMDAT from their aliasee (per LLVM LangRef). + // The aliasee should be a GlobalValue that may have a COMDAT. + if (llvm::GlobalValue *aliaseeObj = alias->getAliaseeObject()) { + if (auto *aliasee = dyn_cast(aliaseeObj)) { + if (aliasee->hasComdat()) + aliasOp.setComdatAttr(comdatMapping.lookup(aliasee->getComdat())); + } + } + return success(); } diff --git a/mlir/test/mlir-link/adapted/comdat-nodeduplicate-1.mlir b/mlir/test/mlir-link/adapted/comdat-nodeduplicate-1.mlir new file mode 100644 index 0000000000000..d363931562606 --- /dev/null +++ b/mlir/test/mlir-link/adapted/comdat-nodeduplicate-1.mlir @@ -0,0 +1,15 @@ +// RUN: not mlir-link -split-input-file %s 2>&1 | FileCheck %s + +// CHECK: error: Linker error + +llvm.comdat @__llvm_global_comdat { + llvm.comdat_selector @foo nodeduplicate +} +llvm.mlir.global external @foo(43 : i64) comdat(@__llvm_global_comdat::@foo) {addr_space = 0 : i32} : i64 + +// ----- + +llvm.comdat @__llvm_global_comdat { + llvm.comdat_selector @foo nodeduplicate +} +llvm.mlir.global external @foo(43 : i64) comdat(@__llvm_global_comdat::@foo) {addr_space = 0 : i32} : i64 diff --git a/mlir/test/mlir-link/adapted/comdat-nodeduplicate-2.mlir b/mlir/test/mlir-link/adapted/comdat-nodeduplicate-2.mlir new file mode 100644 index 0000000000000..9164378bf3518 --- /dev/null +++ b/mlir/test/mlir-link/adapted/comdat-nodeduplicate-2.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-link %s -split-input-file | FileCheck %s + +// CHECK-DAG: llvm.mlir.global private @foo.0(0 : i64) +// CHECK-DAG: llvm.mlir.global private @bar.0(0 : i64) +// CHECK-DAG: llvm.mlir.global external hidden @foo(2 : i64) +// CHECK-DAG: llvm.mlir.global external @bar(3 : i64) +// CHECK-DAG: llvm.mlir.global weak_odr @qux(4 : i64) +// CHECK-DAG: llvm.mlir.global linkonce @fred(5 : i64) + +llvm.comdat @__llvm_global_comdat { + llvm.comdat_selector @foo nodeduplicate +} +llvm.mlir.global external @foo(2 : i64) comdat(@__llvm_global_comdat::@foo) {addr_space = 0 : i32, alignment = 8 : i64, section = "data"} : i64 +llvm.mlir.global weak @bar(0 : i64) comdat(@__llvm_global_comdat::@foo) {addr_space = 0 : i32, section = "cnts"} : i64 +llvm.mlir.global weak_odr @qux(4 : i64) comdat(@__llvm_global_comdat::@foo) {addr_space = 0 : i32} : i64 + +// ----- + +llvm.comdat @__llvm_global_comdat { + llvm.comdat_selector @foo nodeduplicate +} +llvm.mlir.global weak hidden @foo(0 : i64) comdat(@__llvm_global_comdat::@foo) {addr_space = 0 : i32, dso_local, section = "data"} : i64 +llvm.mlir.global external @bar(3 : i64) comdat(@__llvm_global_comdat::@foo) {addr_space = 0 : i32, alignment = 16 : i64, dso_local, section = "cnts"} : i64 +llvm.mlir.global linkonce @fred(5 : i64) comdat(@__llvm_global_comdat::@foo) {addr_space = 0 : i32} : i64 diff --git a/mlir/test/mlir-link/adapted/comdat.mlir b/mlir/test/mlir-link/adapted/comdat.mlir index e5ec527e8b390..f2e356c2d39e3 100644 --- a/mlir/test/mlir-link/adapted/comdat.mlir +++ b/mlir/test/mlir-link/adapted/comdat.mlir @@ -1,8 +1,5 @@ // RUN: mlir-link %s %p/Inputs/comdat.mlir -o - | FileCheck %s -// unimplemented conflict resolution -// XFAIL: * - module { llvm.comdat @__llvm_global_comdat { llvm.comdat_selector @foo largest @@ -23,13 +20,13 @@ module { } // CHECK-DAG: llvm.comdat_selector @qux largest -// CHECK-DAG: llvm.comdat_selector @foo comdat largest +// CHECK-DAG: llvm.comdat_selector @foo largest // CHECK-DAG: llvm.comdat_selector @any any -// CHECK-DAG: llvm.mlir.global @foo(43 : i64) comdat{{$}} -// CHECK-DAG: llvm.mlir.global @qux(12 : i64) comdat{{$}} -// CHECK-DAG: llvm.mlir.global @any(6 : i64) comdat{{$}} -// CHECK-NOT: llvm.mlir.global @in_unselected_group(13 : i32) comdat(@__llvm_global_comdat::@qux) - +// CHECK-DAG: llvm.mlir.global external @foo(43 : i64) comdat(@__llvm_global_comdat::@foo) +// CHECK-DAG: llvm.mlir.global external @qux(12 : i64) comdat(@__llvm_global_comdat::@qux) +// CHECK-DAG: llvm.mlir.global external @any(6 : i64) comdat(@__llvm_global_comdat::@any) // CHECK-DAG: llvm.func @baz() -> i32 comdat(@__llvm_global_comdat::@qux) { // CHECK-DAG: llvm.func @bar() -> i32 comdat(@__llvm_global_comdat::@foo) { +// CHECK-NOT: llvm.mlir.global @in_unselected_group(13 : i32) comdat(@__llvm_global_comdat::@qux) + diff --git a/mlir/test/mlir-link/adapted/comdat_group.mlir b/mlir/test/mlir-link/adapted/comdat_group.mlir index 4df2fbbdb26d3..e94480a3d7076 100644 --- a/mlir/test/mlir-link/adapted/comdat_group.mlir +++ b/mlir/test/mlir-link/adapted/comdat_group.mlir @@ -1,12 +1,9 @@ // RUN: mlir-link %s -o - | FileCheck %s -// comdat disappears when linking -// XFAIL: * - -// CHECK: llvm.comdat_selector @linkoncecomdat any -// CHECK: llvm.mlir.global linkonce @linkoncecomdat (2 : i32) -// CHECK: llvm.mlir.global linkonce @linkoncecomdat_unref_var (2 : i32) comdat(@__llvm_global_comdat::@linkoncecomdat) -// CHECK: llvm.func linkonce @linkoncecomdat_unref_func() comdat(@__llvm_global_comdat::@linkoncecomdat) { +// CHECK-DAG: llvm.comdat_selector @linkoncecomdat any +// CHECK-DAG: llvm.mlir.global linkonce @linkoncecomdat(2 : i32) +// CHECK-DAG: llvm.mlir.global linkonce @linkoncecomdat_unref_var(2 : i32) comdat(@__llvm_global_comdat::@linkoncecomdat) +// CHECK-DAG: llvm.func linkonce @linkoncecomdat_unref_func() comdat(@__llvm_global_comdat::@linkoncecomdat) { module { llvm.comdat @__llvm_global_comdat { diff --git a/mlir/test/mlir-link/adapted/ctors2.mlir b/mlir/test/mlir-link/adapted/ctors2.mlir index eb3ab31ee1626..ee4c15ded6a43 100644 --- a/mlir/test/mlir-link/adapted/ctors2.mlir +++ b/mlir/test/mlir-link/adapted/ctors2.mlir @@ -1,10 +1,7 @@ // RUN: mlir-link %s %S/Inputs/ctors2.mlir | FileCheck %s // CHECK: llvm.mlir.global_ctors ctors = [], priorities = [], data = [] -// CHECK: llvm.mlir.global external @foo(1 : i8) comdat - -// comdat not yet supported -// XFAIL: * +// CHECK: llvm.mlir.global external @foo(0 : i8) comdat module { llvm.comdat @__llvm_global_comdat { diff --git a/mlir/test/mlir-link/adapted/ctors3.mlir b/mlir/test/mlir-link/adapted/ctors3.mlir index 5d355415ff97e..a685de7a61b50 100644 --- a/mlir/test/mlir-link/adapted/ctors3.mlir +++ b/mlir/test/mlir-link/adapted/ctors3.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-link %s %S/Inputs/ctors2.mlir | FileCheck %s +// RUN: mlir-link %s %S/Inputs/ctors3.mlir | FileCheck %s // CHECK: llvm.mlir.global_ctors ctors = [], priorities = [], data = [] // CHECK llvm.mlir.global external @foo() comdat diff --git a/mlir/test/mlir-link/adapted/visibility.mlir b/mlir/test/mlir-link/adapted/visibility.mlir index bf0b8ccc1ffb2..ea78561c47c07 100644 --- a/mlir/test/mlir-link/adapted/visibility.mlir +++ b/mlir/test/mlir-link/adapted/visibility.mlir @@ -1,9 +1,6 @@ // RUN: mlir-link %s %p/Inputs/visibility.mlir -o - | FileCheck %s // RUN: mlir-link %p/Inputs/visibility.mlir %s -o - | FileCheck %s -// mlir.alias not yet supported; comdat not yet supported -// XFAIL: * - module { llvm.comdat @__llvm_global_comdat { llvm.comdat_selector @c1 any @@ -14,19 +11,19 @@ module { llvm.mlir.global external @v2(0 : i32) {addr_space = 0 : i32} : i32 // CHECK-DAG: llvm.mlir.global external hidden @v3(0 : i32) llvm.mlir.global external protected @v3(0 : i32) {addr_space = 0 : i32, dso_local} : i32 -// CHECK-DAG: llvm.mlir.global external hidden @v4(0 : i32) +// CHECK-DAG: llvm.mlir.global external hidden @v4(1 : i32) llvm.mlir.global external @v4(1 : i32) comdat(@__llvm_global_comdat::@c1) {addr_space = 0 : i32} : i32 -// CHECK-DAG: llvm.mlir.alias external hidden @a1: i32 { +// CHECK-DAG: llvm.mlir.alias external hidden @a1 : i32 { llvm.mlir.alias external @a1 : i32 { %0 = llvm.mlir.addressof @v1 : !llvm.ptr llvm.return %0 : !llvm.ptr } -// CHECK-DAG: llvm.mlir.alias external protected @a2: i32 { +// CHECK-DAG: llvm.mlir.alias external protected @a2 : i32 { llvm.mlir.alias external @a2 : i32 { %0 = llvm.mlir.addressof @v2 : !llvm.ptr llvm.return %0 : !llvm.ptr } -// CHECK-DAG: llvm.mlir.alias external hidden @a3: i32 { +// CHECK-DAG: llvm.mlir.alias external hidden @a3 {dso_local} : i32 { llvm.mlir.alias external protected @a3 {dso_local} : i32 { %0 = llvm.mlir.addressof @v3 : !llvm.ptr llvm.return %0 : !llvm.ptr @@ -39,7 +36,7 @@ module { llvm.func @f2() { llvm.return } -// CHECK-DAG: llvm.func hidden @f3() { +// CHECK-DAG: llvm.func hidden @f3() attributes {dso_local} { llvm.func protected @f3() attributes {dso_local} { llvm.return }