MLIR 23.0.0git
FuncToLLVM.cpp
Go to the documentation of this file.
1//===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert MLIR Func and builtin dialects
10// into the LLVM IR dialect.
11//
12//===----------------------------------------------------------------------===//
13
15
27#include "mlir/IR/Attributes.h"
28#include "mlir/IR/Builders.h"
30#include "mlir/IR/BuiltinOps.h"
32#include "mlir/IR/SymbolTable.h"
36#include "llvm/ADT/SmallVector.h"
37#include "llvm/IR/Type.h"
38#include "llvm/Support/DebugLog.h"
39#include "llvm/Support/FormatVariadic.h"
40#include <optional>
41
42namespace mlir {
43#define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
44#define GEN_PASS_DEF_SETLLVMMODULEDATALAYOUTPASS
45#include "mlir/Conversion/Passes.h.inc"
46} // namespace mlir
47
48using namespace mlir;
49
50#define PASS_NAME "convert-func-to-llvm"
51#define DEBUG_TYPE PASS_NAME
52
53static constexpr StringRef varargsAttrName = "func.varargs";
54static constexpr StringRef linkageAttrName = "llvm.linkage";
55static constexpr StringRef barePtrAttrName = "llvm.bareptr";
56
57/// Return `true` if the `op` should use bare pointer calling convention.
59 const LLVMTypeConverter *typeConverter) {
60 return (op && op->hasAttr(barePtrAttrName)) ||
61 typeConverter->getOptions().useBarePtrCallConv;
62}
63
64static bool isDiscardableAttr(StringRef name) {
65 return name == linkageAttrName || name == varargsAttrName ||
66 name == LLVM::LLVMDialect::getReadnoneAttrName();
67}
68
69/// Only retain those attributes that are not constructed by
70/// `LLVMFuncOp::build`.
71static void filterFuncAttributes(FunctionOpInterface func,
73 for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
74 if (isDiscardableAttr(attr.getName().strref()))
75 continue;
76 result.push_back(attr);
77 }
78}
79
80/// Add custom lowered funcOp to llvm.func attributes here.
82 LLVM::LLVMFuncOp::Properties properties;
84};
85
86/// Lower discardable function attributes on `func.func` to attributes expected
87/// by `llvm.func`.
88static FailureOr<LoweredFuncAttrs>
89lowerFuncAttributes(FunctionOpInterface func) {
90 MLIRContext *ctx = func->getContext();
91 LoweredFuncAttrs lowered;
92
93 llvm::SmallDenseSet<StringRef> odsAttrNames(
94 LLVM::LLVMFuncOp::getAttributeNames().begin(),
95 LLVM::LLVMFuncOp::getAttributeNames().end());
96
97 NamedAttrList inherentAttrs;
98
99 for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
100 StringRef attrName = attr.getName().strref();
101
102 if (odsAttrNames.contains(attrName)) {
103 LDBG() << "LLVM specific attributes: " << attrName
104 << "should use llvm.* prefix, discarding it";
105 continue;
106 }
107
108 StringRef inherent = attrName;
109 if (inherent.consume_front("llvm.") && odsAttrNames.contains(inherent))
110 inherentAttrs.set(inherent, attr.getValue()); // collect inherent attrs
111 else
112 lowered.discardableAttrs.push_back(attr);
113 }
114
115 // Convert collected inherent attrs into typed properties.
116 if (!inherentAttrs.empty()) {
117 DictionaryAttr dict = inherentAttrs.getDictionary(ctx);
118 auto emitError = [&] {
119 return func.emitOpError("invalid llvm.func property");
120 };
121 if (failed(LLVM::LLVMFuncOp::setPropertiesFromAttr(lowered.properties, dict,
122 emitError))) {
123 return failure();
124 }
125 }
126 return lowered;
127}
128
130 FunctionOpInterface srcFunc,
131 Type llvmFuncType,
132 LLVM::LLVMFuncOp::Properties &props) {
133 MLIRContext *ctx = rewriter.getContext();
134 props.sym_name = rewriter.getStringAttr(srcFunc.getName());
135 props.function_type = TypeAttr::get(llvmFuncType);
136 props.setCConv(LLVM::CConvAttr::get(ctx, LLVM::CConv::C));
137}
138
139/// Propagate argument/results attributes.
140static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
141 FunctionOpInterface funcOp,
142 LLVM::LLVMFuncOp wrapperFuncOp) {
143 auto argAttrs = funcOp.getAllArgAttrs();
144 if (!resultStructType) {
145 if (auto resAttrs = funcOp.getAllResultAttrs())
146 wrapperFuncOp.setAllResultAttrs(resAttrs);
147 if (argAttrs)
148 wrapperFuncOp.setAllArgAttrs(argAttrs);
149 } else {
150 SmallVector<Attribute> argAttributes;
151 // Only modify the argument and result attributes when the result is now
152 // an argument.
153 if (argAttrs) {
154 argAttributes.push_back(builder.getDictionaryAttr({}));
155 argAttributes.append(argAttrs.begin(), argAttrs.end());
156 wrapperFuncOp.setAllArgAttrs(argAttributes);
157 }
158 }
159 cast<FunctionOpInterface>(wrapperFuncOp.getOperation())
160 .setVisibility(funcOp.getVisibility());
161}
162
163/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
164/// arguments instead of unpacked arguments. This function can be called from C
165/// by passing a pointer to a C struct corresponding to a memref descriptor.
166/// Similarly, returned memrefs are passed via pointers to a C struct that is
167/// passed as additional argument.
168/// Internally, the auxiliary function unpacks the descriptor into individual
169/// components and forwards them to `newFuncOp` and forwards the results to
170/// the extra arguments.
171static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
172 const LLVMTypeConverter &typeConverter,
173 FunctionOpInterface funcOp,
174 LLVM::LLVMFuncOp newFuncOp) {
175 auto type = cast<FunctionType>(funcOp.getFunctionType());
176 auto [wrapperFuncType, resultStructType] =
177 typeConverter.convertFunctionTypeCWrapper(type);
178
180 filterFuncAttributes(funcOp, attributes);
181
182 auto wrapperFuncOp = LLVM::LLVMFuncOp::create(
183 rewriter, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
184 wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false,
185 /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
186 propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp);
187
188 OpBuilder::InsertionGuard guard(rewriter);
189 rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(rewriter));
190
192 size_t argOffset = resultStructType ? 1 : 0;
193 for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
194 Value arg = wrapperFuncOp.getArgument(index + argOffset);
195 if (auto memrefType = dyn_cast<MemRefType>(argType)) {
196 Value loaded = LLVM::LoadOp::create(
197 rewriter, loc, typeConverter.convertType(memrefType), arg);
198 MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
199 continue;
200 }
201 if (isa<UnrankedMemRefType>(argType)) {
202 Value loaded = LLVM::LoadOp::create(
203 rewriter, loc, typeConverter.convertType(argType), arg);
204 UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
205 continue;
206 }
207
208 args.push_back(arg);
209 }
210
211 auto call = LLVM::CallOp::create(rewriter, loc, newFuncOp, args);
212
213 if (resultStructType) {
214 LLVM::StoreOp::create(rewriter, loc, call.getResult(),
215 wrapperFuncOp.getArgument(0));
216 LLVM::ReturnOp::create(rewriter, loc, ValueRange{});
217 } else {
218 LLVM::ReturnOp::create(rewriter, loc, call.getResults());
219 }
220}
221
222/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
223/// arguments instead of unpacked arguments. Creates a body for the (external)
224/// `newFuncOp` that allocates a memref descriptor on stack, packs the
225/// individual arguments into this descriptor and passes a pointer to it into
226/// the auxiliary function. If the result of the function cannot be directly
227/// returned, we write it to a special first argument that provides a pointer
228/// to a corresponding struct. This auxiliary external function is now
229/// compatible with functions defined in C using pointers to C structs
230/// corresponding to a memref descriptor.
231static void wrapExternalFunction(OpBuilder &builder, Location loc,
232 const LLVMTypeConverter &typeConverter,
233 FunctionOpInterface funcOp,
234 LLVM::LLVMFuncOp newFuncOp) {
235 OpBuilder::InsertionGuard guard(builder);
236
237 auto [wrapperType, resultStructType] =
238 typeConverter.convertFunctionTypeCWrapper(
239 cast<FunctionType>(funcOp.getFunctionType()));
240 // This conversion can only fail if it could not convert one of the argument
241 // types. But since it has been applied to a non-wrapper function before, it
242 // should have failed earlier and not reach this point at all.
243 assert(wrapperType && "unexpected type conversion failure");
244
246 filterFuncAttributes(funcOp, attributes);
247
248 // Create the auxiliary function.
249 auto wrapperFunc = LLVM::LLVMFuncOp::create(
250 builder, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
251 wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false,
252 /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
253 propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc);
254
255 // The wrapper that we synthetize here should only be visible in this module.
256 newFuncOp.setLinkage(LLVM::Linkage::Private);
257 builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder));
258
259 // Get a ValueRange containing arguments.
260 FunctionType type = cast<FunctionType>(funcOp.getFunctionType());
262 args.reserve(type.getNumInputs());
263 ValueRange wrapperArgsRange(newFuncOp.getArguments());
264
265 if (resultStructType) {
266 // Allocate the struct on the stack and pass the pointer.
267 Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
268 Value one = LLVM::ConstantOp::create(
269 builder, loc, typeConverter.convertType(builder.getIndexType()),
270 builder.getIntegerAttr(builder.getIndexType(), 1));
271 Value result =
272 LLVM::AllocaOp::create(builder, loc, resultType, resultStructType, one);
273 args.push_back(result);
274 }
275
276 // Iterate over the inputs of the original function and pack values into
277 // memref descriptors if the original type is a memref.
278 for (Type input : type.getInputs()) {
279 Value arg;
280 int numToDrop = 1;
281 auto memRefType = dyn_cast<MemRefType>(input);
282 auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(input);
283 if (memRefType || unrankedMemRefType) {
284 numToDrop = memRefType
287 Value packed =
288 memRefType
289 ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
290 wrapperArgsRange.take_front(numToDrop))
292 builder, loc, typeConverter, unrankedMemRefType,
293 wrapperArgsRange.take_front(numToDrop));
294
295 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
296 Value one = LLVM::ConstantOp::create(
297 builder, loc, typeConverter.convertType(builder.getIndexType()),
298 builder.getIntegerAttr(builder.getIndexType(), 1));
299 Value allocated = LLVM::AllocaOp::create(
300 builder, loc, ptrTy, packed.getType(), one, /*alignment=*/0);
301 LLVM::StoreOp::create(builder, loc, packed, allocated);
302 arg = allocated;
303 } else {
304 arg = wrapperArgsRange[0];
305 }
306
307 args.push_back(arg);
308 wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
309 }
310 assert(wrapperArgsRange.empty() && "did not map some of the arguments");
311
312 auto call = LLVM::CallOp::create(builder, loc, wrapperFunc, args);
313
314 if (resultStructType) {
315 Value result =
316 LLVM::LoadOp::create(builder, loc, resultStructType, args.front());
317 LLVM::ReturnOp::create(builder, loc, result);
318 } else {
319 LLVM::ReturnOp::create(builder, loc, call.getResults());
320 }
321}
322
323/// Inserts `llvm.load` ops in the function body to restore the expected pointee
324/// value from `llvm.byval`/`llvm.byref` function arguments that were converted
325/// to LLVM pointer types.
327 ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter,
328 ArrayRef<std::optional<NamedAttribute>> byValRefNonPtrAttrs,
329 LLVM::LLVMFuncOp funcOp) {
330 // Nothing to do for function declarations.
331 if (funcOp.isExternal())
332 return;
333
334 ConversionPatternRewriter::InsertionGuard guard(rewriter);
335 rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
336
337 for (const auto &[arg, byValRefAttr] :
338 llvm::zip(funcOp.getArguments(), byValRefNonPtrAttrs)) {
339 // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
340 if (!byValRefAttr)
341 continue;
342
343 // Insert load to retrieve the actual argument passed by value/reference.
344 assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
345 "Expected LLVM pointer type for argument with "
346 "`llvm.byval`/`llvm.byref` attribute");
347 Type resTy = typeConverter.convertType(
348 cast<TypeAttr>(byValRefAttr->getValue()).getValue());
349
350 Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
351 rewriter.replaceAllUsesWith(arg, valueArg);
352 }
353}
354
355/// TODO: Refactor this function to be more modular and easier to understand.
356FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
357 FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter,
358 const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables) {
359 // Check the funcOp has `FunctionType`.
360 auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
361 if (!funcTy)
362 return rewriter.notifyMatchFailure(
363 funcOp, "Only support FunctionOpInterface with FunctionType");
364
365 // Convert the original function arguments. They are converted using the
366 // LLVMTypeConverter provided to this legalization pattern.
367 auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
368 // Gather `llvm.byval` and `llvm.byref` arguments whose type convertion was
369 // overriden with an LLVM pointer type for later processing.
371 TypeConverter::SignatureConversion result(funcOp.getNumArguments());
372 auto llvmType = dyn_cast_or_null<LLVM::LLVMFunctionType>(
373 converter.convertFunctionSignature(
374 funcOp, varargsAttr && varargsAttr.getValue(),
375 shouldUseBarePtrCallConv(funcOp, &converter), result,
376 byValRefNonPtrAttrs));
377 if (!llvmType)
378 return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
379
380 // Check for unsupported variadic functions.
381 if (!shouldUseBarePtrCallConv(funcOp, &converter))
382 if (funcOp->getAttrOfType<UnitAttr>(
383 LLVM::LLVMDialect::getEmitCWrapperAttrName()))
384 if (llvmType.isVarArg())
385 return funcOp.emitError("C interface for variadic functions is not "
386 "supported yet.");
387
388 FailureOr<LoweredFuncAttrs> loweredAttrs = lowerFuncAttributes(funcOp);
389 if (failed(loweredAttrs))
390 return rewriter.notifyMatchFailure(funcOp,
391 "failed to lower func attributes");
392
393 Operation *symbolTableOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
394
395 if (symbolTables && symbolTableOp) {
396 SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
397 symbolTable.remove(funcOp);
398 }
399 buildLLVMFuncProperties(rewriter, funcOp, llvmType, loweredAttrs->properties);
400 auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, funcOp.getLoc(),
401 loweredAttrs->properties,
402 loweredAttrs->discardableAttrs);
403
404 if (symbolTables && symbolTableOp) {
405 auto ip = rewriter.getInsertionPoint();
406 SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
407 symbolTable.insert(newFuncOp, ip);
408 }
409
410 cast<FunctionOpInterface>(newFuncOp.getOperation())
411 .setVisibility(funcOp.getVisibility());
412
413 // Create a memory effect attribute corresponding to readnone.
414 StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
415 if (funcOp->hasAttr(readnoneAttrName)) {
416 auto memoryAttr = LLVM::MemoryEffectsAttr::get(
417 rewriter.getContext(), {/*other=*/LLVM::ModRefInfo::NoModRef,
418 /*argMem=*/LLVM::ModRefInfo::NoModRef,
419 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
420 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
421 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
422 /*targetMem1=*/LLVM::ModRefInfo::NoModRef});
423 newFuncOp.setMemoryEffectsAttr(memoryAttr);
424 }
425
426 // Propagate argument/result attributes to all converted arguments/result
427 // obtained after converting a given original argument/result.
428 if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
429 assert(!resAttrDicts.empty() && "expected array to be non-empty");
430 if (funcOp.getNumResults() == 1)
431 newFuncOp.setAllResultAttrs(resAttrDicts);
432 }
433 if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
434 SmallVector<Attribute> newArgAttrs(
435 cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
436 for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
437 // Some LLVM IR attribute have a type attached to them. During FuncOp ->
438 // LLVMFuncOp conversion these types may have changed. Account for that
439 // change by converting attributes' types as well.
440 SmallVector<NamedAttribute, 4> convertedAttrs;
441 auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
442 convertedAttrs.reserve(attrsDict.size());
443 for (const NamedAttribute &attr : attrsDict) {
444 const auto convert = [&](const NamedAttribute &attr) {
445 return TypeAttr::get(converter.convertType(
446 cast<TypeAttr>(attr.getValue()).getValue()));
447 };
448 if (attr.getName().getValue() ==
449 LLVM::LLVMDialect::getByValAttrName()) {
450 convertedAttrs.push_back(rewriter.getNamedAttr(
451 LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
452 } else if (attr.getName().getValue() ==
453 LLVM::LLVMDialect::getByRefAttrName()) {
454 convertedAttrs.push_back(rewriter.getNamedAttr(
455 LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
456 } else if (attr.getName().getValue() ==
457 LLVM::LLVMDialect::getStructRetAttrName()) {
458 convertedAttrs.push_back(rewriter.getNamedAttr(
459 LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
460 } else if (attr.getName().getValue() ==
461 LLVM::LLVMDialect::getInAllocaAttrName()) {
462 convertedAttrs.push_back(rewriter.getNamedAttr(
463 LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
464 } else {
465 convertedAttrs.push_back(attr);
466 }
467 }
468 auto mapping = result.getInputMapping(i);
469 assert(mapping && "unexpected deletion of function argument");
470 // Only attach the new argument attributes if there is a one-to-one
471 // mapping from old to new types. Otherwise, attributes might be
472 // attached to types that they do not support.
473 if (mapping->size == 1) {
474 newArgAttrs[mapping->inputNo] =
475 DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
476 continue;
477 }
478 // TODO: Implement custom handling for types that expand to multiple
479 // function arguments.
480 for (size_t j = 0; j < mapping->size; ++j)
481 newArgAttrs[mapping->inputNo + j] =
482 DictionaryAttr::get(rewriter.getContext(), {});
483 }
484 if (!newArgAttrs.empty())
485 newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs));
486 }
487
488 rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
489 newFuncOp.end());
490 // Convert just the entry block. The remaining unstructured control flow is
491 // converted by ControlFlowToLLVM.
492 if (!newFuncOp.getBody().empty())
493 rewriter.applySignatureConversion(&newFuncOp.getBody().front(), result,
494 &converter);
495
496 // Fix the type mismatch between the materialized `llvm.ptr` and the expected
497 // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
498 // function arguments.
499 restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs,
500 newFuncOp);
501
502 if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
503 if (funcOp->getAttrOfType<UnitAttr>(
504 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
505 if (newFuncOp.isExternal())
506 wrapExternalFunction(rewriter, funcOp->getLoc(), converter, funcOp,
507 newFuncOp);
508 else
509 wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
510 newFuncOp);
511 }
512 }
513
514 return newFuncOp;
515}
516
517namespace {
518
519/// FuncOp legalization pattern that converts MemRef arguments to pointers to
520/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
521/// information.
522class FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
523 SymbolTableCollection *symbolTables = nullptr;
524
525public:
526 explicit FuncOpConversion(const LLVMTypeConverter &converter,
527 SymbolTableCollection *symbolTables = nullptr)
528 : ConvertOpToLLVMPattern(converter), symbolTables(symbolTables) {}
529
530 LogicalResult
531 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
532 ConversionPatternRewriter &rewriter) const override {
533 FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
534 cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
535 *getTypeConverter(), symbolTables);
536 if (failed(newFuncOp))
537 return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
538
539 rewriter.eraseOp(funcOp);
540 return success();
541 }
542};
543
544struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
545 using ConvertOpToLLVMPattern<func::ConstantOp>::ConvertOpToLLVMPattern;
546
547 LogicalResult
548 matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
549 ConversionPatternRewriter &rewriter) const override {
550 auto type = typeConverter->convertType(op.getResult().getType());
551 if (!type || !LLVM::isCompatibleType(type))
552 return rewriter.notifyMatchFailure(op, "failed to convert result type");
553
554 auto newOp =
555 LLVM::AddressOfOp::create(rewriter, op.getLoc(), type, op.getValue());
556 for (const NamedAttribute &attr : op->getAttrs()) {
557 if (attr.getName().strref() == "value")
558 continue;
559 newOp->setAttr(attr.getName(), attr.getValue());
560 }
561 rewriter.replaceOp(op, newOp->getResults());
562 return success();
563 }
564};
565
566// A CallOp automatically promotes MemRefType to a sequence of alloca/store and
567// passes the pointer to the MemRef across function boundaries.
568template <typename CallOpType>
569struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
570 using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
571 using Super = CallOpInterfaceLowering<CallOpType>;
572 using Base = ConvertOpToLLVMPattern<CallOpType>;
574
575 LogicalResult matchAndRewriteImpl(CallOpType callOp, Adaptor adaptor,
576 ConversionPatternRewriter &rewriter,
577 bool useBarePtrCallConv = false) const {
578 // Pack the result types into a struct.
579 Type packedResult = nullptr;
580 SmallVector<SmallVector<Type>> groupedResultTypes;
581 unsigned numResults = callOp.getNumResults();
582 auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
583 int64_t numConvertedTypes = 0;
584 if (numResults != 0) {
585 if (!(packedResult = this->getTypeConverter()->packFunctionResults(
586 resultTypes, useBarePtrCallConv, &groupedResultTypes,
587 &numConvertedTypes)))
588 return failure();
589 }
590
591 if (useBarePtrCallConv) {
592 for (auto it : callOp->getOperands()) {
593 Type operandType = it.getType();
594 if (isa<UnrankedMemRefType>(operandType)) {
595 // Unranked memref is not supported in the bare pointer calling
596 // convention.
597 return failure();
598 }
599 }
600 }
601 auto promoted = this->getTypeConverter()->promoteOperands(
602 callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
603 adaptor.getOperands(), rewriter, useBarePtrCallConv);
604 auto newOp = LLVM::CallOp::create(rewriter, callOp.getLoc(),
605 packedResult ? TypeRange(packedResult)
606 : TypeRange(),
607 promoted, callOp->getAttrs());
608
609 newOp.getProperties().operandSegmentSizes = {
610 static_cast<int32_t>(promoted.size()), 0};
611 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
612
613 // Helper function that extracts an individual result from the return value
614 // of the new call op. llvm.call ops support only 0 or 1 result. In case of
615 // 2 or more results, the results are packed into a structure.
616 //
617 // The new call op may have more than 2 results because:
618 // a. The original call op has more than 2 results.
619 // b. An original op result type-converted to more than 1 result.
620 auto getUnpackedResult = [&](unsigned i) -> Value {
621 assert(numConvertedTypes > 0 && "convert op has no results");
622 if (numConvertedTypes == 1) {
623 assert(i == 0 && "out of bounds: converted op has only one result");
624 return newOp->getResult(0);
625 }
626 // Results have been converted to a structure. Extract individual results
627 // from the structure.
628 return LLVM::ExtractValueOp::create(rewriter, callOp.getLoc(),
629 newOp->getResult(0), i);
630 };
631
632 // Group the results into a vector of vectors, such that it is clear which
633 // original op result is replaced with which range of values. (In case of a
634 // 1:N conversion, there can be multiple replacements for a single result.)
635 SmallVector<SmallVector<Value>> results;
636 results.reserve(numResults);
637 unsigned counter = 0;
638 for (unsigned i = 0; i < numResults; ++i) {
639 SmallVector<Value> &group = results.emplace_back();
640 for (unsigned j = 0, e = groupedResultTypes[i].size(); j < e; ++j)
641 group.push_back(getUnpackedResult(counter++));
642 }
643
644 // Special handling for MemRef types.
645 for (unsigned i = 0; i < numResults; ++i) {
646 Type origType = resultTypes[i];
647 auto memrefType = dyn_cast<MemRefType>(origType);
648 auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(origType);
649 if (useBarePtrCallConv && memrefType) {
650 // For the bare-ptr calling convention, promote memref results to
651 // descriptors.
652 assert(results[i].size() == 1 && "expected one converted result");
653 results[i].front() = MemRefDescriptor::fromStaticShape(
654 rewriter, callOp.getLoc(), *this->getTypeConverter(), memrefType,
655 results[i].front());
656 }
657 if (unrankedMemrefType) {
658 assert(!useBarePtrCallConv && "unranked memref is not supported in the "
659 "bare-ptr calling convention");
660 assert(results[i].size() == 1 && "expected one converted result");
661 Value desc = this->copyUnrankedDescriptor(
662 rewriter, callOp.getLoc(), unrankedMemrefType, results[i].front(),
663 /*toDynamic=*/false);
664 if (!desc)
665 return failure();
666 results[i].front() = desc;
667 }
668 }
669
670 rewriter.replaceOpWithMultiple(callOp, results);
671 return success();
672 }
673};
674
675class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
676public:
677 explicit CallOpLowering(const LLVMTypeConverter &typeConverter,
678 SymbolTableCollection *symbolTables = nullptr,
679 PatternBenefit benefit = 1)
680 : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
681 symbolTables(symbolTables) {}
682
683 LogicalResult
684 matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor,
685 ConversionPatternRewriter &rewriter) const override {
686 bool useBarePtrCallConv = false;
687 if (getTypeConverter()->getOptions().useBarePtrCallConv) {
688 useBarePtrCallConv = true;
689 } else if (symbolTables != nullptr) {
690 // Fast lookup.
691 Operation *callee =
692 symbolTables->lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
693 useBarePtrCallConv =
694 callee != nullptr && callee->hasAttr(barePtrAttrName);
695 } else {
696 // Warning: This is a linear lookup.
697 Operation *callee =
698 SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
699 useBarePtrCallConv =
700 callee != nullptr && callee->hasAttr(barePtrAttrName);
701 }
702 return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
703 }
704
705private:
706 SymbolTableCollection *symbolTables = nullptr;
707};
708
709struct CallIndirectOpLowering
710 : public CallOpInterfaceLowering<func::CallIndirectOp> {
711 using Super::Super;
712
713 LogicalResult
714 matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor,
715 ConversionPatternRewriter &rewriter) const override {
716 return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
717 }
718};
719
720struct UnrealizedConversionCastOpLowering
721 : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
722 using ConvertOpToLLVMPattern<
723 UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
724
725 LogicalResult
726 matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
727 ConversionPatternRewriter &rewriter) const override {
728 SmallVector<Type> convertedTypes;
729 if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
730 convertedTypes)) &&
731 convertedTypes == adaptor.getInputs().getTypes()) {
732 rewriter.replaceOp(op, adaptor.getInputs());
733 return success();
734 }
735
736 convertedTypes.clear();
737 if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
738 convertedTypes)) &&
739 convertedTypes == op.getOutputs().getType()) {
740 rewriter.replaceOp(op, adaptor.getInputs());
741 return success();
742 }
743 return failure();
744 }
745};
746
747// Special lowering pattern for `ReturnOps`. Unlike all other operations,
748// `ReturnOp` interacts with the function signature and must have as many
749// operands as the function has return values. Because in LLVM IR, functions
750// can only return 0 or 1 value, we pack multiple values into a structure type.
751// Emit `PoisonOp` followed by `InsertValueOp`s to create such structure if
752// necessary before returning it
753struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
754 using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
755
756 LogicalResult
757 matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
758 ConversionPatternRewriter &rewriter) const override {
759 Location loc = op.getLoc();
760 SmallVector<Value, 4> updatedOperands;
761
762 auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
763 bool useBarePtrCallConv =
764 shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
765
766 for (auto [oldOperand, newOperands] :
767 llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
768 Type oldTy = oldOperand.getType();
769 if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
770 assert(newOperands.size() == 1 && "expected one converted result");
771 if (useBarePtrCallConv &&
772 getTypeConverter()->canConvertToBarePtr(memRefType)) {
773 // For the bare-ptr calling convention, extract the aligned pointer to
774 // be returned from the memref descriptor.
775 MemRefDescriptor memrefDesc(newOperands.front());
776 updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc));
777 continue;
778 }
779 } else if (auto unrankedMemRefType =
780 dyn_cast<UnrankedMemRefType>(oldTy)) {
781 assert(newOperands.size() == 1 && "expected one converted result");
782 if (useBarePtrCallConv) {
783 // Unranked memref is not supported in the bare pointer calling
784 // convention.
785 return failure();
786 }
787 Value updatedDesc =
788 copyUnrankedDescriptor(rewriter, loc, unrankedMemRefType,
789 newOperands.front(), /*toDynamic=*/true);
790 if (!updatedDesc)
791 return failure();
792 updatedOperands.push_back(updatedDesc);
793 continue;
794 }
795
796 llvm::append_range(updatedOperands, newOperands);
797 }
798
799 // If ReturnOp has 0 or 1 operand, create it and return immediately.
800 if (updatedOperands.size() <= 1) {
801 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
802 op, TypeRange(), updatedOperands, op->getAttrs());
803 return success();
804 }
805
806 // Otherwise, we need to pack the arguments into an LLVM struct type before
807 // returning.
808 auto packedType = getTypeConverter()->packFunctionResults(
809 op.getOperandTypes(), useBarePtrCallConv);
810 if (!packedType) {
811 return rewriter.notifyMatchFailure(op, "could not convert result types");
812 }
813
814 Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType);
815 for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
816 packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx);
817 }
818 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
819 op->getAttrs());
820 return success();
821 }
822};
823} // namespace
824
826 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
827 SymbolTableCollection *symbolTables) {
828 patterns.add<FuncOpConversion>(converter, symbolTables);
829}
830
832 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
833 SymbolTableCollection *symbolTables) {
834 populateFuncToLLVMFuncOpConversionPattern(converter, patterns, symbolTables);
835 patterns.add<CallIndirectOpLowering>(converter);
836 patterns.add<CallOpLowering>(converter, symbolTables);
837 patterns.add<ConstantOpLowering>(converter);
838 patterns.add<ReturnOpLowering>(converter);
839}
840
841namespace {
842/// A pass converting Func operations into the LLVM IR dialect.
843struct ConvertFuncToLLVMPass
844 : public impl::ConvertFuncToLLVMPassBase<ConvertFuncToLLVMPass> {
845 using Base::Base;
846
847 /// Run the dialect converter on the module.
848 void runOnOperation() override {
849 ModuleOp m = getOperation();
850 StringRef dataLayout;
851 auto dataLayoutAttr = dyn_cast_or_null<StringAttr>(
852 m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));
853 if (dataLayoutAttr)
854 dataLayout = dataLayoutAttr.getValue();
855
856 if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
857 dataLayout, [this](const Twine &message) {
858 getOperation().emitError() << message.str();
859 }))) {
860 signalPassFailure();
861 return;
862 }
863
864 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
865
866 LowerToLLVMOptions options(&getContext(),
867 dataLayoutAnalysis.getAtOrAbove(m));
868 options.useBarePtrCallConv = useBarePtrCallConv;
869 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
870 options.overrideIndexBitwidth(indexBitwidth);
871 options.dataLayout = llvm::DataLayout(dataLayout);
872
873 LLVMTypeConverter typeConverter(&getContext(), options,
874 &dataLayoutAnalysis);
875
876 RewritePatternSet patterns(&getContext());
877 SymbolTableCollection symbolTables;
878
879 populateFuncToLLVMConversionPatterns(typeConverter, patterns,
880 &symbolTables);
881
882 LLVMConversionTarget target(getContext());
883 if (failed(applyPartialConversion(m, target, std::move(patterns))))
884 signalPassFailure();
885 }
886};
887
888struct SetLLVMModuleDataLayoutPass
889 : public impl::SetLLVMModuleDataLayoutPassBase<
890 SetLLVMModuleDataLayoutPass> {
891 using Base::Base;
892
893 /// Run the dialect converter on the module.
894 void runOnOperation() override {
895 if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
896 this->dataLayout, [this](const Twine &message) {
897 getOperation().emitError() << message.str();
898 }))) {
899 signalPassFailure();
900 return;
901 }
902 ModuleOp m = getOperation();
903 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
904 StringAttr::get(m.getContext(), this->dataLayout));
905 }
906};
907} // namespace
908
909//===----------------------------------------------------------------------===//
910// ConvertToLLVMPatternInterface implementation
911//===----------------------------------------------------------------------===//
912
913namespace {
914/// Implement the interface to convert Func to LLVM.
915struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
916 FuncToLLVMDialectInterface(Dialect *dialect)
917 : ConvertToLLVMPatternInterface(dialect) {}
918 /// Hook for derived dialect interface to provide conversion patterns
919 /// and mark dialect legal for the conversion target.
920 void populateConvertToLLVMConversionPatterns(
921 ConversionTarget &target, LLVMTypeConverter &typeConverter,
922 RewritePatternSet &patterns) const final {
923 populateFuncToLLVMConversionPatterns(typeConverter, patterns);
924 }
925};
926} // namespace
927
929 registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
930 dialect->addInterfaces<FuncToLLVMDialectInterface>();
931 });
932}
return success()
static void restoreByValRefArgumentType(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, ArrayRef< std::optional< NamedAttribute > > byValRefNonPtrAttrs, LLVM::LLVMFuncOp funcOp)
Inserts llvm.load ops in the function body to restore the expected pointee value from llvm....
static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType, FunctionOpInterface funcOp, LLVM::LLVMFuncOp wrapperFuncOp)
Propagate argument/results attributes.
static bool isDiscardableAttr(StringRef name)
static constexpr StringRef barePtrAttrName
static constexpr StringRef varargsAttrName
static constexpr StringRef linkageAttrName
static void filterFuncAttributes(FunctionOpInterface func, SmallVectorImpl< NamedAttribute > &result)
Only retain those attributes that are not constructed by LLVMFuncOp::build.
static bool shouldUseBarePtrCallConv(Operation *op, const LLVMTypeConverter *typeConverter)
Return true if the op should use bare pointer calling convention.
static void wrapExternalFunction(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, FunctionOpInterface funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
static void buildLLVMFuncProperties(PatternRewriter &rewriter, FunctionOpInterface srcFunc, Type llvmFuncType, LLVM::LLVMFuncOp::Properties &props)
static FailureOr< LoweredFuncAttrs > lowerFuncAttributes(FunctionOpInterface func)
Lower discardable function attributes on func.func to attributes expected by llvm....
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, const LLVMTypeConverter &typeConverter, FunctionOpInterface funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:108
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:230
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
const LowerToLLVMOptions & getOptions() const
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static unsigned getNumUnpackedValues(MemRefType type)
Returns the number of non-aggregate values that would be produced by unpack.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:586
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:274
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class represents a collection of SymbolTables.
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
static unsigned getNumUnpackedValues()
Returns the number of non-aggregate values that would be produced by unpack.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void registerConvertFuncToLLVMInterface(DialectRegistry &registry)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
void populateFuncToLLVMFuncOpConversionPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the default pattern to convert a FuncOp to the LLVM dialect.
FailureOr< LLVM::LLVMFuncOp > convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables=nullptr)
Convert input FunctionOpInterface operation to LLVMFuncOp by using the provided LLVMTypeConverter.
Add custom lowered funcOp to llvm.func attributes here.
NamedAttrList discardableAttrs
LLVM::LLVMFuncOp::Properties properties
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.