MLIR  15.0.0git
FuncToLLVM.cpp
Go to the documentation of this file.
1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR Func and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "../PassDetail.h"
27 #include "mlir/IR/Attributes.h"
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinOps.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
36 #include "mlir/Transforms/Passes.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include "llvm/IR/DerivedTypes.h"
39 #include "llvm/IR/IRBuilder.h"
40 #include "llvm/IR/Type.h"
41 #include "llvm/Support/CommandLine.h"
42 #include "llvm/Support/FormatVariadic.h"
43 #include <algorithm>
44 #include <functional>
45 
46 using namespace mlir;
47 
48 #define PASS_NAME "convert-func-to-llvm"
49 
50 /// Only retain those attributes that are not constructed by
51 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
52 /// attributes.
54  bool filterArgAndResAttrs,
56  for (const auto &attr : attrs) {
57  if (attr.getName() == SymbolTable::getSymbolAttrName() ||
58  attr.getName() == FunctionOpInterface::getTypeAttrName() ||
59  attr.getName() == "func.varargs" ||
60  (filterArgAndResAttrs &&
61  (attr.getName() == FunctionOpInterface::getArgDictAttrName() ||
62  attr.getName() == FunctionOpInterface::getResultDictAttrName())))
63  continue;
64  result.push_back(attr);
65  }
66 }
67 
68 /// Helper function for wrapping all attributes into a single DictionaryAttr
69 static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
70  return DictionaryAttr::get(
71  b.getContext(),
72  b.getNamedAttr(LLVM::LLVMDialect::getStructAttrsAttrName(), attrs));
73 }
74 
75 /// Combines all result attributes into a single DictionaryAttr
76 /// and prepends to argument attrs.
77 /// This is intended to be used to format the attributes for a C wrapper
78 /// function when the result(s) is converted to the first function argument
79 /// (in the multiple return case, all returns get wrapped into a single
80 /// argument). The total number of argument attributes should be equal to
81 /// (number of function arguments) + 1.
82 static void
85  size_t numArguments) {
86  auto allAttrs = SmallVector<Attribute>(
87  numArguments + 1, DictionaryAttr::get(builder.getContext()));
88  NamedAttribute *argAttrs = nullptr;
89  for (auto *it = attributes.begin(); it != attributes.end();) {
90  if (it->getName() == FunctionOpInterface::getArgDictAttrName()) {
91  auto arrayAttrs = it->getValue().cast<ArrayAttr>();
92  assert(arrayAttrs.size() == numArguments &&
93  "Number of arg attrs and args should match");
94  std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1);
95  argAttrs = it;
96  } else if (it->getName() == FunctionOpInterface::getResultDictAttrName()) {
97  auto arrayAttrs = it->getValue().cast<ArrayAttr>();
98  assert(!arrayAttrs.empty() && "expected array to be non-empty");
99  allAttrs[0] = (arrayAttrs.size() == 1)
100  ? arrayAttrs[0]
101  : wrapAsStructAttrs(builder, arrayAttrs);
102  it = attributes.erase(it);
103  continue;
104  }
105  it++;
106  }
107 
108  auto newArgAttrs =
110  builder.getArrayAttr(allAttrs));
111  if (!argAttrs) {
112  attributes.emplace_back(newArgAttrs);
113  return;
114  }
115  *argAttrs = newArgAttrs;
116 }
117 
118 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
119 /// arguments instead of unpacked arguments. This function can be called from C
120 /// by passing a pointer to a C struct corresponding to a memref descriptor.
121 /// Similarly, returned memrefs are passed via pointers to a C struct that is
122 /// passed as additional argument.
123 /// Internally, the auxiliary function unpacks the descriptor into individual
124 /// components and forwards them to `newFuncOp` and forwards the results to
125 /// the extra arguments.
126 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
127  LLVMTypeConverter &typeConverter,
128  func::FuncOp funcOp,
129  LLVM::LLVMFuncOp newFuncOp) {
130  auto type = funcOp.getFunctionType();
132  filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
133  attributes);
134  Type wrapperFuncType;
135  bool resultIsNowArg;
136  std::tie(wrapperFuncType, resultIsNowArg) =
137  typeConverter.convertFunctionTypeCWrapper(type);
138  if (resultIsNowArg)
139  prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments());
140  auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
141  loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
142  wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false,
143  /*cconv*/ LLVM::CConv::C, attributes);
144 
145  OpBuilder::InsertionGuard guard(rewriter);
146  rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
147 
149  size_t argOffset = resultIsNowArg ? 1 : 0;
150  for (auto &en : llvm::enumerate(type.getInputs())) {
151  Value arg = wrapperFuncOp.getArgument(en.index() + argOffset);
152  if (auto memrefType = en.value().dyn_cast<MemRefType>()) {
153  Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
154  MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
155  continue;
156  }
157  if (en.value().isa<UnrankedMemRefType>()) {
158  Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
159  UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
160  continue;
161  }
162 
163  args.push_back(arg);
164  }
165 
166  auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
167 
168  if (resultIsNowArg) {
169  rewriter.create<LLVM::StoreOp>(loc, call.getResult(0),
170  wrapperFuncOp.getArgument(0));
171  rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
172  } else {
173  rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
174  }
175 }
176 
177 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
178 /// arguments instead of unpacked arguments. Creates a body for the (external)
179 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
180 /// individual arguments into this descriptor and passes a pointer to it into
181 /// the auxiliary function. If the result of the function cannot be directly
182 /// returned, we write it to a special first argument that provides a pointer
183 /// to a corresponding struct. This auxiliary external function is now
184 /// compatible with functions defined in C using pointers to C structs
185 /// corresponding to a memref descriptor.
186 static void wrapExternalFunction(OpBuilder &builder, Location loc,
187  LLVMTypeConverter &typeConverter,
188  func::FuncOp funcOp,
189  LLVM::LLVMFuncOp newFuncOp) {
190  OpBuilder::InsertionGuard guard(builder);
191 
192  Type wrapperType;
193  bool resultIsNowArg;
194  std::tie(wrapperType, resultIsNowArg) =
195  typeConverter.convertFunctionTypeCWrapper(funcOp.getFunctionType());
196  // This conversion can only fail if it could not convert one of the argument
197  // types. But since it has been applied to a non-wrapper function before, it
198  // should have failed earlier and not reach this point at all.
199  assert(wrapperType && "unexpected type conversion failure");
200 
202  filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
203  attributes);
204 
205  if (resultIsNowArg)
206  prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments());
207  // Create the auxiliary function.
208  auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
209  loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
210  wrapperType, LLVM::Linkage::External, /*dsoLocal*/ false,
211  /*cconv*/ LLVM::CConv::C, attributes);
212 
213  builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
214 
215  // Get a ValueRange containing arguments.
216  FunctionType type = funcOp.getFunctionType();
218  args.reserve(type.getNumInputs());
219  ValueRange wrapperArgsRange(newFuncOp.getArguments());
220 
221  if (resultIsNowArg) {
222  // Allocate the struct on the stack and pass the pointer.
223  Type resultType =
224  wrapperType.cast<LLVM::LLVMFunctionType>().getParamType(0);
225  Value one = builder.create<LLVM::ConstantOp>(
226  loc, typeConverter.convertType(builder.getIndexType()),
227  builder.getIntegerAttr(builder.getIndexType(), 1));
228  Value result = builder.create<LLVM::AllocaOp>(loc, resultType, one);
229  args.push_back(result);
230  }
231 
232  // Iterate over the inputs of the original function and pack values into
233  // memref descriptors if the original type is a memref.
234  for (auto &en : llvm::enumerate(type.getInputs())) {
235  Value arg;
236  int numToDrop = 1;
237  auto memRefType = en.value().dyn_cast<MemRefType>();
238  auto unrankedMemRefType = en.value().dyn_cast<UnrankedMemRefType>();
239  if (memRefType || unrankedMemRefType) {
240  numToDrop = memRefType
243  Value packed =
244  memRefType
245  ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
246  wrapperArgsRange.take_front(numToDrop))
248  builder, loc, typeConverter, unrankedMemRefType,
249  wrapperArgsRange.take_front(numToDrop));
250 
251  auto ptrTy = LLVM::LLVMPointerType::get(packed.getType());
252  Value one = builder.create<LLVM::ConstantOp>(
253  loc, typeConverter.convertType(builder.getIndexType()),
254  builder.getIntegerAttr(builder.getIndexType(), 1));
255  Value allocated =
256  builder.create<LLVM::AllocaOp>(loc, ptrTy, one, /*alignment=*/0);
257  builder.create<LLVM::StoreOp>(loc, packed, allocated);
258  arg = allocated;
259  } else {
260  arg = wrapperArgsRange[0];
261  }
262 
263  args.push_back(arg);
264  wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
265  }
266  assert(wrapperArgsRange.empty() && "did not map some of the arguments");
267 
268  auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
269 
270  if (resultIsNowArg) {
271  Value result = builder.create<LLVM::LoadOp>(loc, args.front());
272  builder.create<LLVM::ReturnOp>(loc, ValueRange{result});
273  } else {
274  builder.create<LLVM::ReturnOp>(loc, call.getResults());
275  }
276 }
277 
278 namespace {
279 
280 struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
281 protected:
283 
284  // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
285  // to this legalization pattern.
286  LLVM::LLVMFuncOp
287  convertFuncOpToLLVMFuncOp(func::FuncOp funcOp,
288  ConversionPatternRewriter &rewriter) const {
289  // Convert the original function arguments. They are converted using the
290  // LLVMTypeConverter provided to this legalization pattern.
291  auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs");
292  TypeConverter::SignatureConversion result(funcOp.getNumArguments());
293  auto llvmType = getTypeConverter()->convertFunctionSignature(
294  funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(),
295  result);
296  if (!llvmType)
297  return nullptr;
298 
299  // Propagate argument/result attributes to all converted arguments/result
300  // obtained after converting a given original argument/result.
302  filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true,
303  attributes);
304  if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
305  assert(!resAttrDicts.empty() && "expected array to be non-empty");
306  auto newResAttrDicts =
307  (funcOp.getNumResults() == 1)
308  ? resAttrDicts
309  : rewriter.getArrayAttr(
310  {wrapAsStructAttrs(rewriter, resAttrDicts)});
311  attributes.push_back(rewriter.getNamedAttr(
312  FunctionOpInterface::getResultDictAttrName(), newResAttrDicts));
313  }
314  if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
315  SmallVector<Attribute, 4> newArgAttrs(
316  llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
317  for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
318  auto mapping = result.getInputMapping(i);
319  assert(mapping && "unexpected deletion of function argument");
320  for (size_t j = 0; j < mapping->size; ++j)
321  newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
322  }
323  attributes.push_back(
325  rewriter.getArrayAttr(newArgAttrs)));
326  }
327  for (const auto &pair : llvm::enumerate(attributes)) {
328  if (pair.value().getName() == "llvm.linkage") {
329  attributes.erase(attributes.begin() + pair.index());
330  break;
331  }
332  }
333 
334  // Create an LLVM function, use external linkage by default until MLIR
335  // functions have linkage.
336  LLVM::Linkage linkage = LLVM::Linkage::External;
337  if (funcOp->hasAttr("llvm.linkage")) {
338  auto attr =
339  funcOp->getAttr("llvm.linkage").dyn_cast<mlir::LLVM::LinkageAttr>();
340  if (!attr) {
341  funcOp->emitError()
342  << "Contains llvm.linkage attribute not of type LLVM::LinkageAttr";
343  return nullptr;
344  }
345  linkage = attr.getLinkage();
346  }
347  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
348  funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
349  /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C, attributes);
350  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
351  newFuncOp.end());
352  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
353  &result)))
354  return nullptr;
355 
356  return newFuncOp;
357  }
358 };
359 
360 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
361 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
362 /// information.
363 struct FuncOpConversion : public FuncOpConversionBase {
364  FuncOpConversion(LLVMTypeConverter &converter)
365  : FuncOpConversionBase(converter) {}
366 
368  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
369  ConversionPatternRewriter &rewriter) const override {
370  auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
371  if (!newFuncOp)
372  return failure();
373 
374  if (funcOp->getAttrOfType<UnitAttr>(
375  LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
376  if (newFuncOp.isVarArg())
377  return funcOp->emitError("C interface for variadic functions is not "
378  "supported yet.");
379 
380  if (newFuncOp.isExternal())
381  wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(),
382  funcOp, newFuncOp);
383  else
384  wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(),
385  funcOp, newFuncOp);
386  }
387 
388  rewriter.eraseOp(funcOp);
389  return success();
390  }
391 };
392 
393 /// FuncOp legalization pattern that converts MemRef arguments to bare pointers
394 /// to the MemRef element type. This will impact the calling convention and ABI.
395 struct BarePtrFuncOpConversion : public FuncOpConversionBase {
396  using FuncOpConversionBase::FuncOpConversionBase;
397 
399  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
400  ConversionPatternRewriter &rewriter) const override {
401 
402  // TODO: bare ptr conversion could be handled by argument materialization
403  // and most of the code below would go away. But to do this, we would need a
404  // way to distinguish between FuncOp and other regions in the
405  // addArgumentMaterialization hook.
406 
407  // Store the type of memref-typed arguments before the conversion so that we
408  // can promote them to MemRef descriptor at the beginning of the function.
409  SmallVector<Type, 8> oldArgTypes =
410  llvm::to_vector<8>(funcOp.getFunctionType().getInputs());
411 
412  auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
413  if (!newFuncOp)
414  return failure();
415  if (newFuncOp.getBody().empty()) {
416  rewriter.eraseOp(funcOp);
417  return success();
418  }
419 
420  // Promote bare pointers from memref arguments to memref descriptors at the
421  // beginning of the function so that all the memrefs in the function have a
422  // uniform representation.
423  Block *entryBlock = &newFuncOp.getBody().front();
424  auto blockArgs = entryBlock->getArguments();
425  assert(blockArgs.size() == oldArgTypes.size() &&
426  "The number of arguments and types doesn't match");
427 
428  OpBuilder::InsertionGuard guard(rewriter);
429  rewriter.setInsertionPointToStart(entryBlock);
430  for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
431  BlockArgument arg = std::get<0>(it);
432  Type argTy = std::get<1>(it);
433 
434  // Unranked memrefs are not supported in the bare pointer calling
435  // convention. We should have bailed out before in the presence of
436  // unranked memrefs.
437  assert(!argTy.isa<UnrankedMemRefType>() &&
438  "Unranked memref is not supported");
439  auto memrefTy = argTy.dyn_cast<MemRefType>();
440  if (!memrefTy)
441  continue;
442 
443  // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
444  // or unranked memref descriptor and replace placeholder with the last
445  // instruction of the memref descriptor.
446  // TODO: The placeholder is needed to avoid replacing barePtr uses in the
447  // MemRef descriptor instructions. We may want to have a utility in the
448  // rewriter to properly handle this use case.
449  Location loc = funcOp.getLoc();
450  auto placeholder = rewriter.create<LLVM::UndefOp>(
451  loc, getTypeConverter()->convertType(memrefTy));
452  rewriter.replaceUsesOfBlockArgument(arg, placeholder);
453 
455  rewriter, loc, *getTypeConverter(), memrefTy, arg);
456  rewriter.replaceOp(placeholder, {desc});
457  }
458 
459  rewriter.eraseOp(funcOp);
460  return success();
461  }
462 };
463 
464 struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
466 
468  matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
469  ConversionPatternRewriter &rewriter) const override {
470  auto type = typeConverter->convertType(op.getResult().getType());
471  if (!type || !LLVM::isCompatibleType(type))
472  return rewriter.notifyMatchFailure(op, "failed to convert result type");
473 
474  auto newOp =
475  rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
476  for (const NamedAttribute &attr : op->getAttrs()) {
477  if (attr.getName().strref() == "value")
478  continue;
479  newOp->setAttr(attr.getName(), attr.getValue());
480  }
481  rewriter.replaceOp(op, newOp->getResults());
482  return success();
483  }
484 };
485 
486 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
487 // passes the pointer to the MemRef across function boundaries.
488 template <typename CallOpType>
489 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
491  using Super = CallOpInterfaceLowering<CallOpType>;
493 
495  matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor,
496  ConversionPatternRewriter &rewriter) const override {
497  // Pack the result types into a struct.
498  Type packedResult = nullptr;
499  unsigned numResults = callOp.getNumResults();
500  auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
501 
502  if (numResults != 0) {
503  if (!(packedResult =
504  this->getTypeConverter()->packFunctionResults(resultTypes)))
505  return failure();
506  }
507 
508  auto promoted = this->getTypeConverter()->promoteOperands(
509  callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
510  adaptor.getOperands(), rewriter);
511  auto newOp = rewriter.create<LLVM::CallOp>(
512  callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
513  promoted, callOp->getAttrs());
514 
515  SmallVector<Value, 4> results;
516  if (numResults < 2) {
517  // If < 2 results, packing did not do anything and we can just return.
518  results.append(newOp.result_begin(), newOp.result_end());
519  } else {
520  // Otherwise, it had been converted to an operation producing a structure.
521  // Extract individual results from the structure and return them as list.
522  results.reserve(numResults);
523  for (unsigned i = 0; i < numResults; ++i) {
524  auto type =
525  this->typeConverter->convertType(callOp.getResult(i).getType());
526  results.push_back(rewriter.create<LLVM::ExtractValueOp>(
527  callOp.getLoc(), type, newOp->getResult(0),
528  rewriter.getI64ArrayAttr(i)));
529  }
530  }
531 
532  if (this->getTypeConverter()->getOptions().useBarePtrCallConv) {
533  // For the bare-ptr calling convention, promote memref results to
534  // descriptors.
535  assert(results.size() == resultTypes.size() &&
536  "The number of arguments and types doesn't match");
537  this->getTypeConverter()->promoteBarePtrsToDescriptors(
538  rewriter, callOp.getLoc(), resultTypes, results);
539  } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
540  resultTypes, results,
541  /*toDynamic=*/false))) {
542  return failure();
543  }
544 
545  rewriter.replaceOp(callOp, results);
546  return success();
547  }
548 };
549 
550 struct CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
551  using Super::Super;
552 };
553 
554 struct CallIndirectOpLowering
555  : public CallOpInterfaceLowering<func::CallIndirectOp> {
556  using Super::Super;
557 };
558 
559 struct UnrealizedConversionCastOpLowering
560  : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
562  UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
563 
565  matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
566  ConversionPatternRewriter &rewriter) const override {
567  SmallVector<Type> convertedTypes;
568  if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
569  convertedTypes)) &&
570  convertedTypes == adaptor.getInputs().getTypes()) {
571  rewriter.replaceOp(op, adaptor.getInputs());
572  return success();
573  }
574 
575  convertedTypes.clear();
576  if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
577  convertedTypes)) &&
578  convertedTypes == op.getOutputs().getType()) {
579  rewriter.replaceOp(op, adaptor.getInputs());
580  return success();
581  }
582  return failure();
583  }
584 };
585 
586 // Special lowering pattern for `ReturnOps`. Unlike all other operations,
587 // `ReturnOp` interacts with the function signature and must have as many
588 // operands as the function has return values. Because in LLVM IR, functions
589 // can only return 0 or 1 value, we pack multiple values into a structure type.
590 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
591 // necessary before returning it
592 struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
594 
596  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
597  ConversionPatternRewriter &rewriter) const override {
598  Location loc = op.getLoc();
599  unsigned numArguments = op.getNumOperands();
600  SmallVector<Value, 4> updatedOperands;
601 
602  if (getTypeConverter()->getOptions().useBarePtrCallConv) {
603  // For the bare-ptr calling convention, extract the aligned pointer to
604  // be returned from the memref descriptor.
605  for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
606  Type oldTy = std::get<0>(it).getType();
607  Value newOperand = std::get<1>(it);
608  if (oldTy.isa<MemRefType>() && getTypeConverter()->canConvertToBarePtr(
609  oldTy.cast<BaseMemRefType>())) {
610  MemRefDescriptor memrefDesc(newOperand);
611  newOperand = memrefDesc.alignedPtr(rewriter, loc);
612  } else if (oldTy.isa<UnrankedMemRefType>()) {
613  // Unranked memref is not supported in the bare pointer calling
614  // convention.
615  return failure();
616  }
617  updatedOperands.push_back(newOperand);
618  }
619  } else {
620  updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
621  (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
622  updatedOperands,
623  /*toDynamic=*/true);
624  }
625 
626  // If ReturnOp has 0 or 1 operand, create it and return immediately.
627  if (numArguments == 0) {
628  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
629  op->getAttrs());
630  return success();
631  }
632  if (numArguments == 1) {
633  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
634  op, TypeRange(), updatedOperands, op->getAttrs());
635  return success();
636  }
637 
638  // Otherwise, we need to pack the arguments into an LLVM struct type before
639  // returning.
640  auto packedType = getTypeConverter()->packFunctionResults(
641  llvm::to_vector<4>(op.getOperandTypes()));
642 
643  Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
644  for (unsigned i = 0; i < numArguments; ++i) {
645  packed = rewriter.create<LLVM::InsertValueOp>(
646  loc, packedType, packed, updatedOperands[i],
647  rewriter.getI64ArrayAttr(i));
648  }
649  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
650  op->getAttrs());
651  return success();
652  }
653 };
654 } // namespace
655 
657  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
658  if (converter.getOptions().useBarePtrCallConv)
659  patterns.add<BarePtrFuncOpConversion>(converter);
660  else
661  patterns.add<FuncOpConversion>(converter);
662 }
663 
665  RewritePatternSet &patterns) {
666  populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
667  // clang-format off
668  patterns.add<
669  CallIndirectOpLowering,
670  CallOpLowering,
671  ConstantOpLowering,
672  ReturnOpLowering>(converter);
673  // clang-format on
674 }
675 
676 namespace {
677 /// A pass converting Func operations into the LLVM IR dialect.
678 struct ConvertFuncToLLVMPass
679  : public ConvertFuncToLLVMBase<ConvertFuncToLLVMPass> {
680  ConvertFuncToLLVMPass() = default;
681  ConvertFuncToLLVMPass(bool useBarePtrCallConv, unsigned indexBitwidth,
682  bool useAlignedAlloc,
683  const llvm::DataLayout &dataLayout) {
684  this->useBarePtrCallConv = useBarePtrCallConv;
685  this->indexBitwidth = indexBitwidth;
686  this->dataLayout = dataLayout.getStringRepresentation();
687  }
688 
689  /// Run the dialect converter on the module.
690  void runOnOperation() override {
691  if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
692  this->dataLayout, [this](const Twine &message) {
693  getOperation().emitError() << message.str();
694  }))) {
695  signalPassFailure();
696  return;
697  }
698 
699  ModuleOp m = getOperation();
700  const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
701 
702  LowerToLLVMOptions options(&getContext(),
703  dataLayoutAnalysis.getAtOrAbove(m));
704  options.useBarePtrCallConv = useBarePtrCallConv;
705  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
706  options.overrideIndexBitwidth(indexBitwidth);
707  options.dataLayout = llvm::DataLayout(this->dataLayout);
708 
709  LLVMTypeConverter typeConverter(&getContext(), options,
710  &dataLayoutAnalysis);
711 
712  RewritePatternSet patterns(&getContext());
713  populateFuncToLLVMConversionPatterns(typeConverter, patterns);
714 
715  // TODO: Remove these in favor of their dedicated conversion passes.
716  arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns);
717  cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
718 
719  LLVMConversionTarget target(getContext());
720  if (failed(applyPartialConversion(m, target, std::move(patterns))))
721  signalPassFailure();
722 
723  m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
724  StringAttr::get(m.getContext(), this->dataLayout));
725  }
726 };
727 } // namespace
728 
729 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertFuncToLLVMPass() {
730  return std::make_unique<ConvertFuncToLLVMPass>();
731 }
732 
733 std::unique_ptr<OperationPass<ModuleOp>>
735  auto allocLowering = options.allocLowering;
736  // There is no way to provide additional patterns for pass, so
737  // AllocLowering::None will always fail.
738  assert(allocLowering != LowerToLLVMOptions::AllocLowering::None &&
739  "ConvertFuncToLLVMPass doesn't support AllocLowering::None");
740  bool useAlignedAlloc =
742  return std::make_unique<ConvertFuncToLLVMPass>(
743  options.useBarePtrCallConv, options.getIndexBitwidth(), useAlignedAlloc,
744  options.dataLayout);
745 }
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:132
Do not lower heap allocations.
StringRef getResultDictAttrName()
Return the name of the attribute used for function argument attributes.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:55
MLIRContext * getContext() const
Definition: Builders.h:54
U cast() const
Definition: Attributes.h:130
const LowerToLLVMOptions & getOptions() const
Definition: TypeConverter.h:81
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:158
StringRef getArgDictAttrName()
Return the name of the attribute used for function argument attributes.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Block represents an ordered list of Operations.
Definition: Block.h:29
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static unsigned getNumUnpackedValues(MemRefType type)
Returns the number of non-aggregate values that would be produced by unpack.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
LLVM dialect function type.
Definition: LLVMTypes.h:128
Derived class that automatically populates legalization information for different LLVM ops...
static Value pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
void populateArithmeticToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
llvm::DataLayout dataLayout
The data layout of the module to produce.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
static void prependResAttrsToArgAttrs(OpBuilder &builder, SmallVectorImpl< NamedAttribute > &attributes, size_t numArguments)
Combines all result attributes into a single DictionaryAttr and prepends to argument attrs...
Definition: FuncToLLVM.cpp:83
StringRef getTypeAttrName()
Return the name of the attribute used for function types.
Operation & front()
Definition: Block.h:144
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:81
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
unsigned getIndexBitwidth() const
Get the index bitwidth.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static void filterFuncAttributes(ArrayRef< NamedAttribute > attrs, bool filterArgAndResAttrs, SmallVectorImpl< NamedAttribute > &result)
Only retain those attributes that are not constructed by LLVMFuncOp::build.
Definition: FuncToLLVM.cpp:53
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:144
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs)
Helper function for wrapping all attributes into a single DictionaryAttr.
Definition: FuncToLLVM.cpp:69
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class provides all of the information necessary to convert a type signature. ...
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:822
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, LLVMTypeConverter &typeConverter, func::FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
Definition: FuncToLLVM.cpp:126
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
static void wrapExternalFunction(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, func::FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
Definition: FuncToLLVM.cpp:186
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
U dyn_cast() const
Definition: Value.h:100
std::pair< Type, bool > convertFunctionTypeCWrapper(FunctionType type)
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:234
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
Use aligned_alloc for heap allocations.
unsigned getNumParams()
Returns the number of arguments to the function.
Definition: LLVMTypes.cpp:135
static LLVMPointerType get(MLIRContext *context, unsigned addressSpace=0)
Gets or creates an instance of LLVM dialect pointer type pointing to an object of pointee type in the...
Definition: LLVMTypes.cpp:188
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
BlockArgListType getArguments()
Definition: Block.h:76
This class represents an argument of a Block.
Definition: Value.h:300
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void populateFuncToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the default pattern to convert a FuncOp to the LLVM dialect.
Definition: FuncToLLVM.cpp:656
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static unsigned getNumUnpackedValues()
Returns the number of non-aggregate values that would be produced by unpack.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
std::unique_ptr< OperationPass< ModuleOp > > createConvertFuncToLLVMPass()
Creates a pass to convert the Func dialect into the LLVMIR dialect.
Definition: FuncToLLVM.cpp:729
static llvm::ManagedStatic< PassManagerOptions > options
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:369
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:286
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:402
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
Type getType() const
Return the type of this value.
Definition: Value.h:118
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
IndexType getIndexType()
Definition: Builders.cpp:48
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:113
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
Options to control the LLVM lowering.
This class implements a pattern rewriter for use with ConversionPatterns.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:664
static Value pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
bool isa() const
Definition: Types.h:246
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
This class helps build Operations.
Definition: Builders.h:184
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:205
This class provides an abstraction over the different types of ranges over Values.
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
U cast() const
Definition: Types.h:262