40 #include "llvm/ADT/SmallVector.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/IR/DerivedTypes.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/Type.h"
45 #include "llvm/Support/CommandLine.h"
46 #include "llvm/Support/FormatVariadic.h"
51 #define GEN_PASS_DEF_CONVERTFUNCTOLLVM
52 #include "mlir/Conversion/Passes.h.inc"
57 #define PASS_NAME "convert-func-to-llvm"
69 attr.getName() == func.getFunctionTypeAttrName() ||
72 attr.getName() == LLVM::LLVMDialect::getReadnoneAttrName() ||
73 (filterArgAndResAttrs &&
74 (attr.getName() == func.getArgAttrsAttrName() ||
75 attr.getName() == func.getResAttrsAttrName())))
77 result.push_back(attr);
83 return DictionaryAttr::get(
85 b.
getNamedAttr(LLVM::LLVMDialect::getStructAttrsAttrName(), attrs));
99 size_t numArguments = func.getNumArguments();
101 numArguments + 1, DictionaryAttr::get(builder.
getContext()));
103 for (
auto *it = attributes.begin(); it != attributes.end();) {
104 if (it->getName() == func.getArgAttrsAttrName()) {
106 assert(arrayAttrs.size() == numArguments &&
107 "Number of arg attrs and args should match");
108 std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1);
110 }
else if (it->getName() == func.getResAttrsAttrName()) {
112 assert(!arrayAttrs.empty() &&
"expected array to be non-empty");
113 allAttrs[0] = (arrayAttrs.size() == 1)
116 it = attributes.erase(it);
122 auto newArgAttrs = builder.
getNamedAttr(func.getArgAttrsAttrName(),
125 attributes.emplace_back(newArgAttrs);
128 *argAttrs = newArgAttrs;
142 LLVM::LLVMFuncOp newFuncOp) {
143 auto type = funcOp.getFunctionType();
146 auto [wrapperFuncType, resultIsNowArg] =
150 auto wrapperFuncOp = rewriter.
create<LLVM::LLVMFuncOp>(
151 loc, llvm::formatv(
"_mlir_ciface_{0}", funcOp.getName()).str(),
152 wrapperFuncType, LLVM::Linkage::External,
false,
153 LLVM::CConv::C, attributes);
159 size_t argOffset = resultIsNowArg ? 1 : 0;
161 Value arg = wrapperFuncOp.getArgument(en.index() + argOffset);
162 if (
auto memrefType = en.value().dyn_cast<MemRefType>()) {
163 Value loaded = rewriter.
create<LLVM::LoadOp>(loc, arg);
168 Value loaded = rewriter.
create<LLVM::LoadOp>(loc, arg);
176 auto call = rewriter.
create<LLVM::CallOp>(loc, newFuncOp, args);
178 if (resultIsNowArg) {
179 rewriter.
create<LLVM::StoreOp>(loc, call.getResult(),
180 wrapperFuncOp.getArgument(0));
183 rewriter.
create<LLVM::ReturnOp>(loc, call.getResults());
199 LLVM::LLVMFuncOp newFuncOp) {
202 auto [wrapperType, resultIsNowArg] =
207 assert(wrapperType &&
"unexpected type conversion failure");
215 auto wrapperFunc = builder.
create<LLVM::LLVMFuncOp>(
216 loc, llvm::formatv(
"_mlir_ciface_{0}", funcOp.getName()).str(),
217 wrapperType, LLVM::Linkage::External,
false,
218 LLVM::CConv::C, attributes);
223 FunctionType type = funcOp.getFunctionType();
225 args.reserve(type.getNumInputs());
226 ValueRange wrapperArgsRange(newFuncOp.getArguments());
228 if (resultIsNowArg) {
231 wrapperType.
cast<LLVM::LLVMFunctionType>().getParamType(0);
235 Value result = builder.
create<LLVM::AllocaOp>(loc, resultType, one);
236 args.push_back(result);
244 auto memRefType = en.value().
dyn_cast<MemRefType>();
246 if (memRefType || unrankedMemRefType) {
247 numToDrop = memRefType
253 wrapperArgsRange.take_front(numToDrop))
255 builder, loc, typeConverter, unrankedMemRefType,
256 wrapperArgsRange.take_front(numToDrop));
258 auto ptrTy = LLVM::LLVMPointerType::get(packed.
getType());
263 builder.
create<LLVM::AllocaOp>(loc, ptrTy, one, 0);
264 builder.
create<LLVM::StoreOp>(loc, packed, allocated);
267 arg = wrapperArgsRange[0];
271 wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
273 assert(wrapperArgsRange.empty() &&
"did not map some of the arguments");
275 auto call = builder.
create<LLVM::CallOp>(loc, wrapperFunc, args);
277 if (resultIsNowArg) {
278 Value result = builder.
create<LLVM::LoadOp>(loc, args.front());
279 builder.
create<LLVM::ReturnOp>(loc, result);
281 builder.
create<LLVM::ReturnOp>(loc, call.getResults());
294 convertFuncOpToLLVMFuncOp(func::FuncOp funcOp,
300 auto llvmType = getTypeConverter()->convertFunctionSignature(
301 funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(),
310 if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
311 assert(!resAttrDicts.empty() &&
"expected array to be non-empty");
312 auto newResAttrDicts =
313 (funcOp.getNumResults() == 1)
317 attributes.push_back(
318 rewriter.
getNamedAttr(funcOp.getResAttrsAttrName(), newResAttrDicts));
320 if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
322 llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
323 for (
unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
328 auto attrsDict = argAttrDicts[i].cast<DictionaryAttr>();
329 convertedAttrs.reserve(attrsDict.size());
332 return TypeAttr::get(getTypeConverter()->convertType(
333 attr.getValue().cast<TypeAttr>().getValue()));
335 if (attr.getName().getValue() ==
336 LLVM::LLVMDialect::getByValAttrName()) {
338 LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
339 }
else if (attr.getName().getValue() ==
340 LLVM::LLVMDialect::getByRefAttrName()) {
342 LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
343 }
else if (attr.getName().getValue() ==
344 LLVM::LLVMDialect::getStructRetAttrName()) {
346 LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
347 }
else if (attr.getName().getValue() ==
348 LLVM::LLVMDialect::getInAllocaAttrName()) {
350 LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
352 convertedAttrs.push_back(attr);
355 auto mapping = result.getInputMapping(i);
356 assert(mapping &&
"unexpected deletion of function argument");
360 if (mapping->size == 1) {
361 newArgAttrs[mapping->inputNo] =
362 DictionaryAttr::get(rewriter.
getContext(), convertedAttrs);
367 for (
size_t j = 0;
j < mapping->size; ++
j)
368 newArgAttrs[mapping->inputNo +
j] =
369 DictionaryAttr::get(rewriter.
getContext(), {});
372 funcOp.getArgAttrsAttrName(), rewriter.
getArrayAttr(newArgAttrs)));
377 LLVM::Linkage linkage = LLVM::Linkage::External;
383 <<
" attribute not of type LLVM::LinkageAttr";
386 linkage = attr.getLinkage();
390 StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
391 LLVM::MemoryEffectsAttr memoryAttr = {};
392 if (funcOp->hasAttr(readnoneAttrName)) {
393 auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
395 funcOp->emitError() <<
"Contains " << readnoneAttrName
396 <<
" attribute not of type UnitAttr";
399 memoryAttr = LLVM::MemoryEffectsAttr::get(rewriter.
getContext(),
400 {LLVM::ModRefInfo::NoModRef,
401 LLVM::ModRefInfo::NoModRef,
402 LLVM::ModRefInfo::NoModRef});
404 auto newFuncOp = rewriter.
create<LLVM::LLVMFuncOp>(
405 funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
406 false, LLVM::CConv::C, attributes);
409 newFuncOp.setMemoryAttr(memoryAttr);
423 struct FuncOpConversion :
public FuncOpConversionBase {
425 : FuncOpConversionBase(converter) {}
428 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
430 auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
434 if (funcOp->getAttrOfType<UnitAttr>(
435 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
436 if (newFuncOp.isVarArg())
437 return funcOp->emitError(
"C interface for variadic functions is not "
440 if (newFuncOp.isExternal())
455 struct BarePtrFuncOpConversion :
public FuncOpConversionBase {
456 using FuncOpConversionBase::FuncOpConversionBase;
459 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
470 llvm::to_vector<8>(funcOp.getFunctionType().getInputs());
472 auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
475 if (newFuncOp.getBody().empty()) {
483 Block *entryBlock = &newFuncOp.getBody().
front();
485 assert(blockArgs.size() == oldArgTypes.size() &&
486 "The number of arguments and types doesn't match");
490 for (
auto it : llvm::zip(blockArgs, oldArgTypes)) {
492 Type argTy = std::get<1>(it);
498 "Unranked memref is not supported");
499 auto memrefTy = argTy.
dyn_cast<MemRefType>();
510 auto placeholder = rewriter.
create<LLVM::UndefOp>(
511 loc, getTypeConverter()->convertType(memrefTy));
515 rewriter, loc, *getTypeConverter(), memrefTy, arg);
528 matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
530 auto type = typeConverter->convertType(op.getResult().getType());
535 rewriter.
create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
537 if (attr.getName().strref() ==
"value")
539 newOp->
setAttr(attr.getName(), attr.getValue());
541 rewriter.
replaceOp(op, newOp->getResults());
548 template <
typename CallOpType>
551 using Super = CallOpInterfaceLowering<CallOpType>;
555 matchAndRewrite(CallOpType callOp,
typename CallOpType::Adaptor adaptor,
558 Type packedResult =
nullptr;
559 unsigned numResults = callOp.getNumResults();
560 auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
562 if (numResults != 0) {
564 this->getTypeConverter()->packFunctionResults(resultTypes)))
568 auto promoted = this->getTypeConverter()->promoteOperands(
569 callOp.getLoc(), callOp->getOperands(),
570 adaptor.getOperands(), rewriter);
571 auto newOp = rewriter.
create<LLVM::CallOp>(
573 promoted, callOp->getAttrs());
576 if (numResults < 2) {
578 results.append(newOp.result_begin(), newOp.result_end());
582 results.reserve(numResults);
583 for (
unsigned i = 0; i < numResults; ++i) {
584 results.push_back(rewriter.
create<LLVM::ExtractValueOp>(
585 callOp.getLoc(), newOp->getResult(0), i));
589 if (this->getTypeConverter()->getOptions().useBarePtrCallConv) {
592 assert(results.size() == resultTypes.size() &&
593 "The number of arguments and types doesn't match");
594 this->getTypeConverter()->promoteBarePtrsToDescriptors(
595 rewriter, callOp.getLoc(), resultTypes, results);
596 }
else if (
failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
597 resultTypes, results,
607 struct CallOpLowering :
public CallOpInterfaceLowering<func::CallOp> {
611 struct CallIndirectOpLowering
612 :
public CallOpInterfaceLowering<func::CallIndirectOp> {
616 struct UnrealizedConversionCastOpLowering
622 matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
625 if (
succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
627 convertedTypes == adaptor.getInputs().getTypes()) {
628 rewriter.
replaceOp(op, adaptor.getInputs());
632 convertedTypes.clear();
633 if (
succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
635 convertedTypes == op.getOutputs().getType()) {
636 rewriter.
replaceOp(op, adaptor.getInputs());
653 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
656 unsigned numArguments = op.getNumOperands();
659 if (getTypeConverter()->getOptions().useBarePtrCallConv) {
662 for (
auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
663 Type oldTy = std::get<0>(it).getType();
664 Value newOperand = std::get<1>(it);
665 if (oldTy.
isa<MemRefType>() && getTypeConverter()->canConvertToBarePtr(
668 newOperand = memrefDesc.alignedPtr(rewriter, loc);
674 updatedOperands.push_back(newOperand);
677 updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
678 (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
684 if (numArguments <= 1) {
686 op,
TypeRange(), updatedOperands, op->getAttrs());
693 getTypeConverter()->packFunctionResults(op.getOperandTypes());
695 Value packed = rewriter.
create<LLVM::UndefOp>(loc, packedType);
697 packed = rewriter.
create<LLVM::InsertValueOp>(loc, packed, it.value(),
710 patterns.
add<BarePtrFuncOpConversion>(converter);
712 patterns.
add<FuncOpConversion>(converter);
720 CallIndirectOpLowering,
723 ReturnOpLowering>(converter);
729 struct ConvertFuncToLLVMPass
730 :
public impl::ConvertFuncToLLVMBase<ConvertFuncToLLVMPass> {
731 ConvertFuncToLLVMPass() =
default;
732 ConvertFuncToLLVMPass(
bool useBarePtrCallConv,
unsigned indexBitwidth,
733 bool useAlignedAlloc,
734 const llvm::DataLayout &dataLayout) {
735 this->useBarePtrCallConv = useBarePtrCallConv;
736 this->indexBitwidth = indexBitwidth;
737 this->dataLayout = dataLayout.getStringRepresentation();
741 void runOnOperation()
override {
742 if (
failed(LLVM::LLVMDialect::verifyDataLayoutString(
743 this->dataLayout, [
this](
const Twine &message) {
744 getOperation().emitError() << message.str();
750 ModuleOp m = getOperation();
751 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
754 dataLayoutAnalysis.getAtOrAbove(m));
755 options.useBarePtrCallConv = useBarePtrCallConv;
757 options.overrideIndexBitwidth(indexBitwidth);
758 options.dataLayout = llvm::DataLayout(this->dataLayout);
761 &dataLayoutAnalysis);
774 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
775 StringAttr::get(m.getContext(), this->dataLayout));
781 return std::make_unique<ConvertFuncToLLVMPass>();
784 std::unique_ptr<OperationPass<ModuleOp>>
786 auto allocLowering =
options.allocLowering;
790 "ConvertFuncToLLVMPass doesn't support AllocLowering::None");
791 bool useAlignedAlloc =
792 (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc);
793 return std::make_unique<ConvertFuncToLLVMPass>(
794 options.useBarePtrCallConv,
options.getIndexBitwidth(), useAlignedAlloc,
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void prependResAttrsToArgAttrs(OpBuilder &builder, SmallVectorImpl< NamedAttribute > &attributes, func::FuncOp func)
Combines all result attributes into a single DictionaryAttr and prepends to argument attrs.
static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs)
Helper function for wrapping all attributes into a single DictionaryAttr.
static void wrapExternalFunction(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, func::FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
static constexpr StringRef varargsAttrName
static constexpr StringRef linkageAttrName
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, LLVMTypeConverter &typeConverter, func::FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs, SmallVectorImpl< NamedAttribute > &result)
Only retain those attributes that are not constructed by LLVMFuncOp::build.
static llvm::ManagedStatic< PassManagerOptions > options
This class provides a shared interface for ranked and unranked memref types.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIntegerAttr(Type type, int64_t value)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
const LowerToLLVMOptions & getOptions() const
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
std::pair< Type, bool > convertFunctionTypeCWrapper(FunctionType type)
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
static Value pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
static unsigned getNumUnpackedValues(MemRefType type)
Returns the number of non-aggregate values that would be produced by unpack.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
NamedAttribute represents a combination of a name and an Attribute value.
Attribute getValue() const
Return the value of the attribute.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides all of the information necessary to convert a type signature.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
static Value pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static unsigned getNumUnpackedValues()
Returns the number of non-aggregate values that would be produced by unpack.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
std::unique_ptr< OperationPass< ModuleOp > > createConvertFuncToLLVMPass()
Creates a pass to convert the Func dialect into the LLVMIR dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the Func dialect to LLVM.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
void populateFuncToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the default pattern to convert a FuncOp to the LLVM dialect.
This class represents an efficient way to signal success or failure.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.