MLIR  21.0.0git
FuncToLLVM.cpp
Go to the documentation of this file.
1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR Func and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
27 #include "mlir/IR/Attributes.h"
28 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinOps.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/SymbolTable.h"
33 #include "mlir/IR/TypeUtilities.h"
35 #include "mlir/Transforms/Passes.h"
36 #include "llvm/ADT/SmallVector.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/Support/FormatVariadic.h"
39 #include <optional>
40 
41 namespace mlir {
42 #define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
43 #define GEN_PASS_DEF_SETLLVMMODULEDATALAYOUTPASS
44 #include "mlir/Conversion/Passes.h.inc"
45 } // namespace mlir
46 
47 using namespace mlir;
48 
49 #define PASS_NAME "convert-func-to-llvm"
50 
51 static constexpr StringRef varargsAttrName = "func.varargs";
52 static constexpr StringRef linkageAttrName = "llvm.linkage";
53 static constexpr StringRef barePtrAttrName = "llvm.bareptr";
54 
55 /// Return `true` if the `op` should use bare pointer calling convention.
57  const LLVMTypeConverter *typeConverter) {
58  return (op && op->hasAttr(barePtrAttrName)) ||
59  typeConverter->getOptions().useBarePtrCallConv;
60 }
61 
62 /// Only retain those attributes that are not constructed by
63 /// `LLVMFuncOp::build`.
64 static void filterFuncAttributes(FunctionOpInterface func,
66  for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
67  if (attr.getName() == linkageAttrName ||
68  attr.getName() == varargsAttrName ||
69  attr.getName() == LLVM::LLVMDialect::getReadnoneAttrName())
70  continue;
71  result.push_back(attr);
72  }
73 }
74 
75 /// Propagate argument/results attributes.
76 static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
77  FunctionOpInterface funcOp,
78  LLVM::LLVMFuncOp wrapperFuncOp) {
79  auto argAttrs = funcOp.getAllArgAttrs();
80  if (!resultStructType) {
81  if (auto resAttrs = funcOp.getAllResultAttrs())
82  wrapperFuncOp.setAllResultAttrs(resAttrs);
83  if (argAttrs)
84  wrapperFuncOp.setAllArgAttrs(argAttrs);
85  } else {
86  SmallVector<Attribute> argAttributes;
87  // Only modify the argument and result attributes when the result is now
88  // an argument.
89  if (argAttrs) {
90  argAttributes.push_back(builder.getDictionaryAttr({}));
91  argAttributes.append(argAttrs.begin(), argAttrs.end());
92  wrapperFuncOp.setAllArgAttrs(argAttributes);
93  }
94  }
95  cast<FunctionOpInterface>(wrapperFuncOp.getOperation())
96  .setVisibility(funcOp.getVisibility());
97 }
98 
99 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
100 /// arguments instead of unpacked arguments. This function can be called from C
101 /// by passing a pointer to a C struct corresponding to a memref descriptor.
102 /// Similarly, returned memrefs are passed via pointers to a C struct that is
103 /// passed as additional argument.
104 /// Internally, the auxiliary function unpacks the descriptor into individual
105 /// components and forwards them to `newFuncOp` and forwards the results to
106 /// the extra arguments.
107 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
108  const LLVMTypeConverter &typeConverter,
109  FunctionOpInterface funcOp,
110  LLVM::LLVMFuncOp newFuncOp) {
111  auto type = cast<FunctionType>(funcOp.getFunctionType());
112  auto [wrapperFuncType, resultStructType] =
113  typeConverter.convertFunctionTypeCWrapper(type);
114 
115  SmallVector<NamedAttribute> attributes;
116  filterFuncAttributes(funcOp, attributes);
117 
118  auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
119  loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
120  wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false,
121  /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
122  propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp);
123 
124  OpBuilder::InsertionGuard guard(rewriter);
125  rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(rewriter));
126 
128  size_t argOffset = resultStructType ? 1 : 0;
129  for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
130  Value arg = wrapperFuncOp.getArgument(index + argOffset);
131  if (auto memrefType = dyn_cast<MemRefType>(argType)) {
132  Value loaded = rewriter.create<LLVM::LoadOp>(
133  loc, typeConverter.convertType(memrefType), arg);
134  MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
135  continue;
136  }
137  if (isa<UnrankedMemRefType>(argType)) {
138  Value loaded = rewriter.create<LLVM::LoadOp>(
139  loc, typeConverter.convertType(argType), arg);
140  UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
141  continue;
142  }
143 
144  args.push_back(arg);
145  }
146 
147  auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
148 
149  if (resultStructType) {
150  rewriter.create<LLVM::StoreOp>(loc, call.getResult(),
151  wrapperFuncOp.getArgument(0));
152  rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
153  } else {
154  rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
155  }
156 }
157 
158 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
159 /// arguments instead of unpacked arguments. Creates a body for the (external)
160 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
161 /// individual arguments into this descriptor and passes a pointer to it into
162 /// the auxiliary function. If the result of the function cannot be directly
163 /// returned, we write it to a special first argument that provides a pointer
164 /// to a corresponding struct. This auxiliary external function is now
165 /// compatible with functions defined in C using pointers to C structs
166 /// corresponding to a memref descriptor.
167 static void wrapExternalFunction(OpBuilder &builder, Location loc,
168  const LLVMTypeConverter &typeConverter,
169  FunctionOpInterface funcOp,
170  LLVM::LLVMFuncOp newFuncOp) {
171  OpBuilder::InsertionGuard guard(builder);
172 
173  auto [wrapperType, resultStructType] =
174  typeConverter.convertFunctionTypeCWrapper(
175  cast<FunctionType>(funcOp.getFunctionType()));
176  // This conversion can only fail if it could not convert one of the argument
177  // types. But since it has been applied to a non-wrapper function before, it
178  // should have failed earlier and not reach this point at all.
179  assert(wrapperType && "unexpected type conversion failure");
180 
182  filterFuncAttributes(funcOp, attributes);
183 
184  // Create the auxiliary function.
185  auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
186  loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
187  wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false,
188  /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
189  propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc);
190 
191  // The wrapper that we synthetize here should only be visible in this module.
192  newFuncOp.setLinkage(LLVM::Linkage::Private);
193  builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder));
194 
195  // Get a ValueRange containing arguments.
196  FunctionType type = cast<FunctionType>(funcOp.getFunctionType());
198  args.reserve(type.getNumInputs());
199  ValueRange wrapperArgsRange(newFuncOp.getArguments());
200 
201  if (resultStructType) {
202  // Allocate the struct on the stack and pass the pointer.
203  Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
204  Value one = builder.create<LLVM::ConstantOp>(
205  loc, typeConverter.convertType(builder.getIndexType()),
206  builder.getIntegerAttr(builder.getIndexType(), 1));
207  Value result =
208  builder.create<LLVM::AllocaOp>(loc, resultType, resultStructType, one);
209  args.push_back(result);
210  }
211 
212  // Iterate over the inputs of the original function and pack values into
213  // memref descriptors if the original type is a memref.
214  for (Type input : type.getInputs()) {
215  Value arg;
216  int numToDrop = 1;
217  auto memRefType = dyn_cast<MemRefType>(input);
218  auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(input);
219  if (memRefType || unrankedMemRefType) {
220  numToDrop = memRefType
223  Value packed =
224  memRefType
225  ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
226  wrapperArgsRange.take_front(numToDrop))
228  builder, loc, typeConverter, unrankedMemRefType,
229  wrapperArgsRange.take_front(numToDrop));
230 
231  auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
232  Value one = builder.create<LLVM::ConstantOp>(
233  loc, typeConverter.convertType(builder.getIndexType()),
234  builder.getIntegerAttr(builder.getIndexType(), 1));
235  Value allocated = builder.create<LLVM::AllocaOp>(
236  loc, ptrTy, packed.getType(), one, /*alignment=*/0);
237  builder.create<LLVM::StoreOp>(loc, packed, allocated);
238  arg = allocated;
239  } else {
240  arg = wrapperArgsRange[0];
241  }
242 
243  args.push_back(arg);
244  wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
245  }
246  assert(wrapperArgsRange.empty() && "did not map some of the arguments");
247 
248  auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
249 
250  if (resultStructType) {
251  Value result =
252  builder.create<LLVM::LoadOp>(loc, resultStructType, args.front());
253  builder.create<LLVM::ReturnOp>(loc, result);
254  } else {
255  builder.create<LLVM::ReturnOp>(loc, call.getResults());
256  }
257 }
258 
259 /// Inserts `llvm.load` ops in the function body to restore the expected pointee
260 /// value from `llvm.byval`/`llvm.byref` function arguments that were converted
261 /// to LLVM pointer types.
263  ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter,
264  ArrayRef<std::optional<NamedAttribute>> byValRefNonPtrAttrs,
265  LLVM::LLVMFuncOp funcOp) {
266  // Nothing to do for function declarations.
267  if (funcOp.isExternal())
268  return;
269 
270  ConversionPatternRewriter::InsertionGuard guard(rewriter);
271  rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
272 
273  for (const auto &[arg, byValRefAttr] :
274  llvm::zip(funcOp.getArguments(), byValRefNonPtrAttrs)) {
275  // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
276  if (!byValRefAttr)
277  continue;
278 
279  // Insert load to retrieve the actual argument passed by value/reference.
280  assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
281  "Expected LLVM pointer type for argument with "
282  "`llvm.byval`/`llvm.byref` attribute");
283  Type resTy = typeConverter.convertType(
284  cast<TypeAttr>(byValRefAttr->getValue()).getValue());
285 
286  Value valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
287  rewriter.replaceUsesOfBlockArgument(arg, valueArg);
288  }
289 }
290 
291 FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
292  FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter,
293  const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables) {
294  // Check the funcOp has `FunctionType`.
295  auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
296  if (!funcTy)
297  return rewriter.notifyMatchFailure(
298  funcOp, "Only support FunctionOpInterface with FunctionType");
299 
300  // Convert the original function arguments. They are converted using the
301  // LLVMTypeConverter provided to this legalization pattern.
302  auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
303  // Gather `llvm.byval` and `llvm.byref` arguments whose type convertion was
304  // overriden with an LLVM pointer type for later processing.
305  SmallVector<std::optional<NamedAttribute>> byValRefNonPtrAttrs;
306  TypeConverter::SignatureConversion result(funcOp.getNumArguments());
307  auto llvmType = dyn_cast_or_null<LLVM::LLVMFunctionType>(
308  converter.convertFunctionSignature(
309  funcOp, varargsAttr && varargsAttr.getValue(),
310  shouldUseBarePtrCallConv(funcOp, &converter), result,
311  byValRefNonPtrAttrs));
312  if (!llvmType)
313  return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
314 
315  // Check for unsupported variadic functions.
316  if (!shouldUseBarePtrCallConv(funcOp, &converter))
317  if (funcOp->getAttrOfType<UnitAttr>(
318  LLVM::LLVMDialect::getEmitCWrapperAttrName()))
319  if (llvmType.isVarArg())
320  return funcOp.emitError("C interface for variadic functions is not "
321  "supported yet.");
322 
323  // Create an LLVM function, use external linkage by default until MLIR
324  // functions have linkage.
325  LLVM::Linkage linkage = LLVM::Linkage::External;
326  if (funcOp->hasAttr(linkageAttrName)) {
327  auto attr =
328  dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
329  if (!attr) {
330  funcOp->emitError() << "Contains " << linkageAttrName
331  << " attribute not of type LLVM::LinkageAttr";
332  return rewriter.notifyMatchFailure(
333  funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
334  }
335  linkage = attr.getLinkage();
336  }
337 
338  // Check for invalid attributes.
339  StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
340  if (funcOp->hasAttr(readnoneAttrName)) {
341  auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
342  if (!attr) {
343  funcOp->emitError() << "Contains " << readnoneAttrName
344  << " attribute not of type UnitAttr";
345  return rewriter.notifyMatchFailure(
346  funcOp, "Contains readnone attribute not of type UnitAttr");
347  }
348  }
349 
351  filterFuncAttributes(funcOp, attributes);
352 
353  Operation *symbolTableOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
354 
355  if (symbolTables && symbolTableOp) {
356  SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
357  symbolTable.remove(funcOp);
358  }
359 
360  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
361  funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
362  /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
363  attributes);
364 
365  if (symbolTables && symbolTableOp) {
366  auto ip = rewriter.getInsertionPoint();
367  SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
368  symbolTable.insert(newFuncOp, ip);
369  }
370 
371  cast<FunctionOpInterface>(newFuncOp.getOperation())
372  .setVisibility(funcOp.getVisibility());
373 
374  // Create a memory effect attribute corresponding to readnone.
375  if (funcOp->hasAttr(readnoneAttrName)) {
376  auto memoryAttr = LLVM::MemoryEffectsAttr::get(
377  rewriter.getContext(),
378  {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
379  LLVM::ModRefInfo::NoModRef});
380  newFuncOp.setMemoryEffectsAttr(memoryAttr);
381  }
382 
383  // Propagate argument/result attributes to all converted arguments/result
384  // obtained after converting a given original argument/result.
385  if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
386  assert(!resAttrDicts.empty() && "expected array to be non-empty");
387  if (funcOp.getNumResults() == 1)
388  newFuncOp.setAllResultAttrs(resAttrDicts);
389  }
390  if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
391  SmallVector<Attribute> newArgAttrs(
392  cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
393  for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
394  // Some LLVM IR attribute have a type attached to them. During FuncOp ->
395  // LLVMFuncOp conversion these types may have changed. Account for that
396  // change by converting attributes' types as well.
397  SmallVector<NamedAttribute, 4> convertedAttrs;
398  auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
399  convertedAttrs.reserve(attrsDict.size());
400  for (const NamedAttribute &attr : attrsDict) {
401  const auto convert = [&](const NamedAttribute &attr) {
402  return TypeAttr::get(converter.convertType(
403  cast<TypeAttr>(attr.getValue()).getValue()));
404  };
405  if (attr.getName().getValue() ==
406  LLVM::LLVMDialect::getByValAttrName()) {
407  convertedAttrs.push_back(rewriter.getNamedAttr(
408  LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
409  } else if (attr.getName().getValue() ==
410  LLVM::LLVMDialect::getByRefAttrName()) {
411  convertedAttrs.push_back(rewriter.getNamedAttr(
412  LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
413  } else if (attr.getName().getValue() ==
414  LLVM::LLVMDialect::getStructRetAttrName()) {
415  convertedAttrs.push_back(rewriter.getNamedAttr(
416  LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
417  } else if (attr.getName().getValue() ==
418  LLVM::LLVMDialect::getInAllocaAttrName()) {
419  convertedAttrs.push_back(rewriter.getNamedAttr(
420  LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
421  } else {
422  convertedAttrs.push_back(attr);
423  }
424  }
425  auto mapping = result.getInputMapping(i);
426  assert(mapping && "unexpected deletion of function argument");
427  // Only attach the new argument attributes if there is a one-to-one
428  // mapping from old to new types. Otherwise, attributes might be
429  // attached to types that they do not support.
430  if (mapping->size == 1) {
431  newArgAttrs[mapping->inputNo] =
432  DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
433  continue;
434  }
435  // TODO: Implement custom handling for types that expand to multiple
436  // function arguments.
437  for (size_t j = 0; j < mapping->size; ++j)
438  newArgAttrs[mapping->inputNo + j] =
439  DictionaryAttr::get(rewriter.getContext(), {});
440  }
441  if (!newArgAttrs.empty())
442  newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs));
443  }
444 
445  rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
446  newFuncOp.end());
447  // Convert just the entry block. The remaining unstructured control flow is
448  // converted by ControlFlowToLLVM.
449  if (!newFuncOp.getBody().empty())
450  rewriter.applySignatureConversion(&newFuncOp.getBody().front(), result,
451  &converter);
452 
453  // Fix the type mismatch between the materialized `llvm.ptr` and the expected
454  // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
455  // function arguments.
456  restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs,
457  newFuncOp);
458 
459  if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
460  if (funcOp->getAttrOfType<UnitAttr>(
461  LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
462  if (newFuncOp.isExternal())
463  wrapExternalFunction(rewriter, funcOp->getLoc(), converter, funcOp,
464  newFuncOp);
465  else
466  wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
467  newFuncOp);
468  }
469  }
470 
471  return newFuncOp;
472 }
473 
474 namespace {
475 
476 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
477 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
478 /// information.
479 class FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
480  SymbolTableCollection *symbolTables = nullptr;
481 
482 public:
483  explicit FuncOpConversion(const LLVMTypeConverter &converter,
484  SymbolTableCollection *symbolTables = nullptr)
485  : ConvertOpToLLVMPattern(converter), symbolTables(symbolTables) {}
486 
487  LogicalResult
488  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
489  ConversionPatternRewriter &rewriter) const override {
490  FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
491  cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
492  *getTypeConverter(), symbolTables);
493  if (failed(newFuncOp))
494  return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
495 
496  rewriter.eraseOp(funcOp);
497  return success();
498  }
499 };
500 
501 struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
503 
504  LogicalResult
505  matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
506  ConversionPatternRewriter &rewriter) const override {
507  auto type = typeConverter->convertType(op.getResult().getType());
508  if (!type || !LLVM::isCompatibleType(type))
509  return rewriter.notifyMatchFailure(op, "failed to convert result type");
510 
511  auto newOp =
512  rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
513  for (const NamedAttribute &attr : op->getAttrs()) {
514  if (attr.getName().strref() == "value")
515  continue;
516  newOp->setAttr(attr.getName(), attr.getValue());
517  }
518  rewriter.replaceOp(op, newOp->getResults());
519  return success();
520  }
521 };
522 
523 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
524 // passes the pointer to the MemRef across function boundaries.
525 template <typename CallOpType>
526 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
528  using Super = CallOpInterfaceLowering<CallOpType>;
530 
531  LogicalResult matchAndRewriteImpl(CallOpType callOp,
532  typename CallOpType::Adaptor adaptor,
533  ConversionPatternRewriter &rewriter,
534  bool useBarePtrCallConv = false) const {
535  // Pack the result types into a struct.
536  Type packedResult = nullptr;
537  unsigned numResults = callOp.getNumResults();
538  auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
539 
540  if (numResults != 0) {
541  if (!(packedResult = this->getTypeConverter()->packFunctionResults(
542  resultTypes, useBarePtrCallConv)))
543  return failure();
544  }
545 
546  if (useBarePtrCallConv) {
547  for (auto it : callOp->getOperands()) {
548  Type operandType = it.getType();
549  if (isa<UnrankedMemRefType>(operandType)) {
550  // Unranked memref is not supported in the bare pointer calling
551  // convention.
552  return failure();
553  }
554  }
555  }
556  auto promoted = this->getTypeConverter()->promoteOperands(
557  callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
558  adaptor.getOperands(), rewriter, useBarePtrCallConv);
559  auto newOp = rewriter.create<LLVM::CallOp>(
560  callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
561  promoted, callOp->getAttrs());
562 
563  newOp.getProperties().operandSegmentSizes = {
564  static_cast<int32_t>(promoted.size()), 0};
565  newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
566 
567  SmallVector<Value, 4> results;
568  if (numResults < 2) {
569  // If < 2 results, packing did not do anything and we can just return.
570  results.append(newOp.result_begin(), newOp.result_end());
571  } else {
572  // Otherwise, it had been converted to an operation producing a structure.
573  // Extract individual results from the structure and return them as list.
574  results.reserve(numResults);
575  for (unsigned i = 0; i < numResults; ++i) {
576  results.push_back(rewriter.create<LLVM::ExtractValueOp>(
577  callOp.getLoc(), newOp->getResult(0), i));
578  }
579  }
580 
581  if (useBarePtrCallConv) {
582  // For the bare-ptr calling convention, promote memref results to
583  // descriptors.
584  assert(results.size() == resultTypes.size() &&
585  "The number of arguments and types doesn't match");
586  this->getTypeConverter()->promoteBarePtrsToDescriptors(
587  rewriter, callOp.getLoc(), resultTypes, results);
588  } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
589  resultTypes, results,
590  /*toDynamic=*/false))) {
591  return failure();
592  }
593 
594  rewriter.replaceOp(callOp, results);
595  return success();
596  }
597 };
598 
599 class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
600 public:
601  explicit CallOpLowering(const LLVMTypeConverter &typeConverter,
602  SymbolTableCollection *symbolTables = nullptr,
603  PatternBenefit benefit = 1)
604  : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
605  symbolTables(symbolTables) {}
606 
607  LogicalResult
608  matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
609  ConversionPatternRewriter &rewriter) const override {
610  bool useBarePtrCallConv = false;
611  if (getTypeConverter()->getOptions().useBarePtrCallConv) {
612  useBarePtrCallConv = true;
613  } else if (symbolTables != nullptr) {
614  // Fast lookup.
615  Operation *callee =
616  symbolTables->lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
617  useBarePtrCallConv =
618  callee != nullptr && callee->hasAttr(barePtrAttrName);
619  } else {
620  // Warning: This is a linear lookup.
621  Operation *callee =
622  SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
623  useBarePtrCallConv =
624  callee != nullptr && callee->hasAttr(barePtrAttrName);
625  }
626  return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
627  }
628 
629 private:
630  SymbolTableCollection *symbolTables = nullptr;
631 };
632 
633 struct CallIndirectOpLowering
634  : public CallOpInterfaceLowering<func::CallIndirectOp> {
635  using Super::Super;
636 
637  LogicalResult
638  matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
639  ConversionPatternRewriter &rewriter) const override {
640  return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
641  }
642 };
643 
644 struct UnrealizedConversionCastOpLowering
645  : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
647  UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
648 
649  LogicalResult
650  matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
651  ConversionPatternRewriter &rewriter) const override {
652  SmallVector<Type> convertedTypes;
653  if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
654  convertedTypes)) &&
655  convertedTypes == adaptor.getInputs().getTypes()) {
656  rewriter.replaceOp(op, adaptor.getInputs());
657  return success();
658  }
659 
660  convertedTypes.clear();
661  if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
662  convertedTypes)) &&
663  convertedTypes == op.getOutputs().getType()) {
664  rewriter.replaceOp(op, adaptor.getInputs());
665  return success();
666  }
667  return failure();
668  }
669 };
670 
671 // Special lowering pattern for `ReturnOps`. Unlike all other operations,
672 // `ReturnOp` interacts with the function signature and must have as many
673 // operands as the function has return values. Because in LLVM IR, functions
674 // can only return 0 or 1 value, we pack multiple values into a structure type.
675 // Emit `PoisonOp` followed by `InsertValueOp`s to create such structure if
676 // necessary before returning it
677 struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
679 
680  LogicalResult
681  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
682  ConversionPatternRewriter &rewriter) const override {
683  Location loc = op.getLoc();
684  unsigned numArguments = op.getNumOperands();
685  SmallVector<Value, 4> updatedOperands;
686 
687  auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
688  bool useBarePtrCallConv =
689  shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
690  if (useBarePtrCallConv) {
691  // For the bare-ptr calling convention, extract the aligned pointer to
692  // be returned from the memref descriptor.
693  for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
694  Type oldTy = std::get<0>(it).getType();
695  Value newOperand = std::get<1>(it);
696  if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
697  cast<BaseMemRefType>(oldTy))) {
698  MemRefDescriptor memrefDesc(newOperand);
699  newOperand = memrefDesc.allocatedPtr(rewriter, loc);
700  } else if (isa<UnrankedMemRefType>(oldTy)) {
701  // Unranked memref is not supported in the bare pointer calling
702  // convention.
703  return failure();
704  }
705  updatedOperands.push_back(newOperand);
706  }
707  } else {
708  updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
709  (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
710  updatedOperands,
711  /*toDynamic=*/true);
712  }
713 
714  // If ReturnOp has 0 or 1 operand, create it and return immediately.
715  if (numArguments <= 1) {
716  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
717  op, TypeRange(), updatedOperands, op->getAttrs());
718  return success();
719  }
720 
721  // Otherwise, we need to pack the arguments into an LLVM struct type before
722  // returning.
723  auto packedType = getTypeConverter()->packFunctionResults(
724  op.getOperandTypes(), useBarePtrCallConv);
725  if (!packedType) {
726  return rewriter.notifyMatchFailure(op, "could not convert result types");
727  }
728 
729  Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
730  for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
731  packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
732  }
733  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
734  op->getAttrs());
735  return success();
736  }
737 };
738 } // namespace
739 
741  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
742  SymbolTableCollection *symbolTables) {
743  patterns.add<FuncOpConversion>(converter, symbolTables);
744 }
745 
747  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
748  SymbolTableCollection *symbolTables) {
749  populateFuncToLLVMFuncOpConversionPattern(converter, patterns, symbolTables);
750  patterns.add<CallIndirectOpLowering>(converter);
751  patterns.add<CallOpLowering>(converter, symbolTables);
752  patterns.add<ConstantOpLowering>(converter);
753  patterns.add<ReturnOpLowering>(converter);
754 }
755 
756 namespace {
757 /// A pass converting Func operations into the LLVM IR dialect.
758 struct ConvertFuncToLLVMPass
759  : public impl::ConvertFuncToLLVMPassBase<ConvertFuncToLLVMPass> {
760  using Base::Base;
761 
762  /// Run the dialect converter on the module.
763  void runOnOperation() override {
764  ModuleOp m = getOperation();
765  StringRef dataLayout;
766  auto dataLayoutAttr = dyn_cast_or_null<StringAttr>(
767  m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));
768  if (dataLayoutAttr)
769  dataLayout = dataLayoutAttr.getValue();
770 
771  if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
772  dataLayout, [this](const Twine &message) {
773  getOperation().emitError() << message.str();
774  }))) {
775  signalPassFailure();
776  return;
777  }
778 
779  const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
780 
782  dataLayoutAnalysis.getAtOrAbove(m));
783  options.useBarePtrCallConv = useBarePtrCallConv;
784  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
785  options.overrideIndexBitwidth(indexBitwidth);
786  options.dataLayout = llvm::DataLayout(dataLayout);
787 
788  LLVMTypeConverter typeConverter(&getContext(), options,
789  &dataLayoutAnalysis);
790 
792  SymbolTableCollection symbolTables;
793 
795  &symbolTables);
796 
798  if (failed(applyPartialConversion(m, target, std::move(patterns))))
799  signalPassFailure();
800  }
801 };
802 
803 struct SetLLVMModuleDataLayoutPass
804  : public impl::SetLLVMModuleDataLayoutPassBase<
805  SetLLVMModuleDataLayoutPass> {
806  using Base::Base;
807 
808  /// Run the dialect converter on the module.
809  void runOnOperation() override {
810  if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
811  this->dataLayout, [this](const Twine &message) {
812  getOperation().emitError() << message.str();
813  }))) {
814  signalPassFailure();
815  return;
816  }
817  ModuleOp m = getOperation();
818  m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
819  StringAttr::get(m.getContext(), this->dataLayout));
820  }
821 };
822 } // namespace
823 
824 //===----------------------------------------------------------------------===//
825 // ConvertToLLVMPatternInterface implementation
826 //===----------------------------------------------------------------------===//
827 
828 namespace {
829 /// Implement the interface to convert Func to LLVM.
830 struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
832  /// Hook for derived dialect interface to provide conversion patterns
833  /// and mark dialect legal for the conversion target.
834  void populateConvertToLLVMConversionPatterns(
835  ConversionTarget &target, LLVMTypeConverter &typeConverter,
836  RewritePatternSet &patterns) const final {
838  }
839 };
840 } // namespace
841 
843  registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
844  dialect->addInterfaces<FuncToLLVMDialectInterface>();
845  });
846 }
static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType, FunctionOpInterface funcOp, LLVM::LLVMFuncOp wrapperFuncOp)
Propagate argument/results attributes.
Definition: FuncToLLVM.cpp:76
static void restoreByValRefArgumentType(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, ArrayRef< std::optional< NamedAttribute >> byValRefNonPtrAttrs, LLVM::LLVMFuncOp funcOp)
Inserts llvm.load ops in the function body to restore the expected pointee value from llvm....
Definition: FuncToLLVM.cpp:262
static constexpr StringRef barePtrAttrName
Definition: FuncToLLVM.cpp:53
static constexpr StringRef varargsAttrName
Definition: FuncToLLVM.cpp:51
static constexpr StringRef linkageAttrName
Definition: FuncToLLVM.cpp:52
static void filterFuncAttributes(FunctionOpInterface func, SmallVectorImpl< NamedAttribute > &result)
Only retain those attributes that are not constructed by LLVMFuncOp::build.
Definition: FuncToLLVM.cpp:64
static bool shouldUseBarePtrCallConv(Operation *op, const LLVMTypeConverter *typeConverter)
Return true if the op should use bare pointer calling convention.
Definition: FuncToLLVM.cpp:56
static void wrapExternalFunction(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, FunctionOpInterface funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
Definition: FuncToLLVM.cpp:167
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, const LLVMTypeConverter &typeConverter, FunctionOpInterface funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
Definition: FuncToLLVM.cpp:107
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:158
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
IndexType getIndexType()
Definition: Builders.cpp:50
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:99
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:89
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to)
Replace all the uses of the block argument from with to.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:199
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
const LowerToLLVMOptions & getOptions() const
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static unsigned getNumUnpackedValues(MemRefType type)
Returns the number of non-aggregate values that would be produced by unpack.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
void remove(Operation *op)
Remove the given symbol from the table, without deleting it.
This class provides all of the information necessary to convert a type signature.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
static unsigned getNumUnpackedValues()
Returns the number of non-aggregate values that would be produced by unpack.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:796
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void registerConvertFuncToLLVMInterface(DialectRegistry &registry)
Definition: FuncToLLVM.cpp:842
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:746
void populateFuncToLLVMFuncOpConversionPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables=nullptr)
Collect the default pattern to convert a FuncOp to the LLVM dialect.
Definition: FuncToLLVM.cpp:740
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
FailureOr< LLVM::LLVMFuncOp > convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables=nullptr)
Convert input FunctionOpInterface operation to LLVMFuncOp by using the provided LLVMTypeConverter.
Definition: FuncToLLVM.cpp:291
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.