40 #include "llvm/ADT/SmallVector.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/IR/DerivedTypes.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/Type.h"
45 #include "llvm/Support/Casting.h"
46 #include "llvm/Support/CommandLine.h"
47 #include "llvm/Support/FormatVariadic.h"
53 #define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
54 #define GEN_PASS_DEF_SETLLVMMODULEDATALAYOUTPASS
55 #include "mlir/Conversion/Passes.h.inc"
60 #define PASS_NAME "convert-func-to-llvm"
80 attr.getName() == LLVM::LLVMDialect::getReadnoneAttrName())
82 result.push_back(attr);
88 FunctionOpInterface funcOp,
89 LLVM::LLVMFuncOp wrapperFuncOp) {
90 auto argAttrs = funcOp.getAllArgAttrs();
91 if (!resultStructType) {
92 if (
auto resAttrs = funcOp.getAllResultAttrs())
93 wrapperFuncOp.setAllResultAttrs(resAttrs);
95 wrapperFuncOp.setAllArgAttrs(argAttrs);
102 argAttributes.append(argAttrs.begin(), argAttrs.end());
103 wrapperFuncOp.setAllArgAttrs(argAttributes);
106 cast<FunctionOpInterface>(wrapperFuncOp.getOperation())
107 .setVisibility(funcOp.getVisibility());
120 FunctionOpInterface funcOp,
121 LLVM::LLVMFuncOp newFuncOp) {
122 auto type = cast<FunctionType>(funcOp.getFunctionType());
123 auto [wrapperFuncType, resultStructType] =
129 auto wrapperFuncOp = rewriter.
create<LLVM::LLVMFuncOp>(
130 loc, llvm::formatv(
"_mlir_ciface_{0}", funcOp.getName()).str(),
131 wrapperFuncType, LLVM::Linkage::External,
false,
132 LLVM::CConv::C,
nullptr, attributes);
139 size_t argOffset = resultStructType ? 1 : 0;
141 Value arg = wrapperFuncOp.getArgument(index + argOffset);
142 if (
auto memrefType = dyn_cast<MemRefType>(argType)) {
148 if (isa<UnrankedMemRefType>(argType)) {
158 auto call = rewriter.
create<LLVM::CallOp>(loc, newFuncOp, args);
160 if (resultStructType) {
161 rewriter.
create<LLVM::StoreOp>(loc, call.getResult(),
162 wrapperFuncOp.getArgument(0));
165 rewriter.
create<LLVM::ReturnOp>(loc, call.getResults());
180 FunctionOpInterface funcOp,
181 LLVM::LLVMFuncOp newFuncOp) {
184 auto [wrapperType, resultStructType] =
186 cast<FunctionType>(funcOp.getFunctionType()));
190 assert(wrapperType &&
"unexpected type conversion failure");
196 auto wrapperFunc = builder.
create<LLVM::LLVMFuncOp>(
197 loc, llvm::formatv(
"_mlir_ciface_{0}", funcOp.getName()).str(),
198 wrapperType, LLVM::Linkage::External,
false,
199 LLVM::CConv::C,
nullptr, attributes);
203 newFuncOp.setLinkage(LLVM::Linkage::Private);
207 FunctionType type = cast<FunctionType>(funcOp.getFunctionType());
209 args.reserve(type.getNumInputs());
210 ValueRange wrapperArgsRange(newFuncOp.getArguments());
212 if (resultStructType) {
214 Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
219 builder.
create<LLVM::AllocaOp>(loc, resultType, resultStructType, one);
220 args.push_back(result);
225 for (
Type input : type.getInputs()) {
228 auto memRefType = dyn_cast<MemRefType>(input);
229 auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(input);
230 if (memRefType || unrankedMemRefType) {
231 numToDrop = memRefType
237 wrapperArgsRange.take_front(numToDrop))
239 builder, loc, typeConverter, unrankedMemRefType,
240 wrapperArgsRange.take_front(numToDrop));
247 loc, ptrTy, packed.
getType(), one, 0);
248 builder.
create<LLVM::StoreOp>(loc, packed, allocated);
251 arg = wrapperArgsRange[0];
255 wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
257 assert(wrapperArgsRange.empty() &&
"did not map some of the arguments");
259 auto call = builder.
create<LLVM::CallOp>(loc, wrapperFunc, args);
261 if (resultStructType) {
263 builder.
create<LLVM::LoadOp>(loc, resultStructType, args.front());
264 builder.
create<LLVM::ReturnOp>(loc, result);
266 builder.
create<LLVM::ReturnOp>(loc, call.getResults());
270 FailureOr<LLVM::LLVMFuncOp>
275 auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
278 funcOp,
"Only support FunctionOpInterface with FunctionType");
285 funcTy, varargsAttr && varargsAttr.getValue(),
292 LLVM::Linkage linkage = LLVM::Linkage::External;
298 <<
" attribute not of type LLVM::LinkageAttr";
300 funcOp,
"Contains linkage attribute not of type LLVM::LinkageAttr");
302 linkage = attr.getLinkage();
307 auto newFuncOp = rewriter.
create<LLVM::LLVMFuncOp>(
308 funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
309 false, LLVM::CConv::C,
nullptr,
311 cast<FunctionOpInterface>(newFuncOp.getOperation())
312 .setVisibility(funcOp.getVisibility());
315 StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
316 if (funcOp->hasAttr(readnoneAttrName)) {
317 auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
319 funcOp->emitError() <<
"Contains " << readnoneAttrName
320 <<
" attribute not of type UnitAttr";
322 funcOp,
"Contains readnone attribute not of type UnitAttr");
326 {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
327 LLVM::ModRefInfo::NoModRef});
328 newFuncOp.setMemoryEffectsAttr(memoryAttr);
333 if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
334 assert(!resAttrDicts.empty() &&
"expected array to be non-empty");
335 if (funcOp.getNumResults() == 1)
336 newFuncOp.setAllResultAttrs(resAttrDicts);
338 if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
340 cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
341 for (
unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
346 auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
347 convertedAttrs.reserve(attrsDict.size());
351 cast<TypeAttr>(attr.getValue()).getValue()));
353 if (attr.getName().getValue() ==
354 LLVM::LLVMDialect::getByValAttrName()) {
356 LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
357 }
else if (attr.getName().getValue() ==
358 LLVM::LLVMDialect::getByRefAttrName()) {
360 LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
361 }
else if (attr.getName().getValue() ==
362 LLVM::LLVMDialect::getStructRetAttrName()) {
364 LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
365 }
else if (attr.getName().getValue() ==
366 LLVM::LLVMDialect::getInAllocaAttrName()) {
368 LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
370 convertedAttrs.push_back(attr);
373 auto mapping = result.getInputMapping(i);
374 assert(mapping &&
"unexpected deletion of function argument");
378 if (mapping->size == 1) {
379 newArgAttrs[mapping->inputNo] =
385 for (
size_t j = 0;
j < mapping->size; ++
j)
386 newArgAttrs[mapping->inputNo +
j] =
389 if (!newArgAttrs.empty())
390 newFuncOp.setAllArgAttrs(rewriter.
getArrayAttr(newArgAttrs));
398 "region types conversion failed");
402 if (funcOp->getAttrOfType<UnitAttr>(
403 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
404 if (newFuncOp.isVarArg())
405 return funcOp.emitError(
"C interface for variadic functions is not "
408 if (newFuncOp.isExternal())
430 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
433 cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
434 *getTypeConverter());
435 if (failed(newFuncOp))
447 matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
454 rewriter.
create<LLVM::AddressOfOp>(op.
getLoc(), type, op.getValue());
456 if (attr.getName().strref() ==
"value")
458 newOp->
setAttr(attr.getName(), attr.getValue());
460 rewriter.
replaceOp(op, newOp->getResults());
467 template <
typename CallOpType>
470 using Super = CallOpInterfaceLowering<CallOpType>;
473 LogicalResult matchAndRewriteImpl(CallOpType callOp,
474 typename CallOpType::Adaptor adaptor,
476 bool useBarePtrCallConv =
false)
const {
478 Type packedResult =
nullptr;
479 unsigned numResults = callOp.getNumResults();
480 auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
482 if (numResults != 0) {
483 if (!(packedResult = this->getTypeConverter()->packFunctionResults(
484 resultTypes, useBarePtrCallConv)))
488 if (useBarePtrCallConv) {
489 for (
auto it : callOp->getOperands()) {
490 Type operandType = it.getType();
491 if (isa<UnrankedMemRefType>(operandType)) {
498 auto promoted = this->getTypeConverter()->promoteOperands(
499 callOp.getLoc(), callOp->getOperands(),
500 adaptor.getOperands(), rewriter, useBarePtrCallConv);
501 auto newOp = rewriter.
create<LLVM::CallOp>(
503 promoted, callOp->getAttrs());
506 if (numResults < 2) {
508 results.append(newOp.result_begin(), newOp.result_end());
512 results.reserve(numResults);
513 for (
unsigned i = 0; i < numResults; ++i) {
514 results.push_back(rewriter.
create<LLVM::ExtractValueOp>(
515 callOp.getLoc(), newOp->getResult(0), i));
519 if (useBarePtrCallConv) {
522 assert(results.size() == resultTypes.size() &&
523 "The number of arguments and types doesn't match");
524 this->getTypeConverter()->promoteBarePtrsToDescriptors(
525 rewriter, callOp.getLoc(), resultTypes, results);
526 }
else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
527 resultTypes, results,
537 class CallOpLowering :
public CallOpInterfaceLowering<func::CallOp> {
542 : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
543 symbolTable(symbolTable) {}
546 matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
548 bool useBarePtrCallConv =
false;
549 if (getTypeConverter()->getOptions().useBarePtrCallConv) {
550 useBarePtrCallConv =
true;
551 }
else if (symbolTable !=
nullptr) {
554 symbolTable->lookup(callOp.getCalleeAttr().getValue());
564 return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
571 struct CallIndirectOpLowering
572 :
public CallOpInterfaceLowering<func::CallIndirectOp> {
576 matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
578 return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
582 struct UnrealizedConversionCastOpLowering
588 matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
591 if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
593 convertedTypes == adaptor.getInputs().getTypes()) {
594 rewriter.
replaceOp(op, adaptor.getInputs());
598 convertedTypes.clear();
599 if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
601 convertedTypes == op.getOutputs().getType()) {
602 rewriter.
replaceOp(op, adaptor.getInputs());
619 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
626 bool useBarePtrCallConv =
628 if (useBarePtrCallConv) {
631 for (
auto it : llvm::zip(op->
getOperands(), adaptor.getOperands())) {
632 Type oldTy = std::get<0>(it).getType();
633 Value newOperand = std::get<1>(it);
634 if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
635 cast<BaseMemRefType>(oldTy))) {
637 newOperand = memrefDesc.allocatedPtr(rewriter, loc);
638 }
else if (isa<UnrankedMemRefType>(oldTy)) {
643 updatedOperands.push_back(newOperand);
646 updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
653 if (numArguments <= 1) {
661 auto packedType = getTypeConverter()->packFunctionResults(
667 Value packed = rewriter.
create<LLVM::UndefOp>(loc, packedType);
669 packed = rewriter.
create<LLVM::InsertValueOp>(loc, packed, operand, idx);
680 patterns.
add<FuncOpConversion>(converter);
687 patterns.
add<CallIndirectOpLowering>(converter);
688 patterns.
add<CallOpLowering>(converter, symbolTable);
689 patterns.
add<ConstantOpLowering>(converter);
690 patterns.
add<ReturnOpLowering>(converter);
695 struct ConvertFuncToLLVMPass
696 :
public impl::ConvertFuncToLLVMPassBase<ConvertFuncToLLVMPass> {
700 void runOnOperation()
override {
701 ModuleOp m = getOperation();
702 StringRef dataLayout;
703 auto dataLayoutAttr = dyn_cast_or_null<StringAttr>(
704 m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));
706 dataLayout = dataLayoutAttr.getValue();
708 if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
709 dataLayout, [
this](
const Twine &message) {
710 getOperation().emitError() << message.str();
716 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
719 dataLayoutAnalysis.getAtOrAbove(m));
720 options.useBarePtrCallConv = useBarePtrCallConv;
722 options.overrideIndexBitwidth(indexBitwidth);
723 options.dataLayout = llvm::DataLayout(dataLayout);
726 &dataLayoutAnalysis);
728 std::optional<SymbolTable> optSymbolTable = std::nullopt;
730 if (!
options.useBarePtrCallConv) {
731 optSymbolTable.emplace(m);
732 symbolTable = &optSymbolTable.value();
749 struct SetLLVMModuleDataLayoutPass
750 :
public impl::SetLLVMModuleDataLayoutPassBase<
751 SetLLVMModuleDataLayoutPass> {
755 void runOnOperation()
override {
756 if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
757 this->dataLayout, [
this](
const Twine &message) {
758 getOperation().emitError() << message.str();
763 ModuleOp m = getOperation();
764 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
780 void populateConvertToLLVMConversionPatterns(
790 dialect->addInterfaces<FuncToLLVMDialectInterface>();
static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType, FunctionOpInterface funcOp, LLVM::LLVMFuncOp wrapperFuncOp)
Propagate argument/results attributes.
static constexpr StringRef barePtrAttrName
static constexpr StringRef varargsAttrName
static constexpr StringRef linkageAttrName
static void filterFuncAttributes(FunctionOpInterface func, SmallVectorImpl< NamedAttribute > &result)
Only retain those attributes that are not constructed by LLVMFuncOp::build.
static bool shouldUseBarePtrCallConv(Operation *op, const LLVMTypeConverter *typeConverter)
Return true if the op should use bare pointer calling convention.
static void wrapExternalFunction(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, FunctionOpInterface funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, const LLVMTypeConverter &typeConverter, FunctionOpInterface funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIntegerAttr(Type type, int64_t value)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
const LowerToLLVMOptions & getOptions() const
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
static unsigned getNumUnpackedValues(MemRefType type)
Returns the number of non-aggregate values that would be produced by unpack.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides all of the information necessary to convert a type signature.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
static unsigned getNumUnpackedValues()
Returns the number of non-aggregate values that would be produced by unpack.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void registerConvertFuncToLLVMInterface(DialectRegistry ®istry)
FailureOr< LLVM::LLVMFuncOp > convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter)
Convert input FunctionOpInterface operation to LLVMFuncOp by using the provided LLVMTypeConverter.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateFuncToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the default pattern to convert a FuncOp to the LLVM dialect.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.