MLIR  18.0.0git
FuncToLLVM.cpp
Go to the documentation of this file.
1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR Func and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/Builders.h"
33 #include "mlir/IR/BuiltinOps.h"
34 #include "mlir/IR/IRMapping.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/IR/TypeUtilities.h"
40 #include "mlir/Transforms/Passes.h"
41 #include "llvm/ADT/SmallVector.h"
42 #include "llvm/ADT/TypeSwitch.h"
43 #include "llvm/IR/DerivedTypes.h"
44 #include "llvm/IR/IRBuilder.h"
45 #include "llvm/IR/Type.h"
46 #include "llvm/Support/Casting.h"
47 #include "llvm/Support/CommandLine.h"
48 #include "llvm/Support/FormatVariadic.h"
49 #include <algorithm>
50 #include <functional>
51 
52 namespace mlir {
53 #define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
54 #define GEN_PASS_DEF_SETLLVMMODULEDATALAYOUTPASS
55 #include "mlir/Conversion/Passes.h.inc"
56 } // namespace mlir
57 
58 using namespace mlir;
59 
60 #define PASS_NAME "convert-func-to-llvm"
61 
62 static constexpr StringRef varargsAttrName = "func.varargs";
63 static constexpr StringRef linkageAttrName = "llvm.linkage";
64 static constexpr StringRef barePtrAttrName = "llvm.bareptr";
65 
66 /// Return `true` if the `op` should use bare pointer calling convention.
68  const LLVMTypeConverter *typeConverter) {
69  return (op && op->hasAttr(barePtrAttrName)) ||
70  typeConverter->getOptions().useBarePtrCallConv;
71 }
72 
73 /// Only retain those attributes that are not constructed by
74 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
75 /// attributes.
76 static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs,
78  for (const NamedAttribute &attr : func->getAttrs()) {
79  if (attr.getName() == SymbolTable::getSymbolAttrName() ||
80  attr.getName() == func.getFunctionTypeAttrName() ||
81  attr.getName() == linkageAttrName ||
82  attr.getName() == varargsAttrName ||
83  attr.getName() == LLVM::LLVMDialect::getReadnoneAttrName() ||
84  (filterArgAndResAttrs &&
85  (attr.getName() == func.getArgAttrsAttrName() ||
86  attr.getName() == func.getResAttrsAttrName())))
87  continue;
88  result.push_back(attr);
89  }
90 }
91 
92 /// Adds a an empty set of argument attributes for the newly added argument in
93 /// front of the existing ones.
94 static void prependEmptyArgAttr(OpBuilder &builder,
95  SmallVectorImpl<NamedAttribute> &newFuncAttrs,
96  func::FuncOp func) {
97  auto argAttrs = func.getArgAttrs();
98  // Nothing to do when there were no arg attrs beforehand.
99  if (!argAttrs)
100  return;
101 
102  size_t numArguments = func.getNumArguments();
103  SmallVector<Attribute> newArgAttrs;
104  newArgAttrs.reserve(numArguments + 1);
105  // Insert empty dictionary for the new argument.
106  newArgAttrs.push_back(builder.getDictionaryAttr({}));
107 
108  llvm::append_range(newArgAttrs, *argAttrs);
109  auto newNamedAttr = builder.getNamedAttr(func.getArgAttrsAttrName(),
110  builder.getArrayAttr(newArgAttrs));
111  newFuncAttrs.push_back(newNamedAttr);
112 }
113 
114 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
115 /// arguments instead of unpacked arguments. This function can be called from C
116 /// by passing a pointer to a C struct corresponding to a memref descriptor.
117 /// Similarly, returned memrefs are passed via pointers to a C struct that is
118 /// passed as additional argument.
119 /// Internally, the auxiliary function unpacks the descriptor into individual
120 /// components and forwards them to `newFuncOp` and forwards the results to
121 /// the extra arguments.
122 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
123  const LLVMTypeConverter &typeConverter,
124  func::FuncOp funcOp,
125  LLVM::LLVMFuncOp newFuncOp) {
126  auto type = funcOp.getFunctionType();
127  auto [wrapperFuncType, resultStructType] =
128  typeConverter.convertFunctionTypeCWrapper(type);
129 
131  // Only modify the argument and result attributes when the result is now an
132  // argument.
133  if (resultStructType)
134  prependEmptyArgAttr(rewriter, attributes, funcOp);
136  funcOp, /*filterArgAndResAttrs=*/static_cast<bool>(resultStructType),
137  attributes);
138  auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
139  loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
140  wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false,
141  /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
142 
143  OpBuilder::InsertionGuard guard(rewriter);
144  rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
145 
147  size_t argOffset = resultStructType ? 1 : 0;
148  for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
149  Value arg = wrapperFuncOp.getArgument(index + argOffset);
150  if (auto memrefType = dyn_cast<MemRefType>(argType)) {
151  Value loaded = rewriter.create<LLVM::LoadOp>(
152  loc, typeConverter.convertType(memrefType), arg);
153  MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
154  continue;
155  }
156  if (isa<UnrankedMemRefType>(argType)) {
157  Value loaded = rewriter.create<LLVM::LoadOp>(
158  loc, typeConverter.convertType(argType), arg);
159  UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
160  continue;
161  }
162 
163  args.push_back(arg);
164  }
165 
166  auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
167 
168  if (resultStructType) {
169  rewriter.create<LLVM::StoreOp>(loc, call.getResult(),
170  wrapperFuncOp.getArgument(0));
171  rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
172  } else {
173  rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
174  }
175 }
176 
177 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
178 /// arguments instead of unpacked arguments. Creates a body for the (external)
179 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
180 /// individual arguments into this descriptor and passes a pointer to it into
181 /// the auxiliary function. If the result of the function cannot be directly
182 /// returned, we write it to a special first argument that provides a pointer
183 /// to a corresponding struct. This auxiliary external function is now
184 /// compatible with functions defined in C using pointers to C structs
185 /// corresponding to a memref descriptor.
186 static void wrapExternalFunction(OpBuilder &builder, Location loc,
187  const LLVMTypeConverter &typeConverter,
188  func::FuncOp funcOp,
189  LLVM::LLVMFuncOp newFuncOp) {
190  OpBuilder::InsertionGuard guard(builder);
191 
192  auto [wrapperType, resultStructType] =
193  typeConverter.convertFunctionTypeCWrapper(funcOp.getFunctionType());
194  // This conversion can only fail if it could not convert one of the argument
195  // types. But since it has been applied to a non-wrapper function before, it
196  // should have failed earlier and not reach this point at all.
197  assert(wrapperType && "unexpected type conversion failure");
198 
200  // Only modify the argument and result attributes when the result is now an
201  // argument.
202  if (resultStructType)
203  prependEmptyArgAttr(builder, attributes, funcOp);
205  funcOp, /*filterArgAndResAttrs=*/static_cast<bool>(resultStructType),
206  attributes);
207 
208  // Create the auxiliary function.
209  auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
210  loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
211  wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false,
212  /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
213 
214  // The wrapper that we synthetize here should only be visible in this module.
215  newFuncOp.setLinkage(LLVM::Linkage::Private);
216  builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
217 
218  // Get a ValueRange containing arguments.
219  FunctionType type = funcOp.getFunctionType();
221  args.reserve(type.getNumInputs());
222  ValueRange wrapperArgsRange(newFuncOp.getArguments());
223 
224  if (resultStructType) {
225  // Allocate the struct on the stack and pass the pointer.
226  Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
227  Value one = builder.create<LLVM::ConstantOp>(
228  loc, typeConverter.convertType(builder.getIndexType()),
229  builder.getIntegerAttr(builder.getIndexType(), 1));
230  Value result =
231  builder.create<LLVM::AllocaOp>(loc, resultType, resultStructType, one);
232  args.push_back(result);
233  }
234 
235  // Iterate over the inputs of the original function and pack values into
236  // memref descriptors if the original type is a memref.
237  for (Type input : type.getInputs()) {
238  Value arg;
239  int numToDrop = 1;
240  auto memRefType = dyn_cast<MemRefType>(input);
241  auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(input);
242  if (memRefType || unrankedMemRefType) {
243  numToDrop = memRefType
246  Value packed =
247  memRefType
248  ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
249  wrapperArgsRange.take_front(numToDrop))
251  builder, loc, typeConverter, unrankedMemRefType,
252  wrapperArgsRange.take_front(numToDrop));
253 
254  auto ptrTy = typeConverter.getPointerType(packed.getType());
255  Value one = builder.create<LLVM::ConstantOp>(
256  loc, typeConverter.convertType(builder.getIndexType()),
257  builder.getIntegerAttr(builder.getIndexType(), 1));
258  Value allocated = builder.create<LLVM::AllocaOp>(
259  loc, ptrTy, packed.getType(), one, /*alignment=*/0);
260  builder.create<LLVM::StoreOp>(loc, packed, allocated);
261  arg = allocated;
262  } else {
263  arg = wrapperArgsRange[0];
264  }
265 
266  args.push_back(arg);
267  wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
268  }
269  assert(wrapperArgsRange.empty() && "did not map some of the arguments");
270 
271  auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
272 
273  if (resultStructType) {
274  Value result =
275  builder.create<LLVM::LoadOp>(loc, resultStructType, args.front());
276  builder.create<LLVM::ReturnOp>(loc, result);
277  } else {
278  builder.create<LLVM::ReturnOp>(loc, call.getResults());
279  }
280 }
281 
282 /// Modifies the body of the function to construct the `MemRefDescriptor` from
283 /// the bare pointer calling convention lowering of `memref` types.
285  ConversionPatternRewriter &rewriter, Location loc,
286  const LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp,
287  TypeRange oldArgTypes) {
288  if (funcOp.getBody().empty())
289  return;
290 
291  // Promote bare pointers from memref arguments to memref descriptors at the
292  // beginning of the function so that all the memrefs in the function have a
293  // uniform representation.
294  Block *entryBlock = &funcOp.getBody().front();
295  auto blockArgs = entryBlock->getArguments();
296  assert(blockArgs.size() == oldArgTypes.size() &&
297  "The number of arguments and types doesn't match");
298 
299  OpBuilder::InsertionGuard guard(rewriter);
300  rewriter.setInsertionPointToStart(entryBlock);
301  for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
302  BlockArgument arg = std::get<0>(it);
303  Type argTy = std::get<1>(it);
304 
305  // Unranked memrefs are not supported in the bare pointer calling
306  // convention. We should have bailed out before in the presence of
307  // unranked memrefs.
308  assert(!isa<UnrankedMemRefType>(argTy) &&
309  "Unranked memref is not supported");
310  auto memrefTy = dyn_cast<MemRefType>(argTy);
311  if (!memrefTy)
312  continue;
313 
314  // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
315  // or unranked memref descriptor and replace placeholder with the last
316  // instruction of the memref descriptor.
317  // TODO: The placeholder is needed to avoid replacing barePtr uses in the
318  // MemRef descriptor instructions. We may want to have a utility in the
319  // rewriter to properly handle this use case.
320  Location loc = funcOp.getLoc();
321  auto placeholder = rewriter.create<LLVM::UndefOp>(
322  loc, typeConverter.convertType(memrefTy));
323  rewriter.replaceUsesOfBlockArgument(arg, placeholder);
324 
325  Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter,
326  memrefTy, arg);
327  rewriter.replaceOp(placeholder, {desc});
328  }
329 }
330 
331 namespace {
332 
333 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
334 protected:
336 
337  // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
338  // to this legalization pattern.
340  convertFuncOpToLLVMFuncOp(func::FuncOp funcOp,
341  ConversionPatternRewriter &rewriter) const {
342  // Convert the original function arguments. They are converted using the
343  // LLVMTypeConverter provided to this legalization pattern.
344  auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
345  TypeConverter::SignatureConversion result(funcOp.getNumArguments());
346  auto llvmType = getTypeConverter()->convertFunctionSignature(
347  funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(),
348  shouldUseBarePtrCallConv(funcOp, getTypeConverter()), result);
349  if (!llvmType)
350  return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
351 
352  // Propagate argument/result attributes to all converted arguments/result
353  // obtained after converting a given original argument/result.
355  filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/true, attributes);
356  if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
357  assert(!resAttrDicts.empty() && "expected array to be non-empty");
358  auto newResAttrDicts =
359  (funcOp.getNumResults() == 1)
360  ? resAttrDicts
361  : rewriter.getArrayAttr(rewriter.getDictionaryAttr({}));
362  attributes.push_back(
363  rewriter.getNamedAttr(funcOp.getResAttrsAttrName(), newResAttrDicts));
364  }
365  if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
366  SmallVector<Attribute, 4> newArgAttrs(
367  cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
368  for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
369  // Some LLVM IR attribute have a type attached to them. During FuncOp ->
370  // LLVMFuncOp conversion these types may have changed. Account for that
371  // change by converting attributes' types as well.
372  SmallVector<NamedAttribute, 4> convertedAttrs;
373  auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
374  convertedAttrs.reserve(attrsDict.size());
375  for (const NamedAttribute &attr : attrsDict) {
376  const auto convert = [&](const NamedAttribute &attr) {
377  return TypeAttr::get(getTypeConverter()->convertType(
378  cast<TypeAttr>(attr.getValue()).getValue()));
379  };
380  if (attr.getName().getValue() ==
381  LLVM::LLVMDialect::getByValAttrName()) {
382  convertedAttrs.push_back(rewriter.getNamedAttr(
383  LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
384  } else if (attr.getName().getValue() ==
385  LLVM::LLVMDialect::getByRefAttrName()) {
386  convertedAttrs.push_back(rewriter.getNamedAttr(
387  LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
388  } else if (attr.getName().getValue() ==
389  LLVM::LLVMDialect::getStructRetAttrName()) {
390  convertedAttrs.push_back(rewriter.getNamedAttr(
391  LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
392  } else if (attr.getName().getValue() ==
393  LLVM::LLVMDialect::getInAllocaAttrName()) {
394  convertedAttrs.push_back(rewriter.getNamedAttr(
395  LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
396  } else {
397  convertedAttrs.push_back(attr);
398  }
399  }
400  auto mapping = result.getInputMapping(i);
401  assert(mapping && "unexpected deletion of function argument");
402  // Only attach the new argument attributes if there is a one-to-one
403  // mapping from old to new types. Otherwise, attributes might be
404  // attached to types that they do not support.
405  if (mapping->size == 1) {
406  newArgAttrs[mapping->inputNo] =
407  DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
408  continue;
409  }
410  // TODO: Implement custom handling for types that expand to multiple
411  // function arguments.
412  for (size_t j = 0; j < mapping->size; ++j)
413  newArgAttrs[mapping->inputNo + j] =
414  DictionaryAttr::get(rewriter.getContext(), {});
415  }
416  attributes.push_back(rewriter.getNamedAttr(
417  funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(newArgAttrs)));
418  }
419 
420  // Create an LLVM function, use external linkage by default until MLIR
421  // functions have linkage.
422  LLVM::Linkage linkage = LLVM::Linkage::External;
423  if (funcOp->hasAttr(linkageAttrName)) {
424  auto attr =
425  dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
426  if (!attr) {
427  funcOp->emitError() << "Contains " << linkageAttrName
428  << " attribute not of type LLVM::LinkageAttr";
429  return rewriter.notifyMatchFailure(
430  funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
431  }
432  linkage = attr.getLinkage();
433  }
434 
435  // Create a memory effect attribute corresponding to readnone.
436  StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
437  LLVM::MemoryEffectsAttr memoryAttr = {};
438  if (funcOp->hasAttr(readnoneAttrName)) {
439  auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
440  if (!attr) {
441  funcOp->emitError() << "Contains " << readnoneAttrName
442  << " attribute not of type UnitAttr";
443  return rewriter.notifyMatchFailure(
444  funcOp, "Contains readnone attribute not of type UnitAttr");
445  }
446  memoryAttr = LLVM::MemoryEffectsAttr::get(rewriter.getContext(),
447  {LLVM::ModRefInfo::NoModRef,
448  LLVM::ModRefInfo::NoModRef,
449  LLVM::ModRefInfo::NoModRef});
450  }
451  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
452  funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
453  /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
454  attributes);
455  // If the memory attribute was created, add it to the function.
456  if (memoryAttr)
457  newFuncOp.setMemoryAttr(memoryAttr);
458  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
459  newFuncOp.end());
460  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
461  &result))) {
462  return rewriter.notifyMatchFailure(funcOp,
463  "region types conversion failed");
464  }
465 
466  return newFuncOp;
467  }
468 };
469 
470 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
471 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
472 /// information.
473 struct FuncOpConversion : public FuncOpConversionBase {
474  FuncOpConversion(const LLVMTypeConverter &converter)
475  : FuncOpConversionBase(converter) {}
476 
478  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
479  ConversionPatternRewriter &rewriter) const override {
480  FailureOr<LLVM::LLVMFuncOp> newFuncOp =
481  convertFuncOpToLLVMFuncOp(funcOp, rewriter);
482  if (failed(newFuncOp))
483  return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
484 
485  if (!shouldUseBarePtrCallConv(funcOp, this->getTypeConverter())) {
486  if (funcOp->getAttrOfType<UnitAttr>(
487  LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
488  if (newFuncOp->isVarArg())
489  return funcOp->emitError("C interface for variadic functions is not "
490  "supported yet.");
491 
492  if (newFuncOp->isExternal())
493  wrapExternalFunction(rewriter, funcOp->getLoc(), *getTypeConverter(),
494  funcOp, *newFuncOp);
495  else
496  wrapForExternalCallers(rewriter, funcOp->getLoc(),
497  *getTypeConverter(), funcOp, *newFuncOp);
498  }
499  } else {
500  modifyFuncOpToUseBarePtrCallingConv(rewriter, funcOp->getLoc(),
501  *getTypeConverter(), *newFuncOp,
502  funcOp.getFunctionType().getInputs());
503  }
504 
505  rewriter.eraseOp(funcOp);
506  return success();
507  }
508 };
509 
510 struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
512 
514  matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
515  ConversionPatternRewriter &rewriter) const override {
516  auto type = typeConverter->convertType(op.getResult().getType());
517  if (!type || !LLVM::isCompatibleType(type))
518  return rewriter.notifyMatchFailure(op, "failed to convert result type");
519 
520  auto newOp =
521  rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
522  for (const NamedAttribute &attr : op->getAttrs()) {
523  if (attr.getName().strref() == "value")
524  continue;
525  newOp->setAttr(attr.getName(), attr.getValue());
526  }
527  rewriter.replaceOp(op, newOp->getResults());
528  return success();
529  }
530 };
531 
532 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
533 // passes the pointer to the MemRef across function boundaries.
534 template <typename CallOpType>
535 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
537  using Super = CallOpInterfaceLowering<CallOpType>;
539 
540  LogicalResult matchAndRewriteImpl(CallOpType callOp,
541  typename CallOpType::Adaptor adaptor,
542  ConversionPatternRewriter &rewriter,
543  bool useBarePtrCallConv = false) const {
544  // Pack the result types into a struct.
545  Type packedResult = nullptr;
546  unsigned numResults = callOp.getNumResults();
547  auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
548 
549  if (numResults != 0) {
550  if (!(packedResult = this->getTypeConverter()->packFunctionResults(
551  resultTypes, useBarePtrCallConv)))
552  return failure();
553  }
554 
555  if (useBarePtrCallConv) {
556  for (auto it : callOp->getOperands()) {
557  Type operandType = it.getType();
558  if (isa<UnrankedMemRefType>(operandType)) {
559  // Unranked memref is not supported in the bare pointer calling
560  // convention.
561  return failure();
562  }
563  }
564  }
565  auto promoted = this->getTypeConverter()->promoteOperands(
566  callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
567  adaptor.getOperands(), rewriter, useBarePtrCallConv);
568  auto newOp = rewriter.create<LLVM::CallOp>(
569  callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
570  promoted, callOp->getAttrs());
571 
572  SmallVector<Value, 4> results;
573  if (numResults < 2) {
574  // If < 2 results, packing did not do anything and we can just return.
575  results.append(newOp.result_begin(), newOp.result_end());
576  } else {
577  // Otherwise, it had been converted to an operation producing a structure.
578  // Extract individual results from the structure and return them as list.
579  results.reserve(numResults);
580  for (unsigned i = 0; i < numResults; ++i) {
581  results.push_back(rewriter.create<LLVM::ExtractValueOp>(
582  callOp.getLoc(), newOp->getResult(0), i));
583  }
584  }
585 
586  if (useBarePtrCallConv) {
587  // For the bare-ptr calling convention, promote memref results to
588  // descriptors.
589  assert(results.size() == resultTypes.size() &&
590  "The number of arguments and types doesn't match");
591  this->getTypeConverter()->promoteBarePtrsToDescriptors(
592  rewriter, callOp.getLoc(), resultTypes, results);
593  } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
594  resultTypes, results,
595  /*toDynamic=*/false))) {
596  return failure();
597  }
598 
599  rewriter.replaceOp(callOp, results);
600  return success();
601  }
602 };
603 
604 struct CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
605  using Super::Super;
606 
608  matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
609  ConversionPatternRewriter &rewriter) const override {
610  bool useBarePtrCallConv = false;
612  callOp, callOp.getCalleeAttr())) {
613  useBarePtrCallConv = shouldUseBarePtrCallConv(callee, getTypeConverter());
614  }
615  return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
616  }
617 };
618 
619 struct CallIndirectOpLowering
620  : public CallOpInterfaceLowering<func::CallIndirectOp> {
621  using Super::Super;
622 
624  matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
625  ConversionPatternRewriter &rewriter) const override {
626  return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
627  }
628 };
629 
630 struct UnrealizedConversionCastOpLowering
631  : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
633  UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
634 
636  matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
637  ConversionPatternRewriter &rewriter) const override {
638  SmallVector<Type> convertedTypes;
639  if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
640  convertedTypes)) &&
641  convertedTypes == adaptor.getInputs().getTypes()) {
642  rewriter.replaceOp(op, adaptor.getInputs());
643  return success();
644  }
645 
646  convertedTypes.clear();
647  if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
648  convertedTypes)) &&
649  convertedTypes == op.getOutputs().getType()) {
650  rewriter.replaceOp(op, adaptor.getInputs());
651  return success();
652  }
653  return failure();
654  }
655 };
656 
657 // Special lowering pattern for `ReturnOps`. Unlike all other operations,
658 // `ReturnOp` interacts with the function signature and must have as many
659 // operands as the function has return values. Because in LLVM IR, functions
660 // can only return 0 or 1 value, we pack multiple values into a structure type.
661 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
662 // necessary before returning it
663 struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
665 
667  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
668  ConversionPatternRewriter &rewriter) const override {
669  Location loc = op.getLoc();
670  unsigned numArguments = op.getNumOperands();
671  SmallVector<Value, 4> updatedOperands;
672 
673  auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
674  bool useBarePtrCallConv =
675  shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
676  if (useBarePtrCallConv) {
677  // For the bare-ptr calling convention, extract the aligned pointer to
678  // be returned from the memref descriptor.
679  for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
680  Type oldTy = std::get<0>(it).getType();
681  Value newOperand = std::get<1>(it);
682  if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
683  cast<BaseMemRefType>(oldTy))) {
684  MemRefDescriptor memrefDesc(newOperand);
685  newOperand = memrefDesc.allocatedPtr(rewriter, loc);
686  } else if (isa<UnrankedMemRefType>(oldTy)) {
687  // Unranked memref is not supported in the bare pointer calling
688  // convention.
689  return failure();
690  }
691  updatedOperands.push_back(newOperand);
692  }
693  } else {
694  updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
695  (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
696  updatedOperands,
697  /*toDynamic=*/true);
698  }
699 
700  // If ReturnOp has 0 or 1 operand, create it and return immediately.
701  if (numArguments <= 1) {
702  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
703  op, TypeRange(), updatedOperands, op->getAttrs());
704  return success();
705  }
706 
707  // Otherwise, we need to pack the arguments into an LLVM struct type before
708  // returning.
709  auto packedType = getTypeConverter()->packFunctionResults(
710  op.getOperandTypes(), useBarePtrCallConv);
711  if (!packedType) {
712  return rewriter.notifyMatchFailure(op, "could not convert result types");
713  }
714 
715  Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
716  for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
717  packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
718  }
719  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
720  op->getAttrs());
721  return success();
722  }
723 };
724 } // namespace
725 
727  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
728  patterns.add<FuncOpConversion>(converter);
729 }
730 
732  RewritePatternSet &patterns) {
733  populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
734  // clang-format off
735  patterns.add<
736  CallIndirectOpLowering,
737  CallOpLowering,
738  ConstantOpLowering,
739  ReturnOpLowering>(converter);
740  // clang-format on
741 }
742 
743 namespace {
744 /// A pass converting Func operations into the LLVM IR dialect.
745 struct ConvertFuncToLLVMPass
746  : public impl::ConvertFuncToLLVMPassBase<ConvertFuncToLLVMPass> {
747  using Base::Base;
748 
749  /// Run the dialect converter on the module.
750  void runOnOperation() override {
751  ModuleOp m = getOperation();
752  StringRef dataLayout;
753  auto dataLayoutAttr = dyn_cast_or_null<StringAttr>(
754  m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));
755  if (dataLayoutAttr)
756  dataLayout = dataLayoutAttr.getValue();
757 
758  if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
759  dataLayout, [this](const Twine &message) {
760  getOperation().emitError() << message.str();
761  }))) {
762  signalPassFailure();
763  return;
764  }
765 
766  const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
767 
769  dataLayoutAnalysis.getAtOrAbove(m));
770  options.useBarePtrCallConv = useBarePtrCallConv;
771  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
772  options.overrideIndexBitwidth(indexBitwidth);
773  options.dataLayout = llvm::DataLayout(dataLayout);
774  options.useOpaquePointers = useOpaquePointers;
775 
776  LLVMTypeConverter typeConverter(&getContext(), options,
777  &dataLayoutAnalysis);
778 
779  RewritePatternSet patterns(&getContext());
780  populateFuncToLLVMConversionPatterns(typeConverter, patterns);
781 
782  // TODO: Remove these in favor of their dedicated conversion passes.
783  arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
784  cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
785 
787  if (failed(applyPartialConversion(m, target, std::move(patterns))))
788  signalPassFailure();
789  }
790 };
791 
792 struct SetLLVMModuleDataLayoutPass
793  : public impl::SetLLVMModuleDataLayoutPassBase<
794  SetLLVMModuleDataLayoutPass> {
795  using Base::Base;
796 
797  /// Run the dialect converter on the module.
798  void runOnOperation() override {
799  if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
800  this->dataLayout, [this](const Twine &message) {
801  getOperation().emitError() << message.str();
802  }))) {
803  signalPassFailure();
804  return;
805  }
806  ModuleOp m = getOperation();
807  m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
808  StringAttr::get(m.getContext(), this->dataLayout));
809  }
810 };
811 } // namespace
812 
813 //===----------------------------------------------------------------------===//
814 // ConvertToLLVMPatternInterface implementation
815 //===----------------------------------------------------------------------===//
816 
817 namespace {
818 /// Implement the interface to convert Func to LLVM.
819 struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
821  /// Hook for derived dialect interface to provide conversion patterns
822  /// and mark dialect legal for the conversion target.
823  void populateConvertToLLVMConversionPatterns(
824  ConversionTarget &target, LLVMTypeConverter &typeConverter,
825  RewritePatternSet &patterns) const final {
826  populateFuncToLLVMConversionPatterns(typeConverter, patterns);
827  }
828 };
829 } // namespace
830 
832  registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
833  dialect->addInterfaces<FuncToLLVMDialectInterface>();
834  });
835 }
static void modifyFuncOpToUseBarePtrCallingConv(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp, TypeRange oldArgTypes)
Modifies the body of the function to construct the MemRefDescriptor from the bare pointer calling con...
Definition: FuncToLLVM.cpp:284
static void prependEmptyArgAttr(OpBuilder &builder, SmallVectorImpl< NamedAttribute > &newFuncAttrs, func::FuncOp func)
Adds a an empty set of argument attributes for the newly added argument in front of the existing ones...
Definition: FuncToLLVM.cpp:94
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, const LLVMTypeConverter &typeConverter, func::FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
Definition: FuncToLLVM.cpp:122
static constexpr StringRef barePtrAttrName
Definition: FuncToLLVM.cpp:64
static constexpr StringRef varargsAttrName
Definition: FuncToLLVM.cpp:62
static constexpr StringRef linkageAttrName
Definition: FuncToLLVM.cpp:63
static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs, SmallVectorImpl< NamedAttribute > &result)
Only retain those attributes that are not constructed by LLVMFuncOp::build.
Definition: FuncToLLVM.cpp:76
static bool shouldUseBarePtrCallConv(Operation *op, const LLVMTypeConverter *typeConverter)
Return true if the op should use bare pointer calling convention.
Definition: FuncToLLVM.cpp:67
static void wrapExternalFunction(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, func::FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
Definition: FuncToLLVM.cpp:186
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class represents an argument of a Block.
Definition: Value.h:310
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgListType getArguments()
Definition: Block.h:80
Operation & front()
Definition: Block.h:146
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
IndexType getIndexType()
Definition: Builders.cpp:71
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:120
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:110
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:139
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
const LowerToLLVMOptions & getOptions() const
Definition: TypeConverter.h:93
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LLVM::LLVMPointerType getPointerType(Type elementType, unsigned addressSpace=0) const
Creates an LLVM pointer type with the given element type and address space.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
static unsigned getNumUnpackedValues(MemRefType type)
Returns the number of non-aggregate values that would be produced by unpack.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:198
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:538
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:486
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:560
operand_type_range getOperandTypes()
Definition: Operation.h:392
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:59
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides all of the information necessary to convert a type signature.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
static unsigned getNumUnpackedValues()
Returns the number of non-aggregate values that would be produced by unpack.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:913
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
void registerConvertFuncToLLVMInterface(DialectRegistry &registry)
Definition: FuncToLLVM.cpp:831
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:731
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
void populateFuncToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the default pattern to convert a FuncOp to the LLVM dialect.
Definition: FuncToLLVM.cpp:726
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.