MLIR 22.0.0git
FuncToLLVM.cpp
Go to the documentation of this file.
1//===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert MLIR Func and builtin dialects
10// into the LLVM IR dialect.
11//
12//===----------------------------------------------------------------------===//
13
15
27#include "mlir/IR/Attributes.h"
28#include "mlir/IR/Builders.h"
30#include "mlir/IR/BuiltinOps.h"
32#include "mlir/IR/SymbolTable.h"
36#include "llvm/ADT/SmallVector.h"
37#include "llvm/IR/Type.h"
38#include "llvm/Support/FormatVariadic.h"
39#include <optional>
40
41namespace mlir {
42#define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
43#define GEN_PASS_DEF_SETLLVMMODULEDATALAYOUTPASS
44#include "mlir/Conversion/Passes.h.inc"
45} // namespace mlir
46
47using namespace mlir;
48
49#define PASS_NAME "convert-func-to-llvm"
50
51static constexpr StringRef varargsAttrName = "func.varargs";
52static constexpr StringRef linkageAttrName = "llvm.linkage";
53static constexpr StringRef barePtrAttrName = "llvm.bareptr";
54
55/// Return `true` if the `op` should use bare pointer calling convention.
57 const LLVMTypeConverter *typeConverter) {
58 return (op && op->hasAttr(barePtrAttrName)) ||
59 typeConverter->getOptions().useBarePtrCallConv;
60}
61
62/// Only retain those attributes that are not constructed by
63/// `LLVMFuncOp::build`.
64static void filterFuncAttributes(FunctionOpInterface func,
66 for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
67 if (attr.getName() == linkageAttrName ||
68 attr.getName() == varargsAttrName ||
69 attr.getName() == LLVM::LLVMDialect::getReadnoneAttrName())
70 continue;
71 result.push_back(attr);
72 }
73}
74
75/// Propagate argument/results attributes.
76static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
77 FunctionOpInterface funcOp,
78 LLVM::LLVMFuncOp wrapperFuncOp) {
79 auto argAttrs = funcOp.getAllArgAttrs();
80 if (!resultStructType) {
81 if (auto resAttrs = funcOp.getAllResultAttrs())
82 wrapperFuncOp.setAllResultAttrs(resAttrs);
83 if (argAttrs)
84 wrapperFuncOp.setAllArgAttrs(argAttrs);
85 } else {
86 SmallVector<Attribute> argAttributes;
87 // Only modify the argument and result attributes when the result is now
88 // an argument.
89 if (argAttrs) {
90 argAttributes.push_back(builder.getDictionaryAttr({}));
91 argAttributes.append(argAttrs.begin(), argAttrs.end());
92 wrapperFuncOp.setAllArgAttrs(argAttributes);
93 }
94 }
95 cast<FunctionOpInterface>(wrapperFuncOp.getOperation())
96 .setVisibility(funcOp.getVisibility());
97}
98
99/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
100/// arguments instead of unpacked arguments. This function can be called from C
101/// by passing a pointer to a C struct corresponding to a memref descriptor.
102/// Similarly, returned memrefs are passed via pointers to a C struct that is
103/// passed as additional argument.
104/// Internally, the auxiliary function unpacks the descriptor into individual
105/// components and forwards them to `newFuncOp` and forwards the results to
106/// the extra arguments.
107static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
108 const LLVMTypeConverter &typeConverter,
109 FunctionOpInterface funcOp,
110 LLVM::LLVMFuncOp newFuncOp) {
111 auto type = cast<FunctionType>(funcOp.getFunctionType());
112 auto [wrapperFuncType, resultStructType] =
113 typeConverter.convertFunctionTypeCWrapper(type);
114
116 filterFuncAttributes(funcOp, attributes);
117
118 auto wrapperFuncOp = LLVM::LLVMFuncOp::create(
119 rewriter, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
120 wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false,
121 /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
122 propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp);
123
124 OpBuilder::InsertionGuard guard(rewriter);
125 rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(rewriter));
126
128 size_t argOffset = resultStructType ? 1 : 0;
129 for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
130 Value arg = wrapperFuncOp.getArgument(index + argOffset);
131 if (auto memrefType = dyn_cast<MemRefType>(argType)) {
132 Value loaded = LLVM::LoadOp::create(
133 rewriter, loc, typeConverter.convertType(memrefType), arg);
134 MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
135 continue;
136 }
137 if (isa<UnrankedMemRefType>(argType)) {
138 Value loaded = LLVM::LoadOp::create(
139 rewriter, loc, typeConverter.convertType(argType), arg);
140 UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
141 continue;
142 }
143
144 args.push_back(arg);
145 }
146
147 auto call = LLVM::CallOp::create(rewriter, loc, newFuncOp, args);
148
149 if (resultStructType) {
150 LLVM::StoreOp::create(rewriter, loc, call.getResult(),
151 wrapperFuncOp.getArgument(0));
152 LLVM::ReturnOp::create(rewriter, loc, ValueRange{});
153 } else {
154 LLVM::ReturnOp::create(rewriter, loc, call.getResults());
155 }
156}
157
158/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
159/// arguments instead of unpacked arguments. Creates a body for the (external)
160/// `newFuncOp` that allocates a memref descriptor on stack, packs the
161/// individual arguments into this descriptor and passes a pointer to it into
162/// the auxiliary function. If the result of the function cannot be directly
163/// returned, we write it to a special first argument that provides a pointer
164/// to a corresponding struct. This auxiliary external function is now
165/// compatible with functions defined in C using pointers to C structs
166/// corresponding to a memref descriptor.
167static void wrapExternalFunction(OpBuilder &builder, Location loc,
168 const LLVMTypeConverter &typeConverter,
169 FunctionOpInterface funcOp,
170 LLVM::LLVMFuncOp newFuncOp) {
171 OpBuilder::InsertionGuard guard(builder);
172
173 auto [wrapperType, resultStructType] =
174 typeConverter.convertFunctionTypeCWrapper(
175 cast<FunctionType>(funcOp.getFunctionType()));
176 // This conversion can only fail if it could not convert one of the argument
177 // types. But since it has been applied to a non-wrapper function before, it
178 // should have failed earlier and not reach this point at all.
179 assert(wrapperType && "unexpected type conversion failure");
180
182 filterFuncAttributes(funcOp, attributes);
183
184 // Create the auxiliary function.
185 auto wrapperFunc = LLVM::LLVMFuncOp::create(
186 builder, loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
187 wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false,
188 /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
189 propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc);
190
191 // The wrapper that we synthetize here should only be visible in this module.
192 newFuncOp.setLinkage(LLVM::Linkage::Private);
193 builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder));
194
195 // Get a ValueRange containing arguments.
196 FunctionType type = cast<FunctionType>(funcOp.getFunctionType());
198 args.reserve(type.getNumInputs());
199 ValueRange wrapperArgsRange(newFuncOp.getArguments());
200
201 if (resultStructType) {
202 // Allocate the struct on the stack and pass the pointer.
203 Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
204 Value one = LLVM::ConstantOp::create(
205 builder, loc, typeConverter.convertType(builder.getIndexType()),
206 builder.getIntegerAttr(builder.getIndexType(), 1));
207 Value result =
208 LLVM::AllocaOp::create(builder, loc, resultType, resultStructType, one);
209 args.push_back(result);
210 }
211
212 // Iterate over the inputs of the original function and pack values into
213 // memref descriptors if the original type is a memref.
214 for (Type input : type.getInputs()) {
215 Value arg;
216 int numToDrop = 1;
217 auto memRefType = dyn_cast<MemRefType>(input);
218 auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(input);
219 if (memRefType || unrankedMemRefType) {
220 numToDrop = memRefType
223 Value packed =
224 memRefType
225 ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
226 wrapperArgsRange.take_front(numToDrop))
228 builder, loc, typeConverter, unrankedMemRefType,
229 wrapperArgsRange.take_front(numToDrop));
230
231 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
232 Value one = LLVM::ConstantOp::create(
233 builder, loc, typeConverter.convertType(builder.getIndexType()),
234 builder.getIntegerAttr(builder.getIndexType(), 1));
235 Value allocated = LLVM::AllocaOp::create(
236 builder, loc, ptrTy, packed.getType(), one, /*alignment=*/0);
237 LLVM::StoreOp::create(builder, loc, packed, allocated);
238 arg = allocated;
239 } else {
240 arg = wrapperArgsRange[0];
241 }
242
243 args.push_back(arg);
244 wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
245 }
246 assert(wrapperArgsRange.empty() && "did not map some of the arguments");
247
248 auto call = LLVM::CallOp::create(builder, loc, wrapperFunc, args);
249
250 if (resultStructType) {
251 Value result =
252 LLVM::LoadOp::create(builder, loc, resultStructType, args.front());
253 LLVM::ReturnOp::create(builder, loc, result);
254 } else {
255 LLVM::ReturnOp::create(builder, loc, call.getResults());
256 }
257}
258
259/// Inserts `llvm.load` ops in the function body to restore the expected pointee
260/// value from `llvm.byval`/`llvm.byref` function arguments that were converted
261/// to LLVM pointer types.
263 ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter,
264 ArrayRef<std::optional<NamedAttribute>> byValRefNonPtrAttrs,
265 LLVM::LLVMFuncOp funcOp) {
266 // Nothing to do for function declarations.
267 if (funcOp.isExternal())
268 return;
269
270 ConversionPatternRewriter::InsertionGuard guard(rewriter);
271 rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
272
273 for (const auto &[arg, byValRefAttr] :
274 llvm::zip(funcOp.getArguments(), byValRefNonPtrAttrs)) {
275 // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
276 if (!byValRefAttr)
277 continue;
278
279 // Insert load to retrieve the actual argument passed by value/reference.
280 assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
281 "Expected LLVM pointer type for argument with "
282 "`llvm.byval`/`llvm.byref` attribute");
283 Type resTy = typeConverter.convertType(
284 cast<TypeAttr>(byValRefAttr->getValue()).getValue());
285
286 Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
287 rewriter.replaceAllUsesWith(arg, valueArg);
288 }
289}
290
291FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
292 FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter,
293 const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables) {
294 // Check the funcOp has `FunctionType`.
295 auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
296 if (!funcTy)
297 return rewriter.notifyMatchFailure(
298 funcOp, "Only support FunctionOpInterface with FunctionType");
299
300 // Convert the original function arguments. They are converted using the
301 // LLVMTypeConverter provided to this legalization pattern.
302 auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
303 // Gather `llvm.byval` and `llvm.byref` arguments whose type convertion was
304 // overriden with an LLVM pointer type for later processing.
306 TypeConverter::SignatureConversion result(funcOp.getNumArguments());
307 auto llvmType = dyn_cast_or_null<LLVM::LLVMFunctionType>(
308 converter.convertFunctionSignature(
309 funcOp, varargsAttr && varargsAttr.getValue(),
310 shouldUseBarePtrCallConv(funcOp, &converter), result,
311 byValRefNonPtrAttrs));
312 if (!llvmType)
313 return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
314
315 // Check for unsupported variadic functions.
316 if (!shouldUseBarePtrCallConv(funcOp, &converter))
317 if (funcOp->getAttrOfType<UnitAttr>(
318 LLVM::LLVMDialect::getEmitCWrapperAttrName()))
319 if (llvmType.isVarArg())
320 return funcOp.emitError("C interface for variadic functions is not "
321 "supported yet.");
322
323 // Create an LLVM function, use external linkage by default until MLIR
324 // functions have linkage.
325 LLVM::Linkage linkage = LLVM::Linkage::External;
326 if (funcOp->hasAttr(linkageAttrName)) {
327 auto attr =
328 dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
329 if (!attr) {
330 funcOp->emitError() << "Contains " << linkageAttrName
331 << " attribute not of type LLVM::LinkageAttr";
332 return rewriter.notifyMatchFailure(
333 funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
334 }
335 linkage = attr.getLinkage();
336 }
337
338 // Check for invalid attributes.
339 StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
340 if (funcOp->hasAttr(readnoneAttrName)) {
341 auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
342 if (!attr) {
343 funcOp->emitError() << "Contains " << readnoneAttrName
344 << " attribute not of type UnitAttr";
345 return rewriter.notifyMatchFailure(
346 funcOp, "Contains readnone attribute not of type UnitAttr");
347 }
348 }
349
351 filterFuncAttributes(funcOp, attributes);
352
353 Operation *symbolTableOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
354
355 if (symbolTables && symbolTableOp) {
356 SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
357 symbolTable.remove(funcOp);
358 }
359
360 auto newFuncOp = LLVM::LLVMFuncOp::create(
361 rewriter, funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
362 /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
363 attributes);
364
365 if (symbolTables && symbolTableOp) {
366 auto ip = rewriter.getInsertionPoint();
367 SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
368 symbolTable.insert(newFuncOp, ip);
369 }
370
371 cast<FunctionOpInterface>(newFuncOp.getOperation())
372 .setVisibility(funcOp.getVisibility());
373
374 // Create a memory effect attribute corresponding to readnone.
375 if (funcOp->hasAttr(readnoneAttrName)) {
376 auto memoryAttr = LLVM::MemoryEffectsAttr::get(
377 rewriter.getContext(), {/*other=*/LLVM::ModRefInfo::NoModRef,
378 /*argMem=*/LLVM::ModRefInfo::NoModRef,
379 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
380 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
381 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
382 /*targetMem1=*/LLVM::ModRefInfo::NoModRef});
383 newFuncOp.setMemoryEffectsAttr(memoryAttr);
384 }
385
386 // Propagate argument/result attributes to all converted arguments/result
387 // obtained after converting a given original argument/result.
388 if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
389 assert(!resAttrDicts.empty() && "expected array to be non-empty");
390 if (funcOp.getNumResults() == 1)
391 newFuncOp.setAllResultAttrs(resAttrDicts);
392 }
393 if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
394 SmallVector<Attribute> newArgAttrs(
395 cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
396 for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
397 // Some LLVM IR attribute have a type attached to them. During FuncOp ->
398 // LLVMFuncOp conversion these types may have changed. Account for that
399 // change by converting attributes' types as well.
400 SmallVector<NamedAttribute, 4> convertedAttrs;
401 auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
402 convertedAttrs.reserve(attrsDict.size());
403 for (const NamedAttribute &attr : attrsDict) {
404 const auto convert = [&](const NamedAttribute &attr) {
405 return TypeAttr::get(converter.convertType(
406 cast<TypeAttr>(attr.getValue()).getValue()));
407 };
408 if (attr.getName().getValue() ==
409 LLVM::LLVMDialect::getByValAttrName()) {
410 convertedAttrs.push_back(rewriter.getNamedAttr(
411 LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
412 } else if (attr.getName().getValue() ==
413 LLVM::LLVMDialect::getByRefAttrName()) {
414 convertedAttrs.push_back(rewriter.getNamedAttr(
415 LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
416 } else if (attr.getName().getValue() ==
417 LLVM::LLVMDialect::getStructRetAttrName()) {
418 convertedAttrs.push_back(rewriter.getNamedAttr(
419 LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
420 } else if (attr.getName().getValue() ==
421 LLVM::LLVMDialect::getInAllocaAttrName()) {
422 convertedAttrs.push_back(rewriter.getNamedAttr(
423 LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
424 } else {
425 convertedAttrs.push_back(attr);
426 }
427 }
428 auto mapping = result.getInputMapping(i);
429 assert(mapping && "unexpected deletion of function argument");
430 // Only attach the new argument attributes if there is a one-to-one
431 // mapping from old to new types. Otherwise, attributes might be
432 // attached to types that they do not support.
433 if (mapping->size == 1) {
434 newArgAttrs[mapping->inputNo] =
435 DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
436 continue;
437 }
438 // TODO: Implement custom handling for types that expand to multiple
439 // function arguments.
440 for (size_t j = 0; j < mapping->size; ++j)
441 newArgAttrs[mapping->inputNo + j] =
442 DictionaryAttr::get(rewriter.getContext(), {});
443 }
444 if (!newArgAttrs.empty())
445 newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs));
446 }
447
448 rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
449 newFuncOp.end());
450 // Convert just the entry block. The remaining unstructured control flow is
451 // converted by ControlFlowToLLVM.
452 if (!newFuncOp.getBody().empty())
453 rewriter.applySignatureConversion(&newFuncOp.getBody().front(), result,
454 &converter);
455
456 // Fix the type mismatch between the materialized `llvm.ptr` and the expected
457 // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
458 // function arguments.
459 restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs,
460 newFuncOp);
461
462 if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
463 if (funcOp->getAttrOfType<UnitAttr>(
464 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
465 if (newFuncOp.isExternal())
466 wrapExternalFunction(rewriter, funcOp->getLoc(), converter, funcOp,
467 newFuncOp);
468 else
469 wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
470 newFuncOp);
471 }
472 }
473
474 return newFuncOp;
475}
476
477namespace {
478
479/// FuncOp legalization pattern that converts MemRef arguments to pointers to
480/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
481/// information.
482class FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
483 SymbolTableCollection *symbolTables = nullptr;
484
485public:
486 explicit FuncOpConversion(const LLVMTypeConverter &converter,
487 SymbolTableCollection *symbolTables = nullptr)
488 : ConvertOpToLLVMPattern(converter), symbolTables(symbolTables) {}
489
490 LogicalResult
491 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
492 ConversionPatternRewriter &rewriter) const override {
493 FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
494 cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
495 *getTypeConverter(), symbolTables);
496 if (failed(newFuncOp))
497 return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
498
499 rewriter.eraseOp(funcOp);
500 return success();
501 }
502};
503
504struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
505 using ConvertOpToLLVMPattern<func::ConstantOp>::ConvertOpToLLVMPattern;
506
507 LogicalResult
508 matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
509 ConversionPatternRewriter &rewriter) const override {
510 auto type = typeConverter->convertType(op.getResult().getType());
511 if (!type || !LLVM::isCompatibleType(type))
512 return rewriter.notifyMatchFailure(op, "failed to convert result type");
513
514 auto newOp =
515 LLVM::AddressOfOp::create(rewriter, op.getLoc(), type, op.getValue());
516 for (const NamedAttribute &attr : op->getAttrs()) {
517 if (attr.getName().strref() == "value")
518 continue;
519 newOp->setAttr(attr.getName(), attr.getValue());
520 }
521 rewriter.replaceOp(op, newOp->getResults());
522 return success();
523 }
524};
525
526// A CallOp automatically promotes MemRefType to a sequence of alloca/store and
527// passes the pointer to the MemRef across function boundaries.
528template <typename CallOpType>
529struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
530 using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
531 using Super = CallOpInterfaceLowering<CallOpType>;
532 using Base = ConvertOpToLLVMPattern<CallOpType>;
534
535 LogicalResult matchAndRewriteImpl(CallOpType callOp, Adaptor adaptor,
536 ConversionPatternRewriter &rewriter,
537 bool useBarePtrCallConv = false) const {
538 // Pack the result types into a struct.
539 Type packedResult = nullptr;
540 SmallVector<SmallVector<Type>> groupedResultTypes;
541 unsigned numResults = callOp.getNumResults();
542 auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
543 int64_t numConvertedTypes = 0;
544 if (numResults != 0) {
545 if (!(packedResult = this->getTypeConverter()->packFunctionResults(
546 resultTypes, useBarePtrCallConv, &groupedResultTypes,
547 &numConvertedTypes)))
548 return failure();
549 }
550
551 if (useBarePtrCallConv) {
552 for (auto it : callOp->getOperands()) {
553 Type operandType = it.getType();
554 if (isa<UnrankedMemRefType>(operandType)) {
555 // Unranked memref is not supported in the bare pointer calling
556 // convention.
557 return failure();
558 }
559 }
560 }
561 auto promoted = this->getTypeConverter()->promoteOperands(
562 callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
563 adaptor.getOperands(), rewriter, useBarePtrCallConv);
564 auto newOp = LLVM::CallOp::create(rewriter, callOp.getLoc(),
565 packedResult ? TypeRange(packedResult)
566 : TypeRange(),
567 promoted, callOp->getAttrs());
568
569 newOp.getProperties().operandSegmentSizes = {
570 static_cast<int32_t>(promoted.size()), 0};
571 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
572
573 // Helper function that extracts an individual result from the return value
574 // of the new call op. llvm.call ops support only 0 or 1 result. In case of
575 // 2 or more results, the results are packed into a structure.
576 //
577 // The new call op may have more than 2 results because:
578 // a. The original call op has more than 2 results.
579 // b. An original op result type-converted to more than 1 result.
580 auto getUnpackedResult = [&](unsigned i) -> Value {
581 assert(numConvertedTypes > 0 && "convert op has no results");
582 if (numConvertedTypes == 1) {
583 assert(i == 0 && "out of bounds: converted op has only one result");
584 return newOp->getResult(0);
585 }
586 // Results have been converted to a structure. Extract individual results
587 // from the structure.
588 return LLVM::ExtractValueOp::create(rewriter, callOp.getLoc(),
589 newOp->getResult(0), i);
590 };
591
592 // Group the results into a vector of vectors, such that it is clear which
593 // original op result is replaced with which range of values. (In case of a
594 // 1:N conversion, there can be multiple replacements for a single result.)
595 SmallVector<SmallVector<Value>> results;
596 results.reserve(numResults);
597 unsigned counter = 0;
598 for (unsigned i = 0; i < numResults; ++i) {
599 SmallVector<Value> &group = results.emplace_back();
600 for (unsigned j = 0, e = groupedResultTypes[i].size(); j < e; ++j)
601 group.push_back(getUnpackedResult(counter++));
602 }
603
604 // Special handling for MemRef types.
605 for (unsigned i = 0; i < numResults; ++i) {
606 Type origType = resultTypes[i];
607 auto memrefType = dyn_cast<MemRefType>(origType);
608 auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(origType);
609 if (useBarePtrCallConv && memrefType) {
610 // For the bare-ptr calling convention, promote memref results to
611 // descriptors.
612 assert(results[i].size() == 1 && "expected one converted result");
613 results[i].front() = MemRefDescriptor::fromStaticShape(
614 rewriter, callOp.getLoc(), *this->getTypeConverter(), memrefType,
615 results[i].front());
616 }
617 if (unrankedMemrefType) {
618 assert(!useBarePtrCallConv && "unranked memref is not supported in the "
619 "bare-ptr calling convention");
620 assert(results[i].size() == 1 && "expected one converted result");
621 Value desc = this->copyUnrankedDescriptor(
622 rewriter, callOp.getLoc(), unrankedMemrefType, results[i].front(),
623 /*toDynamic=*/false);
624 if (!desc)
625 return failure();
626 results[i].front() = desc;
627 }
628 }
629
630 rewriter.replaceOpWithMultiple(callOp, results);
631 return success();
632 }
633};
634
635class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
636public:
637 explicit CallOpLowering(const LLVMTypeConverter &typeConverter,
638 SymbolTableCollection *symbolTables = nullptr,
639 PatternBenefit benefit = 1)
640 : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
641 symbolTables(symbolTables) {}
642
643 LogicalResult
644 matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor,
645 ConversionPatternRewriter &rewriter) const override {
646 bool useBarePtrCallConv = false;
647 if (getTypeConverter()->getOptions().useBarePtrCallConv) {
648 useBarePtrCallConv = true;
649 } else if (symbolTables != nullptr) {
650 // Fast lookup.
651 Operation *callee =
652 symbolTables->lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
653 useBarePtrCallConv =
654 callee != nullptr && callee->hasAttr(barePtrAttrName);
655 } else {
656 // Warning: This is a linear lookup.
657 Operation *callee =
658 SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
659 useBarePtrCallConv =
660 callee != nullptr && callee->hasAttr(barePtrAttrName);
661 }
662 return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
663 }
664
665private:
666 SymbolTableCollection *symbolTables = nullptr;
667};
668
669struct CallIndirectOpLowering
670 : public CallOpInterfaceLowering<func::CallIndirectOp> {
671 using Super::Super;
672
673 LogicalResult
674 matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor,
675 ConversionPatternRewriter &rewriter) const override {
676 return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
677 }
678};
679
680struct UnrealizedConversionCastOpLowering
681 : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
682 using ConvertOpToLLVMPattern<
683 UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
684
685 LogicalResult
686 matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
687 ConversionPatternRewriter &rewriter) const override {
688 SmallVector<Type> convertedTypes;
689 if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
690 convertedTypes)) &&
691 convertedTypes == adaptor.getInputs().getTypes()) {
692 rewriter.replaceOp(op, adaptor.getInputs());
693 return success();
694 }
695
696 convertedTypes.clear();
697 if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
698 convertedTypes)) &&
699 convertedTypes == op.getOutputs().getType()) {
700 rewriter.replaceOp(op, adaptor.getInputs());
701 return success();
702 }
703 return failure();
704 }
705};
706
707// Special lowering pattern for `ReturnOps`. Unlike all other operations,
708// `ReturnOp` interacts with the function signature and must have as many
709// operands as the function has return values. Because in LLVM IR, functions
710// can only return 0 or 1 value, we pack multiple values into a structure type.
711// Emit `PoisonOp` followed by `InsertValueOp`s to create such structure if
712// necessary before returning it
713struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
714 using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
715
716 LogicalResult
717 matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
718 ConversionPatternRewriter &rewriter) const override {
719 Location loc = op.getLoc();
720 SmallVector<Value, 4> updatedOperands;
721
722 auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
723 bool useBarePtrCallConv =
724 shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
725
726 for (auto [oldOperand, newOperands] :
727 llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
728 Type oldTy = oldOperand.getType();
729 if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
730 assert(newOperands.size() == 1 && "expected one converted result");
731 if (useBarePtrCallConv &&
732 getTypeConverter()->canConvertToBarePtr(memRefType)) {
733 // For the bare-ptr calling convention, extract the aligned pointer to
734 // be returned from the memref descriptor.
735 MemRefDescriptor memrefDesc(newOperands.front());
736 updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc));
737 continue;
738 }
739 } else if (auto unrankedMemRefType =
740 dyn_cast<UnrankedMemRefType>(oldTy)) {
741 assert(newOperands.size() == 1 && "expected one converted result");
742 if (useBarePtrCallConv) {
743 // Unranked memref is not supported in the bare pointer calling
744 // convention.
745 return failure();
746 }
747 Value updatedDesc =
748 copyUnrankedDescriptor(rewriter, loc, unrankedMemRefType,
749 newOperands.front(), /*toDynamic=*/true);
750 if (!updatedDesc)
751 return failure();
752 updatedOperands.push_back(updatedDesc);
753 continue;
754 }
755
756 llvm::append_range(updatedOperands, newOperands);
757 }
758
759 // If ReturnOp has 0 or 1 operand, create it and return immediately.
760 if (updatedOperands.size() <= 1) {
761 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
762 op, TypeRange(), updatedOperands, op->getAttrs());
763 return success();
764 }
765
766 // Otherwise, we need to pack the arguments into an LLVM struct type before
767 // returning.
768 auto packedType = getTypeConverter()->packFunctionResults(
769 op.getOperandTypes(), useBarePtrCallConv);
770 if (!packedType) {
771 return rewriter.notifyMatchFailure(op, "could not convert result types");
772 }
773
774 Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType);
775 for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
776 packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx);
777 }
778 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
779 op->getAttrs());
780 return success();
781 }
782};
783} // namespace
784
787 SymbolTableCollection *symbolTables) {
788 patterns.add<FuncOpConversion>(converter, symbolTables);
789}
790
793 SymbolTableCollection *symbolTables) {
794 populateFuncToLLVMFuncOpConversionPattern(converter, patterns, symbolTables);
795 patterns.add<CallIndirectOpLowering>(converter);
796 patterns.add<CallOpLowering>(converter, symbolTables);
797 patterns.add<ConstantOpLowering>(converter);
798 patterns.add<ReturnOpLowering>(converter);
799}
800
801namespace {
802/// A pass converting Func operations into the LLVM IR dialect.
803struct ConvertFuncToLLVMPass
804 : public impl::ConvertFuncToLLVMPassBase<ConvertFuncToLLVMPass> {
805 using Base::Base;
806
807 /// Run the dialect converter on the module.
808 void runOnOperation() override {
809 ModuleOp m = getOperation();
810 StringRef dataLayout;
811 auto dataLayoutAttr = dyn_cast_or_null<StringAttr>(
812 m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));
813 if (dataLayoutAttr)
814 dataLayout = dataLayoutAttr.getValue();
815
816 if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
817 dataLayout, [this](const Twine &message) {
818 getOperation().emitError() << message.str();
819 }))) {
820 signalPassFailure();
821 return;
822 }
823
824 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
825
826 LowerToLLVMOptions options(&getContext(),
827 dataLayoutAnalysis.getAtOrAbove(m));
828 options.useBarePtrCallConv = useBarePtrCallConv;
829 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
830 options.overrideIndexBitwidth(indexBitwidth);
831 options.dataLayout = llvm::DataLayout(dataLayout);
832
833 LLVMTypeConverter typeConverter(&getContext(), options,
834 &dataLayoutAnalysis);
835
836 RewritePatternSet patterns(&getContext());
837 SymbolTableCollection symbolTables;
838
840 &symbolTables);
841
842 LLVMConversionTarget target(getContext());
843 if (failed(applyPartialConversion(m, target, std::move(patterns))))
844 signalPassFailure();
845 }
846};
847
848struct SetLLVMModuleDataLayoutPass
850 SetLLVMModuleDataLayoutPass> {
851 using Base::Base;
852
853 /// Run the dialect converter on the module.
854 void runOnOperation() override {
855 if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
856 this->dataLayout, [this](const Twine &message) {
857 getOperation().emitError() << message.str();
858 }))) {
859 signalPassFailure();
860 return;
861 }
862 ModuleOp m = getOperation();
863 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
864 StringAttr::get(m.getContext(), this->dataLayout));
865 }
866};
867} // namespace
868
869//===----------------------------------------------------------------------===//
870// ConvertToLLVMPatternInterface implementation
871//===----------------------------------------------------------------------===//
872
873namespace {
874/// Implement the interface to convert Func to LLVM.
875struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
877 /// Hook for derived dialect interface to provide conversion patterns
878 /// and mark dialect legal for the conversion target.
879 void populateConvertToLLVMConversionPatterns(
880 ConversionTarget &target, LLVMTypeConverter &typeConverter,
881 RewritePatternSet &patterns) const final {
883 }
884};
885} // namespace
886
888 registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
889 dialect->addInterfaces<FuncToLLVMDialectInterface>();
890 });
891}
return success()
static void restoreByValRefArgumentType(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, ArrayRef< std::optional< NamedAttribute > > byValRefNonPtrAttrs, LLVM::LLVMFuncOp funcOp)
Inserts llvm.load ops in the function body to restore the expected pointee value from llvm....
static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType, FunctionOpInterface funcOp, LLVM::LLVMFuncOp wrapperFuncOp)
Propagate argument/results attributes.
static constexpr StringRef barePtrAttrName
static constexpr StringRef varargsAttrName
static constexpr StringRef linkageAttrName
static void filterFuncAttributes(FunctionOpInterface func, SmallVectorImpl< NamedAttribute > &result)
Only retain those attributes that are not constructed by LLVMFuncOp::build.
static bool shouldUseBarePtrCallConv(Operation *op, const LLVMTypeConverter *typeConverter)
Return true if the op should use bare pointer calling convention.
static void wrapExternalFunction(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, FunctionOpInterface funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, const LLVMTypeConverter &typeConverter, FunctionOpInterface funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
ArrayAttr()
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:104
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:207
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition Pattern.h:210
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
const LowerToLLVMOptions & getOptions() const
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static unsigned getNumUnpackedValues(MemRefType type)
Returns the number of non-aggregate values that would be produced by unpack.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:560
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
This class represents a collection of SymbolTables.
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
static unsigned getNumUnpackedValues()
Returns the number of non-aggregate values that would be produced by unpack.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void registerConvertFuncToLLVMInterface(DialectRegistry &registry)
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
void populateFuncToLLVMFuncOpConversionPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the default pattern to convert a FuncOp to the LLVM dialect.
const FrozenRewritePatternSet & patterns
FailureOr< LLVM::LLVMFuncOp > convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables=nullptr)
Convert input FunctionOpInterface operation to LLVMFuncOp by using the provided LLVMTypeConverter.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.