37 #include "llvm/ADT/PostOrderIterator.h"
38 #include "llvm/ADT/SetVector.h"
39 #include "llvm/ADT/StringExtras.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Analysis/TargetFolder.h"
42 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
43 #include "llvm/IR/BasicBlock.h"
44 #include "llvm/IR/CFG.h"
45 #include "llvm/IR/Constants.h"
46 #include "llvm/IR/DerivedTypes.h"
47 #include "llvm/IR/IRBuilder.h"
48 #include "llvm/IR/InlineAsm.h"
49 #include "llvm/IR/IntrinsicsNVPTX.h"
50 #include "llvm/IR/LLVMContext.h"
51 #include "llvm/IR/MDBuilder.h"
52 #include "llvm/IR/Module.h"
53 #include "llvm/IR/Verifier.h"
54 #include "llvm/Support/Debug.h"
55 #include "llvm/Support/raw_ostream.h"
56 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
57 #include "llvm/Transforms/Utils/Cloning.h"
58 #include "llvm/Transforms/Utils/ModuleUtils.h"
62 #define DEBUG_TYPE "llvm-dialect-to-llvm-ir"
70 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
93 class InstructionCapturingInserter :
public llvm::IRBuilderCallbackInserter {
96 InstructionCapturingInserter()
97 : llvm::IRBuilderCallbackInserter([
this](llvm::Instruction *instruction) {
98 if (LLVM_LIKELY(enabled))
99 capturedInstructions.push_back(instruction);
104 return capturedInstructions;
108 void clearCapturedInstructions() { capturedInstructions.clear(); }
111 class CollectionScope {
114 CollectionScope(llvm::IRBuilderBase &irBuilder,
bool isBuilderCapturing);
122 return inserter->getCapturedInstructions();
127 InstructionCapturingInserter *inserter =
nullptr;
137 void setEnabled(
bool enabled =
true) { this->enabled = enabled; }
144 bool enabled =
false;
147 using CapturingIRBuilder =
148 llvm::IRBuilder<llvm::TargetFolder, InstructionCapturingInserter>;
151 InstructionCapturingInserter::CollectionScope::CollectionScope(
152 llvm::IRBuilderBase &irBuilder,
bool isBuilderCapturing) {
154 if (!isBuilderCapturing)
157 auto &capturingIRBuilder =
static_cast<CapturingIRBuilder &
>(irBuilder);
158 inserter = &capturingIRBuilder.getInserter();
159 wasEnabled = inserter->enabled;
161 previouslyCollectedInstructions.swap(inserter->capturedInstructions);
162 inserter->setEnabled(
true);
165 InstructionCapturingInserter::CollectionScope::~CollectionScope() {
169 previouslyCollectedInstructions.swap(inserter->capturedInstructions);
173 llvm::append_range(inserter->capturedInstructions,
174 previouslyCollectedInstructions);
176 inserter->setEnabled(wasEnabled);
181 static FailureOr<llvm::DataLayout>
184 std::optional<Location> loc = std::nullopt) {
189 std::string llvmDataLayout;
190 llvm::raw_string_ostream layoutStream(llvmDataLayout);
191 for (DataLayoutEntryInterface entry : attribute.getEntries()) {
192 auto key = llvm::dyn_cast_if_present<StringAttr>(entry.getKey());
195 if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) {
196 auto value = cast<StringAttr>(entry.getValue());
197 bool isLittleEndian =
198 value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle;
199 layoutStream <<
"-" << (isLittleEndian ?
"e" :
"E");
202 if (key.getValue() == DLTIDialect::kDataLayoutManglingModeKey) {
203 auto value = cast<StringAttr>(entry.getValue());
204 layoutStream <<
"-m:" << value.getValue();
207 if (key.getValue() == DLTIDialect::kDataLayoutProgramMemorySpaceKey) {
208 auto value = cast<IntegerAttr>(entry.getValue());
209 uint64_t space = value.getValue().getZExtValue();
213 layoutStream <<
"-P" << space;
216 if (key.getValue() == DLTIDialect::kDataLayoutGlobalMemorySpaceKey) {
217 auto value = cast<IntegerAttr>(entry.getValue());
218 uint64_t space = value.getValue().getZExtValue();
222 layoutStream <<
"-G" << space;
225 if (key.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey) {
226 auto value = cast<IntegerAttr>(entry.getValue());
227 uint64_t space = value.getValue().getZExtValue();
231 layoutStream <<
"-A" << space;
234 if (key.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey) {
235 auto value = cast<IntegerAttr>(entry.getValue());
236 uint64_t alignment = value.getValue().getZExtValue();
240 layoutStream <<
"-S" << alignment;
243 emitError(*loc) <<
"unsupported data layout key " << key;
250 for (DataLayoutEntryInterface entry : attribute.getEntries()) {
251 auto type = llvm::dyn_cast_if_present<Type>(entry.getKey());
255 if (isa<IndexType>(type))
258 LogicalResult result =
260 .Case<IntegerType, Float16Type, Float32Type, Float64Type,
261 Float80Type, Float128Type>([&](
Type type) -> LogicalResult {
262 if (
auto intType = dyn_cast<IntegerType>(type)) {
263 if (intType.getSignedness() != IntegerType::Signless)
265 <<
"unsupported data layout for non-signless integer "
275 layoutStream << size <<
":" << abi;
276 if (abi != preferred)
277 layoutStream <<
":" << preferred;
280 .Case([&](LLVMPointerType type) {
281 layoutStream <<
"p" << type.getAddressSpace() <<
":";
287 layoutStream << size <<
":" << abi <<
":" << preferred <<
":"
291 .Default([loc](
Type type) {
293 <<
"unsupported type in data layout: " << type;
298 StringRef layoutSpec(llvmDataLayout);
299 layoutSpec.consume_front(
"-");
301 return llvm::DataLayout(layoutSpec);
308 static llvm::Constant *
313 llvm::Constant *result = constants.front();
314 constants = constants.drop_front();
318 llvm::Type *elementType;
319 if (
auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
320 elementType = arrayTy->getElementType();
321 }
else if (
auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
322 elementType = vectorTy->getElementType();
324 emitError(loc) <<
"expected sequential LLVM types wrapping a scalar";
329 nested.reserve(shape.front());
330 for (int64_t i = 0; i < shape.front(); ++i) {
337 if (shape.size() == 1 && type->isVectorTy())
346 if (
auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
347 type = arrayTy->getElementType();
348 }
else if (
auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
349 type = vectorTy->getElementType();
363 static llvm::Constant *
365 llvm::Type *llvmType,
367 if (!denseElementsAttr)
371 if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType))
374 ShapedType type = denseElementsAttr.
getType();
375 if (type.getNumElements() == 0)
383 int64_t elementByteSize = denseElementsAttr.
getRawData().size() /
385 if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits())
390 bool hasVectorElementType = isa<VectorType>(type.getElementType());
391 int64_t numAggregates =
393 (hasVectorElementType ? 1
394 : denseElementsAttr.
getType().getShape().back());
396 if (!hasVectorElementType)
397 outerShape = outerShape.drop_back();
400 if (denseElementsAttr.
isSplat() &&
401 (isa<VectorType>(type) || hasVectorElementType)) {
405 llvm::Constant *splatVector =
406 llvm::ConstantDataVector::getSplat(0, splatValue);
411 if (denseElementsAttr.
isSplat())
416 std::function<llvm::Constant *(StringRef)> buildCstData;
417 if (isa<TensorType>(type)) {
418 auto vectorElementType = dyn_cast<VectorType>(type.getElementType());
419 if (vectorElementType && vectorElementType.getRank() == 1) {
420 buildCstData = [&](StringRef data) {
421 return llvm::ConstantDataVector::getRaw(
422 data, vectorElementType.getShape().back(), innermostLLVMType);
424 }
else if (!vectorElementType) {
425 buildCstData = [&](StringRef data) {
426 return llvm::ConstantDataArray::getRaw(data, type.getShape().back(),
430 }
else if (isa<VectorType>(type)) {
431 buildCstData = [&](StringRef data) {
432 return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
442 int64_t aggregateSize = denseElementsAttr.
getType().getShape().back() *
443 (innermostLLVMType->getScalarSizeInBits() / 8);
444 constants.reserve(numAggregates);
445 for (
unsigned i = 0; i < numAggregates; ++i) {
446 StringRef data(denseElementsAttr.
getRawData().data() + i * aggregateSize,
448 constants.push_back(buildCstData(data));
464 assert(denseResourceAttr &&
"expected non-null attribute");
467 if (!llvm::ConstantDataSequential::isElementTypeCompatible(
468 innermostLLVMType)) {
469 emitError(loc,
"no known conversion for innermost element type");
473 ShapedType type = denseResourceAttr.getType();
474 assert(type.getNumElements() > 0 &&
"Expected non-empty elements attribute");
478 emitError(loc,
"resource does not exist");
489 int64_t numElements = denseResourceAttr.getType().getNumElements();
490 int64_t elementByteSize = rawData.size() / numElements;
491 if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits()) {
492 emitError(loc,
"raw data size does not match element type size");
498 bool hasVectorElementType = isa<VectorType>(type.getElementType());
499 int64_t numAggregates =
500 numElements / (hasVectorElementType
502 : denseResourceAttr.getType().getShape().back());
504 if (!hasVectorElementType)
505 outerShape = outerShape.drop_back();
508 std::function<llvm::Constant *(StringRef)> buildCstData;
509 if (isa<TensorType>(type)) {
510 auto vectorElementType = dyn_cast<VectorType>(type.getElementType());
511 if (vectorElementType && vectorElementType.getRank() == 1) {
512 buildCstData = [&](StringRef data) {
513 return llvm::ConstantDataVector::getRaw(
514 data, vectorElementType.getShape().back(), innermostLLVMType);
516 }
else if (!vectorElementType) {
517 buildCstData = [&](StringRef data) {
518 return llvm::ConstantDataArray::getRaw(data, type.getShape().back(),
522 }
else if (isa<VectorType>(type)) {
523 buildCstData = [&](StringRef data) {
524 return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
529 emitError(loc,
"unsupported dense_resource type");
536 int64_t aggregateSize = denseResourceAttr.getType().getShape().back() *
537 (innermostLLVMType->getScalarSizeInBits() / 8);
538 constants.reserve(numAggregates);
539 for (
unsigned i = 0; i < numAggregates; ++i) {
540 StringRef data(rawData.data() + i * aggregateSize, aggregateSize);
541 constants.push_back(buildCstData(data));
558 if (
auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
559 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
561 emitError(loc,
"expected an array attribute for a struct constant");
565 structElements.reserve(structType->getNumElements());
566 for (
auto [elemType, elemAttr] :
567 zip_equal(structType->elements(), arrayAttr)) {
568 llvm::Constant *element =
572 structElements.push_back(element);
578 if (
auto intAttr = dyn_cast<IntegerAttr>(attr))
581 intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));
582 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
583 const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
588 unsigned floatWidth = APFloat::getSizeInBits(sem);
589 if (llvmType->isIntegerTy(floatWidth))
591 floatAttr.getValue().bitcastToAPInt());
593 llvm::Type::getFloatingPointTy(llvmType->getContext(),
594 floatAttr.getValue().getSemantics())) {
595 emitError(loc,
"FloatAttr does not match expected type of the constant");
600 if (
auto funcAttr = dyn_cast<FlatSymbolRefAttr>(attr))
601 return llvm::ConstantExpr::getBitCast(
603 if (
auto splatAttr = dyn_cast<SplatElementsAttr>(attr)) {
604 llvm::Type *elementType;
605 uint64_t numElements;
606 bool isScalable =
false;
607 if (
auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
608 elementType = arrayTy->getElementType();
609 numElements = arrayTy->getNumElements();
610 }
else if (
auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) {
611 elementType = fVectorTy->getElementType();
612 numElements = fVectorTy->getNumElements();
613 }
else if (
auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
614 elementType = sVectorTy->getElementType();
615 numElements = sVectorTy->getMinNumElements();
618 llvm_unreachable(
"unrecognized constant vector type");
623 bool elementTypeSequential =
624 isa<llvm::ArrayType, llvm::VectorType>(elementType);
627 elementTypeSequential ? splatAttr
629 loc, moduleTranslation);
632 if (llvmType->isVectorTy())
633 return llvm::ConstantVector::getSplat(
635 if (llvmType->isArrayTy()) {
637 if (child->isZeroValue()) {
640 if (llvm::ConstantDataSequential::isElementTypeCompatible(
643 if (isa<llvm::IntegerType>(elementType)) {
644 if (llvm::ConstantInt *ci = dyn_cast<llvm::ConstantInt>(child)) {
645 if (ci->getBitWidth() == 8) {
650 if (ci->getBitWidth() == 16) {
655 if (ci->getBitWidth() == 32) {
660 if (ci->getBitWidth() == 64) {
670 std::vector<llvm::Constant *> constants(numElements, child);
677 if (llvm::Constant *result =
679 llvmType, moduleTranslation)) {
683 if (
auto denseResourceAttr = dyn_cast<DenseResourceElementsAttr>(attr)) {
689 if (
auto elementsAttr = dyn_cast<ElementsAttr>(attr)) {
690 assert(elementsAttr.getShapedType().hasStaticShape());
691 assert(!elementsAttr.getShapedType().getShape().empty() &&
692 "unexpected empty elements attribute shape");
695 constants.reserve(elementsAttr.getNumElements());
697 for (
auto n : elementsAttr.getValues<
Attribute>()) {
700 if (!constants.back())
705 constantsRef, elementsAttr.getShapedType().getShape(), llvmType, loc);
706 assert(constantsRef.empty() &&
"did not consume all elemental constants");
710 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
714 stringAttr.getValue().size()});
716 emitError(loc,
"unsupported constant value");
720 ModuleTranslation::ModuleTranslation(
Operation *module,
721 std::unique_ptr<llvm::Module> llvmModule)
722 : mlirModule(module), llvmModule(std::move(llvmModule)),
726 *this, *this->llvmModule)),
727 typeTranslator(this->llvmModule->
getContext()),
730 "mlirModule should honor LLVM's module semantics.");
733 ModuleTranslation::~ModuleTranslation() {
735 ompBuilder->finalize();
740 toProcess.push_back(®ion);
741 while (!toProcess.empty()) {
742 Region *current = toProcess.pop_back_val();
743 for (
Block &block : *current) {
744 blockMapping.erase(&block);
745 for (
Value arg : block.getArguments())
746 valueMapping.erase(arg);
748 for (
Value value : op.getResults())
749 valueMapping.erase(value);
750 if (op.hasSuccessors())
751 branchMapping.erase(&op);
752 if (isa<LLVM::GlobalOp>(op))
753 globalsMapping.erase(&op);
754 if (isa<LLVM::AliasOp>(op))
755 aliasesMapping.erase(&op);
756 if (isa<LLVM::CallOp>(op))
757 callMapping.erase(&op);
760 llvm::map_range(op.getRegions(), [](
Region &r) { return &r; }));
769 unsigned numArguments,
unsigned index) {
771 if (isa<LLVM::BrOp>(terminator))
778 auto branch = cast<BranchOpInterface>(terminator);
781 (!seenSuccessors.contains(successor) || successorOperands.
empty()) &&
782 "successors with arguments in LLVM branches must be different blocks");
783 seenSuccessors.insert(successor);
789 if (
auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator)) {
792 return condBranchOp.getSuccessor(0) == current
793 ? condBranchOp.getTrueDestOperands()[index]
794 : condBranchOp.getFalseDestOperands()[index];
797 if (
auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) {
800 if (switchOp.getDefaultDestination() == current)
801 return switchOp.getDefaultOperands()[index];
803 if (i.value() == current)
804 return switchOp.getCaseOperands(i.index())[index];
807 if (
auto indBrOp = dyn_cast<LLVM::IndirectBrOp>(terminator)) {
810 if (indBrOp->getSuccessor(i.index()) == current)
811 return indBrOp.getSuccessorOperands(i.index())[index];
815 if (
auto invokeOp = dyn_cast<LLVM::InvokeOp>(terminator)) {
816 return invokeOp.getNormalDest() == current
817 ? invokeOp.getNormalDestOperands()[index]
818 : invokeOp.getUnwindDestOperands()[index];
822 "only branch, switch or invoke operations can be terminators "
823 "of a block that has successors");
831 for (
Block &bb : llvm::drop_begin(region)) {
832 llvm::BasicBlock *llvmBB = state.lookupBlock(&bb);
833 auto phis = llvmBB->phis();
834 auto numArguments = bb.getNumArguments();
835 assert(numArguments == std::distance(phis.begin(), phis.end()));
837 for (
auto *pred : bb.getPredecessors()) {
843 llvm::Instruction *terminator =
844 state.lookupBranch(pred->getTerminator());
845 assert(terminator &&
"missing the mapping for a terminator");
847 &bb, pred, numArguments, index)),
848 terminator->getParent());
857 llvm::Module *module = builder.GetInsertBlock()->getModule();
859 llvm::Intrinsic::getOrInsertDeclaration(module, intrinsic, tys);
860 return builder.CreateCall(fn, args);
869 assert(immArgPositions.size() == immArgAttrNames.size() &&
870 "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
874 size_t numOpBundleOperands = 0;
875 auto opBundleSizesAttr = cast_if_present<DenseI32ArrayAttr>(
876 intrOp->
getAttr(LLVMDialect::getOpBundleSizesAttrName()));
877 auto opBundleTagsAttr = cast_if_present<ArrayAttr>(
878 intrOp->
getAttr(LLVMDialect::getOpBundleTagsAttrName()));
880 if (opBundleSizesAttr && opBundleTagsAttr) {
881 ArrayRef<int> opBundleSizes = opBundleSizesAttr.asArrayRef();
882 assert(opBundleSizes.size() == opBundleTagsAttr.size() &&
883 "operand bundles and tags do not match");
885 numOpBundleOperands =
886 std::accumulate(opBundleSizes.begin(), opBundleSizes.end(),
size_t(0));
887 assert(numOpBundleOperands <= intrOp->getNumOperands() &&
888 "operand bundle operands is more than the number of operands");
891 size_t nextOperandIdx = 0;
892 opBundles.reserve(opBundleSizesAttr.size());
894 for (
auto [opBundleTagAttr, bundleSize] :
895 llvm::zip(opBundleTagsAttr, opBundleSizes)) {
896 auto bundleTag = cast<StringAttr>(opBundleTagAttr).str();
898 operands.slice(nextOperandIdx, bundleSize));
899 opBundles.emplace_back(std::move(bundleTag), std::move(bundleOperands));
900 nextOperandIdx += bundleSize;
905 auto opOperands = intrOp->
getOperands().drop_back(numOpBundleOperands);
906 auto operands = moduleTranslation.
lookupValues(opOperands);
908 for (
auto [immArgPos, immArgName] :
909 llvm::zip(immArgPositions, immArgAttrNames)) {
910 auto attr = llvm::cast<TypedAttr>(intrOp->
getAttr(immArgName));
911 assert(attr.getType().isIntOrFloat() &&
"expected int or float immarg");
912 auto *type = moduleTranslation.
convertType(attr.getType());
914 type, attr, intrOp->
getLoc(), moduleTranslation);
917 for (
auto &arg : args) {
919 arg = operands[opArg++];
924 for (
unsigned overloadedResultIdx : overloadedResults) {
925 if (numResults > 1) {
927 overloadedTypes.push_back(moduleTranslation.
convertType(
929 .getBody()[overloadedResultIdx]));
931 overloadedTypes.push_back(
935 for (
unsigned overloadedOperandIdx : overloadedOperands)
936 overloadedTypes.push_back(args[overloadedOperandIdx]->
getType());
937 llvm::Module *module = builder.GetInsertBlock()->getModule();
938 llvm::Function *llvmIntr = llvm::Intrinsic::getOrInsertDeclaration(
939 module, intrinsic, overloadedTypes);
941 return builder.CreateCall(llvmIntr, args, opBundles);
946 LogicalResult ModuleTranslation::convertOperation(
Operation &op,
947 llvm::IRBuilderBase &builder,
948 bool recordInsertions) {
951 return op.
emitError(
"cannot be converted to LLVM IR: missing "
952 "`LLVMTranslationDialectInterface` registration for "
956 InstructionCapturingInserter::CollectionScope scope(builder,
959 return op.
emitError(
"LLVM Translation failed for operation: ")
962 return convertDialectAttributes(&op, scope.getCapturedInstructions());
972 LogicalResult ModuleTranslation::convertBlockImpl(
Block &bb,
973 bool ignoreArguments,
974 llvm::IRBuilderBase &builder,
975 bool recordInsertions) {
977 auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram();
985 if (!ignoreArguments) {
987 unsigned numPredecessors =
988 std::distance(predecessors.begin(), predecessors.end());
990 auto wrappedType = arg.getType();
993 "block argument does not have an LLVM type");
994 builder.SetCurrentDebugLocation(
995 debugTranslation->translateLoc(arg.getLoc(), subprogram));
997 llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
1003 for (
auto &op : bb) {
1005 builder.SetCurrentDebugLocation(
1006 debugTranslation->translateLoc(op.
getLoc(), subprogram));
1008 if (failed(convertOperation(op, builder, recordInsertions)))
1012 if (
auto iface = dyn_cast<BranchWeightOpInterface>(op))
1031 llvm::Constant *cst) {
1032 return (linkage == llvm::GlobalVariable::ExternalLinkage && !cst) ||
1033 linkage == llvm::GlobalVariable::ExternalWeakLinkage;
1039 llvm::GlobalValue *gv) {
1040 if (dsoLocalRequested)
1041 gv->setDSOLocal(
true);
1044 LogicalResult ModuleTranslation::convertGlobalsAndAliases() {
1053 for (
auto op :
getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
1055 llvm::Constant *cst =
nullptr;
1056 if (op.getValueOrNull()) {
1059 if (
auto strAttr = dyn_cast_or_null<StringAttr>(op.getValueOrNull())) {
1060 cst = llvm::ConstantDataArray::getString(
1061 llvmModule->getContext(), strAttr.getValue(),
false);
1062 type = cst->getType();
1069 auto linkage = convertLinkageToLLVM(op.getLinkage());
1075 if (!dropInitializer && !cst)
1077 else if (dropInitializer && cst)
1080 auto *var =
new llvm::GlobalVariable(
1081 *llvmModule, type, op.getConstant(), linkage, cst, op.getSymName(),
1083 op.getThreadLocal_() ? llvm::GlobalValue::GeneralDynamicTLSModel
1084 : llvm::GlobalValue::NotThreadLocal,
1085 op.getAddrSpace(), op.getExternallyInitialized());
1087 if (std::optional<mlir::SymbolRefAttr> comdat = op.getComdat()) {
1088 auto selectorOp = cast<ComdatSelectorOp>(
1090 var->setComdat(comdatMapping.lookup(selectorOp));
1093 if (op.getUnnamedAddr().has_value())
1094 var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr()));
1096 if (op.getSection().has_value())
1097 var->setSection(*op.getSection());
1101 std::optional<uint64_t> alignment = op.getAlignment();
1102 if (alignment.has_value())
1103 var->setAlignment(llvm::MaybeAlign(alignment.value()));
1105 var->setVisibility(convertVisibilityToLLVM(op.getVisibility_()));
1107 globalsMapping.try_emplace(op, var);
1110 if (op.getDbgExprs()) {
1111 for (
auto exprAttr :
1112 op.getDbgExprs()->getAsRange<DIGlobalVariableExpressionAttr>()) {
1113 llvm::DIGlobalVariableExpression *diGlobalExpr =
1114 debugTranslation->translateGlobalVariableExpression(exprAttr);
1115 llvm::DIGlobalVariable *diGlobalVar = diGlobalExpr->getVariable();
1116 var->addDebugInfo(diGlobalExpr);
1135 llvm::DIScope *scope = diGlobalVar->getScope();
1136 if (
auto *mod = dyn_cast_if_present<llvm::DIModule>(scope))
1137 scope = mod->getScope();
1138 else if (
auto *cb = dyn_cast_if_present<llvm::DICommonBlock>(scope)) {
1140 dyn_cast_if_present<llvm::DISubprogram>(cb->getScope()))
1141 scope = sp->getUnit();
1142 }
else if (
auto *sp = dyn_cast_if_present<llvm::DISubprogram>(scope))
1143 scope = sp->getUnit();
1146 if (llvm::DICompileUnit *compileUnit =
1147 dyn_cast_if_present<llvm::DICompileUnit>(scope)) {
1150 allGVars[compileUnit].push_back(diGlobalExpr);
1157 for (
auto op :
getModuleBody(mlirModule).getOps<LLVM::AliasOp>()) {
1159 llvm::Constant *cst =
nullptr;
1160 llvm::GlobalValue::LinkageTypes linkage =
1161 convertLinkageToLLVM(op.getLinkage());
1162 llvm::Module &llvmMod = *llvmModule;
1165 llvm::GlobalAlias *var = llvm::GlobalAlias::create(
1166 type, op.getAddrSpace(), linkage, op.getSymName(), cst,
1169 var->setThreadLocalMode(op.getThreadLocal_()
1170 ? llvm::GlobalAlias::GeneralDynamicTLSModel
1171 : llvm::GlobalAlias::NotThreadLocal);
1176 if (op.getUnnamedAddr().has_value())
1177 var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr()));
1179 var->setVisibility(convertVisibilityToLLVM(op.getVisibility_()));
1181 aliasesMapping.try_emplace(op, var);
1185 for (
auto op :
getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
1186 if (
Block *initializer = op.getInitializerBlock()) {
1187 llvm::IRBuilder<llvm::TargetFolder> builder(
1188 llvmModule->getContext(),
1189 llvm::TargetFolder(llvmModule->getDataLayout()));
1191 [[maybe_unused]]
int numConstantsHit = 0;
1192 [[maybe_unused]]
int numConstantsErased = 0;
1195 for (
auto &op : initializer->without_terminator()) {
1196 if (failed(convertOperation(op, builder)))
1209 if (
auto *agg = dyn_cast<llvm::ConstantAggregate>(cst)) {
1213 auto [iterator, inserted] =
1214 constantAggregateUseMap.try_emplace(agg, numUsers);
1217 iterator->second += numUsers;
1223 auto cst = dyn_cast<llvm::ConstantAggregate>(
lookupValue(v));
1226 auto iter = constantAggregateUseMap.find(cst);
1227 assert(iter != constantAggregateUseMap.end() &&
"constant not found");
1229 if (iter->second == 0) {
1232 if (cst->user_empty()) {
1233 cst->destroyConstant();
1234 numConstantsErased++;
1236 constantAggregateUseMap.erase(iter);
1241 ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
1242 llvm::Constant *cst =
1243 cast<llvm::Constant>(
lookupValue(ret.getOperand(0)));
1244 auto *global = cast<llvm::GlobalVariable>(
lookupGlobal(op));
1246 global->setInitializer(cst);
1250 for (
auto it : constantAggregateUseMap) {
1251 auto cst = it.first;
1252 cst->removeDeadConstantUsers();
1253 if (cst->user_empty()) {
1254 cst->destroyConstant();
1255 numConstantsErased++;
1259 LLVM_DEBUG(llvm::dbgs()
1260 <<
"Convert initializer for " << op.
getName() <<
"\n";
1261 llvm::dbgs() << numConstantsHit <<
" new constants hit\n";
1263 << numConstantsErased <<
" dangling constants erased\n";);
1269 auto ctorOp = dyn_cast<GlobalCtorsOp>(op);
1270 auto dtorOp = dyn_cast<GlobalDtorsOp>(op);
1271 if (!ctorOp && !dtorOp)
1277 if ((ctorOp && ctorOp.getCtors().empty()) ||
1278 (dtorOp && dtorOp.getDtors().empty())) {
1279 llvm::IRBuilder<llvm::TargetFolder> builder(
1280 llvmModule->getContext(),
1281 llvm::TargetFolder(llvmModule->getDataLayout()));
1283 builder.getInt32Ty(), builder.getPtrTy(), builder.getPtrTy());
1285 llvm::Constant *zeroInit = llvm::Constant::getNullValue(at);
1286 (void)
new llvm::GlobalVariable(
1287 *llvmModule, zeroInit->getType(),
false,
1288 llvm::GlobalValue::AppendingLinkage, zeroInit,
1289 ctorOp ?
"llvm.global_ctors" :
"llvm.global_dtors");
1292 ? llvm::zip(ctorOp.getCtors(), ctorOp.getPriorities())
1293 :
llvm::zip(dtorOp.getDtors(), dtorOp.getPriorities());
1294 auto appendGlobalFn =
1295 ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors;
1296 for (
const auto &[sym, prio] : range) {
1299 appendGlobalFn(*llvmModule, f, cast<IntegerAttr>(prio).getInt(),
1305 for (
auto op :
getModuleBody(mlirModule).getOps<LLVM::GlobalOp>())
1306 if (failed(convertDialectAttributes(op, {})))
1311 for (
const auto &[compileUnit, globals] : allGVars) {
1312 compileUnit->replaceGlobalVariables(
1317 for (
auto op :
getModuleBody(mlirModule).getOps<LLVM::AliasOp>()) {
1318 Block &initializer = op.getInitializerBlock();
1319 llvm::IRBuilder<llvm::TargetFolder> builder(
1320 llvmModule->getContext(),
1321 llvm::TargetFolder(llvmModule->getDataLayout()));
1324 if (failed(convertOperation(op, builder)))
1331 auto *cst = cast<llvm::Constant>(
lookupValue(ret.getOperand(0)));
1332 assert(aliasesMapping.count(op));
1333 auto *alias = cast<llvm::GlobalAlias>(aliasesMapping[op]);
1334 alias->setAliasee(cst);
1337 for (
auto op :
getModuleBody(mlirModule).getOps<LLVM::AliasOp>())
1338 if (failed(convertDialectAttributes(op, {})))
1351 llvm::Function *llvmFunc,
1353 StringRef value = StringRef()) {
1354 auto kind = llvm::Attribute::getAttrKindFromName(key);
1356 llvmFunc->addFnAttr(key, value);
1360 if (llvm::Attribute::isIntAttrKind(
kind)) {
1362 return emitError(loc) <<
"LLVM attribute '" << key <<
"' expects a value";
1365 if (!value.getAsInteger(0, result))
1366 llvmFunc->addFnAttr(
1369 llvmFunc->addFnAttr(key, value);
1374 return emitError(loc) <<
"LLVM attribute '" << key
1375 <<
"' does not expect a value, found '" << value
1378 llvmFunc->addFnAttr(
kind);
1384 const llvm::APInt &value) {
1391 const llvm::APInt &value) {
1399 llvm::Metadata *typeMD =
1401 llvm::Metadata *isSignedMD =
1411 values, std::back_inserter(mdValues), [&context](int32_t value) {
1424 static LogicalResult
1426 llvm::Function *llvmFunc) {
1431 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
1438 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
1439 if (!arrayAttr || arrayAttr.size() != 2)
1441 <<
"expected 'passthrough' to contain string or array attributes";
1443 auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]);
1444 auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]);
1445 if (!keyAttr || !valueAttr)
1447 <<
"expected arrays within 'passthrough' to contain two strings";
1450 valueAttr.getValue())))
1456 LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
1459 blockMapping.clear();
1460 valueMapping.clear();
1461 branchMapping.clear();
1465 for (
auto [mlirArg, llvmArg] :
1466 llvm::zip(func.getArguments(), llvmFunc->args()))
1470 if (func.getPersonality()) {
1471 llvm::Type *ty = llvm::PointerType::getUnqual(llvmFunc->getContext());
1472 if (llvm::Constant *pfunc =
getLLVMConstant(ty, func.getPersonalityAttr(),
1473 func.getLoc(), *
this))
1474 llvmFunc->setPersonalityFn(pfunc);
1477 if (std::optional<StringRef> section = func.getSection())
1478 llvmFunc->setSection(*section);
1480 if (func.getArmStreaming())
1481 llvmFunc->addFnAttr(
"aarch64_pstate_sm_enabled");
1482 else if (func.getArmLocallyStreaming())
1483 llvmFunc->addFnAttr(
"aarch64_pstate_sm_body");
1484 else if (func.getArmStreamingCompatible())
1485 llvmFunc->addFnAttr(
"aarch64_pstate_sm_compatible");
1487 if (func.getArmNewZa())
1488 llvmFunc->addFnAttr(
"aarch64_new_za");
1489 else if (func.getArmInZa())
1490 llvmFunc->addFnAttr(
"aarch64_in_za");
1491 else if (func.getArmOutZa())
1492 llvmFunc->addFnAttr(
"aarch64_out_za");
1493 else if (func.getArmInoutZa())
1494 llvmFunc->addFnAttr(
"aarch64_inout_za");
1495 else if (func.getArmPreservesZa())
1496 llvmFunc->addFnAttr(
"aarch64_preserves_za");
1498 if (
auto targetCpu = func.getTargetCpu())
1499 llvmFunc->addFnAttr(
"target-cpu", *targetCpu);
1501 if (
auto tuneCpu = func.getTuneCpu())
1502 llvmFunc->addFnAttr(
"tune-cpu", *tuneCpu);
1504 if (
auto attr = func.getVscaleRange())
1505 llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
1507 attr->getMaxRange().getInt()));
1509 if (
auto unsafeFpMath = func.getUnsafeFpMath())
1510 llvmFunc->addFnAttr(
"unsafe-fp-math", llvm::toStringRef(*unsafeFpMath));
1512 if (
auto noInfsFpMath = func.getNoInfsFpMath())
1513 llvmFunc->addFnAttr(
"no-infs-fp-math", llvm::toStringRef(*noInfsFpMath));
1515 if (
auto noNansFpMath = func.getNoNansFpMath())
1516 llvmFunc->addFnAttr(
"no-nans-fp-math", llvm::toStringRef(*noNansFpMath));
1518 if (
auto approxFuncFpMath = func.getApproxFuncFpMath())
1519 llvmFunc->addFnAttr(
"approx-func-fp-math",
1520 llvm::toStringRef(*approxFuncFpMath));
1522 if (
auto noSignedZerosFpMath = func.getNoSignedZerosFpMath())
1523 llvmFunc->addFnAttr(
"no-signed-zeros-fp-math",
1524 llvm::toStringRef(*noSignedZerosFpMath));
1526 if (
auto denormalFpMath = func.getDenormalFpMath())
1527 llvmFunc->addFnAttr(
"denormal-fp-math", *denormalFpMath);
1529 if (
auto denormalFpMathF32 = func.getDenormalFpMathF32())
1530 llvmFunc->addFnAttr(
"denormal-fp-math-f32", *denormalFpMathF32);
1532 if (
auto fpContract = func.getFpContract())
1533 llvmFunc->addFnAttr(
"fp-contract", *fpContract);
1535 if (
auto instrumentFunctionEntry = func.getInstrumentFunctionEntry())
1536 llvmFunc->addFnAttr(
"instrument-function-entry", *instrumentFunctionEntry);
1538 if (
auto instrumentFunctionExit = func.getInstrumentFunctionExit())
1539 llvmFunc->addFnAttr(
"instrument-function-exit", *instrumentFunctionExit);
1542 llvm::LLVMContext &llvmContext = llvmFunc->getContext();
1543 for (
auto &bb : func) {
1544 auto *llvmBB = llvm::BasicBlock::Create(llvmContext);
1545 llvmBB->insertInto(llvmFunc);
1552 for (
Block *bb : blocks) {
1553 CapturingIRBuilder builder(llvmContext,
1554 llvm::TargetFolder(llvmModule->getDataLayout()));
1555 if (failed(convertBlockImpl(*bb, bb->isEntryBlock(), builder,
1565 return convertDialectAttributes(func, {});
1568 LogicalResult ModuleTranslation::convertDialectAttributes(
1571 if (failed(iface.
amendOperation(op, instructions, attribute, *
this)))
1579 llvm::Function *llvmFunc) {
1580 if (!func.getMemoryEffects())
1583 MemoryEffectsAttr memEffects = func.getMemoryEffectsAttr();
1586 llvm::MemoryEffects newMemEffects =
1587 llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
1588 convertModRefInfoToLLVM(memEffects.getArgMem()));
1589 newMemEffects |= llvm::MemoryEffects(
1590 llvm::MemoryEffects::Location::InaccessibleMem,
1591 convertModRefInfoToLLVM(memEffects.getInaccessibleMem()));
1593 llvm::MemoryEffects(llvm::MemoryEffects::Location::Other,
1594 convertModRefInfoToLLVM(memEffects.getOther()));
1595 llvmFunc->setMemoryEffects(newMemEffects);
1600 llvm::Function *llvmFunc) {
1601 if (func.getNoInlineAttr())
1602 llvmFunc->addFnAttr(llvm::Attribute::NoInline);
1603 if (func.getAlwaysInlineAttr())
1604 llvmFunc->addFnAttr(llvm::Attribute::AlwaysInline);
1605 if (func.getOptimizeNoneAttr())
1606 llvmFunc->addFnAttr(llvm::Attribute::OptimizeNone);
1607 if (func.getConvergentAttr())
1608 llvmFunc->addFnAttr(llvm::Attribute::Convergent);
1609 if (func.getNoUnwindAttr())
1610 llvmFunc->addFnAttr(llvm::Attribute::NoUnwind);
1611 if (func.getWillReturnAttr())
1612 llvmFunc->addFnAttr(llvm::Attribute::WillReturn);
1613 if (TargetFeaturesAttr targetFeatAttr = func.getTargetFeaturesAttr())
1614 llvmFunc->addFnAttr(
"target-features", targetFeatAttr.getFeaturesString());
1615 if (FramePointerKindAttr fpAttr = func.getFramePointerAttr())
1616 llvmFunc->addFnAttr(
"frame-pointer", stringifyFramePointerKind(
1617 fpAttr.getFramePointerKind()));
1618 if (UWTableKindAttr uwTableKindAttr = func.getUwtableKindAttr())
1619 llvmFunc->setUWTableKind(
1620 convertUWTableKindToLLVM(uwTableKindAttr.getUwtableKind()));
1626 llvm::Function *llvmFunc,
1628 llvm::LLVMContext &llvmContext = llvmFunc->getContext();
1630 if (VecTypeHintAttr vecTypeHint = func.getVecTypeHintAttr()) {
1631 Type type = vecTypeHint.getHint().getValue();
1632 llvm::Type *llvmType = translation.
convertType(type);
1633 bool isSigned = vecTypeHint.getIsSigned();
1634 llvmFunc->setMetadata(
1635 func.getVecTypeHintAttrName(),
1640 func.getWorkGroupSizeHint()) {
1641 llvmFunc->setMetadata(
1642 func.getWorkGroupSizeHintAttrName(),
1647 func.getReqdWorkGroupSize()) {
1648 llvmFunc->setMetadata(
1649 func.getReqdWorkGroupSizeAttrName(),
1653 if (std::optional<uint32_t> intelReqdSubGroupSize =
1654 func.getIntelReqdSubGroupSize()) {
1655 llvmFunc->setMetadata(
1656 func.getIntelReqdSubGroupSizeAttrName(),
1658 llvm::APInt(32, *intelReqdSubGroupSize)));
1663 llvm::Attribute::AttrKind llvmKind,
1668 .Case<TypeAttr>([&](
auto typeAttr) {
1669 attrBuilder.addTypeAttr(
1670 llvmKind, moduleTranslation.
convertType(typeAttr.getValue()));
1673 .Case<IntegerAttr>([&](
auto intAttr) {
1674 attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
1677 .Case<UnitAttr>([&](
auto) {
1678 attrBuilder.addAttribute(llvmKind);
1681 .Case<LLVM::ConstantRangeAttr>([&](
auto rangeAttr) {
1682 attrBuilder.addConstantRangeAttr(
1684 llvm::ConstantRange(rangeAttr.getLower(), rangeAttr.getUpper()));
1687 .Default([loc](
auto) {
1688 return emitError(loc,
"unsupported parameter attribute type");
1692 FailureOr<llvm::AttrBuilder>
1694 DictionaryAttr paramAttrs) {
1695 llvm::AttrBuilder attrBuilder(llvmModule->getContext());
1699 for (
auto namedAttr : paramAttrs) {
1700 auto it = attrNameToKindMapping.find(namedAttr.getName());
1701 if (it != attrNameToKindMapping.end()) {
1702 llvm::Attribute::AttrKind llvmKind = it->second;
1706 }
else if (namedAttr.getNameDialect()) {
1715 FailureOr<llvm::AttrBuilder>
1717 DictionaryAttr paramAttrs) {
1718 llvm::AttrBuilder attrBuilder(llvmModule->getContext());
1721 for (
auto namedAttr : paramAttrs) {
1722 auto it = attrNameToKindMapping.find(namedAttr.getName());
1723 if (it != attrNameToKindMapping.end()) {
1724 llvm::Attribute::AttrKind llvmKind = it->second;
1734 LogicalResult ModuleTranslation::convertFunctionSignatures() {
1737 for (
auto function :
getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
1738 llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
1740 cast<llvm::FunctionType>(
convertType(
function.getFunctionType())));
1741 llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee());
1742 llvmFunc->setLinkage(convertLinkageToLLVM(
function.getLinkage()));
1743 llvmFunc->setCallingConv(convertCConvToLLVM(
function.getCConv()));
1754 if (std::optional<uint64_t> entryCount =
function.getFunctionEntryCount())
1755 llvmFunc->setEntryCount(entryCount.value());
1758 if (ArrayAttr allResultAttrs =
function.getAllResultAttrs()) {
1759 DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]);
1760 FailureOr<llvm::AttrBuilder> attrBuilder =
1762 if (failed(attrBuilder))
1764 llvmFunc->addRetAttrs(*attrBuilder);
1770 FailureOr<llvm::AttrBuilder> attrBuilder =
1772 if (failed(attrBuilder))
1774 llvmArg.addAttrs(*attrBuilder);
1780 function.getLoc(),
function.getPassthrough(), llvmFunc)))
1784 llvmFunc->setVisibility(convertVisibilityToLLVM(
function.getVisibility_()));
1787 if (std::optional<mlir::SymbolRefAttr> comdat =
function.getComdat()) {
1788 auto selectorOp = cast<ComdatSelectorOp>(
1790 llvmFunc->setComdat(comdatMapping.lookup(selectorOp));
1793 if (
auto gc =
function.getGarbageCollector())
1794 llvmFunc->setGC(gc->str());
1796 if (
auto unnamedAddr =
function.getUnnamedAddr())
1797 llvmFunc->setUnnamedAddr(convertUnnamedAddrToLLVM(*unnamedAddr));
1799 if (
auto alignment =
function.getAlignment())
1800 llvmFunc->setAlignment(llvm::MaybeAlign(*alignment));
1803 debugTranslation->translate(
function, *llvmFunc);
1809 LogicalResult ModuleTranslation::convertFunctions() {
1811 for (
auto function :
getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
1814 if (
function.isExternal()) {
1815 if (failed(convertDialectAttributes(
function, {})))
1820 if (failed(convertOneFunction(
function)))
1827 LogicalResult ModuleTranslation::convertComdats() {
1828 for (
auto comdatOp :
getModuleBody(mlirModule).getOps<ComdatOp>()) {
1829 for (
auto selectorOp : comdatOp.getOps<ComdatSelectorOp>()) {
1831 if (module->getComdatSymbolTable().contains(selectorOp.getSymName()))
1833 <<
"comdat selection symbols must be unique even in different "
1835 llvm::Comdat *comdat = module->getOrInsertComdat(selectorOp.getSymName());
1836 comdat->setSelectionKind(convertComdatToLLVM(selectorOp.getComdat()));
1837 comdatMapping.try_emplace(selectorOp, comdat);
1843 LogicalResult ModuleTranslation::convertUnresolvedBlockAddress() {
1844 for (
auto &[blockAddressOp, llvmCst] : unresolvedBlockAddressMapping) {
1845 BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr();
1847 assert(llvmBlock &&
"expected LLVM blocks to be already translated");
1851 lookupFunction(blockAddressAttr.getFunction().getValue()), llvmBlock);
1852 llvmCst->replaceAllUsesWith(llvmBlockAddr);
1853 assert(llvmCst->use_empty() &&
"expected all uses to be replaced");
1854 cast<llvm::GlobalVariable>(llvmCst)->eraseFromParent();
1856 unresolvedBlockAddressMapping.clear();
1861 llvm::Instruction *inst) {
1862 if (llvm::MDNode *node = loopAnnotationTranslation->getAccessGroups(op))
1863 inst->setMetadata(llvm::LLVMContext::MD_access_group, node);
1868 auto [scopeIt, scopeInserted] =
1869 aliasScopeMetadataMapping.try_emplace(aliasScopeAttr,
nullptr);
1871 return scopeIt->second;
1872 llvm::LLVMContext &ctx = llvmModule->getContext();
1873 auto dummy = llvm::MDNode::getTemporary(ctx, std::nullopt);
1875 auto [domainIt, insertedDomain] = aliasDomainMetadataMapping.try_emplace(
1876 aliasScopeAttr.getDomain(),
nullptr);
1877 if (insertedDomain) {
1880 operands.push_back(dummy.get());
1881 if (StringAttr description = aliasScopeAttr.getDomain().getDescription())
1885 llvm::Metadata *replacement;
1886 if (
auto stringAttr =
1887 dyn_cast<StringAttr>(aliasScopeAttr.getDomain().getId()))
1890 replacement = domainIt->second;
1891 domainIt->second->replaceOperandWith(0, replacement);
1894 assert(domainIt->second &&
"Scope's domain should already be valid");
1897 operands.push_back(dummy.get());
1898 operands.push_back(domainIt->second);
1899 if (StringAttr description = aliasScopeAttr.getDescription())
1903 llvm::Metadata *replacement;
1904 if (
auto stringAttr = dyn_cast<StringAttr>(aliasScopeAttr.getId()))
1907 replacement = scopeIt->second;
1908 scopeIt->second->replaceOperandWith(0, replacement);
1909 return scopeIt->second;
1915 nodes.reserve(aliasScopeAttrs.size());
1916 for (AliasScopeAttr aliasScopeAttr : aliasScopeAttrs)
1922 llvm::Instruction *inst) {
1923 auto populateScopeMetadata = [&](ArrayAttr aliasScopeAttrs,
unsigned kind) {
1924 if (!aliasScopeAttrs || aliasScopeAttrs.empty())
1927 llvm::to_vector(aliasScopeAttrs.getAsRange<AliasScopeAttr>()));
1928 inst->setMetadata(
kind, node);
1931 populateScopeMetadata(op.getAliasScopesOrNull(),
1932 llvm::LLVMContext::MD_alias_scope);
1933 populateScopeMetadata(op.getNoAliasScopesOrNull(),
1934 llvm::LLVMContext::MD_noalias);
1937 llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr)
const {
1938 return tbaaMetadataMapping.lookup(tbaaAttr);
1942 llvm::Instruction *inst) {
1943 ArrayAttr tagRefs = op.getTBAATagsOrNull();
1944 if (!tagRefs || tagRefs.empty())
1951 if (tagRefs.size() > 1) {
1952 op.emitWarning() <<
"TBAA access tags were not translated, because LLVM "
1953 "IR only supports a single tag per instruction";
1957 llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0]));
1958 inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
1962 DereferenceableOpInterface op, llvm::Instruction *inst) {
1963 DereferenceableAttr derefAttr = op.getDereferenceableOrNull();
1971 unsigned kindId = derefAttr.getMayBeNull()
1972 ? llvm::LLVMContext::MD_dereferenceable_or_null
1973 : llvm::LLVMContext::MD_dereferenceable;
1974 inst->setMetadata(kindId, derefSizeNode);
1983 assert(inst &&
"expected the operation to have a mapping to an instruction");
1986 llvm::LLVMContext::MD_prof,
1990 LogicalResult ModuleTranslation::createTBAAMetadata() {
1991 llvm::LLVMContext &ctx = llvmModule->getContext();
2003 walker.
addWalk([&](TBAARootAttr root) {
2004 tbaaMetadataMapping.insert(
2008 walker.
addWalk([&](TBAATypeDescriptorAttr descriptor) {
2011 for (TBAAMemberAttr member : descriptor.getMembers()) {
2012 operands.push_back(tbaaMetadataMapping.lookup(member.getTypeDesc()));
2020 walker.
addWalk([&](TBAATagAttr tag) {
2023 operands.push_back(tbaaMetadataMapping.lookup(tag.getBaseType()));
2024 operands.push_back(tbaaMetadataMapping.lookup(tag.getAccessType()));
2028 if (tag.getConstant())
2035 mlirModule->
walk([&](AliasAnalysisOpInterface analysisOpInterface) {
2036 if (
auto attr = analysisOpInterface.getTBAATagsOrNull())
2043 LogicalResult ModuleTranslation::createIdentMetadata() {
2045 LLVMDialect::getIdentAttrName())) {
2046 StringRef ident = attr;
2047 llvm::LLVMContext &ctx = llvmModule->getContext();
2048 llvm::NamedMDNode *namedMd =
2049 llvmModule->getOrInsertNamedMetadata(LLVMDialect::getIdentAttrName());
2051 namedMd->addOperand(md);
2057 LogicalResult ModuleTranslation::createCommandlineMetadata() {
2059 LLVMDialect::getCommandlineAttrName())) {
2060 StringRef cmdLine = attr;
2061 llvm::LLVMContext &ctx = llvmModule->getContext();
2062 llvm::NamedMDNode *nmd = llvmModule->getOrInsertNamedMetadata(
2063 LLVMDialect::getCommandlineAttrName());
2066 nmd->addOperand(md);
2072 LogicalResult ModuleTranslation::createDependentLibrariesMetadata() {
2074 LLVM::LLVMDialect::getDependentLibrariesAttrName())) {
2076 llvmModule->getOrInsertNamedMetadata(
"llvm.dependent-libraries");
2077 llvm::LLVMContext &ctx = llvmModule->getContext();
2079 cast<ArrayAttr>(dependentLibrariesAttr).getAsRange<StringAttr>()) {
2082 nmd->addOperand(md);
2089 llvm::Instruction *inst) {
2090 LoopAnnotationAttr attr =
2092 .Case<LLVM::BrOp, LLVM::CondBrOp>(
2093 [](
auto branchOp) {
return branchOp.getLoopAnnotationAttr(); });
2096 llvm::MDNode *loopMD =
2097 loopAnnotationTranslation->translateLoopAnnotation(attr, op);
2098 inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
2102 auto iface = cast<DisjointFlagInterface>(op);
2104 if (
auto disjointInst = dyn_cast<llvm::PossiblyDisjointInst>(value))
2105 disjointInst->setIsDisjoint(iface.getIsDisjoint());
2115 remapped.reserve(values.size());
2116 for (
Value v : values)
2123 ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(*llvmModule);
2124 ompBuilder->initialize();
2129 ompBuilder->setConfig(llvm::OpenMPIRBuilderConfig(
2137 return ompBuilder.get();
2141 llvm::DILocalScope *scope) {
2142 return debugTranslation->translateLoc(loc, scope);
2145 llvm::DIExpression *
2147 return debugTranslation->translateExpression(attr);
2150 llvm::DIGlobalVariableExpression *
2152 LLVM::DIGlobalVariableExpressionAttr attr) {
2153 return debugTranslation->translateGlobalVariableExpression(attr);
2157 return debugTranslation->translate(attr);
2162 return convertRoundingModeToLLVM(rounding);
2166 LLVM::FPExceptionBehavior exceptionBehavior) {
2167 return convertFPExceptionBehaviorToLLVM(exceptionBehavior);
2172 return llvmModule->getOrInsertNamedMetadata(name);
2175 void ModuleTranslation::StackFrame::anchor() {}
2177 static std::unique_ptr<llvm::Module>
2180 m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>();
2181 auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext);
2184 llvmModule->setNewDbgInfoFormatFlag(
false);
2185 if (
auto dataLayoutAttr =
2186 m->getDiscardableAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) {
2187 llvmModule->setDataLayout(cast<StringAttr>(dataLayoutAttr).getValue());
2189 FailureOr<llvm::DataLayout> llvmDataLayout(llvm::DataLayout(
""));
2190 if (
auto iface = dyn_cast<DataLayoutOpInterface>(m)) {
2191 if (DataLayoutSpecInterface spec = iface.getDataLayoutSpec()) {
2195 }
else if (
auto mod = dyn_cast<ModuleOp>(m)) {
2196 if (DataLayoutSpecInterface spec = mod.getDataLayoutSpec()) {
2201 if (failed(llvmDataLayout))
2203 llvmModule->setDataLayout(*llvmDataLayout);
2205 if (
auto targetTripleAttr =
2206 m->getDiscardableAttr(LLVM::LLVMDialect::getTargetTripleAttrName()))
2207 llvmModule->setTargetTriple(
2208 llvm::Triple(cast<StringAttr>(targetTripleAttr).getValue()));
2213 std::unique_ptr<llvm::Module>
2215 StringRef name,
bool disableVerification) {
2217 module->
emitOpError(
"can not be translated to an LLVMIR module");
2221 std::unique_ptr<llvm::Module> llvmModule =
2230 llvm::IRBuilder<llvm::TargetFolder> llvmBuilder(
2232 llvm::TargetFolder(translator.getLLVMModule()->getDataLayout()));
2238 if (failed(translator.convertOperation(*module, llvmBuilder)))
2241 if (failed(translator.convertComdats()))
2243 if (failed(translator.convertFunctionSignatures()))
2245 if (failed(translator.convertGlobalsAndAliases()))
2247 if (failed(translator.createTBAAMetadata()))
2249 if (failed(translator.createIdentMetadata()))
2251 if (failed(translator.createCommandlineMetadata()))
2253 if (failed(translator.createDependentLibrariesMetadata()))
2258 if (!isa<LLVM::LLVMFuncOp, LLVM::AliasOp, LLVM::GlobalOp,
2259 LLVM::GlobalCtorsOp, LLVM::GlobalDtorsOp, LLVM::ComdatOp>(&o) &&
2261 failed(translator.convertOperation(o, llvmBuilder))) {
2269 if (failed(translator.convertFunctions()))
2274 if (failed(translator.convertUnresolvedBlockAddress()))
2284 translator.debugTranslation->addModuleFlagsIfNotPresent();
2286 if (!disableVerification &&
2287 llvm::verifyModule(*translator.llvmModule, &llvm::errs()))
2290 return std::move(translator.llvmModule);
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1194::ArityGroupAndKind::Kind kind
static Value getPHISourceValue(Block *current, Block *pred, unsigned numArguments, unsigned index)
Get the SSA value passed to the current block from the terminator operation of its predecessor.
static llvm::MDNode * convertIntegerToMDNode(llvm::LLVMContext &context, const llvm::APInt &value)
Return a representation of value as an MDNode.
static llvm::Constant * convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, llvm::Type *llvmType, const ModuleTranslation &moduleTranslation)
Convert a dense elements attribute to an LLVM IR constant using its raw data storage if possible.
static llvm::MDNode * convertVecTypeHintToMDNode(llvm::LLVMContext &context, llvm::Type *type, bool isSigned)
Return an MDNode encoding vec_type_hint metadata.
static Block & getModuleBody(Operation *module)
A helper method to get the single Block in an operation honoring LLVM's module requirements.
static llvm::MDNode * convertIntegerArrayToMDNode(llvm::LLVMContext &context, ArrayRef< int32_t > values)
Return an MDNode with a tuple given by the values in values.
static llvm::Metadata * convertIntegerToMetadata(llvm::LLVMContext &context, const llvm::APInt &value)
Return a representation of value as metadata.
static void addRuntimePreemptionSpecifier(bool dsoLocalRequested, llvm::GlobalValue *gv)
Sets the runtime preemption specifier of gv to dso_local if dsoLocalRequested is true,...
static LogicalResult checkedAddLLVMFnAttribute(Location loc, llvm::Function *llvmFunc, StringRef key, StringRef value=StringRef())
Attempts to add an attribute identified by key, optionally with the given value to LLVM function llvm...
static void convertFunctionAttributes(LLVMFuncOp func, llvm::Function *llvmFunc)
Converts function attributes from func and attaches them to llvmFunc.
llvm::cl::opt< bool > UseNewDbgInfoFormat
static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage, llvm::Constant *cst)
A helper method to decide if a constant must not be set as a global variable initializer.
static llvm::Type * getInnermostElementType(llvm::Type *type)
Returns the first non-sequential type nested in sequential types.
static void convertFunctionKernelAttributes(LLVMFuncOp func, llvm::Function *llvmFunc, ModuleTranslation &translation)
Converts function attributes from func and attaches them to llvmFunc.
static std::unique_ptr< llvm::Module > prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, StringRef name)
static llvm::Constant * convertDenseResourceElementsAttr(Location loc, DenseResourceElementsAttr denseResourceAttr, llvm::Type *llvmType, const ModuleTranslation &moduleTranslation)
Convert a dense resource elements attribute to an LLVM IR constant using its raw data storage if poss...
static LogicalResult convertParameterAttr(llvm::AttrBuilder &attrBuilder, llvm::Attribute::AttrKind llvmKind, NamedAttribute namedAttr, ModuleTranslation &moduleTranslation, Location loc)
static void convertFunctionMemoryAttributes(LLVMFuncOp func, llvm::Function *llvmFunc)
Converts memory effect attributes from func and attaches them to llvmFunc.
static LogicalResult forwardPassthroughAttributes(Location loc, std::optional< ArrayAttr > attributes, llvm::Function *llvmFunc)
Attaches the attributes listed in the given array attribute to llvmFunc.
static llvm::Constant * buildSequentialConstant(ArrayRef< llvm::Constant * > &constants, ArrayRef< int64_t > shape, llvm::Type *type, Location loc)
Builds a constant of a sequential LLVM type type, potentially containing other sequential types recur...
This class represents a processed binary blob of data.
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
void addWalk(WalkFn< Attribute > &&fn)
Register a walk function for a given attribute or type.
WalkResult walk(T element)
Walk the given attribute/type, and recursively walk any sub elements.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< pred_iterator > getPredecessors()
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
The main mechanism for performing data layout queries.
std::optional< uint64_t > getTypeIndexBitwidth(Type t) const
Returns the bitwidth that should be used when performing index computations for the given pointer-lik...
uint64_t getTypePreferredAlignment(Type t) const
Returns the preferred of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t getNumElements() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
const InterfaceType * getInterfaceFor(Object *obj) const
Get the interface for a given object, or null if one is not registered.
Base class for dialect interfaces providing translation to LLVM IR.
virtual LogicalResult convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to provide translation of the operations to LLVM IR.
virtual LogicalResult convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const
Acts on the given function operation using the interface implemented by the dialect of one of the fun...
virtual LogicalResult amendOperation(Operation *op, ArrayRef< llvm::Instruction * > instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const
Acts on the given operation using the interface implemented by the dialect of one of the operation's ...
This class represents the base attribute for all debug info attributes.
Implementation class for module translation.
llvm::fp::ExceptionBehavior translateFPExceptionBehavior(LLVM::FPExceptionBehavior exceptionBehavior)
Translates the given LLVM FP exception behavior metadata.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
llvm::DIGlobalVariableExpression * translateGlobalVariableExpression(LLVM::DIGlobalVariableExpressionAttr attr)
Translates the given LLVM global variable expression metadata.
FailureOr< llvm::AttrBuilder > convertParameterAttrs(mlir::Location loc, DictionaryAttr paramAttrs)
Translates parameter attributes of a call and adds them to the returned AttrBuilder.
llvm::NamedMDNode * getOrInsertNamedModuleMetadata(StringRef name)
Gets the named metadata in the LLVM IR module being constructed, creating it if it does not exist.
llvm::Instruction * lookupBranch(Operation *op) const
Finds an LLVM IR instruction that corresponds to the given MLIR operation with successors.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::DILocation * translateLoc(Location loc, llvm::DILocalScope *scope)
Translates the given location.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
void setDereferenceableMetadata(DereferenceableOpInterface op, llvm::Instruction *inst)
Sets LLVM dereferenceable metadata for operations that have dereferenceable attributes.
void setBranchWeightsMetadata(BranchWeightOpInterface op)
Sets LLVM profiling metadata for operations that have branch weights.
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::RoundingMode translateRoundingMode(LLVM::RoundingMode rounding)
Translates the given LLVM rounding mode metadata.
void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst)
Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
llvm::DIExpression * translateExpression(LLVM::DIExpressionAttr attr)
Translates the given LLVM DWARF expression metadata.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::CallInst * lookupCall(Operation *op) const
Finds an LLVM call instruction that corresponds to the given MLIR call operation.
llvm::Metadata * translateDebugInfo(LLVM::DINodeAttr attr)
Translates the given LLVM debug info metadata.
void setDisjointFlag(Operation *op, llvm::Value *value)
Sets the disjoint flag attribute for the exported instruction value given the original operation op.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
llvm::BasicBlock * lookupBlockAddress(BlockAddressAttr attr) const
Finds the LLVM basic block that corresponds to the given BlockAddressAttr.
llvm::MDNode * getOrCreateAliasScopes(ArrayRef< AliasScopeAttr > aliasScopeAttrs)
Returns the LLVM metadata corresponding to an array of mlir LLVM dialect alias scope attributes.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::MDNode * getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr)
Returns the LLVM metadata corresponding to a mlir LLVM dialect alias scope attribute.
void forgetMapping(Region ®ion)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void setAliasScopeMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst)
void setAccessGroupsMetadata(AccessGroupOpInterface op, llvm::Instruction *inst)
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
void setLoopMetadata(Operation *op, llvm::Instruction *inst)
Sets LLVM loop metadata for branch operations that have a loop annotation attribute.
llvm::Type * translateType(Type type)
Translates the given MLIR LLVM dialect type to LLVM IR.
A helper class that converts LoopAnnotationAttrs and AccessGroupAttrs into corresponding llvm::MDNode...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttribute represents a combination of a name and an Attribute value.
Attribute getValue() const
Return the value of the attribute.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Attribute getDiscardableAttr(StringRef name)
Access a discardable attribute by name, returns an null Attribute if the discardable attribute does n...
Value getOperand(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Block * getSuccessor(unsigned index)
unsigned getNumSuccessors()
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
dialect_attr_range getDialectAttrs()
Return a range corresponding to the dialect attributes for this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class models how operands are forwarded to block arguments in control flow.
bool empty() const
Returns true if there are no successor operands.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_iterator use_end() const
Type getType() const
Return the type of this value.
use_iterator use_begin() const
ArrayRef< T > asArrayRef() const
The OpAsmOpInterface, see OpAsmInterface.td for more details.
void connectPHINodes(Region ®ion, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
static llvm::DenseMap< llvm::StringRef, llvm::Attribute::AttrKind > getAttrNameToKindMapping()
Returns a dense map from LLVM attribute name to their kind in LLVM IR dialect.
llvm::Constant * getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc, const ModuleTranslation &moduleTranslation)
Create an LLVM IR constant of llvmType from the MLIR attribute attr.
bool satisfiesLLVMModule(Operation *op)
LLVM requires some operations to be inside of a Module operation.
void legalizeDIExpressionsRecursively(Operation *op)
Register all known legalization patterns declared here and apply them to all ops in op.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void ensureDistinctSuccessors(Operation *op)
Make argument-taking successors of each block distinct.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index)
Returns the dictionary attribute corresponding to the argument at 'index'.
Include the generated interface declarations.
std::unique_ptr< llvm::Module > translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, llvm::StringRef name="LLVMDialectModule", bool disableVerification=false)
Translates a given LLVM dialect module into an LLVM IR module living in the given context.
SetVector< Block * > getBlocksSortedByDominance(Region ®ion)
Gets a list of blocks that is sorted according to dominance.
DataLayoutSpecInterface translateDataLayout(const llvm::DataLayout &dataLayout, MLIRContext *context)
Translate the given LLVM data layout into an MLIR equivalent using the DLTI dialect.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...