19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InlineAsm.h"
21 #include "llvm/IR/MDBuilder.h"
22 #include "llvm/IR/MatrixBuilder.h"
23 #include "llvm/IR/Operator.h"
30 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
33 using llvmFMF = llvm::FastMathFlags;
34 using FuncT = void (llvmFMF::*)(bool);
35 const std::pair<FastmathFlags, FuncT> handlers[] = {
37 {FastmathFlags::nnan, &llvmFMF::setNoNaNs},
38 {FastmathFlags::ninf, &llvmFMF::setNoInfs},
39 {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros},
40 {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal},
42 {FastmathFlags::afn, &llvmFMF::setApproxFunc},
43 {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc},
46 llvm::FastMathFlags ret;
47 ::mlir::LLVM::FastmathFlags fmfMlir = op.getFastmathAttr().getValue();
48 for (
auto it : handlers)
49 if (bitEnumContainsAll(fmfMlir, it.first))
50 (ret.*(it.second))(
true);
57 llvm::append_range(position, indices);
62 static std::string
diagStr(
const llvm::Type *type) {
64 llvm::raw_string_ostream os(str);
78 allArgTys.push_back(moduleTranslation.
convertType(type));
82 resTy = llvm::Type::getVoidTy(module->getContext());
90 getIntrinsicInfoTableEntries(
id, table);
94 if (llvm::Intrinsic::matchIntrinsicSignature(ft, tableRef,
96 llvm::Intrinsic::MatchIntrinsicTypesResult::MatchIntrinsicTypes_Match) {
98 <<
diagStr(ft) <<
" to overloaded intrinsic " << op.getIntrinAttr()
99 <<
" does not match any of the overloads";
103 return llvm::Intrinsic::getDeclaration(module,
id, overloadedArgTysRef);
110 llvm::Module *module = builder.GetInsertBlock()->getModule();
112 llvm::Function::lookupIntrinsicID(op.getIntrinAttr());
115 << op.getIntrinAttr();
117 llvm::Function *fn =
nullptr;
118 if (llvm::Intrinsic::isOverloaded(
id)) {
125 fn = llvm::Intrinsic::getDeclaration(module,
id, {});
129 const llvm::Type *intrinType =
131 ? llvm::Type::getVoidTy(module->getContext())
133 if (intrinType != fn->getReturnType()) {
135 <<
diagStr(intrinType) <<
" but " << op.getIntrinAttr()
136 <<
" actually returns " <<
diagStr(fn->getReturnType());
141 if (!fn->getFunctionType()->isVarArg() &&
145 <<
" expects " << fn->arg_size();
147 if (fn->getFunctionType()->isVarArg() &&
151 << op.getIntrinAttr() <<
" expects at least " << fn->arg_size();
154 for (
unsigned i = 0, e = fn->arg_size(); i != e; ++i) {
155 const llvm::Type *expected = fn->getArg(i)->getType();
156 const llvm::Type *actual =
158 if (actual != expected) {
160 << i <<
" has type " <<
diagStr(actual) <<
" but "
161 << op.getIntrinAttr() <<
" expects " <<
diagStr(expected);
165 FastmathFlagsInterface itf = op;
179 llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder);
180 if (
auto fmf = dyn_cast<FastmathFlagsInterface>(opInst))
183 #include "mlir/Dialect/LLVMIR/LLVMConversions.inc"
184 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicConversions.inc"
194 Type resultType = resultTypes.empty()
196 : resultTypes.front();
197 return llvm::cast<llvm::FunctionType>(moduleTranslation.
convertType(
199 llvm::to_vector(args.getTypes()),
false)));
206 if (
auto callOp = dyn_cast<LLVM::CallOp>(opInst)) {
207 auto operands = moduleTranslation.
lookupValues(callOp.getOperands());
209 llvm::CallInst *call;
210 if (
auto attr = callOp.getCalleeAttr()) {
211 call = builder.CreateCall(
214 call = builder.CreateCall(getCalleeFunctionType(callOp.getResultTypes(),
215 callOp.getArgOperands()),
216 operandsRef.front(), operandsRef.drop_front());
226 else if (!call->getType()->isVoidTy())
228 moduleTranslation.
mapCall(callOp, call);
232 if (
auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
236 llvm::append_range(operandTypes, inlineAsmOp.getOperands().getTypes());
239 if (inlineAsmOp.getNumResults() == 0) {
242 assert(inlineAsmOp.getNumResults() == 1);
243 resultType = inlineAsmOp.getResultTypes()[0];
246 llvm::InlineAsm *inlineAsmInst =
247 inlineAsmOp.getAsmDialect()
249 static_cast<llvm::FunctionType *
>(
251 inlineAsmOp.getAsmString(), inlineAsmOp.getConstraints(),
252 inlineAsmOp.getHasSideEffects(),
253 inlineAsmOp.getIsAlignStack(),
254 convertAsmDialectToLLVM(*inlineAsmOp.getAsmDialect()))
257 inlineAsmOp.getAsmString(),
258 inlineAsmOp.getConstraints(),
259 inlineAsmOp.getHasSideEffects(),
260 inlineAsmOp.getIsAlignStack());
261 llvm::CallInst *inst = builder.CreateCall(
263 moduleTranslation.
lookupValues(inlineAsmOp.getOperands()));
264 if (
auto maybeOperandAttrs = inlineAsmOp.getOperandAttrs()) {
265 llvm::AttributeList attrList;
270 DictionaryAttr dAttr = cast<DictionaryAttr>(attr);
272 cast<TypeAttr>(dAttr.get(InlineAsmOp::getElementTypeAttrName()));
274 llvm::Type *ty = moduleTranslation.
convertType(tAttr.getValue());
275 b.addTypeAttr(llvm::Attribute::ElementType, ty);
279 attrList = attrList.addAttributesAtIndex(
282 inst->setAttributes(attrList);
290 if (
auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
291 auto operands = moduleTranslation.
lookupValues(invOp.getCalleeOperands());
293 llvm::Instruction *result;
295 result = builder.CreateInvoke(
297 moduleTranslation.
lookupBlock(invOp.getSuccessor(0)),
298 moduleTranslation.
lookupBlock(invOp.getSuccessor(1)), operandsRef);
300 result = builder.CreateInvoke(
301 getCalleeFunctionType(invOp.getResultTypes(), invOp.getArgOperands()),
303 moduleTranslation.
lookupBlock(invOp.getSuccessor(0)),
304 moduleTranslation.
lookupBlock(invOp.getSuccessor(1)),
305 operandsRef.drop_front());
307 moduleTranslation.
mapBranch(invOp, result);
309 if (invOp->getNumResults() != 0) {
313 return success(result->getType()->isVoidTy());
316 if (
auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
317 llvm::Type *ty = moduleTranslation.
convertType(lpOp.getType());
318 llvm::LandingPadInst *lpi =
319 builder.CreateLandingPad(ty, lpOp.getNumOperands());
320 lpi->setCleanup(lpOp.getCleanup());
323 for (llvm::Value *operand :
326 if (
auto *constOperand = dyn_cast<llvm::Constant>(operand))
327 lpi->addClause(constOperand);
329 moduleTranslation.
mapValue(lpOp.getResult(), lpi);
335 if (
auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
336 llvm::BranchInst *branch =
337 builder.CreateBr(moduleTranslation.
lookupBlock(brOp.getSuccessor()));
338 moduleTranslation.
mapBranch(&opInst, branch);
342 if (
auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
343 llvm::BranchInst *branch = builder.CreateCondBr(
344 moduleTranslation.
lookupValue(condbrOp.getOperand(0)),
345 moduleTranslation.
lookupBlock(condbrOp.getSuccessor(0)),
346 moduleTranslation.
lookupBlock(condbrOp.getSuccessor(1)));
347 moduleTranslation.
mapBranch(&opInst, branch);
351 if (
auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
352 llvm::SwitchInst *switchInst = builder.CreateSwitch(
353 moduleTranslation.
lookupValue(switchOp.getValue()),
354 moduleTranslation.
lookupBlock(switchOp.getDefaultDestination()),
355 switchOp.getCaseDestinations().size());
358 if (!switchOp.getCaseValues())
361 auto *ty = llvm::cast<llvm::IntegerType>(
362 moduleTranslation.
convertType(switchOp.getValue().getType()));
364 llvm::zip(llvm::cast<DenseIntElementsAttr>(*switchOp.getCaseValues()),
365 switchOp.getCaseDestinations()))
370 moduleTranslation.
mapBranch(&opInst, switchInst);
377 if (
auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
378 LLVM::GlobalOp global =
379 addressOfOp.getGlobal(moduleTranslation.
symbolTable());
380 LLVM::LLVMFuncOp
function =
381 addressOfOp.getFunction(moduleTranslation.
symbolTable());
384 assert((global ||
function) &&
385 "referencing an undefined global or function");
388 addressOfOp.getResult(),
400 class LLVMDialectLLVMIRTranslationInterface
408 convertOperation(
Operation *op, llvm::IRBuilderBase &builder,
416 registry.
insert<LLVM::LLVMDialect>();
418 dialect->addInterfaces<LLVMDialectLLVMIRTranslationInterface>();
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static std::string diagStr(const llvm::Type *type)
Convert an LLVM type to a string for printing in diagnostics.
static FailureOr< llvm::Function * > getOverloadedDeclaration(CallIntrinsicOp op, llvm::Intrinsic::ID id, llvm::Module *module, LLVM::ModuleTranslation &moduleTranslation)
Get the declaration of an overloaded llvm intrinsic.
static LogicalResult convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Builder for LLVM_CallIntrinsicOp.
static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
Attributes are known-constant values of operations.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
A symbol reference with a reference path containing a single element.
Base class for dialect interfaces providing translation to LLVM IR.
LLVMTranslationDialectInterface(Dialect *dialect)
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapCall(Operation *mlir, llvm::CallInst *llvm)
Stores a mapping between an MLIR call operation and a corresponding LLVM call instruction.
void mapBranch(Operation *mlir, llvm::Instruction *llvm)
Stores the mapping between an MLIR operation with successors and a corresponding LLVM IR instruction.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst)
Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void setAliasScopeMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst)
void setAccessGroupsMetadata(AccessGroupOpInterface op, llvm::Instruction *inst)
MLIRContext & getContext()
Returns the MLIR context of the module being translated.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
void setLoopMetadata(Operation *op, llvm::Instruction *inst)
Sets LLVM loop metadata for branch operations that have a loop annotation attribute.
MLIRContext is the top-level object for a collection of MLIR operations.
void appendDialectRegistry(const DialectRegistry ®istry)
Append the contents of the given dialect registry to the registry associated with this context.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
Type getType() const
Return the type of this value.
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
llvm::Constant * getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc, const ModuleTranslation &moduleTranslation)
Create an LLVM IR constant of llvmType from the MLIR attribute attr.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void registerLLVMDialectTranslation(DialectRegistry ®istry)
Register the LLVM dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.