19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/InlineAsm.h" 21 #include "llvm/IR/MDBuilder.h" 22 #include "llvm/IR/MatrixBuilder.h" 23 #include "llvm/IR/Operator.h" 29 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc" 34 case LLVM::ICmpPredicate::eq:
35 return llvm::CmpInst::Predicate::ICMP_EQ;
36 case LLVM::ICmpPredicate::ne:
37 return llvm::CmpInst::Predicate::ICMP_NE;
38 case LLVM::ICmpPredicate::slt:
39 return llvm::CmpInst::Predicate::ICMP_SLT;
40 case LLVM::ICmpPredicate::sle:
41 return llvm::CmpInst::Predicate::ICMP_SLE;
42 case LLVM::ICmpPredicate::sgt:
43 return llvm::CmpInst::Predicate::ICMP_SGT;
44 case LLVM::ICmpPredicate::sge:
45 return llvm::CmpInst::Predicate::ICMP_SGE;
46 case LLVM::ICmpPredicate::ult:
47 return llvm::CmpInst::Predicate::ICMP_ULT;
48 case LLVM::ICmpPredicate::ule:
49 return llvm::CmpInst::Predicate::ICMP_ULE;
50 case LLVM::ICmpPredicate::ugt:
51 return llvm::CmpInst::Predicate::ICMP_UGT;
52 case LLVM::ICmpPredicate::uge:
53 return llvm::CmpInst::Predicate::ICMP_UGE;
55 llvm_unreachable(
"incorrect comparison predicate");
60 case LLVM::FCmpPredicate::_false:
61 return llvm::CmpInst::Predicate::FCMP_FALSE;
62 case LLVM::FCmpPredicate::oeq:
63 return llvm::CmpInst::Predicate::FCMP_OEQ;
64 case LLVM::FCmpPredicate::ogt:
65 return llvm::CmpInst::Predicate::FCMP_OGT;
66 case LLVM::FCmpPredicate::oge:
67 return llvm::CmpInst::Predicate::FCMP_OGE;
68 case LLVM::FCmpPredicate::olt:
69 return llvm::CmpInst::Predicate::FCMP_OLT;
70 case LLVM::FCmpPredicate::ole:
71 return llvm::CmpInst::Predicate::FCMP_OLE;
72 case LLVM::FCmpPredicate::one:
73 return llvm::CmpInst::Predicate::FCMP_ONE;
74 case LLVM::FCmpPredicate::ord:
75 return llvm::CmpInst::Predicate::FCMP_ORD;
76 case LLVM::FCmpPredicate::ueq:
77 return llvm::CmpInst::Predicate::FCMP_UEQ;
78 case LLVM::FCmpPredicate::ugt:
79 return llvm::CmpInst::Predicate::FCMP_UGT;
80 case LLVM::FCmpPredicate::uge:
81 return llvm::CmpInst::Predicate::FCMP_UGE;
82 case LLVM::FCmpPredicate::ult:
83 return llvm::CmpInst::Predicate::FCMP_ULT;
84 case LLVM::FCmpPredicate::ule:
85 return llvm::CmpInst::Predicate::FCMP_ULE;
86 case LLVM::FCmpPredicate::une:
87 return llvm::CmpInst::Predicate::FCMP_UNE;
88 case LLVM::FCmpPredicate::uno:
89 return llvm::CmpInst::Predicate::FCMP_UNO;
90 case LLVM::FCmpPredicate::_true:
91 return llvm::CmpInst::Predicate::FCMP_TRUE;
93 llvm_unreachable(
"incorrect comparison predicate");
98 case LLVM::AtomicBinOp::xchg:
99 return llvm::AtomicRMWInst::BinOp::Xchg;
100 case LLVM::AtomicBinOp::add:
101 return llvm::AtomicRMWInst::BinOp::Add;
102 case LLVM::AtomicBinOp::sub:
103 return llvm::AtomicRMWInst::BinOp::Sub;
104 case LLVM::AtomicBinOp::_and:
105 return llvm::AtomicRMWInst::BinOp::And;
106 case LLVM::AtomicBinOp::nand:
107 return llvm::AtomicRMWInst::BinOp::Nand;
108 case LLVM::AtomicBinOp::_or:
109 return llvm::AtomicRMWInst::BinOp::Or;
110 case LLVM::AtomicBinOp::_xor:
111 return llvm::AtomicRMWInst::BinOp::Xor;
113 return llvm::AtomicRMWInst::BinOp::Max;
115 return llvm::AtomicRMWInst::BinOp::Min;
116 case LLVM::AtomicBinOp::umax:
117 return llvm::AtomicRMWInst::BinOp::UMax;
118 case LLVM::AtomicBinOp::umin:
119 return llvm::AtomicRMWInst::BinOp::UMin;
120 case LLVM::AtomicBinOp::fadd:
121 return llvm::AtomicRMWInst::BinOp::FAdd;
122 case LLVM::AtomicBinOp::fsub:
123 return llvm::AtomicRMWInst::BinOp::FSub;
125 llvm_unreachable(
"incorrect atomic binary operator");
130 case LLVM::AtomicOrdering::not_atomic:
131 return llvm::AtomicOrdering::NotAtomic;
132 case LLVM::AtomicOrdering::unordered:
133 return llvm::AtomicOrdering::Unordered;
134 case LLVM::AtomicOrdering::monotonic:
135 return llvm::AtomicOrdering::Monotonic;
136 case LLVM::AtomicOrdering::acquire:
137 return llvm::AtomicOrdering::Acquire;
138 case LLVM::AtomicOrdering::release:
139 return llvm::AtomicOrdering::Release;
140 case LLVM::AtomicOrdering::acq_rel:
141 return llvm::AtomicOrdering::AcquireRelease;
142 case LLVM::AtomicOrdering::seq_cst:
143 return llvm::AtomicOrdering::SequentiallyConsistent;
145 llvm_unreachable(
"incorrect atomic ordering");
149 using llvmFMF = llvm::FastMathFlags;
150 using FuncT =
void (llvmFMF::*)(bool);
151 const std::pair<FastmathFlags, FuncT> handlers[] = {
153 {FastmathFlags::nnan, &llvmFMF::setNoNaNs},
154 {FastmathFlags::ninf, &llvmFMF::setNoInfs},
155 {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros},
156 {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal},
158 {FastmathFlags::afn, &llvmFMF::setApproxFunc},
159 {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc},
162 llvm::FastMathFlags ret;
163 auto fmf = op.getFastmathFlags();
164 for (
auto it : handlers)
165 if (bitEnumContains(fmf, it.first))
166 (ret.*(it.second))(
true);
173 LoopOptionCase option,
178 case LoopOptionCase::disable_licm:
179 name =
"llvm.licm.disable";
180 cstValue = llvm::ConstantInt::getBool(ctx, value);
182 case LoopOptionCase::disable_unroll:
183 name =
"llvm.loop.unroll.disable";
184 cstValue = llvm::ConstantInt::getBool(ctx, value);
186 case LoopOptionCase::interleave_count:
187 name =
"llvm.loop.interleave.count";
188 cstValue = llvm::ConstantInt::get(
189 llvm::IntegerType::get(ctx, 32), value);
191 case LoopOptionCase::disable_pipeline:
192 name =
"llvm.loop.pipeline.disable";
193 cstValue = llvm::ConstantInt::getBool(ctx, value);
195 case LoopOptionCase::pipeline_initiation_interval:
196 name =
"llvm.loop.pipeline.initiationinterval";
197 cstValue = llvm::ConstantInt::get(
198 llvm::IntegerType::get(ctx, 32), value);
201 return llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name),
202 llvm::ConstantAsMetadata::get(cstValue)});
206 llvm::IRBuilderBase &builder,
209 llvm::Module *module = builder.GetInsertBlock()->getModule();
212 llvm::LLVMContext &ctx = module->getContext();
216 auto dummy = llvm::MDNode::getTemporary(ctx, llvm::None);
217 loopOptions.push_back(dummy.get());
219 auto loopAttr = attr.cast<DictionaryAttr>();
220 auto parallelAccessGroup =
221 loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName());
222 if (parallelAccessGroup.hasValue()) {
224 parallelAccess.push_back(
225 llvm::MDString::get(ctx,
"llvm.loop.parallel_accesses"));
226 for (SymbolRefAttr accessGroupRef : parallelAccessGroup->getValue()
228 .getAsRange<SymbolRefAttr>())
229 parallelAccess.push_back(
231 loopOptions.push_back(llvm::MDNode::get(ctx, parallelAccess));
234 if (
auto loopOptionsAttr = loopAttr.getAs<LoopOptionsAttr>(
235 LLVMDialect::getLoopOptionsAttrName())) {
236 for (
auto option : loopOptionsAttr.getOptions())
237 loopOptions.push_back(
242 loopMD = llvm::MDNode::get(ctx, loopOptions);
243 loopMD->replaceOperandWith(0, loopMD);
250 llvmInst.setMetadata(module->getMDKindID(
"llvm.loop"), loopMD);
257 auto extractPosition = [](ArrayAttr attr) {
259 position.reserve(attr.size());
261 position.push_back(v.cast<IntegerAttr>().getValue().getZExtValue());
265 llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder);
266 if (
auto fmf = dyn_cast<FastmathFlagsInterface>(opInst))
269 #include "mlir/Dialect/LLVMIR/LLVMConversions.inc" 270 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicConversions.inc" 277 auto convertCall = [&](
Operation &op) -> llvm::Value * {
278 auto operands = moduleTranslation.
lookupValues(op.getOperands());
281 return builder.CreateCall(
285 auto *calleeFunctionType = cast<llvm::FunctionType>(
286 moduleTranslation.
convertType(calleeType.getElementType()));
287 return builder.CreateCall(calleeFunctionType, operandsRef.front(),
288 operandsRef.drop_front());
293 if (isa<LLVM::CallOp>(opInst)) {
294 llvm::Value *result = convertCall(opInst);
300 return success(result->getType()->isVoidTy());
303 if (
auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
307 llvm::append_range(operandTypes, inlineAsmOp.getOperands().getTypes());
310 if (inlineAsmOp.getNumResults() == 0) {
311 resultType = LLVM::LLVMVoidType::get(&moduleTranslation.
getContext());
313 assert(inlineAsmOp.getNumResults() == 1);
314 resultType = inlineAsmOp.getResultTypes()[0];
317 llvm::InlineAsm *inlineAsmInst =
318 inlineAsmOp.getAsmDialect().hasValue()
319 ? llvm::InlineAsm::get(
320 static_cast<llvm::FunctionType *>(
322 inlineAsmOp.getAsmString(), inlineAsmOp.getConstraints(),
323 inlineAsmOp.getHasSideEffects(),
324 inlineAsmOp.getIsAlignStack(),
325 convertAsmDialectToLLVM(*inlineAsmOp.getAsmDialect()))
326 : llvm::InlineAsm::get(static_cast<llvm::FunctionType *>(
328 inlineAsmOp.getAsmString(),
329 inlineAsmOp.getConstraints(),
330 inlineAsmOp.getHasSideEffects(),
331 inlineAsmOp.getIsAlignStack());
332 llvm::CallInst *inst = builder.CreateCall(
334 moduleTranslation.
lookupValues(inlineAsmOp.getOperands()));
335 if (
auto maybeOperandAttrs = inlineAsmOp.getOperandAttrs()) {
336 llvm::AttributeList attrList;
341 DictionaryAttr dAttr = attr.
cast<DictionaryAttr>();
343 dAttr.get(InlineAsmOp::getElementTypeAttrName()).cast<TypeAttr>();
345 llvm::Type *ty = moduleTranslation.
convertType(tAttr.getValue());
346 b.addTypeAttr(llvm::Attribute::ElementType, ty);
350 attrList = attrList.addAttributesAtIndex(
353 inst->setAttributes(attrList);
361 if (
auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
362 auto operands = moduleTranslation.
lookupValues(invOp.getCalleeOperands());
364 llvm::Instruction *result;
366 result = builder.CreateInvoke(
368 moduleTranslation.
lookupBlock(invOp.getSuccessor(0)),
369 moduleTranslation.
lookupBlock(invOp.getSuccessor(1)), operandsRef);
373 auto *calleeFunctionType = cast<llvm::FunctionType>(
374 moduleTranslation.
convertType(calleeType.getElementType()));
375 result = builder.CreateInvoke(
376 calleeFunctionType, operandsRef.front(),
377 moduleTranslation.
lookupBlock(invOp.getSuccessor(0)),
378 moduleTranslation.
lookupBlock(invOp.getSuccessor(1)),
379 operandsRef.drop_front());
381 moduleTranslation.
mapBranch(invOp, result);
383 if (invOp->getNumResults() != 0) {
387 return success(result->getType()->isVoidTy());
390 if (
auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
391 llvm::Type *ty = moduleTranslation.
convertType(lpOp.getType());
392 llvm::LandingPadInst *lpi =
393 builder.CreateLandingPad(ty, lpOp.getNumOperands());
394 lpi->setCleanup(lpOp.getCleanup());
397 for (llvm::Value *operand :
400 if (
auto *constOperand = dyn_cast<llvm::Constant>(operand))
401 lpi->addClause(constOperand);
403 moduleTranslation.
mapValue(lpOp.getResult(), lpi);
409 if (
auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
410 llvm::BranchInst *branch =
411 builder.CreateBr(moduleTranslation.
lookupBlock(brOp.getSuccessor()));
412 moduleTranslation.
mapBranch(&opInst, branch);
416 if (
auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
417 llvm::MDNode *branchWeights =
nullptr;
418 if (
auto weights = condbrOp.getBranchWeights()) {
420 auto weightValues = weights->getValues<APInt>();
421 auto trueWeight = weightValues[0].getSExtValue();
422 auto falseWeight = weightValues[1].getSExtValue();
425 .createBranchWeights(static_cast<uint32_t>(trueWeight),
426 static_cast<uint32_t
>(falseWeight));
428 llvm::BranchInst *branch = builder.CreateCondBr(
429 moduleTranslation.
lookupValue(condbrOp.getOperand(0)),
430 moduleTranslation.
lookupBlock(condbrOp.getSuccessor(0)),
431 moduleTranslation.
lookupBlock(condbrOp.getSuccessor(1)), branchWeights);
432 moduleTranslation.
mapBranch(&opInst, branch);
436 if (
auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
437 llvm::MDNode *branchWeights =
nullptr;
438 if (
auto weights = switchOp.getBranchWeights()) {
440 weightValues.reserve(weights->size());
442 weightValues.push_back(weight.getLimitedValue());
443 branchWeights = llvm::MDBuilder(moduleTranslation.
getLLVMContext())
444 .createBranchWeights(weightValues);
447 llvm::SwitchInst *switchInst = builder.CreateSwitch(
448 moduleTranslation.
lookupValue(switchOp.getValue()),
449 moduleTranslation.
lookupBlock(switchOp.getDefaultDestination()),
450 switchOp.getCaseDestinations().size(), branchWeights);
452 auto *ty = llvm::cast<llvm::IntegerType>(
453 moduleTranslation.
convertType(switchOp.getValue().getType()));
456 switchOp.getCaseDestinations()))
458 llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
461 moduleTranslation.
mapBranch(&opInst, switchInst);
468 if (
auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
469 LLVM::GlobalOp global = addressOfOp.getGlobal();
470 LLVM::LLVMFuncOp
function = addressOfOp.getFunction();
473 assert((global ||
function) &&
474 "referencing an undefined global or function");
477 addressOfOp.getResult(),
489 class LLVMDialectLLVMIRTranslationInterface
497 convertOperation(
Operation *op, llvm::IRBuilderBase &builder,
505 registry.
insert<LLVM::LLVMDialect>();
507 dialect->addInterfaces<LLVMDialectLLVMIRTranslationInterface>();
Include the generated interface declarations.
Operation is a basic unit of execution within MLIR.
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
A symbol reference with a reference path containing a single element.
static llvm::MDNode * getLoopOptionMetadata(llvm::LLVMContext &ctx, LoopOptionCase option, int64_t value)
Returns an LLVM metadata node corresponding to a loop option.
AttrClass getAttrOfType(StringAttr name)
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context...
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
static LogicalResult convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void setLoopMetadata(Operation &opInst, llvm::Instruction &llvmInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
void mapLoopOptionsMetadata(Attribute options, llvm::MDNode *metadata)
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op)
static LLVMFunctionType get(Type result, ArrayRef< Type > arguments, bool isVarArg=false)
Gets or creates an instance of LLVM dialect function in the same context as the result type...
static constexpr const bool value
Implementation class for module translation.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
Base class for dialect interfaces providing translation to LLVM IR.
MLIRContext & getContext()
Returns the MLIR context of the module being translated.
Attributes are known-constant values of operations.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p)
Convert MLIR integer comparison predicate to LLVM IR comparison predicate.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
void registerLLVMDialectTranslation(DialectRegistry ®istry)
Register the LLVM dialect and the translation from it to the LLVM IR in the given registry;...
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op)
void mapBranch(Operation *mlir, llvm::Instruction *llvm)
Stores the mapping between an MLIR operation with successors and a corresponding LLVM IR instruction...
llvm::MDNode * getAccessGroup(Operation &opInst, SymbolRefAttr accessGroupRef) const
Returns the LLVM metadata corresponding to a reference to an mlir LLVM dialect access group operation...
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value...
MLIRContext is the top-level object for a collection of MLIR operations.
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
LLVM dialect pointer type.
LLVMTranslationDialectInterface(Dialect *dialect)
unsigned getNumResults()
Return the number of results held by this operation.
llvm::Constant * getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc, const ModuleTranslation &moduleTranslation)
Create an LLVM IR constant of llvmType from the MLIR attribute attr.
llvm::MDNode * lookupLoopOptionsMetadata(Attribute options) const
Returns the LLVM metadata corresponding to a llvm loop's codegen options attribute.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering)
An attribute that represents a reference to a dense integer vector or tensor object.