19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InlineAsm.h"
21 #include "llvm/IR/MDBuilder.h"
22 #include "llvm/IR/MatrixBuilder.h"
23 #include "llvm/IR/Operator.h"
29 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
32 using llvmFMF = llvm::FastMathFlags;
33 using FuncT = void (llvmFMF::*)(bool);
34 const std::pair<FastmathFlags, FuncT> handlers[] = {
36 {FastmathFlags::nnan, &llvmFMF::setNoNaNs},
37 {FastmathFlags::ninf, &llvmFMF::setNoInfs},
38 {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros},
39 {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal},
41 {FastmathFlags::afn, &llvmFMF::setApproxFunc},
42 {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc},
45 llvm::FastMathFlags ret;
46 ::mlir::LLVM::FastmathFlags fmfMlir = op.getFastmathAttr().getValue();
47 for (
auto it : handlers)
48 if (bitEnumContainsAll(fmfMlir, it.first))
49 (ret.*(it.second))(
true);
56 llvm::append_range(position, indices);
61 static std::string
diagStr(
const llvm::Type *type) {
63 llvm::raw_string_ostream os(str);
71 static FailureOr<llvm::Function *>
76 for (
Type type : op->getOperandTypes())
77 allArgTys.push_back(moduleTranslation.
convertType(type));
80 if (op.getNumResults() == 0)
81 resTy = llvm::Type::getVoidTy(module->getContext());
83 resTy = moduleTranslation.
convertType(op.getResult(0).getType());
89 getIntrinsicInfoTableEntries(
id,
table);
93 if (llvm::Intrinsic::matchIntrinsicSignature(ft, tableRef,
95 llvm::Intrinsic::MatchIntrinsicTypesResult::MatchIntrinsicTypes_Match) {
97 <<
diagStr(ft) <<
" to overloaded intrinsic " << op.getIntrinAttr()
98 <<
" does not match any of the overloads";
102 return llvm::Intrinsic::getOrInsertDeclaration(module,
id,
103 overloadedArgTysRef);
106 static llvm::OperandBundleDef
109 std::vector<llvm::Value *> operands;
110 operands.reserve(bundleOperands.size());
111 for (
Value bundleArg : bundleOperands)
112 operands.push_back(moduleTranslation.
lookupValue(bundleArg));
113 return llvm::OperandBundleDef(bundleTag.str(), std::move(operands));
120 bundles.reserve(bundleOperands.size());
122 for (
auto [operands, tagAttr] : llvm::zip_equal(bundleOperands, bundleTags)) {
123 StringRef tag = cast<StringAttr>(tagAttr).getValue();
131 std::optional<ArrayAttr> bundleTags,
140 ArrayAttr resAttrsArray, llvm::CallBase *call,
144 if (
auto argAttrs = cast<DictionaryAttr>(argAttrsAttr);
146 FailureOr<llvm::AttrBuilder> attrBuilder =
148 if (failed(attrBuilder))
150 call->addParamAttrs(argIdx, *attrBuilder);
155 if (resAttrsArray && resAttrsArray.size() > 0) {
156 if (resAttrsArray.size() != 1)
158 if (
auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
160 FailureOr<llvm::AttrBuilder> attrBuilder =
162 if (failed(attrBuilder))
164 call->addRetAttrs(*attrBuilder);
174 callOp.getLoc(), callOp.getArgAttrsAttr(), callOp.getResAttrsAttr(), call,
182 llvm::Module *module = builder.GetInsertBlock()->getModule();
184 llvm::Intrinsic::lookupIntrinsicID(op.getIntrinAttr());
187 << op.getIntrinAttr();
189 llvm::Function *fn =
nullptr;
190 if (llvm::Intrinsic::isOverloaded(
id)) {
193 if (failed(fnOrFailure))
197 fn = llvm::Intrinsic::getOrInsertDeclaration(module,
id, {});
201 const llvm::Type *intrinType =
202 op.getNumResults() == 0
203 ? llvm::Type::getVoidTy(module->getContext())
204 : moduleTranslation.
convertType(op.getResultTypes().front());
205 if (intrinType != fn->getReturnType()) {
207 <<
diagStr(intrinType) <<
" but " << op.getIntrinAttr()
208 <<
" actually returns " <<
diagStr(fn->getReturnType());
213 if (!fn->getFunctionType()->isVarArg() &&
214 op.getArgs().size() != fn->arg_size()) {
216 << op.getArgs().size() <<
" operands but " << op.getIntrinAttr()
217 <<
" expects " << fn->arg_size();
219 if (fn->getFunctionType()->isVarArg() &&
220 op.getArgs().size() < fn->arg_size()) {
222 << op.getArgs().size() <<
" operands but variadic "
223 << op.getIntrinAttr() <<
" expects at least " << fn->arg_size();
226 for (
unsigned i = 0, e = fn->arg_size(); i != e; ++i) {
227 const llvm::Type *expected = fn->getArg(i)->getType();
228 const llvm::Type *actual =
229 moduleTranslation.
convertType(op.getOperandTypes()[i]);
230 if (actual != expected) {
232 << i <<
" has type " <<
diagStr(actual) <<
" but "
233 << op.getIntrinAttr() <<
" expects " <<
diagStr(expected);
237 FastmathFlagsInterface itf = op;
240 auto *inst = builder.CreateCall(
246 op.getResAttrsAttr(), inst,
250 if (op.getNumResults() == 1)
251 moduleTranslation.
mapValue(op->getResults().front()) = inst;
256 llvm::IRBuilderBase &builder,
258 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
259 llvm::LLVMContext &context = llvmModule->getContext();
260 llvm::NamedMDNode *linkerMDNode =
261 llvmModule->getOrInsertNamedMetadata(
"llvm.linker.options");
263 MDNodes.reserve(
options.size());
264 for (
auto s :
options.getAsRange<StringAttr>()) {
266 MDNodes.push_back(MDNode);
270 linkerMDNode->addOperand(listMDNode);
275 llvm::Module *llvmModule = moduleTranslation.
getLLVMModule();
276 for (
auto flagAttr : flags.getAsRange<ModuleFlagAttr>())
277 llvmModule->addModuleFlag(
278 convertModFlagBehaviorToLLVM(flagAttr.getBehavior()),
279 flagAttr.getKey().getValue(), flagAttr.getValue());
286 llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder);
287 if (
auto fmf = dyn_cast<FastmathFlagsInterface>(opInst))
290 #include "mlir/Dialect/LLVMIR/LLVMConversions.inc"
291 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicConversions.inc"
297 if (
auto callOp = dyn_cast<LLVM::CallOp>(opInst)) {
298 auto operands = moduleTranslation.
lookupValues(callOp.getCalleeOperands());
301 callOp.getOpBundleTags(), moduleTranslation);
303 llvm::CallInst *call;
304 if (
auto attr = callOp.getCalleeAttr()) {
306 builder.CreateCall(moduleTranslation.
lookupFunction(attr.getValue()),
307 operandsRef, opBundles);
309 llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
310 moduleTranslation.
convertType(callOp.getCalleeFunctionType()));
311 call = builder.CreateCall(calleeType, operandsRef.front(),
312 operandsRef.drop_front(), opBundles);
314 call->setCallingConv(convertCConvToLLVM(callOp.getCConv()));
315 call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind()));
316 if (callOp.getConvergentAttr())
317 call->addFnAttr(llvm::Attribute::Convergent);
318 if (callOp.getNoUnwindAttr())
319 call->addFnAttr(llvm::Attribute::NoUnwind);
320 if (callOp.getWillReturnAttr())
321 call->addFnAttr(llvm::Attribute::WillReturn);
322 if (callOp.getNoInlineAttr())
323 call->addFnAttr(llvm::Attribute::NoInline);
324 if (callOp.getAlwaysInlineAttr())
325 call->addFnAttr(llvm::Attribute::AlwaysInline);
326 if (callOp.getInlineHintAttr())
327 call->addFnAttr(llvm::Attribute::InlineHint);
332 if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
333 llvm::MemoryEffects memEffects =
334 llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
335 convertModRefInfoToLLVM(memAttr.getArgMem())) |
337 llvm::MemoryEffects::Location::InaccessibleMem,
338 convertModRefInfoToLLVM(memAttr.getInaccessibleMem())) |
339 llvm::MemoryEffects(llvm::MemoryEffects::Location::Other,
340 convertModRefInfoToLLVM(memAttr.getOther()));
341 call->setMemoryEffects(memEffects);
352 else if (!call->getType()->isVoidTy())
354 moduleTranslation.
mapCall(callOp, call);
358 if (
auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
362 llvm::append_range(operandTypes, inlineAsmOp.getOperands().getTypes());
365 if (inlineAsmOp.getNumResults() == 0) {
368 assert(inlineAsmOp.getNumResults() == 1);
369 resultType = inlineAsmOp.getResultTypes()[0];
372 llvm::InlineAsm *inlineAsmInst =
373 inlineAsmOp.getAsmDialect()
375 static_cast<llvm::FunctionType *
>(
377 inlineAsmOp.getAsmString(), inlineAsmOp.getConstraints(),
378 inlineAsmOp.getHasSideEffects(),
379 inlineAsmOp.getIsAlignStack(),
380 convertAsmDialectToLLVM(*inlineAsmOp.getAsmDialect()))
383 inlineAsmOp.getAsmString(),
384 inlineAsmOp.getConstraints(),
385 inlineAsmOp.getHasSideEffects(),
386 inlineAsmOp.getIsAlignStack());
387 llvm::CallInst *inst = builder.CreateCall(
389 moduleTranslation.
lookupValues(inlineAsmOp.getOperands()));
390 if (
auto maybeOperandAttrs = inlineAsmOp.getOperandAttrs()) {
391 llvm::AttributeList attrList;
396 DictionaryAttr dAttr = cast<DictionaryAttr>(attr);
398 cast<TypeAttr>(dAttr.get(InlineAsmOp::getElementTypeAttrName()));
400 llvm::Type *ty = moduleTranslation.
convertType(tAttr.getValue());
401 b.addTypeAttr(llvm::Attribute::ElementType, ty);
405 attrList = attrList.addAttributesAtIndex(
408 inst->setAttributes(attrList);
416 if (
auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
417 auto operands = moduleTranslation.
lookupValues(invOp.getCalleeOperands());
420 invOp.getOpBundleTags(), moduleTranslation);
422 llvm::InvokeInst *result;
424 result = builder.CreateInvoke(
426 moduleTranslation.
lookupBlock(invOp.getSuccessor(0)),
427 moduleTranslation.
lookupBlock(invOp.getSuccessor(1)), operandsRef,
430 llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
431 moduleTranslation.
convertType(invOp.getCalleeFunctionType()));
432 result = builder.CreateInvoke(
433 calleeType, operandsRef.front(),
434 moduleTranslation.
lookupBlock(invOp.getSuccessor(0)),
435 moduleTranslation.
lookupBlock(invOp.getSuccessor(1)),
436 operandsRef.drop_front(), opBundles);
438 result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
442 moduleTranslation.
mapBranch(invOp, result);
444 if (invOp->getNumResults() != 0) {
448 return success(result->getType()->isVoidTy());
451 if (
auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
452 llvm::Type *ty = moduleTranslation.
convertType(lpOp.getType());
453 llvm::LandingPadInst *lpi =
454 builder.CreateLandingPad(ty, lpOp.getNumOperands());
455 lpi->setCleanup(lpOp.getCleanup());
458 for (llvm::Value *operand :
461 if (
auto *constOperand = dyn_cast<llvm::Constant>(operand))
462 lpi->addClause(constOperand);
464 moduleTranslation.
mapValue(lpOp.getResult(), lpi);
470 if (
auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
471 llvm::BranchInst *branch =
472 builder.CreateBr(moduleTranslation.
lookupBlock(brOp.getSuccessor()));
473 moduleTranslation.
mapBranch(&opInst, branch);
477 if (
auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
478 llvm::BranchInst *branch = builder.CreateCondBr(
479 moduleTranslation.
lookupValue(condbrOp.getOperand(0)),
480 moduleTranslation.
lookupBlock(condbrOp.getSuccessor(0)),
481 moduleTranslation.
lookupBlock(condbrOp.getSuccessor(1)));
482 moduleTranslation.
mapBranch(&opInst, branch);
486 if (
auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
487 llvm::SwitchInst *switchInst = builder.CreateSwitch(
488 moduleTranslation.
lookupValue(switchOp.getValue()),
489 moduleTranslation.
lookupBlock(switchOp.getDefaultDestination()),
490 switchOp.getCaseDestinations().size());
493 if (!switchOp.getCaseValues())
496 auto *ty = llvm::cast<llvm::IntegerType>(
497 moduleTranslation.
convertType(switchOp.getValue().getType()));
499 llvm::zip(llvm::cast<DenseIntElementsAttr>(*switchOp.getCaseValues()),
500 switchOp.getCaseDestinations()))
505 moduleTranslation.
mapBranch(&opInst, switchInst);
512 if (
auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
513 LLVM::GlobalOp global =
514 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
515 LLVM::LLVMFuncOp
function =
516 addressOfOp.getFunction(moduleTranslation.
symbolTable());
517 LLVM::AliasOp alias = addressOfOp.getAlias(moduleTranslation.
symbolTable());
520 assert((global ||
function || alias) &&
521 "referencing an undefined global, function, or alias");
523 llvm::Value *llvmValue =
nullptr;
531 moduleTranslation.
mapValue(addressOfOp.getResult(), llvmValue);
537 if (
auto dsoLocalEquivalentOp =
538 dyn_cast<LLVM::DSOLocalEquivalentOp>(opInst)) {
539 LLVM::LLVMFuncOp
function =
540 dsoLocalEquivalentOp.getFunction(moduleTranslation.
symbolTable());
541 LLVM::AliasOp alias =
542 dsoLocalEquivalentOp.getAlias(moduleTranslation.
symbolTable());
545 assert((
function || alias) &&
546 "referencing an undefined function, or alias");
548 llvm::Value *llvmValue =
nullptr;
555 dsoLocalEquivalentOp.getResult(),
562 if (
auto blockAddressOp = dyn_cast<LLVM::BlockAddressOp>(opInst)) {
565 BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr();
566 BlockTagOp blockTagOp = moduleTranslation.
lookupBlockTag(blockAddressAttr);
568 blockTagOp = blockAddressOp.getBlockTagOp();
569 moduleTranslation.
mapBlockTag(blockAddressAttr, blockTagOp);
572 llvm::Value *llvmValue =
nullptr;
573 StringRef fnName = blockAddressAttr.getFunction().getValue();
574 if (llvm::BasicBlock *llvmBlock =
575 moduleTranslation.
lookupBlock(blockTagOp->getBlock())) {
576 llvm::Function *llvmFn = moduleTranslation.
lookupFunction(fnName);
585 llvmValue =
new llvm::GlobalVariable(
588 true, llvm::GlobalValue::LinkageTypes::ExternalLinkage,
590 Twine(
"__mlir_block_address_")
592 .
concat(Twine((uint64_t)blockAddressOp.getOperation())));
596 moduleTranslation.
mapValue(blockAddressOp.getResult(), llvmValue);
602 if (
auto blockTagOp = dyn_cast<LLVM::BlockTagOp>(opInst)) {
603 auto funcOp = blockTagOp->getParentOfType<LLVMFuncOp>();
608 blockTagOp.getTag());
609 moduleTranslation.
mapBlockTag(blockAddressAttr, blockTagOp);
619 class LLVMDialectLLVMIRTranslationInterface
627 convertOperation(
Operation *op, llvm::IRBuilderBase &builder,
635 registry.
insert<LLVM::LLVMDialect>();
637 dialect->addInterfaces<LLVMDialectLLVMIRTranslationInterface>();
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static std::string diagStr(const llvm::Type *type)
Convert an LLVM type to a string for printing in diagnostics.
static LogicalResult convertParameterAndResultAttrs(mlir::Location loc, ArrayAttr argAttrsArray, ArrayAttr resAttrsArray, llvm::CallBase *call, LLVM::ModuleTranslation &moduleTranslation)
static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static FailureOr< llvm::Function * > getOverloadedDeclaration(CallIntrinsicOp op, llvm::Intrinsic::ID id, llvm::Module *module, LLVM::ModuleTranslation &moduleTranslation)
Get the declaration of an overloaded llvm intrinsic.
static LogicalResult convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static SmallVector< llvm::OperandBundleDef > convertOperandBundles(OperandRangeRange bundleOperands, ArrayAttr bundleTags, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Builder for LLVM_CallIntrinsicOp.
static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op)
static llvm::OperandBundleDef convertOperandBundle(OperandRange bundleOperands, StringRef bundleTag, LLVM::ModuleTranslation &moduleTranslation)
static void convertLinkerOptionsOp(ArrayAttr options, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
Attributes are known-constant values of operations.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
Base class for dialect interfaces providing translation to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
Implementation class for module translation.
void mapUnresolvedBlockAddress(BlockAddressOp op, llvm::Value *cst)
Maps a blockaddress operation to its corresponding placeholder LLVM value.
void mapBlockTag(BlockAddressAttr attr, BlockTagOp blockTag)
Maps a blockaddress operation to its corresponding placeholder LLVM value.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapCall(Operation *mlir, llvm::CallInst *llvm)
Stores a mapping between an MLIR call operation and a corresponding LLVM call instruction.
FailureOr< llvm::AttrBuilder > convertParameterAttrs(mlir::Location loc, DictionaryAttr paramAttrs)
Translates parameter attributes of a call and adds them to the returned AttrBuilder.
void mapBranch(Operation *mlir, llvm::Instruction *llvm)
Stores the mapping between an MLIR operation with successors and a corresponding LLVM IR instruction.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::GlobalValue * lookupAlias(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global alias va...
void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst)
Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
BlockTagOp lookupBlockTag(BlockAddressAttr attr) const
Finds an MLIR block that corresponds to the given MLIR call operation.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void setAliasScopeMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst)
void setAccessGroupsMetadata(AccessGroupOpInterface op, llvm::Instruction *inst)
MLIRContext & getContext()
Returns the MLIR context of the module being translated.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
void setLoopMetadata(Operation *op, llvm::Instruction *inst)
Sets LLVM loop metadata for branch operations that have a loop annotation attribute.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
This class represents a contiguous range of operand ranges, e.g.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
llvm::Constant * getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc, const ModuleTranslation &moduleTranslation)
Create an LLVM IR constant of llvmType from the MLIR attribute attr.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerLLVMDialectTranslation(DialectRegistry ®istry)
Register the LLVM dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...