35 #include "llvm/ADT/PostOrderIterator.h"
36 #include "llvm/ADT/SetVector.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
39 #include "llvm/IR/BasicBlock.h"
40 #include "llvm/IR/CFG.h"
41 #include "llvm/IR/Constants.h"
42 #include "llvm/IR/DerivedTypes.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/InlineAsm.h"
45 #include "llvm/IR/IntrinsicsNVPTX.h"
46 #include "llvm/IR/LLVMContext.h"
47 #include "llvm/IR/MDBuilder.h"
48 #include "llvm/IR/Module.h"
49 #include "llvm/IR/Verifier.h"
50 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
51 #include "llvm/Transforms/Utils/Cloning.h"
52 #include "llvm/Transforms/Utils/ModuleUtils.h"
59 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
66 std::optional<Location> loc = std::nullopt) {
71 std::string llvmDataLayout;
72 llvm::raw_string_ostream layoutStream(llvmDataLayout);
73 for (DataLayoutEntryInterface entry : attribute.getEntries()) {
74 auto key = llvm::dyn_cast_if_present<StringAttr>(entry.getKey());
77 if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) {
78 auto value = cast<StringAttr>(entry.getValue());
80 value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle;
81 layoutStream <<
"-" << (isLittleEndian ?
"e" :
"E");
85 if (key.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey) {
86 auto value = cast<IntegerAttr>(entry.getValue());
87 uint64_t space = value.getValue().getZExtValue();
91 layoutStream <<
"-A" << space;
95 if (key.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey) {
96 auto value = cast<IntegerAttr>(entry.getValue());
97 uint64_t alignment = value.getValue().getZExtValue();
101 layoutStream <<
"-S" << alignment;
102 layoutStream.flush();
105 emitError(*loc) <<
"unsupported data layout key " << key;
112 for (DataLayoutEntryInterface entry : attribute.getEntries()) {
113 auto type = llvm::dyn_cast_if_present<Type>(entry.getKey());
117 if (isa<IndexType>(type))
122 .Case<IntegerType, Float16Type, Float32Type, Float64Type,
124 if (
auto intType = dyn_cast<IntegerType>(type)) {
125 if (intType.getSignedness() != IntegerType::Signless)
127 <<
"unsupported data layout for non-signless integer "
137 layoutStream << size <<
":" << abi;
138 if (abi != preferred)
139 layoutStream <<
":" << preferred;
142 .Case([&](LLVMPointerType ptrType) {
143 layoutStream <<
"p" << ptrType.getAddressSpace() <<
":";
148 layoutStream << size <<
":" << abi <<
":" << preferred;
151 layoutStream <<
":" << *index;
154 .Default([loc](
Type type) {
156 <<
"unsupported type in data layout: " << type;
161 layoutStream.flush();
162 StringRef layoutSpec(llvmDataLayout);
163 if (layoutSpec.startswith(
"-"))
164 layoutSpec = layoutSpec.drop_front();
166 return llvm::DataLayout(layoutSpec);
173 static llvm::Constant *
178 llvm::Constant *result = constants.front();
179 constants = constants.drop_front();
183 llvm::Type *elementType;
184 if (
auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
185 elementType = arrayTy->getElementType();
186 }
else if (
auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
187 elementType = vectorTy->getElementType();
189 emitError(loc) <<
"expected sequential LLVM types wrapping a scalar";
194 nested.reserve(shape.front());
195 for (int64_t i = 0; i < shape.front(); ++i) {
202 if (shape.size() == 1 && type->isVectorTy())
211 if (
auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
212 type = arrayTy->getElementType();
213 }
else if (
auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
214 type = vectorTy->getElementType();
228 static llvm::Constant *
230 llvm::Type *llvmType,
232 if (!denseElementsAttr)
236 if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType))
239 ShapedType type = denseElementsAttr.
getType();
240 if (type.getNumElements() == 0)
248 unsigned elementByteSize = denseElementsAttr.
getRawData().size() /
250 if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits())
255 bool hasVectorElementType = isa<VectorType>(type.getElementType());
256 unsigned numAggregates =
258 (hasVectorElementType ? 1
259 : denseElementsAttr.
getType().getShape().back());
261 if (!hasVectorElementType)
262 outerShape = outerShape.drop_back();
265 if (denseElementsAttr.
isSplat() &&
266 (isa<VectorType>(type) || hasVectorElementType)) {
270 llvm::Constant *splatVector =
271 llvm::ConstantDataVector::getSplat(0, splatValue);
276 if (denseElementsAttr.
isSplat())
281 std::function<llvm::Constant *(StringRef)> buildCstData;
282 if (isa<TensorType>(type)) {
283 auto vectorElementType = dyn_cast<VectorType>(type.getElementType());
284 if (vectorElementType && vectorElementType.getRank() == 1) {
285 buildCstData = [&](StringRef data) {
286 return llvm::ConstantDataVector::getRaw(
287 data, vectorElementType.getShape().back(), innermostLLVMType);
289 }
else if (!vectorElementType) {
290 buildCstData = [&](StringRef data) {
291 return llvm::ConstantDataArray::getRaw(data, type.getShape().back(),
295 }
else if (isa<VectorType>(type)) {
296 buildCstData = [&](StringRef data) {
297 return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
307 unsigned aggregateSize = denseElementsAttr.
getType().getShape().back() *
308 (innermostLLVMType->getScalarSizeInBits() / 8);
309 constants.reserve(numAggregates);
310 for (
unsigned i = 0; i < numAggregates; ++i) {
311 StringRef data(denseElementsAttr.
getRawData().data() + i * aggregateSize,
313 constants.push_back(buildCstData(data));
330 if (
auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
331 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
332 if (!arrayAttr || arrayAttr.size() != 2) {
333 emitError(loc,
"expected struct type to be a complex number");
336 llvm::Type *elementType = structType->getElementType(0);
337 llvm::Constant *real =
341 llvm::Constant *imag =
349 if (
auto intAttr = dyn_cast<IntegerAttr>(attr))
352 intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));
353 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
354 const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
359 unsigned floatWidth = APFloat::getSizeInBits(sem);
360 if (llvmType->isIntegerTy(floatWidth))
362 floatAttr.getValue().bitcastToAPInt());
364 llvm::Type::getFloatingPointTy(llvmType->getContext(),
365 floatAttr.getValue().getSemantics())) {
366 emitError(loc,
"FloatAttr does not match expected type of the constant");
371 if (
auto funcAttr = dyn_cast<FlatSymbolRefAttr>(attr))
372 return llvm::ConstantExpr::getBitCast(
374 if (
auto splatAttr = dyn_cast<SplatElementsAttr>(attr)) {
375 llvm::Type *elementType;
376 uint64_t numElements;
377 bool isScalable =
false;
378 if (
auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
379 elementType = arrayTy->getElementType();
380 numElements = arrayTy->getNumElements();
381 }
else if (
auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) {
382 elementType = fVectorTy->getElementType();
383 numElements = fVectorTy->getNumElements();
384 }
else if (
auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
385 elementType = sVectorTy->getElementType();
386 numElements = sVectorTy->getMinNumElements();
389 llvm_unreachable(
"unrecognized constant vector type");
394 bool elementTypeSequential =
395 isa<llvm::ArrayType, llvm::VectorType>(elementType);
398 elementTypeSequential ? splatAttr
400 loc, moduleTranslation);
403 if (llvmType->isVectorTy())
404 return llvm::ConstantVector::getSplat(
406 if (llvmType->isArrayTy()) {
414 if (llvm::Constant *result =
416 llvmType, moduleTranslation)) {
421 if (
auto elementsAttr = dyn_cast<ElementsAttr>(attr)) {
422 assert(elementsAttr.getShapedType().hasStaticShape());
423 assert(!elementsAttr.getShapedType().getShape().empty() &&
424 "unexpected empty elements attribute shape");
427 constants.reserve(elementsAttr.getNumElements());
429 for (
auto n : elementsAttr.getValues<
Attribute>()) {
432 if (!constants.back())
437 constantsRef, elementsAttr.getShapedType().getShape(), llvmType, loc);
438 assert(constantsRef.empty() &&
"did not consume all elemental constants");
442 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
446 stringAttr.getValue().size()});
448 emitError(loc,
"unsupported constant value");
452 ModuleTranslation::ModuleTranslation(
Operation *module,
453 std::unique_ptr<llvm::Module> llvmModule)
454 : mlirModule(module), llvmModule(std::move(llvmModule)),
458 *this, *this->llvmModule)),
459 typeTranslator(this->llvmModule->
getContext()),
462 "mlirModule should honor LLVM's module semantics.");
465 ModuleTranslation::~ModuleTranslation() {
467 ompBuilder->finalize();
472 toProcess.push_back(®ion);
473 while (!toProcess.empty()) {
474 Region *current = toProcess.pop_back_val();
475 for (
Block &block : *current) {
476 blockMapping.erase(&block);
477 for (
Value arg : block.getArguments())
478 valueMapping.erase(arg);
481 valueMapping.erase(value);
483 branchMapping.erase(&op);
484 if (isa<LLVM::GlobalOp>(op))
485 globalsMapping.erase(&op);
497 unsigned numArguments,
unsigned index) {
499 if (isa<LLVM::BrOp>(terminator))
506 auto branch = cast<BranchOpInterface>(terminator);
509 (!seenSuccessors.contains(successor) || successorOperands.
empty()) &&
510 "successors with arguments in LLVM branches must be different blocks");
511 seenSuccessors.insert(successor);
517 if (
auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator)) {
520 return condBranchOp.getSuccessor(0) == current
521 ? condBranchOp.getTrueDestOperands()[index]
522 : condBranchOp.getFalseDestOperands()[index];
525 if (
auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) {
528 if (switchOp.getDefaultDestination() == current)
529 return switchOp.getDefaultOperands()[index];
531 if (i.value() == current)
532 return switchOp.getCaseOperands(i.index())[index];
535 if (
auto invokeOp = dyn_cast<LLVM::InvokeOp>(terminator)) {
536 return invokeOp.getNormalDest() == current
537 ? invokeOp.getNormalDestOperands()[index]
538 : invokeOp.getUnwindDestOperands()[index];
542 "only branch, switch or invoke operations can be terminators "
543 "of a block that has successors");
551 for (
Block &bb : llvm::drop_begin(region)) {
552 llvm::BasicBlock *llvmBB = state.lookupBlock(&bb);
553 auto phis = llvmBB->phis();
554 auto numArguments = bb.getNumArguments();
555 assert(numArguments == std::distance(phis.begin(), phis.end()));
557 for (
auto *pred : bb.getPredecessors()) {
563 llvm::Instruction *terminator =
564 state.lookupBranch(pred->getTerminator());
565 assert(terminator &&
"missing the mapping for a terminator");
567 &bb, pred, numArguments, index)),
568 terminator->getParent());
580 for (
Block &b : region) {
581 if (blocks.count(&b) == 0) {
582 llvm::ReversePostOrderTraversal<Block *> traversal(&b);
583 blocks.insert(traversal.begin(), traversal.end());
586 assert(blocks.size() == region.
getBlocks().size() &&
587 "some blocks are not sorted");
595 llvm::Module *module = builder.GetInsertBlock()->getModule();
596 llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic, tys);
597 return builder.CreateCall(fn, args);
603 ModuleTranslation::convertOperation(
Operation &op,
604 llvm::IRBuilderBase &builder) {
607 return op.
emitError(
"cannot be converted to LLVM IR: missing "
608 "`LLVMTranslationDialectInterface` registration for "
613 return op.
emitError(
"LLVM Translation failed for operation: ")
616 return convertDialectAttributes(&op);
627 llvm::IRBuilderBase &builder) {
629 auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram();
637 if (!ignoreArguments) {
639 unsigned numPredecessors =
640 std::distance(predecessors.begin(), predecessors.end());
642 auto wrappedType = arg.getType();
645 "block argument does not have an LLVM type");
647 llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
653 for (
auto &op : bb) {
655 builder.SetCurrentDebugLocation(
656 debugTranslation->translateLoc(op.
getLoc(), subprogram));
658 if (
failed(convertOperation(op, builder)))
662 if (
auto iface = dyn_cast<BranchWeightOpInterface>(op))
681 llvm::Constant *cst) {
682 return (linkage == llvm::GlobalVariable::ExternalLinkage && !cst) ||
683 linkage == llvm::GlobalVariable::ExternalWeakLinkage;
689 llvm::GlobalValue *gv) {
690 if (dsoLocalRequested)
691 gv->setDSOLocal(
true);
697 for (
auto op :
getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
699 llvm::Constant *cst =
nullptr;
700 if (op.getValueOrNull()) {
703 if (
auto strAttr = dyn_cast_or_null<StringAttr>(op.getValueOrNull())) {
704 cst = llvm::ConstantDataArray::getString(
705 llvmModule->getContext(), strAttr.getValue(),
false);
706 type = cst->getType();
713 auto linkage = convertLinkageToLLVM(op.getLinkage());
714 auto addrSpace = op.getAddrSpace();
720 if (!dropInitializer && !cst)
722 else if (dropInitializer && cst)
725 auto *var =
new llvm::GlobalVariable(
726 *llvmModule, type, op.getConstant(), linkage, cst, op.getSymName(),
728 op.getThreadLocal_() ? llvm::GlobalValue::GeneralDynamicTLSModel
729 : llvm::GlobalValue::NotThreadLocal,
732 if (std::optional<mlir::SymbolRefAttr> comdat = op.getComdat()) {
733 auto selectorOp = cast<ComdatSelectorOp>(
735 var->setComdat(comdatMapping.lookup(selectorOp));
738 if (op.getUnnamedAddr().has_value())
739 var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr()));
741 if (op.getSection().has_value())
742 var->setSection(*op.getSection());
746 std::optional<uint64_t> alignment = op.getAlignment();
747 if (alignment.has_value())
748 var->setAlignment(llvm::MaybeAlign(alignment.value()));
750 var->setVisibility(convertVisibilityToLLVM(op.getVisibility_()));
752 globalsMapping.try_emplace(op, var);
758 for (
auto op :
getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
759 if (
Block *initializer = op.getInitializerBlock()) {
760 llvm::IRBuilder<> builder(llvmModule->getContext());
761 for (
auto &op : initializer->without_terminator()) {
762 if (
failed(convertOperation(op, builder)) ||
766 ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
767 llvm::Constant *cst =
768 cast<llvm::Constant>(
lookupValue(ret.getOperand(0)));
769 auto *global = cast<llvm::GlobalVariable>(
lookupGlobal(op));
771 global->setInitializer(cst);
777 auto ctorOp = dyn_cast<GlobalCtorsOp>(op);
778 auto dtorOp = dyn_cast<GlobalDtorsOp>(op);
779 if (!ctorOp && !dtorOp)
781 auto range = ctorOp ? llvm::zip(ctorOp.getCtors(), ctorOp.getPriorities())
782 :
llvm::zip(dtorOp.getDtors(), dtorOp.getPriorities());
783 auto appendGlobalFn =
784 ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors;
785 for (
auto symbolAndPriority : range) {
787 cast<FlatSymbolRefAttr>(std::get<0>(symbolAndPriority)).getValue());
788 appendGlobalFn(*llvmModule, f,
789 cast<IntegerAttr>(std::get<1>(symbolAndPriority)).getInt(),
794 for (
auto op :
getModuleBody(mlirModule).getOps<LLVM::GlobalOp>())
795 if (
failed(convertDialectAttributes(op)))
808 llvm::Function *llvmFunc,
810 StringRef value = StringRef()) {
811 auto kind = llvm::Attribute::getAttrKindFromName(key);
813 llvmFunc->addFnAttr(key, value);
817 if (llvm::Attribute::isIntAttrKind(kind)) {
819 return emitError(loc) <<
"LLVM attribute '" << key <<
"' expects a value";
822 if (!value.getAsInteger(0, result))
826 llvmFunc->addFnAttr(key, value);
831 return emitError(loc) <<
"LLVM attribute '" << key
832 <<
"' does not expect a value, found '" << value
835 llvmFunc->addFnAttr(kind);
848 llvm::Function *llvmFunc) {
853 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
860 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
861 if (!arrayAttr || arrayAttr.size() != 2)
863 <<
"expected 'passthrough' to contain string or array attributes";
865 auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]);
866 auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]);
867 if (!keyAttr || !valueAttr)
869 <<
"expected arrays within 'passthrough' to contain two strings";
872 valueAttr.getValue())))
878 LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
881 blockMapping.clear();
882 valueMapping.clear();
883 branchMapping.clear();
887 debugTranslation->translate(func, *llvmFunc);
890 for (
auto [mlirArg, llvmArg] :
891 llvm::zip(func.getArguments(), llvmFunc->args()))
895 if (func.getPersonality()) {
896 llvm::Type *ty = llvm::Type::getInt8PtrTy(llvmFunc->getContext());
897 if (llvm::Constant *pfunc =
getLLVMConstant(ty, func.getPersonalityAttr(),
898 func.getLoc(), *
this))
899 llvmFunc->setPersonalityFn(pfunc);
902 if (std::optional<StringRef> section = func.getSection())
903 llvmFunc->setSection(*section);
905 if (func.getArmStreaming())
906 llvmFunc->addFnAttr(
"aarch64_pstate_sm_enabled");
907 else if (func.getArmLocallyStreaming())
908 llvmFunc->addFnAttr(
"aarch64_pstate_sm_body");
910 if (
auto attr = func.getVscaleRange())
911 llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
913 attr->getMaxRange().getInt()));
916 llvm::LLVMContext &llvmContext = llvmFunc->getContext();
917 for (
auto &bb : func) {
918 auto *llvmBB = llvm::BasicBlock::Create(llvmContext);
919 llvmBB->insertInto(llvmFunc);
926 for (
Block *bb : blocks) {
927 llvm::IRBuilder<> builder(llvmContext);
937 return convertDialectAttributes(func);
950 llvm::Function *llvmFunc) {
951 if (!func.getMemory())
954 MemoryEffectsAttr memEffects = func.getMemoryAttr();
957 llvm::MemoryEffects newMemEffects =
958 llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
959 convertModRefInfoToLLVM(memEffects.getArgMem()));
960 newMemEffects |= llvm::MemoryEffects(
961 llvm::MemoryEffects::Location::InaccessibleMem,
962 convertModRefInfoToLLVM(memEffects.getInaccessibleMem()));
964 llvm::MemoryEffects(llvm::MemoryEffects::Location::Other,
965 convertModRefInfoToLLVM(memEffects.getOther()));
966 llvmFunc->setMemoryEffects(newMemEffects);
970 ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) {
971 llvm::AttrBuilder attrBuilder(llvmModule->getContext());
974 Attribute attr = paramAttrs.get(mlirName);
980 llvm::Attribute::AttrKind llvmKindCap = llvmKind;
983 .Case<TypeAttr>([&](
auto typeAttr) {
984 attrBuilder.addTypeAttr(llvmKindCap,
987 .Case<IntegerAttr>([&](
auto intAttr) {
988 attrBuilder.addRawIntAttr(llvmKindCap, intAttr.getInt());
990 .Case<UnitAttr>([&](
auto) { attrBuilder.addAttribute(llvmKindCap); });
996 LogicalResult ModuleTranslation::convertFunctionSignatures() {
999 for (
auto function :
getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
1000 llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
1002 cast<llvm::FunctionType>(
convertType(
function.getFunctionType())));
1003 llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee());
1004 llvmFunc->setLinkage(convertLinkageToLLVM(
function.getLinkage()));
1005 llvmFunc->setCallingConv(convertCConvToLLVM(
function.getCConv()));
1013 if (std::optional<uint64_t> entryCount =
function.getFunctionEntryCount())
1014 llvmFunc->setEntryCount(entryCount.value());
1017 if (ArrayAttr allResultAttrs =
function.getAllResultAttrs()) {
1018 DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]);
1019 llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs));
1025 llvm::AttrBuilder attrBuilder = convertParameterAttrs(argAttrs);
1026 llvmArg.addAttrs(attrBuilder);
1032 function.getLoc(),
function.getPassthrough(), llvmFunc)))
1036 llvmFunc->setVisibility(convertVisibilityToLLVM(
function.getVisibility_()));
1039 if (std::optional<mlir::SymbolRefAttr> comdat =
function.getComdat()) {
1040 auto selectorOp = cast<ComdatSelectorOp>(
1042 llvmFunc->setComdat(comdatMapping.lookup(selectorOp));
1045 if (
auto gc =
function.getGarbageCollector())
1046 llvmFunc->setGC(gc->str());
1048 if (
auto unnamedAddr =
function.getUnnamedAddr())
1049 llvmFunc->setUnnamedAddr(convertUnnamedAddrToLLVM(*unnamedAddr));
1051 if (
auto alignment =
function.getAlignment())
1052 llvmFunc->setAlignment(llvm::MaybeAlign(*alignment));
1060 for (
auto function :
getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
1063 if (
function.isExternal()) {
1064 if (
failed(convertDialectAttributes(
function)))
1069 if (
failed(convertOneFunction(
function)))
1077 for (
auto comdatOp :
getModuleBody(mlirModule).getOps<ComdatOp>()) {
1078 for (
auto selectorOp : comdatOp.getOps<ComdatSelectorOp>()) {
1080 if (module->getComdatSymbolTable().contains(selectorOp.getSymName()))
1082 <<
"comdat selection symbols must be unique even in different "
1084 llvm::Comdat *comdat = module->getOrInsertComdat(selectorOp.getSymName());
1085 comdat->setSelectionKind(convertComdatToLLVM(selectorOp.getComdat()));
1086 comdatMapping.try_emplace(selectorOp, comdat);
1093 llvm::Instruction *inst) {
1094 if (llvm::MDNode *node = loopAnnotationTranslation->getAccessGroups(op))
1095 inst->setMetadata(llvm::LLVMContext::MD_access_group, node);
1100 auto [scopeIt, scopeInserted] =
1101 aliasScopeMetadataMapping.try_emplace(aliasScopeAttr,
nullptr);
1103 return scopeIt->second;
1104 llvm::LLVMContext &ctx = llvmModule->getContext();
1106 auto [domainIt, insertedDomain] = aliasDomainMetadataMapping.try_emplace(
1107 aliasScopeAttr.getDomain(),
nullptr);
1108 if (insertedDomain) {
1111 operands.push_back({});
1112 if (StringAttr description = aliasScopeAttr.getDomain().getDescription())
1116 domainIt->second->replaceOperandWith(0, domainIt->second);
1119 assert(domainIt->second &&
"Scope's domain should already be valid");
1122 operands.push_back({});
1123 operands.push_back(domainIt->second);
1124 if (StringAttr description = aliasScopeAttr.getDescription())
1128 scopeIt->second->replaceOperandWith(0, scopeIt->second);
1129 return scopeIt->second;
1135 nodes.reserve(aliasScopeAttrs.size());
1136 for (AliasScopeAttr aliasScopeAttr : aliasScopeAttrs)
1142 llvm::Instruction *inst) {
1143 auto populateScopeMetadata = [&](ArrayAttr aliasScopeAttrs,
unsigned kind) {
1144 if (!aliasScopeAttrs || aliasScopeAttrs.empty())
1147 llvm::to_vector(aliasScopeAttrs.getAsRange<AliasScopeAttr>()));
1148 inst->setMetadata(kind, node);
1151 populateScopeMetadata(op.getAliasScopesOrNull(),
1152 llvm::LLVMContext::MD_alias_scope);
1153 populateScopeMetadata(op.getNoAliasScopesOrNull(),
1154 llvm::LLVMContext::MD_noalias);
1157 llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr)
const {
1158 return tbaaMetadataMapping.lookup(tbaaAttr);
1162 llvm::Instruction *inst) {
1163 ArrayAttr tagRefs = op.getTBAATagsOrNull();
1164 if (!tagRefs || tagRefs.empty())
1171 if (tagRefs.size() > 1) {
1172 op.
emitWarning() <<
"TBAA access tags were not translated, because LLVM "
1173 "IR only supports a single tag per instruction";
1177 llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0]));
1178 inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
1187 assert(inst &&
"expected the operation to have a mapping to an instruction");
1190 llvm::LLVMContext::MD_prof,
1195 llvm::LLVMContext &ctx = llvmModule->getContext();
1207 walker.
addWalk([&](TBAARootAttr root) {
1208 tbaaMetadataMapping.insert(
1212 walker.
addWalk([&](TBAATypeDescriptorAttr descriptor) {
1215 for (TBAAMemberAttr member : descriptor.getMembers()) {
1216 operands.push_back(tbaaMetadataMapping.lookup(member.getTypeDesc()));
1224 walker.
addWalk([&](TBAATagAttr tag) {
1227 operands.push_back(tbaaMetadataMapping.lookup(tag.getBaseType()));
1228 operands.push_back(tbaaMetadataMapping.lookup(tag.getAccessType()));
1232 if (tag.getConstant())
1239 mlirModule->
walk([&](AliasAnalysisOpInterface analysisOpInterface) {
1240 if (
auto attr = analysisOpInterface.getTBAATagsOrNull())
1248 llvm::Instruction *inst) {
1249 LoopAnnotationAttr attr =
1251 .Case<LLVM::BrOp, LLVM::CondBrOp>(
1252 [](
auto branchOp) {
return branchOp.getLoopAnnotationAttr(); });
1255 llvm::MDNode *loopMD =
1256 loopAnnotationTranslation->translateLoopAnnotation(attr, op);
1257 inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
1267 remapped.reserve(values.size());
1268 for (
Value v : values)
1275 ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(*llvmModule);
1276 ompBuilder->initialize();
1281 ompBuilder->setConfig(llvm::OpenMPIRBuilderConfig(
1289 return ompBuilder.get();
1293 llvm::DILocalScope *scope) {
1294 return debugTranslation->translateLoc(loc, scope);
1298 return debugTranslation->translate(attr);
1303 return llvmModule->getOrInsertNamedMetadata(name);
1306 void ModuleTranslation::StackFrame::anchor() {}
1308 static std::unique_ptr<llvm::Module>
1312 auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext);
1313 if (
auto dataLayoutAttr =
1315 llvmModule->setDataLayout(cast<StringAttr>(dataLayoutAttr).getValue());
1318 if (
auto iface = dyn_cast<DataLayoutOpInterface>(m)) {
1319 if (DataLayoutSpecInterface spec = iface.getDataLayoutSpec()) {
1323 }
else if (
auto mod = dyn_cast<ModuleOp>(m)) {
1324 if (DataLayoutSpecInterface spec =
mod.getDataLayoutSpec()) {
1329 if (
failed(llvmDataLayout))
1331 llvmModule->setDataLayout(*llvmDataLayout);
1333 if (
auto targetTripleAttr =
1335 llvmModule->setTargetTriple(cast<StringAttr>(targetTripleAttr).getValue());
1339 llvm::IRBuilder<> builder(llvmContext);
1340 llvmModule->getOrInsertFunction(
"malloc", builder.getInt8PtrTy(),
1341 builder.getInt64Ty());
1342 llvmModule->getOrInsertFunction(
"free", builder.getVoidTy(),
1343 builder.getInt8PtrTy());
1348 std::unique_ptr<llvm::Module>
1352 module->
emitOpError(
"can not be translated to an LLVMIR module");
1356 std::unique_ptr<llvm::Module> llvmModule =
1364 llvm::IRBuilder<> llvmBuilder(llvmContext);
1370 if (
failed(translator.convertOperation(*module, llvmBuilder)))
1373 if (
failed(translator.convertComdats()))
1375 if (
failed(translator.convertFunctionSignatures()))
1377 if (
failed(translator.convertGlobals()))
1379 if (
failed(translator.createTBAAMetadata()))
1384 if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp, LLVM::GlobalCtorsOp,
1385 LLVM::GlobalDtorsOp, LLVM::ComdatOp>(&o) &&
1387 failed(translator.convertOperation(o, llvmBuilder))) {
1395 if (
failed(translator.convertFunctions()))
1398 if (llvm::verifyModule(*translator.llvmModule, &llvm::errs()))
1401 return std::move(translator.llvmModule);
static MLIRContext * getContext(OpFoldResult val)
static Value getPHISourceValue(Block *current, Block *pred, unsigned numArguments, unsigned index)
Get the SSA value passed to the current block from the terminator operation of its predecessor.
static llvm::Constant * convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, llvm::Type *llvmType, const ModuleTranslation &moduleTranslation)
Convert a dense elements attribute to an LLVM IR constant using its raw data storage if possible.
static Block & getModuleBody(Operation *module)
A helper method to get the single Block in an operation honoring LLVM's module requirements.
static void addRuntimePreemptionSpecifier(bool dsoLocalRequested, llvm::GlobalValue *gv)
Sets the runtime preemption specifier of gv to dso_local if dsoLocalRequested is true,...
static LogicalResult checkedAddLLVMFnAttribute(Location loc, llvm::Function *llvmFunc, StringRef key, StringRef value=StringRef())
Attempts to add an attribute identified by key, optionally with the given value to LLVM function llvm...
static void convertFunctionAttributes(LLVMFuncOp func, llvm::Function *llvmFunc)
Converts the function attributes from LLVMFuncOp and attaches them to the llvm::Function.
static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage, llvm::Constant *cst)
A helper method to decide if a constant must not be set as a global variable initializer.
static llvm::Type * getInnermostElementType(llvm::Type *type)
Returns the first non-sequential type nested in sequential types.
static std::unique_ptr< llvm::Module > prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, StringRef name)
static LogicalResult forwardPassthroughAttributes(Location loc, std::optional< ArrayAttr > attributes, llvm::Function *llvmFunc)
Attaches the attributes listed in the given array attribute to llvmFunc.
static llvm::Constant * buildSequentialConstant(ArrayRef< llvm::Constant * > &constants, ArrayRef< int64_t > shape, llvm::Type *type, Location loc)
Builds a constant of a sequential LLVM type type, potentially containing other sequential types recur...
void addWalk(WalkFn< Attribute > &&fn)
Register a walk function for a given attribute or type.
WalkResult walk(T element)
Walk the given attribute/type, and recursively walk any sub elements.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< pred_iterator > getPredecessors()
BlockArgListType getArguments()
The main mechanism for performing data layout queries.
unsigned getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
unsigned getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
unsigned getTypePreferredAlignment(Type t) const
Returns the preferred of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t getNumElements() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
const InterfaceType * getInterfaceFor(Object *obj) const
Get the interface for a given object, or null if one is not registered.
This class provides support for representing a failure result, or a valid value of type T.
Base class for dialect interfaces providing translation to LLVM IR.
virtual LogicalResult convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to provide translation of the operations to LLVM IR.
virtual LogicalResult amendOperation(Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const
Acts on the given operation using the interface implemented by the dialect of one of the operation's ...
This class represents the base attribute for all debug info attributes.
Implementation class for module translation.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
llvm::NamedMDNode * getOrInsertNamedModuleMetadata(StringRef name)
Gets the named metadata in the LLVM IR module being constructed, creating it if it does not exist.
llvm::Instruction * lookupBranch(Operation *op) const
Finds an LLVM IR instruction that corresponds to the given MLIR operation with successors.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::DILocation * translateLoc(Location loc, llvm::DILocalScope *scope)
Translates the given location.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
void setBranchWeightsMetadata(BranchWeightOpInterface op)
Sets LLVM profiling metadata for operations that have branch weights.
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst)
Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::CallInst * lookupCall(Operation *op) const
Finds an LLVM call instruction that corresponds to the given MLIR call operation.
llvm::Metadata * translateDebugInfo(LLVM::DINodeAttr attr)
Translates the given LLVM debug info metadata.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
llvm::MDNode * getOrCreateAliasScopes(ArrayRef< AliasScopeAttr > aliasScopeAttrs)
Returns the LLVM metadata corresponding to an array of mlir LLVM dialect alias scope attributes.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::MDNode * getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr)
Returns the LLVM metadata corresponding to a mlir LLVM dialect alias scope attribute.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void setAliasScopeMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst)
void setAccessGroupsMetadata(AccessGroupOpInterface op, llvm::Instruction *inst)
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
void setLoopMetadata(Operation *op, llvm::Instruction *inst)
Sets LLVM loop metadata for branch operations that have a loop annotation attribute.
llvm::Type * translateType(Type type)
Translates the given MLIR LLVM dialect type to LLVM IR.
A helper class that converts LoopAnnotationAttrs and AccessGroupAttrs into corresponding llvm::MDNode...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
NamedAttribute represents a combination of a name and an Attribute value.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Attribute getDiscardableAttr(StringRef name)
Access a discardable attribute by name, returns an null Attribute if the discardable attribute does n...
Value getOperand(unsigned idx)
Block * getSuccessor(unsigned index)
unsigned getNumSuccessors()
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
dialect_attr_range getDialectAttrs()
Return a range corresponding to the dialect attributes for this operation.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
This class models how operands are forwarded to block arguments in control flow.
bool empty() const
Returns true if there are no successor operands.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
ArrayRef< T > asArrayRef() const
Include the generated interface declarations.
static llvm::ArrayRef< std::pair< llvm::Attribute::AttrKind, llvm::StringRef > > getAttrKindToNameMapping()
Returns a list of pairs that each hold a mapping from LLVM attribute kinds to their corresponding str...
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
SetVector< Block * > getTopologicallySortedBlocks(Region ®ion)
Get a topologically sorted list of blocks of the given region.
llvm::Constant * getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc, const ModuleTranslation &moduleTranslation)
Create an LLVM IR constant of llvmType from the MLIR attribute attr.
std::optional< unsigned > extractPointerSpecValue(Attribute attr, PtrDLEntryPos pos)
Returns the value that corresponds to named position pos from the data layout entry attr assuming it'...
bool satisfiesLLVMModule(Operation *op)
LLVM requires some operations to be inside of a Module operation.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void ensureDistinctSuccessors(Operation *op)
Make argument-taking successors of each block distinct.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index)
Returns the dictionary attribute corresponding to the argument at 'index'.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
DataLayoutSpecInterface translateDataLayout(const llvm::DataLayout &dataLayout, MLIRContext *context)
Translate the given LLVM data layout into an MLIR equivalent using the DLTI dialect.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< llvm::Module > translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, llvm::StringRef name="LLVMDialectModule")
Translate operation that satisfies LLVM dialect module requirements into an LLVM IR module living in ...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
This class represents an efficient way to signal success or failure.