MLIR  20.0.0git
FuncToLLVM.cpp
Go to the documentation of this file.
1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR Func and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/Builders.h"
33 #include "mlir/IR/BuiltinOps.h"
34 #include "mlir/IR/IRMapping.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/IR/SymbolTable.h"
37 #include "mlir/IR/TypeUtilities.h"
39 #include "mlir/Transforms/Passes.h"
40 #include "llvm/ADT/SmallVector.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/IR/DerivedTypes.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/Type.h"
45 #include "llvm/Support/Casting.h"
46 #include "llvm/Support/CommandLine.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include <algorithm>
49 #include <functional>
50 #include <optional>
51 
52 namespace mlir {
53 #define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
54 #define GEN_PASS_DEF_SETLLVMMODULEDATALAYOUTPASS
55 #include "mlir/Conversion/Passes.h.inc"
56 } // namespace mlir
57 
58 using namespace mlir;
59 
60 #define PASS_NAME "convert-func-to-llvm"
61 
62 static constexpr StringRef varargsAttrName = "func.varargs";
63 static constexpr StringRef linkageAttrName = "llvm.linkage";
64 static constexpr StringRef barePtrAttrName = "llvm.bareptr";
65 
66 /// Return `true` if the `op` should use bare pointer calling convention.
68  const LLVMTypeConverter *typeConverter) {
69  return (op && op->hasAttr(barePtrAttrName)) ||
70  typeConverter->getOptions().useBarePtrCallConv;
71 }
72 
73 /// Only retain those attributes that are not constructed by
74 /// `LLVMFuncOp::build`.
75 static void filterFuncAttributes(FunctionOpInterface func,
77  for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
78  if (attr.getName() == linkageAttrName ||
79  attr.getName() == varargsAttrName ||
80  attr.getName() == LLVM::LLVMDialect::getReadnoneAttrName())
81  continue;
82  result.push_back(attr);
83  }
84 }
85 
86 /// Propagate argument/results attributes.
87 static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
88  FunctionOpInterface funcOp,
89  LLVM::LLVMFuncOp wrapperFuncOp) {
90  auto argAttrs = funcOp.getAllArgAttrs();
91  if (!resultStructType) {
92  if (auto resAttrs = funcOp.getAllResultAttrs())
93  wrapperFuncOp.setAllResultAttrs(resAttrs);
94  if (argAttrs)
95  wrapperFuncOp.setAllArgAttrs(argAttrs);
96  } else {
97  SmallVector<Attribute> argAttributes;
98  // Only modify the argument and result attributes when the result is now
99  // an argument.
100  if (argAttrs) {
101  argAttributes.push_back(builder.getDictionaryAttr({}));
102  argAttributes.append(argAttrs.begin(), argAttrs.end());
103  wrapperFuncOp.setAllArgAttrs(argAttributes);
104  }
105  }
106  cast<FunctionOpInterface>(wrapperFuncOp.getOperation())
107  .setVisibility(funcOp.getVisibility());
108 }
109 
110 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
111 /// arguments instead of unpacked arguments. This function can be called from C
112 /// by passing a pointer to a C struct corresponding to a memref descriptor.
113 /// Similarly, returned memrefs are passed via pointers to a C struct that is
114 /// passed as additional argument.
115 /// Internally, the auxiliary function unpacks the descriptor into individual
116 /// components and forwards them to `newFuncOp` and forwards the results to
117 /// the extra arguments.
118 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
119  const LLVMTypeConverter &typeConverter,
120  FunctionOpInterface funcOp,
121  LLVM::LLVMFuncOp newFuncOp) {
122  auto type = cast<FunctionType>(funcOp.getFunctionType());
123  auto [wrapperFuncType, resultStructType] =
124  typeConverter.convertFunctionTypeCWrapper(type);
125 
126  SmallVector<NamedAttribute> attributes;
127  filterFuncAttributes(funcOp, attributes);
128 
129  auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
130  loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
131  wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false,
132  /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
133  propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp);
134 
135  OpBuilder::InsertionGuard guard(rewriter);
136  rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(rewriter));
137 
139  size_t argOffset = resultStructType ? 1 : 0;
140  for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
141  Value arg = wrapperFuncOp.getArgument(index + argOffset);
142  if (auto memrefType = dyn_cast<MemRefType>(argType)) {
143  Value loaded = rewriter.create<LLVM::LoadOp>(
144  loc, typeConverter.convertType(memrefType), arg);
145  MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
146  continue;
147  }
148  if (isa<UnrankedMemRefType>(argType)) {
149  Value loaded = rewriter.create<LLVM::LoadOp>(
150  loc, typeConverter.convertType(argType), arg);
151  UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
152  continue;
153  }
154 
155  args.push_back(arg);
156  }
157 
158  auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
159 
160  if (resultStructType) {
161  rewriter.create<LLVM::StoreOp>(loc, call.getResult(),
162  wrapperFuncOp.getArgument(0));
163  rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
164  } else {
165  rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
166  }
167 }
168 
169 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
170 /// arguments instead of unpacked arguments. Creates a body for the (external)
171 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
172 /// individual arguments into this descriptor and passes a pointer to it into
173 /// the auxiliary function. If the result of the function cannot be directly
174 /// returned, we write it to a special first argument that provides a pointer
175 /// to a corresponding struct. This auxiliary external function is now
176 /// compatible with functions defined in C using pointers to C structs
177 /// corresponding to a memref descriptor.
178 static void wrapExternalFunction(OpBuilder &builder, Location loc,
179  const LLVMTypeConverter &typeConverter,
180  FunctionOpInterface funcOp,
181  LLVM::LLVMFuncOp newFuncOp) {
182  OpBuilder::InsertionGuard guard(builder);
183 
184  auto [wrapperType, resultStructType] =
185  typeConverter.convertFunctionTypeCWrapper(
186  cast<FunctionType>(funcOp.getFunctionType()));
187  // This conversion can only fail if it could not convert one of the argument
188  // types. But since it has been applied to a non-wrapper function before, it
189  // should have failed earlier and not reach this point at all.
190  assert(wrapperType && "unexpected type conversion failure");
191 
193  filterFuncAttributes(funcOp, attributes);
194 
195  // Create the auxiliary function.
196  auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
197  loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
198  wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false,
199  /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
200  propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc);
201 
202  // The wrapper that we synthetize here should only be visible in this module.
203  newFuncOp.setLinkage(LLVM::Linkage::Private);
204  builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder));
205 
206  // Get a ValueRange containing arguments.
207  FunctionType type = cast<FunctionType>(funcOp.getFunctionType());
209  args.reserve(type.getNumInputs());
210  ValueRange wrapperArgsRange(newFuncOp.getArguments());
211 
212  if (resultStructType) {
213  // Allocate the struct on the stack and pass the pointer.
214  Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
215  Value one = builder.create<LLVM::ConstantOp>(
216  loc, typeConverter.convertType(builder.getIndexType()),
217  builder.getIntegerAttr(builder.getIndexType(), 1));
218  Value result =
219  builder.create<LLVM::AllocaOp>(loc, resultType, resultStructType, one);
220  args.push_back(result);
221  }
222 
223  // Iterate over the inputs of the original function and pack values into
224  // memref descriptors if the original type is a memref.
225  for (Type input : type.getInputs()) {
226  Value arg;
227  int numToDrop = 1;
228  auto memRefType = dyn_cast<MemRefType>(input);
229  auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(input);
230  if (memRefType || unrankedMemRefType) {
231  numToDrop = memRefType
234  Value packed =
235  memRefType
236  ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
237  wrapperArgsRange.take_front(numToDrop))
239  builder, loc, typeConverter, unrankedMemRefType,
240  wrapperArgsRange.take_front(numToDrop));
241 
242  auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
243  Value one = builder.create<LLVM::ConstantOp>(
244  loc, typeConverter.convertType(builder.getIndexType()),
245  builder.getIntegerAttr(builder.getIndexType(), 1));
246  Value allocated = builder.create<LLVM::AllocaOp>(
247  loc, ptrTy, packed.getType(), one, /*alignment=*/0);
248  builder.create<LLVM::StoreOp>(loc, packed, allocated);
249  arg = allocated;
250  } else {
251  arg = wrapperArgsRange[0];
252  }
253 
254  args.push_back(arg);
255  wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
256  }
257  assert(wrapperArgsRange.empty() && "did not map some of the arguments");
258 
259  auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
260 
261  if (resultStructType) {
262  Value result =
263  builder.create<LLVM::LoadOp>(loc, resultStructType, args.front());
264  builder.create<LLVM::ReturnOp>(loc, result);
265  } else {
266  builder.create<LLVM::ReturnOp>(loc, call.getResults());
267  }
268 }
269 
270 /// Inserts `llvm.load` ops in the function body to restore the expected pointee
271 /// value from `llvm.byval`/`llvm.byref` function arguments that were converted
272 /// to LLVM pointer types.
274  ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter,
275  ArrayRef<std::optional<NamedAttribute>> byValRefNonPtrAttrs,
276  ArrayRef<BlockArgument> oldBlockArgs, LLVM::LLVMFuncOp funcOp) {
277  // Nothing to do for function declarations.
278  if (funcOp.isExternal())
279  return;
280 
281  ConversionPatternRewriter::InsertionGuard guard(rewriter);
282  rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
283 
284  for (const auto &[arg, oldArg, byValRefAttr] :
285  llvm::zip(funcOp.getArguments(), oldBlockArgs, byValRefNonPtrAttrs)) {
286  // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
287  if (!byValRefAttr)
288  continue;
289 
290  // Insert load to retrieve the actual argument passed by value/reference.
291  assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
292  "Expected LLVM pointer type for argument with "
293  "`llvm.byval`/`llvm.byref` attribute");
294  Type resTy = typeConverter.convertType(
295  cast<TypeAttr>(byValRefAttr->getValue()).getValue());
296 
297  auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
298  rewriter.replaceUsesOfBlockArgument(oldArg, valueArg);
299  }
300 }
301 
302 FailureOr<LLVM::LLVMFuncOp>
303 mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
304  ConversionPatternRewriter &rewriter,
305  const LLVMTypeConverter &converter) {
306  // Check the funcOp has `FunctionType`.
307  auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
308  if (!funcTy)
309  return rewriter.notifyMatchFailure(
310  funcOp, "Only support FunctionOpInterface with FunctionType");
311 
312  // Keep track of the entry block arguments. They will be needed later.
313  SmallVector<BlockArgument> oldBlockArgs =
314  llvm::to_vector(funcOp.getArguments());
315 
316  // Convert the original function arguments. They are converted using the
317  // LLVMTypeConverter provided to this legalization pattern.
318  auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
319  // Gather `llvm.byval` and `llvm.byref` arguments whose type convertion was
320  // overriden with an LLVM pointer type for later processing.
321  SmallVector<std::optional<NamedAttribute>> byValRefNonPtrAttrs;
322  TypeConverter::SignatureConversion result(funcOp.getNumArguments());
323  auto llvmType = converter.convertFunctionSignature(
324  funcOp, varargsAttr && varargsAttr.getValue(),
325  shouldUseBarePtrCallConv(funcOp, &converter), result,
326  byValRefNonPtrAttrs);
327  if (!llvmType)
328  return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
329 
330  // Create an LLVM function, use external linkage by default until MLIR
331  // functions have linkage.
332  LLVM::Linkage linkage = LLVM::Linkage::External;
333  if (funcOp->hasAttr(linkageAttrName)) {
334  auto attr =
335  dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
336  if (!attr) {
337  funcOp->emitError() << "Contains " << linkageAttrName
338  << " attribute not of type LLVM::LinkageAttr";
339  return rewriter.notifyMatchFailure(
340  funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
341  }
342  linkage = attr.getLinkage();
343  }
344 
346  filterFuncAttributes(funcOp, attributes);
347  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
348  funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
349  /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
350  attributes);
351  cast<FunctionOpInterface>(newFuncOp.getOperation())
352  .setVisibility(funcOp.getVisibility());
353 
354  // Create a memory effect attribute corresponding to readnone.
355  StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
356  if (funcOp->hasAttr(readnoneAttrName)) {
357  auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
358  if (!attr) {
359  funcOp->emitError() << "Contains " << readnoneAttrName
360  << " attribute not of type UnitAttr";
361  return rewriter.notifyMatchFailure(
362  funcOp, "Contains readnone attribute not of type UnitAttr");
363  }
364  auto memoryAttr = LLVM::MemoryEffectsAttr::get(
365  rewriter.getContext(),
366  {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
367  LLVM::ModRefInfo::NoModRef});
368  newFuncOp.setMemoryEffectsAttr(memoryAttr);
369  }
370 
371  // Propagate argument/result attributes to all converted arguments/result
372  // obtained after converting a given original argument/result.
373  if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
374  assert(!resAttrDicts.empty() && "expected array to be non-empty");
375  if (funcOp.getNumResults() == 1)
376  newFuncOp.setAllResultAttrs(resAttrDicts);
377  }
378  if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
379  SmallVector<Attribute> newArgAttrs(
380  cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
381  for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
382  // Some LLVM IR attribute have a type attached to them. During FuncOp ->
383  // LLVMFuncOp conversion these types may have changed. Account for that
384  // change by converting attributes' types as well.
385  SmallVector<NamedAttribute, 4> convertedAttrs;
386  auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
387  convertedAttrs.reserve(attrsDict.size());
388  for (const NamedAttribute &attr : attrsDict) {
389  const auto convert = [&](const NamedAttribute &attr) {
390  return TypeAttr::get(converter.convertType(
391  cast<TypeAttr>(attr.getValue()).getValue()));
392  };
393  if (attr.getName().getValue() ==
394  LLVM::LLVMDialect::getByValAttrName()) {
395  convertedAttrs.push_back(rewriter.getNamedAttr(
396  LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
397  } else if (attr.getName().getValue() ==
398  LLVM::LLVMDialect::getByRefAttrName()) {
399  convertedAttrs.push_back(rewriter.getNamedAttr(
400  LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
401  } else if (attr.getName().getValue() ==
402  LLVM::LLVMDialect::getStructRetAttrName()) {
403  convertedAttrs.push_back(rewriter.getNamedAttr(
404  LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
405  } else if (attr.getName().getValue() ==
406  LLVM::LLVMDialect::getInAllocaAttrName()) {
407  convertedAttrs.push_back(rewriter.getNamedAttr(
408  LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
409  } else {
410  convertedAttrs.push_back(attr);
411  }
412  }
413  auto mapping = result.getInputMapping(i);
414  assert(mapping && "unexpected deletion of function argument");
415  // Only attach the new argument attributes if there is a one-to-one
416  // mapping from old to new types. Otherwise, attributes might be
417  // attached to types that they do not support.
418  if (mapping->size == 1) {
419  newArgAttrs[mapping->inputNo] =
420  DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
421  continue;
422  }
423  // TODO: Implement custom handling for types that expand to multiple
424  // function arguments.
425  for (size_t j = 0; j < mapping->size; ++j)
426  newArgAttrs[mapping->inputNo + j] =
427  DictionaryAttr::get(rewriter.getContext(), {});
428  }
429  if (!newArgAttrs.empty())
430  newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs));
431  }
432 
433  rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
434  newFuncOp.end());
435  // Convert just the entry block. The remaining unstructured control flow is
436  // converted by ControlFlowToLLVM.
437  if (!newFuncOp.getBody().empty())
438  rewriter.applySignatureConversion(&newFuncOp.getBody().front(), result,
439  &converter);
440 
441  // Fix the type mismatch between the materialized `llvm.ptr` and the expected
442  // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
443  // function arguments.
444  restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs,
445  oldBlockArgs, newFuncOp);
446 
447  if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
448  if (funcOp->getAttrOfType<UnitAttr>(
449  LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
450  if (newFuncOp.isVarArg())
451  return funcOp.emitError("C interface for variadic functions is not "
452  "supported yet.");
453 
454  if (newFuncOp.isExternal())
455  wrapExternalFunction(rewriter, funcOp->getLoc(), converter, funcOp,
456  newFuncOp);
457  else
458  wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
459  newFuncOp);
460  }
461  }
462 
463  return newFuncOp;
464 }
465 
466 namespace {
467 
468 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
469 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
470 /// information.
471 struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
472  FuncOpConversion(const LLVMTypeConverter &converter)
473  : ConvertOpToLLVMPattern(converter) {}
474 
475  LogicalResult
476  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
477  ConversionPatternRewriter &rewriter) const override {
478  FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
479  cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
480  *getTypeConverter());
481  if (failed(newFuncOp))
482  return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
483 
484  rewriter.eraseOp(funcOp);
485  return success();
486  }
487 };
488 
489 struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
491 
492  LogicalResult
493  matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
494  ConversionPatternRewriter &rewriter) const override {
495  auto type = typeConverter->convertType(op.getResult().getType());
496  if (!type || !LLVM::isCompatibleType(type))
497  return rewriter.notifyMatchFailure(op, "failed to convert result type");
498 
499  auto newOp =
500  rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
501  for (const NamedAttribute &attr : op->getAttrs()) {
502  if (attr.getName().strref() == "value")
503  continue;
504  newOp->setAttr(attr.getName(), attr.getValue());
505  }
506  rewriter.replaceOp(op, newOp->getResults());
507  return success();
508  }
509 };
510 
511 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
512 // passes the pointer to the MemRef across function boundaries.
513 template <typename CallOpType>
514 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
516  using Super = CallOpInterfaceLowering<CallOpType>;
518 
519  LogicalResult matchAndRewriteImpl(CallOpType callOp,
520  typename CallOpType::Adaptor adaptor,
521  ConversionPatternRewriter &rewriter,
522  bool useBarePtrCallConv = false) const {
523  // Pack the result types into a struct.
524  Type packedResult = nullptr;
525  unsigned numResults = callOp.getNumResults();
526  auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
527 
528  if (numResults != 0) {
529  if (!(packedResult = this->getTypeConverter()->packFunctionResults(
530  resultTypes, useBarePtrCallConv)))
531  return failure();
532  }
533 
534  if (useBarePtrCallConv) {
535  for (auto it : callOp->getOperands()) {
536  Type operandType = it.getType();
537  if (isa<UnrankedMemRefType>(operandType)) {
538  // Unranked memref is not supported in the bare pointer calling
539  // convention.
540  return failure();
541  }
542  }
543  }
544  auto promoted = this->getTypeConverter()->promoteOperands(
545  callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
546  adaptor.getOperands(), rewriter, useBarePtrCallConv);
547  auto newOp = rewriter.create<LLVM::CallOp>(
548  callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
549  promoted, callOp->getAttrs());
550 
551  newOp.getProperties().operandSegmentSizes = {
552  static_cast<int32_t>(promoted.size()), 0};
553  newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
554 
555  SmallVector<Value, 4> results;
556  if (numResults < 2) {
557  // If < 2 results, packing did not do anything and we can just return.
558  results.append(newOp.result_begin(), newOp.result_end());
559  } else {
560  // Otherwise, it had been converted to an operation producing a structure.
561  // Extract individual results from the structure and return them as list.
562  results.reserve(numResults);
563  for (unsigned i = 0; i < numResults; ++i) {
564  results.push_back(rewriter.create<LLVM::ExtractValueOp>(
565  callOp.getLoc(), newOp->getResult(0), i));
566  }
567  }
568 
569  if (useBarePtrCallConv) {
570  // For the bare-ptr calling convention, promote memref results to
571  // descriptors.
572  assert(results.size() == resultTypes.size() &&
573  "The number of arguments and types doesn't match");
574  this->getTypeConverter()->promoteBarePtrsToDescriptors(
575  rewriter, callOp.getLoc(), resultTypes, results);
576  } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
577  resultTypes, results,
578  /*toDynamic=*/false))) {
579  return failure();
580  }
581 
582  rewriter.replaceOp(callOp, results);
583  return success();
584  }
585 };
586 
587 class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
588 public:
589  CallOpLowering(const LLVMTypeConverter &typeConverter,
590  // Can be nullptr.
591  const SymbolTable *symbolTable, PatternBenefit benefit = 1)
592  : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
593  symbolTable(symbolTable) {}
594 
595  LogicalResult
596  matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
597  ConversionPatternRewriter &rewriter) const override {
598  bool useBarePtrCallConv = false;
599  if (getTypeConverter()->getOptions().useBarePtrCallConv) {
600  useBarePtrCallConv = true;
601  } else if (symbolTable != nullptr) {
602  // Fast lookup.
603  Operation *callee =
604  symbolTable->lookup(callOp.getCalleeAttr().getValue());
605  useBarePtrCallConv =
606  callee != nullptr && callee->hasAttr(barePtrAttrName);
607  } else {
608  // Warning: This is a linear lookup.
609  Operation *callee =
610  SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
611  useBarePtrCallConv =
612  callee != nullptr && callee->hasAttr(barePtrAttrName);
613  }
614  return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
615  }
616 
617 private:
618  const SymbolTable *symbolTable = nullptr;
619 };
620 
621 struct CallIndirectOpLowering
622  : public CallOpInterfaceLowering<func::CallIndirectOp> {
623  using Super::Super;
624 
625  LogicalResult
626  matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
627  ConversionPatternRewriter &rewriter) const override {
628  return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
629  }
630 };
631 
632 struct UnrealizedConversionCastOpLowering
633  : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
635  UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
636 
637  LogicalResult
638  matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
639  ConversionPatternRewriter &rewriter) const override {
640  SmallVector<Type> convertedTypes;
641  if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
642  convertedTypes)) &&
643  convertedTypes == adaptor.getInputs().getTypes()) {
644  rewriter.replaceOp(op, adaptor.getInputs());
645  return success();
646  }
647 
648  convertedTypes.clear();
649  if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
650  convertedTypes)) &&
651  convertedTypes == op.getOutputs().getType()) {
652  rewriter.replaceOp(op, adaptor.getInputs());
653  return success();
654  }
655  return failure();
656  }
657 };
658 
659 // Special lowering pattern for `ReturnOps`. Unlike all other operations,
660 // `ReturnOp` interacts with the function signature and must have as many
661 // operands as the function has return values. Because in LLVM IR, functions
662 // can only return 0 or 1 value, we pack multiple values into a structure type.
663 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
664 // necessary before returning it
665 struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
667 
668  LogicalResult
669  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
670  ConversionPatternRewriter &rewriter) const override {
671  Location loc = op.getLoc();
672  unsigned numArguments = op.getNumOperands();
673  SmallVector<Value, 4> updatedOperands;
674 
675  auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
676  bool useBarePtrCallConv =
677  shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
678  if (useBarePtrCallConv) {
679  // For the bare-ptr calling convention, extract the aligned pointer to
680  // be returned from the memref descriptor.
681  for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
682  Type oldTy = std::get<0>(it).getType();
683  Value newOperand = std::get<1>(it);
684  if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
685  cast<BaseMemRefType>(oldTy))) {
686  MemRefDescriptor memrefDesc(newOperand);
687  newOperand = memrefDesc.allocatedPtr(rewriter, loc);
688  } else if (isa<UnrankedMemRefType>(oldTy)) {
689  // Unranked memref is not supported in the bare pointer calling
690  // convention.
691  return failure();
692  }
693  updatedOperands.push_back(newOperand);
694  }
695  } else {
696  updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
697  (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
698  updatedOperands,
699  /*toDynamic=*/true);
700  }
701 
702  // If ReturnOp has 0 or 1 operand, create it and return immediately.
703  if (numArguments <= 1) {
704  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
705  op, TypeRange(), updatedOperands, op->getAttrs());
706  return success();
707  }
708 
709  // Otherwise, we need to pack the arguments into an LLVM struct type before
710  // returning.
711  auto packedType = getTypeConverter()->packFunctionResults(
712  op.getOperandTypes(), useBarePtrCallConv);
713  if (!packedType) {
714  return rewriter.notifyMatchFailure(op, "could not convert result types");
715  }
716 
717  Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
718  for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
719  packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
720  }
721  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
722  op->getAttrs());
723  return success();
724  }
725 };
726 } // namespace
727 
729  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
730  patterns.add<FuncOpConversion>(converter);
731 }
732 
734  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
735  const SymbolTable *symbolTable) {
737  patterns.add<CallIndirectOpLowering>(converter);
738  patterns.add<CallOpLowering>(converter, symbolTable);
739  patterns.add<ConstantOpLowering>(converter);
740  patterns.add<ReturnOpLowering>(converter);
741 }
742 
743 namespace {
744 /// A pass converting Func operations into the LLVM IR dialect.
745 struct ConvertFuncToLLVMPass
746  : public impl::ConvertFuncToLLVMPassBase<ConvertFuncToLLVMPass> {
747  using Base::Base;
748 
749  /// Run the dialect converter on the module.
750  void runOnOperation() override {
751  ModuleOp m = getOperation();
752  StringRef dataLayout;
753  auto dataLayoutAttr = dyn_cast_or_null<StringAttr>(
754  m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));
755  if (dataLayoutAttr)
756  dataLayout = dataLayoutAttr.getValue();
757 
758  if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
759  dataLayout, [this](const Twine &message) {
760  getOperation().emitError() << message.str();
761  }))) {
762  signalPassFailure();
763  return;
764  }
765 
766  const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
767 
769  dataLayoutAnalysis.getAtOrAbove(m));
770  options.useBarePtrCallConv = useBarePtrCallConv;
771  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
772  options.overrideIndexBitwidth(indexBitwidth);
773  options.dataLayout = llvm::DataLayout(dataLayout);
774 
775  LLVMTypeConverter typeConverter(&getContext(), options,
776  &dataLayoutAnalysis);
777 
778  std::optional<SymbolTable> optSymbolTable = std::nullopt;
779  const SymbolTable *symbolTable = nullptr;
780  if (!options.useBarePtrCallConv) {
781  optSymbolTable.emplace(m);
782  symbolTable = &optSymbolTable.value();
783  }
784 
786  populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable);
787 
789  if (failed(applyPartialConversion(m, target, std::move(patterns))))
790  signalPassFailure();
791  }
792 };
793 
794 struct SetLLVMModuleDataLayoutPass
795  : public impl::SetLLVMModuleDataLayoutPassBase<
796  SetLLVMModuleDataLayoutPass> {
797  using Base::Base;
798 
799  /// Run the dialect converter on the module.
800  void runOnOperation() override {
801  if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
802  this->dataLayout, [this](const Twine &message) {
803  getOperation().emitError() << message.str();
804  }))) {
805  signalPassFailure();
806  return;
807  }
808  ModuleOp m = getOperation();
809  m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
810  StringAttr::get(m.getContext(), this->dataLayout));
811  }
812 };
813 } // namespace
814 
815 //===----------------------------------------------------------------------===//
816 // ConvertToLLVMPatternInterface implementation
817 //===----------------------------------------------------------------------===//
818 
819 namespace {
820 /// Implement the interface to convert Func to LLVM.
821 struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
823  /// Hook for derived dialect interface to provide conversion patterns
824  /// and mark dialect legal for the conversion target.
825  void populateConvertToLLVMConversionPatterns(
826  ConversionTarget &target, LLVMTypeConverter &typeConverter,
827  RewritePatternSet &patterns) const final {
829  }
830 };
831 } // namespace
832 
834  registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
835  dialect->addInterfaces<FuncToLLVMDialectInterface>();
836  });
837 }
static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType, FunctionOpInterface funcOp, LLVM::LLVMFuncOp wrapperFuncOp)
Propagate argument/results attributes.
Definition: FuncToLLVM.cpp:87
static constexpr StringRef barePtrAttrName
Definition: FuncToLLVM.cpp:64
static constexpr StringRef varargsAttrName
Definition: FuncToLLVM.cpp:62
static constexpr StringRef linkageAttrName
Definition: FuncToLLVM.cpp:63
static void filterFuncAttributes(FunctionOpInterface func, SmallVectorImpl< NamedAttribute > &result)
Only retain those attributes that are not constructed by LLVMFuncOp::build.
Definition: FuncToLLVM.cpp:75
static bool shouldUseBarePtrCallConv(Operation *op, const LLVMTypeConverter *typeConverter)
Return true if the op should use bare pointer calling convention.
Definition: FuncToLLVM.cpp:67
static void restoreByValRefArgumentType(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, ArrayRef< std::optional< NamedAttribute >> byValRefNonPtrAttrs, ArrayRef< BlockArgument > oldBlockArgs, LLVM::LLVMFuncOp funcOp)
Inserts llvm.load ops in the function body to restore the expected pointee value from llvm....
Definition: FuncToLLVM.cpp:273
static void wrapExternalFunction(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, FunctionOpInterface funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
Definition: FuncToLLVM.cpp:178
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, const LLVMTypeConverter &typeConverter, FunctionOpInterface funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
Definition: FuncToLLVM.cpp:118
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:203
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
IndexType getIndexType()
Definition: Builders.cpp:95
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:144
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:134
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
const LowerToLLVMOptions & getOptions() const
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
static unsigned getNumUnpackedValues(MemRefType type)
Returns the number of non-aggregate values that would be produced by unpack.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides all of the information necessary to convert a type signature.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
static unsigned getNumUnpackedValues()
Returns the number of non-aggregate values that would be produced by unpack.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:860
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFuncToLLVMFuncOpConversionPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the default pattern to convert a FuncOp to the LLVM dialect.
Definition: FuncToLLVM.cpp:728
void registerConvertFuncToLLVMInterface(DialectRegistry &registry)
Definition: FuncToLLVM.cpp:833
const FrozenRewritePatternSet & patterns
FailureOr< LLVM::LLVMFuncOp > convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter)
Convert input FunctionOpInterface operation to LLVMFuncOp by using the provided LLVMTypeConverter.
Definition: FuncToLLVM.cpp:303
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:733
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.