MLIR  21.0.0git
FuncToLLVM.cpp
Go to the documentation of this file.
1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR Func and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/Builders.h"
33 #include "mlir/IR/BuiltinOps.h"
34 #include "mlir/IR/IRMapping.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/IR/SymbolTable.h"
37 #include "mlir/IR/TypeUtilities.h"
39 #include "mlir/Transforms/Passes.h"
40 #include "llvm/ADT/SmallVector.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/IR/DerivedTypes.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/Type.h"
45 #include "llvm/Support/Casting.h"
46 #include "llvm/Support/CommandLine.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include <algorithm>
49 #include <functional>
50 #include <optional>
51 
52 namespace mlir {
53 #define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
54 #define GEN_PASS_DEF_SETLLVMMODULEDATALAYOUTPASS
55 #include "mlir/Conversion/Passes.h.inc"
56 } // namespace mlir
57 
58 using namespace mlir;
59 
60 #define PASS_NAME "convert-func-to-llvm"
61 
62 static constexpr StringRef varargsAttrName = "func.varargs";
63 static constexpr StringRef linkageAttrName = "llvm.linkage";
64 static constexpr StringRef barePtrAttrName = "llvm.bareptr";
65 
66 /// Return `true` if the `op` should use bare pointer calling convention.
68  const LLVMTypeConverter *typeConverter) {
69  return (op && op->hasAttr(barePtrAttrName)) ||
70  typeConverter->getOptions().useBarePtrCallConv;
71 }
72 
73 /// Only retain those attributes that are not constructed by
74 /// `LLVMFuncOp::build`.
75 static void filterFuncAttributes(FunctionOpInterface func,
77  for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
78  if (attr.getName() == linkageAttrName ||
79  attr.getName() == varargsAttrName ||
80  attr.getName() == LLVM::LLVMDialect::getReadnoneAttrName())
81  continue;
82  result.push_back(attr);
83  }
84 }
85 
86 /// Propagate argument/results attributes.
87 static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
88  FunctionOpInterface funcOp,
89  LLVM::LLVMFuncOp wrapperFuncOp) {
90  auto argAttrs = funcOp.getAllArgAttrs();
91  if (!resultStructType) {
92  if (auto resAttrs = funcOp.getAllResultAttrs())
93  wrapperFuncOp.setAllResultAttrs(resAttrs);
94  if (argAttrs)
95  wrapperFuncOp.setAllArgAttrs(argAttrs);
96  } else {
97  SmallVector<Attribute> argAttributes;
98  // Only modify the argument and result attributes when the result is now
99  // an argument.
100  if (argAttrs) {
101  argAttributes.push_back(builder.getDictionaryAttr({}));
102  argAttributes.append(argAttrs.begin(), argAttrs.end());
103  wrapperFuncOp.setAllArgAttrs(argAttributes);
104  }
105  }
106  cast<FunctionOpInterface>(wrapperFuncOp.getOperation())
107  .setVisibility(funcOp.getVisibility());
108 }
109 
110 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
111 /// arguments instead of unpacked arguments. This function can be called from C
112 /// by passing a pointer to a C struct corresponding to a memref descriptor.
113 /// Similarly, returned memrefs are passed via pointers to a C struct that is
114 /// passed as additional argument.
115 /// Internally, the auxiliary function unpacks the descriptor into individual
116 /// components and forwards them to `newFuncOp` and forwards the results to
117 /// the extra arguments.
118 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
119  const LLVMTypeConverter &typeConverter,
120  FunctionOpInterface funcOp,
121  LLVM::LLVMFuncOp newFuncOp) {
122  auto type = cast<FunctionType>(funcOp.getFunctionType());
123  auto [wrapperFuncType, resultStructType] =
124  typeConverter.convertFunctionTypeCWrapper(type);
125 
126  SmallVector<NamedAttribute> attributes;
127  filterFuncAttributes(funcOp, attributes);
128 
129  auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
130  loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
131  wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false,
132  /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
133  propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp);
134 
135  OpBuilder::InsertionGuard guard(rewriter);
136  rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(rewriter));
137 
139  size_t argOffset = resultStructType ? 1 : 0;
140  for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
141  Value arg = wrapperFuncOp.getArgument(index + argOffset);
142  if (auto memrefType = dyn_cast<MemRefType>(argType)) {
143  Value loaded = rewriter.create<LLVM::LoadOp>(
144  loc, typeConverter.convertType(memrefType), arg);
145  MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
146  continue;
147  }
148  if (isa<UnrankedMemRefType>(argType)) {
149  Value loaded = rewriter.create<LLVM::LoadOp>(
150  loc, typeConverter.convertType(argType), arg);
151  UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
152  continue;
153  }
154 
155  args.push_back(arg);
156  }
157 
158  auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
159 
160  if (resultStructType) {
161  rewriter.create<LLVM::StoreOp>(loc, call.getResult(),
162  wrapperFuncOp.getArgument(0));
163  rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
164  } else {
165  rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
166  }
167 }
168 
169 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
170 /// arguments instead of unpacked arguments. Creates a body for the (external)
171 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
172 /// individual arguments into this descriptor and passes a pointer to it into
173 /// the auxiliary function. If the result of the function cannot be directly
174 /// returned, we write it to a special first argument that provides a pointer
175 /// to a corresponding struct. This auxiliary external function is now
176 /// compatible with functions defined in C using pointers to C structs
177 /// corresponding to a memref descriptor.
178 static void wrapExternalFunction(OpBuilder &builder, Location loc,
179  const LLVMTypeConverter &typeConverter,
180  FunctionOpInterface funcOp,
181  LLVM::LLVMFuncOp newFuncOp) {
182  OpBuilder::InsertionGuard guard(builder);
183 
184  auto [wrapperType, resultStructType] =
185  typeConverter.convertFunctionTypeCWrapper(
186  cast<FunctionType>(funcOp.getFunctionType()));
187  // This conversion can only fail if it could not convert one of the argument
188  // types. But since it has been applied to a non-wrapper function before, it
189  // should have failed earlier and not reach this point at all.
190  assert(wrapperType && "unexpected type conversion failure");
191 
193  filterFuncAttributes(funcOp, attributes);
194 
195  // Create the auxiliary function.
196  auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
197  loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
198  wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false,
199  /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
200  propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc);
201 
202  // The wrapper that we synthetize here should only be visible in this module.
203  newFuncOp.setLinkage(LLVM::Linkage::Private);
204  builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder));
205 
206  // Get a ValueRange containing arguments.
207  FunctionType type = cast<FunctionType>(funcOp.getFunctionType());
209  args.reserve(type.getNumInputs());
210  ValueRange wrapperArgsRange(newFuncOp.getArguments());
211 
212  if (resultStructType) {
213  // Allocate the struct on the stack and pass the pointer.
214  Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
215  Value one = builder.create<LLVM::ConstantOp>(
216  loc, typeConverter.convertType(builder.getIndexType()),
217  builder.getIntegerAttr(builder.getIndexType(), 1));
218  Value result =
219  builder.create<LLVM::AllocaOp>(loc, resultType, resultStructType, one);
220  args.push_back(result);
221  }
222 
223  // Iterate over the inputs of the original function and pack values into
224  // memref descriptors if the original type is a memref.
225  for (Type input : type.getInputs()) {
226  Value arg;
227  int numToDrop = 1;
228  auto memRefType = dyn_cast<MemRefType>(input);
229  auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(input);
230  if (memRefType || unrankedMemRefType) {
231  numToDrop = memRefType
234  Value packed =
235  memRefType
236  ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
237  wrapperArgsRange.take_front(numToDrop))
239  builder, loc, typeConverter, unrankedMemRefType,
240  wrapperArgsRange.take_front(numToDrop));
241 
242  auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
243  Value one = builder.create<LLVM::ConstantOp>(
244  loc, typeConverter.convertType(builder.getIndexType()),
245  builder.getIntegerAttr(builder.getIndexType(), 1));
246  Value allocated = builder.create<LLVM::AllocaOp>(
247  loc, ptrTy, packed.getType(), one, /*alignment=*/0);
248  builder.create<LLVM::StoreOp>(loc, packed, allocated);
249  arg = allocated;
250  } else {
251  arg = wrapperArgsRange[0];
252  }
253 
254  args.push_back(arg);
255  wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
256  }
257  assert(wrapperArgsRange.empty() && "did not map some of the arguments");
258 
259  auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
260 
261  if (resultStructType) {
262  Value result =
263  builder.create<LLVM::LoadOp>(loc, resultStructType, args.front());
264  builder.create<LLVM::ReturnOp>(loc, result);
265  } else {
266  builder.create<LLVM::ReturnOp>(loc, call.getResults());
267  }
268 }
269 
270 /// Inserts `llvm.load` ops in the function body to restore the expected pointee
271 /// value from `llvm.byval`/`llvm.byref` function arguments that were converted
272 /// to LLVM pointer types.
274  ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter,
275  ArrayRef<std::optional<NamedAttribute>> byValRefNonPtrAttrs,
276  ArrayRef<BlockArgument> oldBlockArgs, LLVM::LLVMFuncOp funcOp) {
277  // Nothing to do for function declarations.
278  if (funcOp.isExternal())
279  return;
280 
281  ConversionPatternRewriter::InsertionGuard guard(rewriter);
282  rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
283 
284  for (const auto &[arg, oldArg, byValRefAttr] :
285  llvm::zip(funcOp.getArguments(), oldBlockArgs, byValRefNonPtrAttrs)) {
286  // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
287  if (!byValRefAttr)
288  continue;
289 
290  // Insert load to retrieve the actual argument passed by value/reference.
291  assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
292  "Expected LLVM pointer type for argument with "
293  "`llvm.byval`/`llvm.byref` attribute");
294  Type resTy = typeConverter.convertType(
295  cast<TypeAttr>(byValRefAttr->getValue()).getValue());
296 
297  auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
298  rewriter.replaceUsesOfBlockArgument(oldArg, valueArg);
299  }
300 }
301 
302 FailureOr<LLVM::LLVMFuncOp>
303 mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
304  ConversionPatternRewriter &rewriter,
305  const LLVMTypeConverter &converter) {
306  // Check the funcOp has `FunctionType`.
307  auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
308  if (!funcTy)
309  return rewriter.notifyMatchFailure(
310  funcOp, "Only support FunctionOpInterface with FunctionType");
311 
312  // Keep track of the entry block arguments. They will be needed later.
313  SmallVector<BlockArgument> oldBlockArgs =
314  llvm::to_vector(funcOp.getArguments());
315 
316  // Convert the original function arguments. They are converted using the
317  // LLVMTypeConverter provided to this legalization pattern.
318  auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
319  // Gather `llvm.byval` and `llvm.byref` arguments whose type convertion was
320  // overriden with an LLVM pointer type for later processing.
321  SmallVector<std::optional<NamedAttribute>> byValRefNonPtrAttrs;
322  TypeConverter::SignatureConversion result(funcOp.getNumArguments());
323  auto llvmType = dyn_cast_or_null<LLVM::LLVMFunctionType>(
324  converter.convertFunctionSignature(
325  funcOp, varargsAttr && varargsAttr.getValue(),
326  shouldUseBarePtrCallConv(funcOp, &converter), result,
327  byValRefNonPtrAttrs));
328  if (!llvmType)
329  return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
330 
331  // Check for unsupported variadic functions.
332  if (!shouldUseBarePtrCallConv(funcOp, &converter))
333  if (funcOp->getAttrOfType<UnitAttr>(
334  LLVM::LLVMDialect::getEmitCWrapperAttrName()))
335  if (llvmType.isVarArg())
336  return funcOp.emitError("C interface for variadic functions is not "
337  "supported yet.");
338 
339  // Create an LLVM function, use external linkage by default until MLIR
340  // functions have linkage.
341  LLVM::Linkage linkage = LLVM::Linkage::External;
342  if (funcOp->hasAttr(linkageAttrName)) {
343  auto attr =
344  dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
345  if (!attr) {
346  funcOp->emitError() << "Contains " << linkageAttrName
347  << " attribute not of type LLVM::LinkageAttr";
348  return rewriter.notifyMatchFailure(
349  funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
350  }
351  linkage = attr.getLinkage();
352  }
353 
354  // Check for invalid attributes.
355  StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
356  if (funcOp->hasAttr(readnoneAttrName)) {
357  auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
358  if (!attr) {
359  funcOp->emitError() << "Contains " << readnoneAttrName
360  << " attribute not of type UnitAttr";
361  return rewriter.notifyMatchFailure(
362  funcOp, "Contains readnone attribute not of type UnitAttr");
363  }
364  }
365 
367  filterFuncAttributes(funcOp, attributes);
368  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
369  funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
370  /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
371  attributes);
372  cast<FunctionOpInterface>(newFuncOp.getOperation())
373  .setVisibility(funcOp.getVisibility());
374 
375  // Create a memory effect attribute corresponding to readnone.
376  if (funcOp->hasAttr(readnoneAttrName)) {
377  auto memoryAttr = LLVM::MemoryEffectsAttr::get(
378  rewriter.getContext(),
379  {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
380  LLVM::ModRefInfo::NoModRef});
381  newFuncOp.setMemoryEffectsAttr(memoryAttr);
382  }
383 
384  // Propagate argument/result attributes to all converted arguments/result
385  // obtained after converting a given original argument/result.
386  if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
387  assert(!resAttrDicts.empty() && "expected array to be non-empty");
388  if (funcOp.getNumResults() == 1)
389  newFuncOp.setAllResultAttrs(resAttrDicts);
390  }
391  if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
392  SmallVector<Attribute> newArgAttrs(
393  cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
394  for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
395  // Some LLVM IR attribute have a type attached to them. During FuncOp ->
396  // LLVMFuncOp conversion these types may have changed. Account for that
397  // change by converting attributes' types as well.
398  SmallVector<NamedAttribute, 4> convertedAttrs;
399  auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
400  convertedAttrs.reserve(attrsDict.size());
401  for (const NamedAttribute &attr : attrsDict) {
402  const auto convert = [&](const NamedAttribute &attr) {
403  return TypeAttr::get(converter.convertType(
404  cast<TypeAttr>(attr.getValue()).getValue()));
405  };
406  if (attr.getName().getValue() ==
407  LLVM::LLVMDialect::getByValAttrName()) {
408  convertedAttrs.push_back(rewriter.getNamedAttr(
409  LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
410  } else if (attr.getName().getValue() ==
411  LLVM::LLVMDialect::getByRefAttrName()) {
412  convertedAttrs.push_back(rewriter.getNamedAttr(
413  LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
414  } else if (attr.getName().getValue() ==
415  LLVM::LLVMDialect::getStructRetAttrName()) {
416  convertedAttrs.push_back(rewriter.getNamedAttr(
417  LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
418  } else if (attr.getName().getValue() ==
419  LLVM::LLVMDialect::getInAllocaAttrName()) {
420  convertedAttrs.push_back(rewriter.getNamedAttr(
421  LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
422  } else {
423  convertedAttrs.push_back(attr);
424  }
425  }
426  auto mapping = result.getInputMapping(i);
427  assert(mapping && "unexpected deletion of function argument");
428  // Only attach the new argument attributes if there is a one-to-one
429  // mapping from old to new types. Otherwise, attributes might be
430  // attached to types that they do not support.
431  if (mapping->size == 1) {
432  newArgAttrs[mapping->inputNo] =
433  DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
434  continue;
435  }
436  // TODO: Implement custom handling for types that expand to multiple
437  // function arguments.
438  for (size_t j = 0; j < mapping->size; ++j)
439  newArgAttrs[mapping->inputNo + j] =
440  DictionaryAttr::get(rewriter.getContext(), {});
441  }
442  if (!newArgAttrs.empty())
443  newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs));
444  }
445 
446  rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
447  newFuncOp.end());
448  // Convert just the entry block. The remaining unstructured control flow is
449  // converted by ControlFlowToLLVM.
450  if (!newFuncOp.getBody().empty())
451  rewriter.applySignatureConversion(&newFuncOp.getBody().front(), result,
452  &converter);
453 
454  // Fix the type mismatch between the materialized `llvm.ptr` and the expected
455  // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
456  // function arguments.
457  restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs,
458  oldBlockArgs, newFuncOp);
459 
460  if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
461  if (funcOp->getAttrOfType<UnitAttr>(
462  LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
463  if (newFuncOp.isExternal())
464  wrapExternalFunction(rewriter, funcOp->getLoc(), converter, funcOp,
465  newFuncOp);
466  else
467  wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
468  newFuncOp);
469  }
470  }
471 
472  return newFuncOp;
473 }
474 
475 namespace {
476 
477 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
478 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
479 /// information.
480 struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
481  FuncOpConversion(const LLVMTypeConverter &converter)
482  : ConvertOpToLLVMPattern(converter) {}
483 
484  LogicalResult
485  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
486  ConversionPatternRewriter &rewriter) const override {
487  FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
488  cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
489  *getTypeConverter());
490  if (failed(newFuncOp))
491  return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
492 
493  rewriter.eraseOp(funcOp);
494  return success();
495  }
496 };
497 
498 struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
500 
501  LogicalResult
502  matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
503  ConversionPatternRewriter &rewriter) const override {
504  auto type = typeConverter->convertType(op.getResult().getType());
505  if (!type || !LLVM::isCompatibleType(type))
506  return rewriter.notifyMatchFailure(op, "failed to convert result type");
507 
508  auto newOp =
509  rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
510  for (const NamedAttribute &attr : op->getAttrs()) {
511  if (attr.getName().strref() == "value")
512  continue;
513  newOp->setAttr(attr.getName(), attr.getValue());
514  }
515  rewriter.replaceOp(op, newOp->getResults());
516  return success();
517  }
518 };
519 
520 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
521 // passes the pointer to the MemRef across function boundaries.
522 template <typename CallOpType>
523 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
525  using Super = CallOpInterfaceLowering<CallOpType>;
527 
528  LogicalResult matchAndRewriteImpl(CallOpType callOp,
529  typename CallOpType::Adaptor adaptor,
530  ConversionPatternRewriter &rewriter,
531  bool useBarePtrCallConv = false) const {
532  // Pack the result types into a struct.
533  Type packedResult = nullptr;
534  unsigned numResults = callOp.getNumResults();
535  auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
536 
537  if (numResults != 0) {
538  if (!(packedResult = this->getTypeConverter()->packFunctionResults(
539  resultTypes, useBarePtrCallConv)))
540  return failure();
541  }
542 
543  if (useBarePtrCallConv) {
544  for (auto it : callOp->getOperands()) {
545  Type operandType = it.getType();
546  if (isa<UnrankedMemRefType>(operandType)) {
547  // Unranked memref is not supported in the bare pointer calling
548  // convention.
549  return failure();
550  }
551  }
552  }
553  auto promoted = this->getTypeConverter()->promoteOperands(
554  callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
555  adaptor.getOperands(), rewriter, useBarePtrCallConv);
556  auto newOp = rewriter.create<LLVM::CallOp>(
557  callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
558  promoted, callOp->getAttrs());
559 
560  newOp.getProperties().operandSegmentSizes = {
561  static_cast<int32_t>(promoted.size()), 0};
562  newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
563 
564  SmallVector<Value, 4> results;
565  if (numResults < 2) {
566  // If < 2 results, packing did not do anything and we can just return.
567  results.append(newOp.result_begin(), newOp.result_end());
568  } else {
569  // Otherwise, it had been converted to an operation producing a structure.
570  // Extract individual results from the structure and return them as list.
571  results.reserve(numResults);
572  for (unsigned i = 0; i < numResults; ++i) {
573  results.push_back(rewriter.create<LLVM::ExtractValueOp>(
574  callOp.getLoc(), newOp->getResult(0), i));
575  }
576  }
577 
578  if (useBarePtrCallConv) {
579  // For the bare-ptr calling convention, promote memref results to
580  // descriptors.
581  assert(results.size() == resultTypes.size() &&
582  "The number of arguments and types doesn't match");
583  this->getTypeConverter()->promoteBarePtrsToDescriptors(
584  rewriter, callOp.getLoc(), resultTypes, results);
585  } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
586  resultTypes, results,
587  /*toDynamic=*/false))) {
588  return failure();
589  }
590 
591  rewriter.replaceOp(callOp, results);
592  return success();
593  }
594 };
595 
596 class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
597 public:
598  CallOpLowering(const LLVMTypeConverter &typeConverter,
599  // Can be nullptr.
600  const SymbolTable *symbolTable, PatternBenefit benefit = 1)
601  : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
602  symbolTable(symbolTable) {}
603 
604  LogicalResult
605  matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
606  ConversionPatternRewriter &rewriter) const override {
607  bool useBarePtrCallConv = false;
608  if (getTypeConverter()->getOptions().useBarePtrCallConv) {
609  useBarePtrCallConv = true;
610  } else if (symbolTable != nullptr) {
611  // Fast lookup.
612  Operation *callee =
613  symbolTable->lookup(callOp.getCalleeAttr().getValue());
614  useBarePtrCallConv =
615  callee != nullptr && callee->hasAttr(barePtrAttrName);
616  } else {
617  // Warning: This is a linear lookup.
618  Operation *callee =
619  SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
620  useBarePtrCallConv =
621  callee != nullptr && callee->hasAttr(barePtrAttrName);
622  }
623  return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
624  }
625 
626 private:
627  const SymbolTable *symbolTable = nullptr;
628 };
629 
630 struct CallIndirectOpLowering
631  : public CallOpInterfaceLowering<func::CallIndirectOp> {
632  using Super::Super;
633 
634  LogicalResult
635  matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
636  ConversionPatternRewriter &rewriter) const override {
637  return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
638  }
639 };
640 
641 struct UnrealizedConversionCastOpLowering
642  : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
644  UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
645 
646  LogicalResult
647  matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
648  ConversionPatternRewriter &rewriter) const override {
649  SmallVector<Type> convertedTypes;
650  if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
651  convertedTypes)) &&
652  convertedTypes == adaptor.getInputs().getTypes()) {
653  rewriter.replaceOp(op, adaptor.getInputs());
654  return success();
655  }
656 
657  convertedTypes.clear();
658  if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
659  convertedTypes)) &&
660  convertedTypes == op.getOutputs().getType()) {
661  rewriter.replaceOp(op, adaptor.getInputs());
662  return success();
663  }
664  return failure();
665  }
666 };
667 
668 // Special lowering pattern for `ReturnOps`. Unlike all other operations,
669 // `ReturnOp` interacts with the function signature and must have as many
670 // operands as the function has return values. Because in LLVM IR, functions
671 // can only return 0 or 1 value, we pack multiple values into a structure type.
672 // Emit `PoisonOp` followed by `InsertValueOp`s to create such structure if
673 // necessary before returning it
674 struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
676 
677  LogicalResult
678  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
679  ConversionPatternRewriter &rewriter) const override {
680  Location loc = op.getLoc();
681  unsigned numArguments = op.getNumOperands();
682  SmallVector<Value, 4> updatedOperands;
683 
684  auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
685  bool useBarePtrCallConv =
686  shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
687  if (useBarePtrCallConv) {
688  // For the bare-ptr calling convention, extract the aligned pointer to
689  // be returned from the memref descriptor.
690  for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
691  Type oldTy = std::get<0>(it).getType();
692  Value newOperand = std::get<1>(it);
693  if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
694  cast<BaseMemRefType>(oldTy))) {
695  MemRefDescriptor memrefDesc(newOperand);
696  newOperand = memrefDesc.allocatedPtr(rewriter, loc);
697  } else if (isa<UnrankedMemRefType>(oldTy)) {
698  // Unranked memref is not supported in the bare pointer calling
699  // convention.
700  return failure();
701  }
702  updatedOperands.push_back(newOperand);
703  }
704  } else {
705  updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
706  (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
707  updatedOperands,
708  /*toDynamic=*/true);
709  }
710 
711  // If ReturnOp has 0 or 1 operand, create it and return immediately.
712  if (numArguments <= 1) {
713  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
714  op, TypeRange(), updatedOperands, op->getAttrs());
715  return success();
716  }
717 
718  // Otherwise, we need to pack the arguments into an LLVM struct type before
719  // returning.
720  auto packedType = getTypeConverter()->packFunctionResults(
721  op.getOperandTypes(), useBarePtrCallConv);
722  if (!packedType) {
723  return rewriter.notifyMatchFailure(op, "could not convert result types");
724  }
725 
726  Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType);
727  for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
728  packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
729  }
730  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
731  op->getAttrs());
732  return success();
733  }
734 };
735 } // namespace
736 
738  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
739  patterns.add<FuncOpConversion>(converter);
740 }
741 
743  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
744  const SymbolTable *symbolTable) {
746  patterns.add<CallIndirectOpLowering>(converter);
747  patterns.add<CallOpLowering>(converter, symbolTable);
748  patterns.add<ConstantOpLowering>(converter);
749  patterns.add<ReturnOpLowering>(converter);
750 }
751 
752 namespace {
753 /// A pass converting Func operations into the LLVM IR dialect.
754 struct ConvertFuncToLLVMPass
755  : public impl::ConvertFuncToLLVMPassBase<ConvertFuncToLLVMPass> {
756  using Base::Base;
757 
758  /// Run the dialect converter on the module.
759  void runOnOperation() override {
760  ModuleOp m = getOperation();
761  StringRef dataLayout;
762  auto dataLayoutAttr = dyn_cast_or_null<StringAttr>(
763  m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));
764  if (dataLayoutAttr)
765  dataLayout = dataLayoutAttr.getValue();
766 
767  if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
768  dataLayout, [this](const Twine &message) {
769  getOperation().emitError() << message.str();
770  }))) {
771  signalPassFailure();
772  return;
773  }
774 
775  const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
776 
778  dataLayoutAnalysis.getAtOrAbove(m));
779  options.useBarePtrCallConv = useBarePtrCallConv;
780  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
781  options.overrideIndexBitwidth(indexBitwidth);
782  options.dataLayout = llvm::DataLayout(dataLayout);
783 
784  LLVMTypeConverter typeConverter(&getContext(), options,
785  &dataLayoutAnalysis);
786 
787  std::optional<SymbolTable> optSymbolTable = std::nullopt;
788  const SymbolTable *symbolTable = nullptr;
789  if (!options.useBarePtrCallConv) {
790  optSymbolTable.emplace(m);
791  symbolTable = &optSymbolTable.value();
792  }
793 
795  populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable);
796 
798  if (failed(applyPartialConversion(m, target, std::move(patterns))))
799  signalPassFailure();
800  }
801 };
802 
803 struct SetLLVMModuleDataLayoutPass
804  : public impl::SetLLVMModuleDataLayoutPassBase<
805  SetLLVMModuleDataLayoutPass> {
806  using Base::Base;
807 
808  /// Run the dialect converter on the module.
809  void runOnOperation() override {
810  if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
811  this->dataLayout, [this](const Twine &message) {
812  getOperation().emitError() << message.str();
813  }))) {
814  signalPassFailure();
815  return;
816  }
817  ModuleOp m = getOperation();
818  m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
819  StringAttr::get(m.getContext(), this->dataLayout));
820  }
821 };
822 } // namespace
823 
824 //===----------------------------------------------------------------------===//
825 // ConvertToLLVMPatternInterface implementation
826 //===----------------------------------------------------------------------===//
827 
828 namespace {
829 /// Implement the interface to convert Func to LLVM.
830 struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
832  /// Hook for derived dialect interface to provide conversion patterns
833  /// and mark dialect legal for the conversion target.
834  void populateConvertToLLVMConversionPatterns(
835  ConversionTarget &target, LLVMTypeConverter &typeConverter,
836  RewritePatternSet &patterns) const final {
838  }
839 };
840 } // namespace
841 
843  registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
844  dialect->addInterfaces<FuncToLLVMDialectInterface>();
845  });
846 }
static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType, FunctionOpInterface funcOp, LLVM::LLVMFuncOp wrapperFuncOp)
Propagate argument/results attributes.
Definition: FuncToLLVM.cpp:87
static constexpr StringRef barePtrAttrName
Definition: FuncToLLVM.cpp:64
static constexpr StringRef varargsAttrName
Definition: FuncToLLVM.cpp:62
static constexpr StringRef linkageAttrName
Definition: FuncToLLVM.cpp:63
static void filterFuncAttributes(FunctionOpInterface func, SmallVectorImpl< NamedAttribute > &result)
Only retain those attributes that are not constructed by LLVMFuncOp::build.
Definition: FuncToLLVM.cpp:75
static bool shouldUseBarePtrCallConv(Operation *op, const LLVMTypeConverter *typeConverter)
Return true if the op should use bare pointer calling convention.
Definition: FuncToLLVM.cpp:67
static void restoreByValRefArgumentType(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, ArrayRef< std::optional< NamedAttribute >> byValRefNonPtrAttrs, ArrayRef< BlockArgument > oldBlockArgs, LLVM::LLVMFuncOp funcOp)
Inserts llvm.load ops in the function body to restore the expected pointee value from llvm....
Definition: FuncToLLVM.cpp:273
static void wrapExternalFunction(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, FunctionOpInterface funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
Definition: FuncToLLVM.cpp:178
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, const LLVMTypeConverter &typeConverter, FunctionOpInterface funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
Definition: FuncToLLVM.cpp:118
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
IndexType getIndexType()
Definition: Builders.cpp:51
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:100
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:90
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:155
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
const LowerToLLVMOptions & getOptions() const
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static unsigned getNumUnpackedValues(MemRefType type)
Returns the number of non-aggregate values that would be produced by unpack.
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides all of the information necessary to convert a type signature.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
static unsigned getNumUnpackedValues()
Returns the number of non-aggregate values that would be produced by unpack.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:796
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void populateFuncToLLVMFuncOpConversionPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the default pattern to convert a FuncOp to the LLVM dialect.
Definition: FuncToLLVM.cpp:737
void registerConvertFuncToLLVMInterface(DialectRegistry &registry)
Definition: FuncToLLVM.cpp:842
const FrozenRewritePatternSet & patterns
FailureOr< LLVM::LLVMFuncOp > convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter)
Convert input FunctionOpInterface operation to LLVMFuncOp by using the provided LLVMTypeConverter.
Definition: FuncToLLVM.cpp:303
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateFuncToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:742
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.