MLIR  19.0.0git
FuncToLLVM.cpp
Go to the documentation of this file.
1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR Func and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/Builders.h"
33 #include "mlir/IR/BuiltinOps.h"
34 #include "mlir/IR/IRMapping.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/IR/SymbolTable.h"
37 #include "mlir/IR/TypeUtilities.h"
40 #include "mlir/Transforms/Passes.h"
41 #include "llvm/ADT/SmallVector.h"
42 #include "llvm/ADT/TypeSwitch.h"
43 #include "llvm/IR/DerivedTypes.h"
44 #include "llvm/IR/IRBuilder.h"
45 #include "llvm/IR/Type.h"
46 #include "llvm/Support/Casting.h"
47 #include "llvm/Support/CommandLine.h"
48 #include "llvm/Support/FormatVariadic.h"
49 #include <algorithm>
50 #include <functional>
51 #include <optional>
52 
53 namespace mlir {
54 #define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
55 #define GEN_PASS_DEF_SETLLVMMODULEDATALAYOUTPASS
56 #include "mlir/Conversion/Passes.h.inc"
57 } // namespace mlir
58 
59 using namespace mlir;
60 
61 #define PASS_NAME "convert-func-to-llvm"
62 
63 static constexpr StringRef varargsAttrName = "func.varargs";
64 static constexpr StringRef linkageAttrName = "llvm.linkage";
65 static constexpr StringRef barePtrAttrName = "llvm.bareptr";
66 
67 /// Return `true` if the `op` should use bare pointer calling convention.
69  const LLVMTypeConverter *typeConverter) {
70  return (op && op->hasAttr(barePtrAttrName)) ||
71  typeConverter->getOptions().useBarePtrCallConv;
72 }
73 
74 /// Only retain those attributes that are not constructed by
75 /// `LLVMFuncOp::build`.
76 static void filterFuncAttributes(FunctionOpInterface func,
78  for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
79  if (attr.getName() == linkageAttrName ||
80  attr.getName() == varargsAttrName ||
81  attr.getName() == LLVM::LLVMDialect::getReadnoneAttrName())
82  continue;
83  result.push_back(attr);
84  }
85 }
86 
87 /// Propagate argument/results attributes.
88 static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
89  FunctionOpInterface funcOp,
90  LLVM::LLVMFuncOp wrapperFuncOp) {
91  auto argAttrs = funcOp.getAllArgAttrs();
92  if (!resultStructType) {
93  if (auto resAttrs = funcOp.getAllResultAttrs())
94  wrapperFuncOp.setAllResultAttrs(resAttrs);
95  if (argAttrs)
96  wrapperFuncOp.setAllArgAttrs(argAttrs);
97  } else {
98  SmallVector<Attribute> argAttributes;
99  // Only modify the argument and result attributes when the result is now
100  // an argument.
101  if (argAttrs) {
102  argAttributes.push_back(builder.getDictionaryAttr({}));
103  argAttributes.append(argAttrs.begin(), argAttrs.end());
104  wrapperFuncOp.setAllArgAttrs(argAttributes);
105  }
106  }
107  cast<FunctionOpInterface>(wrapperFuncOp.getOperation())
108  .setVisibility(funcOp.getVisibility());
109 }
110 
111 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
112 /// arguments instead of unpacked arguments. This function can be called from C
113 /// by passing a pointer to a C struct corresponding to a memref descriptor.
114 /// Similarly, returned memrefs are passed via pointers to a C struct that is
115 /// passed as additional argument.
116 /// Internally, the auxiliary function unpacks the descriptor into individual
117 /// components and forwards them to `newFuncOp` and forwards the results to
118 /// the extra arguments.
119 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
120  const LLVMTypeConverter &typeConverter,
121  FunctionOpInterface funcOp,
122  LLVM::LLVMFuncOp newFuncOp) {
123  auto type = cast<FunctionType>(funcOp.getFunctionType());
124  auto [wrapperFuncType, resultStructType] =
125  typeConverter.convertFunctionTypeCWrapper(type);
126 
127  SmallVector<NamedAttribute> attributes;
128  filterFuncAttributes(funcOp, attributes);
129 
130  auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
131  loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
132  wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false,
133  /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
134  propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp);
135 
136  OpBuilder::InsertionGuard guard(rewriter);
137  rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(rewriter));
138 
140  size_t argOffset = resultStructType ? 1 : 0;
141  for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
142  Value arg = wrapperFuncOp.getArgument(index + argOffset);
143  if (auto memrefType = dyn_cast<MemRefType>(argType)) {
144  Value loaded = rewriter.create<LLVM::LoadOp>(
145  loc, typeConverter.convertType(memrefType), arg);
146  MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
147  continue;
148  }
149  if (isa<UnrankedMemRefType>(argType)) {
150  Value loaded = rewriter.create<LLVM::LoadOp>(
151  loc, typeConverter.convertType(argType), arg);
152  UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
153  continue;
154  }
155 
156  args.push_back(arg);
157  }
158 
159  auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
160 
161  if (resultStructType) {
162  rewriter.create<LLVM::StoreOp>(loc, call.getResult(),
163  wrapperFuncOp.getArgument(0));
164  rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
165  } else {
166  rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
167  }
168 }
169 
170 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
171 /// arguments instead of unpacked arguments. Creates a body for the (external)
172 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
173 /// individual arguments into this descriptor and passes a pointer to it into
174 /// the auxiliary function. If the result of the function cannot be directly
175 /// returned, we write it to a special first argument that provides a pointer
176 /// to a corresponding struct. This auxiliary external function is now
177 /// compatible with functions defined in C using pointers to C structs
178 /// corresponding to a memref descriptor.
179 static void wrapExternalFunction(OpBuilder &builder, Location loc,
180  const LLVMTypeConverter &typeConverter,
181  FunctionOpInterface funcOp,
182  LLVM::LLVMFuncOp newFuncOp) {
183  OpBuilder::InsertionGuard guard(builder);
184 
185  auto [wrapperType, resultStructType] =
186  typeConverter.convertFunctionTypeCWrapper(
187  cast<FunctionType>(funcOp.getFunctionType()));
188  // This conversion can only fail if it could not convert one of the argument
189  // types. But since it has been applied to a non-wrapper function before, it
190  // should have failed earlier and not reach this point at all.
191  assert(wrapperType && "unexpected type conversion failure");
192 
194  filterFuncAttributes(funcOp, attributes);
195 
196  // Create the auxiliary function.
197  auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
198  loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
199  wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false,
200  /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
201  propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc);
202 
203  // The wrapper that we synthetize here should only be visible in this module.
204  newFuncOp.setLinkage(LLVM::Linkage::Private);
205  builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder));
206 
207  // Get a ValueRange containing arguments.
208  FunctionType type = cast<FunctionType>(funcOp.getFunctionType());
210  args.reserve(type.getNumInputs());
211  ValueRange wrapperArgsRange(newFuncOp.getArguments());
212 
213  if (resultStructType) {
214  // Allocate the struct on the stack and pass the pointer.
215  Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
216  Value one = builder.create<LLVM::ConstantOp>(
217  loc, typeConverter.convertType(builder.getIndexType()),
218  builder.getIntegerAttr(builder.getIndexType(), 1));
219  Value result =
220  builder.create<LLVM::AllocaOp>(loc, resultType, resultStructType, one);
221  args.push_back(result);
222  }
223 
224  // Iterate over the inputs of the original function and pack values into
225  // memref descriptors if the original type is a memref.
226  for (Type input : type.getInputs()) {
227  Value arg;
228  int numToDrop = 1;
229  auto memRefType = dyn_cast<MemRefType>(input);
230  auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(input);
231  if (memRefType || unrankedMemRefType) {
232  numToDrop = memRefType
235  Value packed =
236  memRefType
237  ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
238  wrapperArgsRange.take_front(numToDrop))
240  builder, loc, typeConverter, unrankedMemRefType,
241  wrapperArgsRange.take_front(numToDrop));
242 
243  auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
244  Value one = builder.create<LLVM::ConstantOp>(
245  loc, typeConverter.convertType(builder.getIndexType()),
246  builder.getIntegerAttr(builder.getIndexType(), 1));
247  Value allocated = builder.create<LLVM::AllocaOp>(
248  loc, ptrTy, packed.getType(), one, /*alignment=*/0);
249  builder.create<LLVM::StoreOp>(loc, packed, allocated);
250  arg = allocated;
251  } else {
252  arg = wrapperArgsRange[0];
253  }
254 
255  args.push_back(arg);
256  wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
257  }
258  assert(wrapperArgsRange.empty() && "did not map some of the arguments");
259 
260  auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
261 
262  if (resultStructType) {
263  Value result =
264  builder.create<LLVM::LoadOp>(loc, resultStructType, args.front());
265  builder.create<LLVM::ReturnOp>(loc, result);
266  } else {
267  builder.create<LLVM::ReturnOp>(loc, call.getResults());
268  }
269 }
270 
271 /// Modifies the body of the function to construct the `MemRefDescriptor` from
272 /// the bare pointer calling convention lowering of `memref` types.
274  ConversionPatternRewriter &rewriter, Location loc,
275  const LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp,
276  TypeRange oldArgTypes) {
277  if (funcOp.getBody().empty())
278  return;
279 
280  // Promote bare pointers from memref arguments to memref descriptors at the
281  // beginning of the function so that all the memrefs in the function have a
282  // uniform representation.
283  Block *entryBlock = &funcOp.getBody().front();
284  auto blockArgs = entryBlock->getArguments();
285  assert(blockArgs.size() == oldArgTypes.size() &&
286  "The number of arguments and types doesn't match");
287 
288  OpBuilder::InsertionGuard guard(rewriter);
289  rewriter.setInsertionPointToStart(entryBlock);
290  for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
291  BlockArgument arg = std::get<0>(it);
292  Type argTy = std::get<1>(it);
293 
294  // Unranked memrefs are not supported in the bare pointer calling
295  // convention. We should have bailed out before in the presence of
296  // unranked memrefs.
297  assert(!isa<UnrankedMemRefType>(argTy) &&
298  "Unranked memref is not supported");
299  auto memrefTy = dyn_cast<MemRefType>(argTy);
300  if (!memrefTy)
301  continue;
302 
303  // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
304  // or unranked memref descriptor and replace placeholder with the last
305  // instruction of the memref descriptor.
306  // TODO: The placeholder is needed to avoid replacing barePtr uses in the
307  // MemRef descriptor instructions. We may want to have a utility in the
308  // rewriter to properly handle this use case.
309  Location loc = funcOp.getLoc();
310  auto placeholder = rewriter.create<LLVM::UndefOp>(
311  loc, typeConverter.convertType(memrefTy));
312  rewriter.replaceUsesOfBlockArgument(arg, placeholder);
313 
314  Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter,
315  memrefTy, arg);
316  rewriter.replaceOp(placeholder, {desc});
317  }
318 }
319 
321 mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
322  ConversionPatternRewriter &rewriter,
323  const LLVMTypeConverter &converter) {
324  // Check the funcOp has `FunctionType`.
325  auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
326  if (!funcTy)
327  return rewriter.notifyMatchFailure(
328  funcOp, "Only support FunctionOpInterface with FunctionType");
329 
330  // Convert the original function arguments. They are converted using the
331  // LLVMTypeConverter provided to this legalization pattern.
332  auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
333  TypeConverter::SignatureConversion result(funcOp.getNumArguments());
334  auto llvmType = converter.convertFunctionSignature(
335  funcTy, varargsAttr && varargsAttr.getValue(),
336  shouldUseBarePtrCallConv(funcOp, &converter), result);
337  if (!llvmType)
338  return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
339 
340  // Create an LLVM function, use external linkage by default until MLIR
341  // functions have linkage.
342  LLVM::Linkage linkage = LLVM::Linkage::External;
343  if (funcOp->hasAttr(linkageAttrName)) {
344  auto attr =
345  dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
346  if (!attr) {
347  funcOp->emitError() << "Contains " << linkageAttrName
348  << " attribute not of type LLVM::LinkageAttr";
349  return rewriter.notifyMatchFailure(
350  funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
351  }
352  linkage = attr.getLinkage();
353  }
354 
356  filterFuncAttributes(funcOp, attributes);
357  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
358  funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
359  /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
360  attributes);
361  cast<FunctionOpInterface>(newFuncOp.getOperation())
362  .setVisibility(funcOp.getVisibility());
363 
364  // Create a memory effect attribute corresponding to readnone.
365  StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
366  if (funcOp->hasAttr(readnoneAttrName)) {
367  auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
368  if (!attr) {
369  funcOp->emitError() << "Contains " << readnoneAttrName
370  << " attribute not of type UnitAttr";
371  return rewriter.notifyMatchFailure(
372  funcOp, "Contains readnone attribute not of type UnitAttr");
373  }
374  auto memoryAttr = LLVM::MemoryEffectsAttr::get(
375  rewriter.getContext(),
376  {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
377  LLVM::ModRefInfo::NoModRef});
378  newFuncOp.setMemoryAttr(memoryAttr);
379  }
380 
381  // Propagate argument/result attributes to all converted arguments/result
382  // obtained after converting a given original argument/result.
383  if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
384  assert(!resAttrDicts.empty() && "expected array to be non-empty");
385  if (funcOp.getNumResults() == 1)
386  newFuncOp.setAllResultAttrs(resAttrDicts);
387  }
388  if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
389  SmallVector<Attribute> newArgAttrs(
390  cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
391  for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
392  // Some LLVM IR attribute have a type attached to them. During FuncOp ->
393  // LLVMFuncOp conversion these types may have changed. Account for that
394  // change by converting attributes' types as well.
395  SmallVector<NamedAttribute, 4> convertedAttrs;
396  auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
397  convertedAttrs.reserve(attrsDict.size());
398  for (const NamedAttribute &attr : attrsDict) {
399  const auto convert = [&](const NamedAttribute &attr) {
400  return TypeAttr::get(converter.convertType(
401  cast<TypeAttr>(attr.getValue()).getValue()));
402  };
403  if (attr.getName().getValue() ==
404  LLVM::LLVMDialect::getByValAttrName()) {
405  convertedAttrs.push_back(rewriter.getNamedAttr(
406  LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
407  } else if (attr.getName().getValue() ==
408  LLVM::LLVMDialect::getByRefAttrName()) {
409  convertedAttrs.push_back(rewriter.getNamedAttr(
410  LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
411  } else if (attr.getName().getValue() ==
412  LLVM::LLVMDialect::getStructRetAttrName()) {
413  convertedAttrs.push_back(rewriter.getNamedAttr(
414  LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
415  } else if (attr.getName().getValue() ==
416  LLVM::LLVMDialect::getInAllocaAttrName()) {
417  convertedAttrs.push_back(rewriter.getNamedAttr(
418  LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
419  } else {
420  convertedAttrs.push_back(attr);
421  }
422  }
423  auto mapping = result.getInputMapping(i);
424  assert(mapping && "unexpected deletion of function argument");
425  // Only attach the new argument attributes if there is a one-to-one
426  // mapping from old to new types. Otherwise, attributes might be
427  // attached to types that they do not support.
428  if (mapping->size == 1) {
429  newArgAttrs[mapping->inputNo] =
430  DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
431  continue;
432  }
433  // TODO: Implement custom handling for types that expand to multiple
434  // function arguments.
435  for (size_t j = 0; j < mapping->size; ++j)
436  newArgAttrs[mapping->inputNo + j] =
437  DictionaryAttr::get(rewriter.getContext(), {});
438  }
439  if (!newArgAttrs.empty())
440  newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs));
441  }
442 
443  rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
444  newFuncOp.end());
445  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), converter,
446  &result))) {
447  return rewriter.notifyMatchFailure(funcOp,
448  "region types conversion failed");
449  }
450 
451  if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
452  if (funcOp->getAttrOfType<UnitAttr>(
453  LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
454  if (newFuncOp.isVarArg())
455  return funcOp.emitError("C interface for variadic functions is not "
456  "supported yet.");
457 
458  if (newFuncOp.isExternal())
459  wrapExternalFunction(rewriter, funcOp->getLoc(), converter, funcOp,
460  newFuncOp);
461  else
462  wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
463  newFuncOp);
464  }
465  } else {
467  rewriter, funcOp->getLoc(), converter, newFuncOp,
468  llvm::cast<FunctionType>(funcOp.getFunctionType()).getInputs());
469  }
470 
471  return newFuncOp;
472 }
473 
474 namespace {
475 
476 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
477 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
478 /// information.
479 struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
480  FuncOpConversion(const LLVMTypeConverter &converter)
481  : ConvertOpToLLVMPattern(converter) {}
482 
484  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
485  ConversionPatternRewriter &rewriter) const override {
487  cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
488  *getTypeConverter());
489  if (failed(newFuncOp))
490  return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
491 
492  rewriter.eraseOp(funcOp);
493  return success();
494  }
495 };
496 
497 struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
499 
501  matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
502  ConversionPatternRewriter &rewriter) const override {
503  auto type = typeConverter->convertType(op.getResult().getType());
504  if (!type || !LLVM::isCompatibleType(type))
505  return rewriter.notifyMatchFailure(op, "failed to convert result type");
506 
507  auto newOp =
508  rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
509  for (const NamedAttribute &attr : op->getAttrs()) {
510  if (attr.getName().strref() == "value")
511  continue;
512  newOp->setAttr(attr.getName(), attr.getValue());
513  }
514  rewriter.replaceOp(op, newOp->getResults());
515  return success();
516  }
517 };
518 
519 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
520 // passes the pointer to the MemRef across function boundaries.
521 template <typename CallOpType>
522 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
524  using Super = CallOpInterfaceLowering<CallOpType>;
526 
527  LogicalResult matchAndRewriteImpl(CallOpType callOp,
528  typename CallOpType::Adaptor adaptor,
529  ConversionPatternRewriter &rewriter,
530  bool useBarePtrCallConv = false) const {
531  // Pack the result types into a struct.
532  Type packedResult = nullptr;
533  unsigned numResults = callOp.getNumResults();
534  auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
535 
536  if (numResults != 0) {
537  if (!(packedResult = this->getTypeConverter()->packFunctionResults(
538  resultTypes, useBarePtrCallConv)))
539  return failure();
540  }
541 
542  if (useBarePtrCallConv) {
543  for (auto it : callOp->getOperands()) {
544  Type operandType = it.getType();
545  if (isa<UnrankedMemRefType>(operandType)) {
546  // Unranked memref is not supported in the bare pointer calling
547  // convention.
548  return failure();
549  }
550  }
551  }
552  auto promoted = this->getTypeConverter()->promoteOperands(
553  callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
554  adaptor.getOperands(), rewriter, useBarePtrCallConv);
555  auto newOp = rewriter.create<LLVM::CallOp>(
556  callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
557  promoted, callOp->getAttrs());
558 
559  SmallVector<Value, 4> results;
560  if (numResults < 2) {
561  // If < 2 results, packing did not do anything and we can just return.
562  results.append(newOp.result_begin(), newOp.result_end());
563  } else {
564  // Otherwise, it had been converted to an operation producing a structure.
565  // Extract individual results from the structure and return them as list.
566  results.reserve(numResults);
567  for (unsigned i = 0; i < numResults; ++i) {
568  results.push_back(rewriter.create<LLVM::ExtractValueOp>(
569  callOp.getLoc(), newOp->getResult(0), i));
570  }
571  }
572 
573  if (useBarePtrCallConv) {
574  // For the bare-ptr calling convention, promote memref results to
575  // descriptors.
576  assert(results.size() == resultTypes.size() &&
577  "The number of arguments and types doesn't match");
578  this->getTypeConverter()->promoteBarePtrsToDescriptors(
579  rewriter, callOp.getLoc(), resultTypes, results);
580  } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
581  resultTypes, results,
582  /*toDynamic=*/false))) {
583  return failure();
584  }
585 
586  rewriter.replaceOp(callOp, results);
587  return success();
588  }
589 };
590 
591 class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
592 public:
593  CallOpLowering(const LLVMTypeConverter &typeConverter,
594  // Can be nullptr.
595  const SymbolTable *symbolTable, PatternBenefit benefit = 1)
596  : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
597  symbolTable(symbolTable) {}
598 
600  matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
601  ConversionPatternRewriter &rewriter) const override {
602  bool useBarePtrCallConv = false;
603  if (getTypeConverter()->getOptions().useBarePtrCallConv) {
604  useBarePtrCallConv = true;
605  } else if (symbolTable != nullptr) {
606  // Fast lookup.
607  Operation *callee =
608  symbolTable->lookup(callOp.getCalleeAttr().getValue());
609  useBarePtrCallConv =
610  callee != nullptr && callee->hasAttr(barePtrAttrName);
611  } else {
612  // Warning: This is a linear lookup.
613  Operation *callee =
614  SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
615  useBarePtrCallConv =
616  callee != nullptr && callee->hasAttr(barePtrAttrName);
617  }
618  return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
619  }
620 
621 private:
622  const SymbolTable *symbolTable = nullptr;
623 };
624 
625 struct CallIndirectOpLowering
626  : public CallOpInterfaceLowering<func::CallIndirectOp> {
627  using Super::Super;
628 
630  matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
631  ConversionPatternRewriter &rewriter) const override {
632  return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
633  }
634 };
635 
636 struct UnrealizedConversionCastOpLowering
637  : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
639  UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
640 
642  matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
643  ConversionPatternRewriter &rewriter) const override {
644  SmallVector<Type> convertedTypes;
645  if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
646  convertedTypes)) &&
647  convertedTypes == adaptor.getInputs().getTypes()) {
648  rewriter.replaceOp(op, adaptor.getInputs());
649  return success();
650  }
651 
652  convertedTypes.clear();
653  if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
654  convertedTypes)) &&
655  convertedTypes == op.getOutputs().getType()) {
656  rewriter.replaceOp(op, adaptor.getInputs());
657  return success();
658  }
659  return failure();
660  }
661 };
662 
663 // Special lowering pattern for `ReturnOps`. Unlike all other operations,
664 // `ReturnOp` interacts with the function signature and must have as many
665 // operands as the function has return values. Because in LLVM IR, functions
666 // can only return 0 or 1 value, we pack multiple values into a structure type.
667 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
668 // necessary before returning it
669 struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
671 
673  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
674  ConversionPatternRewriter &rewriter) const override {
675  Location loc = op.getLoc();
676  unsigned numArguments = op.getNumOperands();
677  SmallVector<Value, 4> updatedOperands;
678 
679  auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
680  bool useBarePtrCallConv =
681  shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
682  if (useBarePtrCallConv) {
683  // For the bare-ptr calling convention, extract the aligned pointer to
684  // be returned from the memref descriptor.
685  for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
686  Type oldTy = std::get<0>(it).getType();
687  Value newOperand = std::get<1>(it);
688  if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
689  cast<BaseMemRefType>(oldTy))) {
690  MemRefDescriptor memrefDesc(newOperand);
691  newOperand = memrefDesc.allocatedPtr(rewriter, loc);
692  } else if (isa<UnrankedMemRefType>(oldTy)) {
693  // Unranked memref is not supported in the bare pointer calling
694  // convention.
695  return failure();
696  }
697  updatedOperands.push_back(newOperand);
698  }
699  } else {
700  updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
701  (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
702  updatedOperands,
703  /*toDynamic=*/true);
704  }
705 
706  // If ReturnOp has 0 or 1 operand, create it and return immediately.
707  if (numArguments <= 1) {
708  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
709  op, TypeRange(), updatedOperands, op->getAttrs());
710  return success();
711  }
712 
713  // Otherwise, we need to pack the arguments into an LLVM struct type before
714  // returning.
715  auto packedType = getTypeConverter()->packFunctionResults(
716  op.getOperandTypes(), useBarePtrCallConv);
717  if (!packedType) {
718  return rewriter.notifyMatchFailure(op, "could not convert result types");
719  }
720 
721  Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
722  for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
723  packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
724  }
725  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
726  op->getAttrs());
727  return success();
728  }
729 };
730 } // namespace
731 
733  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
734  patterns.add<FuncOpConversion>(converter);
735 }
736 
738  LLVMTypeConverter &converter, RewritePatternSet &patterns,
739  const SymbolTable *symbolTable) {
740  populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
741  patterns.add<CallIndirectOpLowering>(converter);
742  patterns.add<CallOpLowering>(converter, symbolTable);
743  patterns.add<ConstantOpLowering>(converter);
744  patterns.add<ReturnOpLowering>(converter);
745 }
746 
747 namespace {
748 /// A pass converting Func operations into the LLVM IR dialect.
749 struct ConvertFuncToLLVMPass
750  : public impl::ConvertFuncToLLVMPassBase<ConvertFuncToLLVMPass> {
751  using Base::Base;
752 
753  /// Run the dialect converter on the module.
754  void runOnOperation() override {
755  ModuleOp m = getOperation();
756  StringRef dataLayout;
757  auto dataLayoutAttr = dyn_cast_or_null<StringAttr>(
758  m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));
759  if (dataLayoutAttr)
760  dataLayout = dataLayoutAttr.getValue();
761 
762  if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
763  dataLayout, [this](const Twine &message) {
764  getOperation().emitError() << message.str();
765  }))) {
766  signalPassFailure();
767  return;
768  }
769 
770  const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
771 
773  dataLayoutAnalysis.getAtOrAbove(m));
774  options.useBarePtrCallConv = useBarePtrCallConv;
775  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
776  options.overrideIndexBitwidth(indexBitwidth);
777  options.dataLayout = llvm::DataLayout(dataLayout);
778 
779  LLVMTypeConverter typeConverter(&getContext(), options,
780  &dataLayoutAnalysis);
781 
782  std::optional<SymbolTable> optSymbolTable = std::nullopt;
783  const SymbolTable *symbolTable = nullptr;
784  if (!options.useBarePtrCallConv) {
785  optSymbolTable.emplace(m);
786  symbolTable = &optSymbolTable.value();
787  }
788 
789  RewritePatternSet patterns(&getContext());
790  populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable);
791 
792  // TODO(https://github.com/llvm/llvm-project/issues/70982): Remove these in
793  // favor of their dedicated conversion passes.
794  arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
795  cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
796 
798  if (failed(applyPartialConversion(m, target, std::move(patterns))))
799  signalPassFailure();
800  }
801 };
802 
803 struct SetLLVMModuleDataLayoutPass
804  : public impl::SetLLVMModuleDataLayoutPassBase<
805  SetLLVMModuleDataLayoutPass> {
806  using Base::Base;
807 
808  /// Run the dialect converter on the module.
809  void runOnOperation() override {
810  if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
811  this->dataLayout, [this](const Twine &message) {
812  getOperation().emitError() << message.str();
813  }))) {
814  signalPassFailure();
815  return;
816  }
817  ModuleOp m = getOperation();
818  m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
819  StringAttr::get(m.getContext(), this->dataLayout));
820  }
821 };
822 } // namespace
823 
824 //===----------------------------------------------------------------------===//
825 // ConvertToLLVMPatternInterface implementation
826 //===----------------------------------------------------------------------===//
827 
828 namespace {
829 /// Implement the interface to convert Func to LLVM.
830 struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
832  /// Hook for derived dialect interface to provide conversion patterns
833  /// and mark dialect legal for the conversion target.
834  void populateConvertToLLVMConversionPatterns(
835  ConversionTarget &target, LLVMTypeConverter &typeConverter,
836  RewritePatternSet &patterns) const final {
837  populateFuncToLLVMConversionPatterns(typeConverter, patterns);
838  }
839 };
840 } // namespace
841 
843  registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
844  dialect->addInterfaces<FuncToLLVMDialectInterface>();
845  });
846 }
static void modifyFuncOpToUseBarePtrCallingConv(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp, TypeRange oldArgTypes)
Modifies the body of the function to construct the MemRefDescriptor from the bare pointer calling con...
Definition: FuncToLLVM.cpp:273
static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType, FunctionOpInterface funcOp, LLVM::LLVMFuncOp wrapperFuncOp)
Propagate argument/results attributes.
Definition: FuncToLLVM.cpp:88
static constexpr StringRef barePtrAttrName
Definition: FuncToLLVM.cpp:65
static constexpr StringRef varargsAttrName
Definition: FuncToLLVM.cpp:63
static constexpr StringRef linkageAttrName
Definition: FuncToLLVM.cpp:64
static void filterFuncAttributes(FunctionOpInterface func, SmallVectorImpl< NamedAttribute > &result)
Only retain those attributes that are not constructed by LLVMFuncOp::build.
Definition: FuncToLLVM.cpp:76
static bool shouldUseBarePtrCallConv(Operation *op, const LLVMTypeConverter *typeConverter)
Return true if the op should use bare pointer calling convention.
Definition: FuncToLLVM.cpp:68
static void wrapExternalFunction(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, FunctionOpInterface funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
Definition: FuncToLLVM.cpp:179
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, const LLVMTypeConverter &typeConverter, FunctionOpInterface funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
Definition: FuncToLLVM.cpp:119
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
IndexType getIndexType()
Definition: Builders.cpp:71
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:120
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:110
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
const LowerToLLVMOptions & getOptions() const
Definition: TypeConverter.h:94
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
static unsigned getNumUnpackedValues(MemRefType type)
Returns the number of non-aggregate values that would be produced by unpack.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
operand_type_range getOperandTypes()
Definition: Operation.h:392
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides all of the information necessary to convert a type signature.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
static unsigned getNumUnpackedValues()
Returns the number of non-aggregate values that would be produced by unpack.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:858
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:737
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void registerConvertFuncToLLVMInterface(DialectRegistry &registry)
Definition: FuncToLLVM.cpp:842
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
FailureOr< LLVM::LLVMFuncOp > convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter)
Convert input FunctionOpInterface operation to LLVMFuncOp by using the provided LLVMTypeConverter.
Definition: FuncToLLVM.cpp:321
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
void populateFuncToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the default pattern to convert a FuncOp to the LLVM dialect.
Definition: FuncToLLVM.cpp:732
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.