MLIR  20.0.0git
FuncToLLVM.cpp
Go to the documentation of this file.
1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR Func and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/Builders.h"
33 #include "mlir/IR/BuiltinOps.h"
34 #include "mlir/IR/IRMapping.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/IR/SymbolTable.h"
37 #include "mlir/IR/TypeUtilities.h"
39 #include "mlir/Transforms/Passes.h"
40 #include "llvm/ADT/SmallVector.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/IR/DerivedTypes.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/Type.h"
45 #include "llvm/Support/Casting.h"
46 #include "llvm/Support/CommandLine.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include <algorithm>
49 #include <functional>
50 #include <optional>
51 
52 namespace mlir {
53 #define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
54 #define GEN_PASS_DEF_SETLLVMMODULEDATALAYOUTPASS
55 #include "mlir/Conversion/Passes.h.inc"
56 } // namespace mlir
57 
58 using namespace mlir;
59 
60 #define PASS_NAME "convert-func-to-llvm"
61 
62 static constexpr StringRef varargsAttrName = "func.varargs";
63 static constexpr StringRef linkageAttrName = "llvm.linkage";
64 static constexpr StringRef barePtrAttrName = "llvm.bareptr";
65 
66 /// Return `true` if the `op` should use bare pointer calling convention.
68  const LLVMTypeConverter *typeConverter) {
69  return (op && op->hasAttr(barePtrAttrName)) ||
70  typeConverter->getOptions().useBarePtrCallConv;
71 }
72 
73 /// Only retain those attributes that are not constructed by
74 /// `LLVMFuncOp::build`.
75 static void filterFuncAttributes(FunctionOpInterface func,
77  for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
78  if (attr.getName() == linkageAttrName ||
79  attr.getName() == varargsAttrName ||
80  attr.getName() == LLVM::LLVMDialect::getReadnoneAttrName())
81  continue;
82  result.push_back(attr);
83  }
84 }
85 
86 /// Propagate argument/results attributes.
87 static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
88  FunctionOpInterface funcOp,
89  LLVM::LLVMFuncOp wrapperFuncOp) {
90  auto argAttrs = funcOp.getAllArgAttrs();
91  if (!resultStructType) {
92  if (auto resAttrs = funcOp.getAllResultAttrs())
93  wrapperFuncOp.setAllResultAttrs(resAttrs);
94  if (argAttrs)
95  wrapperFuncOp.setAllArgAttrs(argAttrs);
96  } else {
97  SmallVector<Attribute> argAttributes;
98  // Only modify the argument and result attributes when the result is now
99  // an argument.
100  if (argAttrs) {
101  argAttributes.push_back(builder.getDictionaryAttr({}));
102  argAttributes.append(argAttrs.begin(), argAttrs.end());
103  wrapperFuncOp.setAllArgAttrs(argAttributes);
104  }
105  }
106  cast<FunctionOpInterface>(wrapperFuncOp.getOperation())
107  .setVisibility(funcOp.getVisibility());
108 }
109 
110 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
111 /// arguments instead of unpacked arguments. This function can be called from C
112 /// by passing a pointer to a C struct corresponding to a memref descriptor.
113 /// Similarly, returned memrefs are passed via pointers to a C struct that is
114 /// passed as additional argument.
115 /// Internally, the auxiliary function unpacks the descriptor into individual
116 /// components and forwards them to `newFuncOp` and forwards the results to
117 /// the extra arguments.
118 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
119  const LLVMTypeConverter &typeConverter,
120  FunctionOpInterface funcOp,
121  LLVM::LLVMFuncOp newFuncOp) {
122  auto type = cast<FunctionType>(funcOp.getFunctionType());
123  auto [wrapperFuncType, resultStructType] =
124  typeConverter.convertFunctionTypeCWrapper(type);
125 
126  SmallVector<NamedAttribute> attributes;
127  filterFuncAttributes(funcOp, attributes);
128 
129  auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
130  loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
131  wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false,
132  /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
133  propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp);
134 
135  OpBuilder::InsertionGuard guard(rewriter);
136  rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(rewriter));
137 
139  size_t argOffset = resultStructType ? 1 : 0;
140  for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
141  Value arg = wrapperFuncOp.getArgument(index + argOffset);
142  if (auto memrefType = dyn_cast<MemRefType>(argType)) {
143  Value loaded = rewriter.create<LLVM::LoadOp>(
144  loc, typeConverter.convertType(memrefType), arg);
145  MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
146  continue;
147  }
148  if (isa<UnrankedMemRefType>(argType)) {
149  Value loaded = rewriter.create<LLVM::LoadOp>(
150  loc, typeConverter.convertType(argType), arg);
151  UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
152  continue;
153  }
154 
155  args.push_back(arg);
156  }
157 
158  auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
159 
160  if (resultStructType) {
161  rewriter.create<LLVM::StoreOp>(loc, call.getResult(),
162  wrapperFuncOp.getArgument(0));
163  rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
164  } else {
165  rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
166  }
167 }
168 
169 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
170 /// arguments instead of unpacked arguments. Creates a body for the (external)
171 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
172 /// individual arguments into this descriptor and passes a pointer to it into
173 /// the auxiliary function. If the result of the function cannot be directly
174 /// returned, we write it to a special first argument that provides a pointer
175 /// to a corresponding struct. This auxiliary external function is now
176 /// compatible with functions defined in C using pointers to C structs
177 /// corresponding to a memref descriptor.
178 static void wrapExternalFunction(OpBuilder &builder, Location loc,
179  const LLVMTypeConverter &typeConverter,
180  FunctionOpInterface funcOp,
181  LLVM::LLVMFuncOp newFuncOp) {
182  OpBuilder::InsertionGuard guard(builder);
183 
184  auto [wrapperType, resultStructType] =
185  typeConverter.convertFunctionTypeCWrapper(
186  cast<FunctionType>(funcOp.getFunctionType()));
187  // This conversion can only fail if it could not convert one of the argument
188  // types. But since it has been applied to a non-wrapper function before, it
189  // should have failed earlier and not reach this point at all.
190  assert(wrapperType && "unexpected type conversion failure");
191 
193  filterFuncAttributes(funcOp, attributes);
194 
195  // Create the auxiliary function.
196  auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
197  loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
198  wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false,
199  /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
200  propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc);
201 
202  // The wrapper that we synthetize here should only be visible in this module.
203  newFuncOp.setLinkage(LLVM::Linkage::Private);
204  builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder));
205 
206  // Get a ValueRange containing arguments.
207  FunctionType type = cast<FunctionType>(funcOp.getFunctionType());
209  args.reserve(type.getNumInputs());
210  ValueRange wrapperArgsRange(newFuncOp.getArguments());
211 
212  if (resultStructType) {
213  // Allocate the struct on the stack and pass the pointer.
214  Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
215  Value one = builder.create<LLVM::ConstantOp>(
216  loc, typeConverter.convertType(builder.getIndexType()),
217  builder.getIntegerAttr(builder.getIndexType(), 1));
218  Value result =
219  builder.create<LLVM::AllocaOp>(loc, resultType, resultStructType, one);
220  args.push_back(result);
221  }
222 
223  // Iterate over the inputs of the original function and pack values into
224  // memref descriptors if the original type is a memref.
225  for (Type input : type.getInputs()) {
226  Value arg;
227  int numToDrop = 1;
228  auto memRefType = dyn_cast<MemRefType>(input);
229  auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(input);
230  if (memRefType || unrankedMemRefType) {
231  numToDrop = memRefType
234  Value packed =
235  memRefType
236  ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
237  wrapperArgsRange.take_front(numToDrop))
239  builder, loc, typeConverter, unrankedMemRefType,
240  wrapperArgsRange.take_front(numToDrop));
241 
242  auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
243  Value one = builder.create<LLVM::ConstantOp>(
244  loc, typeConverter.convertType(builder.getIndexType()),
245  builder.getIntegerAttr(builder.getIndexType(), 1));
246  Value allocated = builder.create<LLVM::AllocaOp>(
247  loc, ptrTy, packed.getType(), one, /*alignment=*/0);
248  builder.create<LLVM::StoreOp>(loc, packed, allocated);
249  arg = allocated;
250  } else {
251  arg = wrapperArgsRange[0];
252  }
253 
254  args.push_back(arg);
255  wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
256  }
257  assert(wrapperArgsRange.empty() && "did not map some of the arguments");
258 
259  auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
260 
261  if (resultStructType) {
262  Value result =
263  builder.create<LLVM::LoadOp>(loc, resultStructType, args.front());
264  builder.create<LLVM::ReturnOp>(loc, result);
265  } else {
266  builder.create<LLVM::ReturnOp>(loc, call.getResults());
267  }
268 }
269 
270 FailureOr<LLVM::LLVMFuncOp>
271 mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
272  ConversionPatternRewriter &rewriter,
273  const LLVMTypeConverter &converter) {
274  // Check the funcOp has `FunctionType`.
275  auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
276  if (!funcTy)
277  return rewriter.notifyMatchFailure(
278  funcOp, "Only support FunctionOpInterface with FunctionType");
279 
280  // Convert the original function arguments. They are converted using the
281  // LLVMTypeConverter provided to this legalization pattern.
282  auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
283  TypeConverter::SignatureConversion result(funcOp.getNumArguments());
284  auto llvmType = converter.convertFunctionSignature(
285  funcTy, varargsAttr && varargsAttr.getValue(),
286  shouldUseBarePtrCallConv(funcOp, &converter), result);
287  if (!llvmType)
288  return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
289 
290  // Create an LLVM function, use external linkage by default until MLIR
291  // functions have linkage.
292  LLVM::Linkage linkage = LLVM::Linkage::External;
293  if (funcOp->hasAttr(linkageAttrName)) {
294  auto attr =
295  dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
296  if (!attr) {
297  funcOp->emitError() << "Contains " << linkageAttrName
298  << " attribute not of type LLVM::LinkageAttr";
299  return rewriter.notifyMatchFailure(
300  funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
301  }
302  linkage = attr.getLinkage();
303  }
304 
306  filterFuncAttributes(funcOp, attributes);
307  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
308  funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
309  /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
310  attributes);
311  cast<FunctionOpInterface>(newFuncOp.getOperation())
312  .setVisibility(funcOp.getVisibility());
313 
314  // Create a memory effect attribute corresponding to readnone.
315  StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
316  if (funcOp->hasAttr(readnoneAttrName)) {
317  auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
318  if (!attr) {
319  funcOp->emitError() << "Contains " << readnoneAttrName
320  << " attribute not of type UnitAttr";
321  return rewriter.notifyMatchFailure(
322  funcOp, "Contains readnone attribute not of type UnitAttr");
323  }
324  auto memoryAttr = LLVM::MemoryEffectsAttr::get(
325  rewriter.getContext(),
326  {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
327  LLVM::ModRefInfo::NoModRef});
328  newFuncOp.setMemoryEffectsAttr(memoryAttr);
329  }
330 
331  // Propagate argument/result attributes to all converted arguments/result
332  // obtained after converting a given original argument/result.
333  if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
334  assert(!resAttrDicts.empty() && "expected array to be non-empty");
335  if (funcOp.getNumResults() == 1)
336  newFuncOp.setAllResultAttrs(resAttrDicts);
337  }
338  if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
339  SmallVector<Attribute> newArgAttrs(
340  cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
341  for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
342  // Some LLVM IR attribute have a type attached to them. During FuncOp ->
343  // LLVMFuncOp conversion these types may have changed. Account for that
344  // change by converting attributes' types as well.
345  SmallVector<NamedAttribute, 4> convertedAttrs;
346  auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
347  convertedAttrs.reserve(attrsDict.size());
348  for (const NamedAttribute &attr : attrsDict) {
349  const auto convert = [&](const NamedAttribute &attr) {
350  return TypeAttr::get(converter.convertType(
351  cast<TypeAttr>(attr.getValue()).getValue()));
352  };
353  if (attr.getName().getValue() ==
354  LLVM::LLVMDialect::getByValAttrName()) {
355  convertedAttrs.push_back(rewriter.getNamedAttr(
356  LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
357  } else if (attr.getName().getValue() ==
358  LLVM::LLVMDialect::getByRefAttrName()) {
359  convertedAttrs.push_back(rewriter.getNamedAttr(
360  LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
361  } else if (attr.getName().getValue() ==
362  LLVM::LLVMDialect::getStructRetAttrName()) {
363  convertedAttrs.push_back(rewriter.getNamedAttr(
364  LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
365  } else if (attr.getName().getValue() ==
366  LLVM::LLVMDialect::getInAllocaAttrName()) {
367  convertedAttrs.push_back(rewriter.getNamedAttr(
368  LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
369  } else {
370  convertedAttrs.push_back(attr);
371  }
372  }
373  auto mapping = result.getInputMapping(i);
374  assert(mapping && "unexpected deletion of function argument");
375  // Only attach the new argument attributes if there is a one-to-one
376  // mapping from old to new types. Otherwise, attributes might be
377  // attached to types that they do not support.
378  if (mapping->size == 1) {
379  newArgAttrs[mapping->inputNo] =
380  DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
381  continue;
382  }
383  // TODO: Implement custom handling for types that expand to multiple
384  // function arguments.
385  for (size_t j = 0; j < mapping->size; ++j)
386  newArgAttrs[mapping->inputNo + j] =
387  DictionaryAttr::get(rewriter.getContext(), {});
388  }
389  if (!newArgAttrs.empty())
390  newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs));
391  }
392 
393  rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
394  newFuncOp.end());
395  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), converter,
396  &result))) {
397  return rewriter.notifyMatchFailure(funcOp,
398  "region types conversion failed");
399  }
400 
401  if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
402  if (funcOp->getAttrOfType<UnitAttr>(
403  LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
404  if (newFuncOp.isVarArg())
405  return funcOp.emitError("C interface for variadic functions is not "
406  "supported yet.");
407 
408  if (newFuncOp.isExternal())
409  wrapExternalFunction(rewriter, funcOp->getLoc(), converter, funcOp,
410  newFuncOp);
411  else
412  wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
413  newFuncOp);
414  }
415  }
416 
417  return newFuncOp;
418 }
419 
420 namespace {
421 
422 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
423 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
424 /// information.
425 struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
426  FuncOpConversion(const LLVMTypeConverter &converter)
427  : ConvertOpToLLVMPattern(converter) {}
428 
429  LogicalResult
430  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
431  ConversionPatternRewriter &rewriter) const override {
432  FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
433  cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
434  *getTypeConverter());
435  if (failed(newFuncOp))
436  return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
437 
438  rewriter.eraseOp(funcOp);
439  return success();
440  }
441 };
442 
443 struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
445 
446  LogicalResult
447  matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
448  ConversionPatternRewriter &rewriter) const override {
449  auto type = typeConverter->convertType(op.getResult().getType());
450  if (!type || !LLVM::isCompatibleType(type))
451  return rewriter.notifyMatchFailure(op, "failed to convert result type");
452 
453  auto newOp =
454  rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
455  for (const NamedAttribute &attr : op->getAttrs()) {
456  if (attr.getName().strref() == "value")
457  continue;
458  newOp->setAttr(attr.getName(), attr.getValue());
459  }
460  rewriter.replaceOp(op, newOp->getResults());
461  return success();
462  }
463 };
464 
465 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
466 // passes the pointer to the MemRef across function boundaries.
467 template <typename CallOpType>
468 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
470  using Super = CallOpInterfaceLowering<CallOpType>;
472 
473  LogicalResult matchAndRewriteImpl(CallOpType callOp,
474  typename CallOpType::Adaptor adaptor,
475  ConversionPatternRewriter &rewriter,
476  bool useBarePtrCallConv = false) const {
477  // Pack the result types into a struct.
478  Type packedResult = nullptr;
479  unsigned numResults = callOp.getNumResults();
480  auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
481 
482  if (numResults != 0) {
483  if (!(packedResult = this->getTypeConverter()->packFunctionResults(
484  resultTypes, useBarePtrCallConv)))
485  return failure();
486  }
487 
488  if (useBarePtrCallConv) {
489  for (auto it : callOp->getOperands()) {
490  Type operandType = it.getType();
491  if (isa<UnrankedMemRefType>(operandType)) {
492  // Unranked memref is not supported in the bare pointer calling
493  // convention.
494  return failure();
495  }
496  }
497  }
498  auto promoted = this->getTypeConverter()->promoteOperands(
499  callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
500  adaptor.getOperands(), rewriter, useBarePtrCallConv);
501  auto newOp = rewriter.create<LLVM::CallOp>(
502  callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
503  promoted, callOp->getAttrs());
504 
505  SmallVector<Value, 4> results;
506  if (numResults < 2) {
507  // If < 2 results, packing did not do anything and we can just return.
508  results.append(newOp.result_begin(), newOp.result_end());
509  } else {
510  // Otherwise, it had been converted to an operation producing a structure.
511  // Extract individual results from the structure and return them as list.
512  results.reserve(numResults);
513  for (unsigned i = 0; i < numResults; ++i) {
514  results.push_back(rewriter.create<LLVM::ExtractValueOp>(
515  callOp.getLoc(), newOp->getResult(0), i));
516  }
517  }
518 
519  if (useBarePtrCallConv) {
520  // For the bare-ptr calling convention, promote memref results to
521  // descriptors.
522  assert(results.size() == resultTypes.size() &&
523  "The number of arguments and types doesn't match");
524  this->getTypeConverter()->promoteBarePtrsToDescriptors(
525  rewriter, callOp.getLoc(), resultTypes, results);
526  } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
527  resultTypes, results,
528  /*toDynamic=*/false))) {
529  return failure();
530  }
531 
532  rewriter.replaceOp(callOp, results);
533  return success();
534  }
535 };
536 
537 class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
538 public:
539  CallOpLowering(const LLVMTypeConverter &typeConverter,
540  // Can be nullptr.
541  const SymbolTable *symbolTable, PatternBenefit benefit = 1)
542  : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
543  symbolTable(symbolTable) {}
544 
545  LogicalResult
546  matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
547  ConversionPatternRewriter &rewriter) const override {
548  bool useBarePtrCallConv = false;
549  if (getTypeConverter()->getOptions().useBarePtrCallConv) {
550  useBarePtrCallConv = true;
551  } else if (symbolTable != nullptr) {
552  // Fast lookup.
553  Operation *callee =
554  symbolTable->lookup(callOp.getCalleeAttr().getValue());
555  useBarePtrCallConv =
556  callee != nullptr && callee->hasAttr(barePtrAttrName);
557  } else {
558  // Warning: This is a linear lookup.
559  Operation *callee =
560  SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
561  useBarePtrCallConv =
562  callee != nullptr && callee->hasAttr(barePtrAttrName);
563  }
564  return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
565  }
566 
567 private:
568  const SymbolTable *symbolTable = nullptr;
569 };
570 
571 struct CallIndirectOpLowering
572  : public CallOpInterfaceLowering<func::CallIndirectOp> {
573  using Super::Super;
574 
575  LogicalResult
576  matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
577  ConversionPatternRewriter &rewriter) const override {
578  return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
579  }
580 };
581 
582 struct UnrealizedConversionCastOpLowering
583  : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
585  UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
586 
587  LogicalResult
588  matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
589  ConversionPatternRewriter &rewriter) const override {
590  SmallVector<Type> convertedTypes;
591  if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
592  convertedTypes)) &&
593  convertedTypes == adaptor.getInputs().getTypes()) {
594  rewriter.replaceOp(op, adaptor.getInputs());
595  return success();
596  }
597 
598  convertedTypes.clear();
599  if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
600  convertedTypes)) &&
601  convertedTypes == op.getOutputs().getType()) {
602  rewriter.replaceOp(op, adaptor.getInputs());
603  return success();
604  }
605  return failure();
606  }
607 };
608 
609 // Special lowering pattern for `ReturnOps`. Unlike all other operations,
610 // `ReturnOp` interacts with the function signature and must have as many
611 // operands as the function has return values. Because in LLVM IR, functions
612 // can only return 0 or 1 value, we pack multiple values into a structure type.
613 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
614 // necessary before returning it
615 struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
617 
618  LogicalResult
619  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
620  ConversionPatternRewriter &rewriter) const override {
621  Location loc = op.getLoc();
622  unsigned numArguments = op.getNumOperands();
623  SmallVector<Value, 4> updatedOperands;
624 
625  auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
626  bool useBarePtrCallConv =
627  shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
628  if (useBarePtrCallConv) {
629  // For the bare-ptr calling convention, extract the aligned pointer to
630  // be returned from the memref descriptor.
631  for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
632  Type oldTy = std::get<0>(it).getType();
633  Value newOperand = std::get<1>(it);
634  if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
635  cast<BaseMemRefType>(oldTy))) {
636  MemRefDescriptor memrefDesc(newOperand);
637  newOperand = memrefDesc.allocatedPtr(rewriter, loc);
638  } else if (isa<UnrankedMemRefType>(oldTy)) {
639  // Unranked memref is not supported in the bare pointer calling
640  // convention.
641  return failure();
642  }
643  updatedOperands.push_back(newOperand);
644  }
645  } else {
646  updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
647  (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
648  updatedOperands,
649  /*toDynamic=*/true);
650  }
651 
652  // If ReturnOp has 0 or 1 operand, create it and return immediately.
653  if (numArguments <= 1) {
654  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
655  op, TypeRange(), updatedOperands, op->getAttrs());
656  return success();
657  }
658 
659  // Otherwise, we need to pack the arguments into an LLVM struct type before
660  // returning.
661  auto packedType = getTypeConverter()->packFunctionResults(
662  op.getOperandTypes(), useBarePtrCallConv);
663  if (!packedType) {
664  return rewriter.notifyMatchFailure(op, "could not convert result types");
665  }
666 
667  Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
668  for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
669  packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
670  }
671  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
672  op->getAttrs());
673  return success();
674  }
675 };
676 } // namespace
677 
679  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
680  patterns.add<FuncOpConversion>(converter);
681 }
682 
684  LLVMTypeConverter &converter, RewritePatternSet &patterns,
685  const SymbolTable *symbolTable) {
686  populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
687  patterns.add<CallIndirectOpLowering>(converter);
688  patterns.add<CallOpLowering>(converter, symbolTable);
689  patterns.add<ConstantOpLowering>(converter);
690  patterns.add<ReturnOpLowering>(converter);
691 }
692 
693 namespace {
694 /// A pass converting Func operations into the LLVM IR dialect.
695 struct ConvertFuncToLLVMPass
696  : public impl::ConvertFuncToLLVMPassBase<ConvertFuncToLLVMPass> {
697  using Base::Base;
698 
699  /// Run the dialect converter on the module.
700  void runOnOperation() override {
701  ModuleOp m = getOperation();
702  StringRef dataLayout;
703  auto dataLayoutAttr = dyn_cast_or_null<StringAttr>(
704  m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));
705  if (dataLayoutAttr)
706  dataLayout = dataLayoutAttr.getValue();
707 
708  if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
709  dataLayout, [this](const Twine &message) {
710  getOperation().emitError() << message.str();
711  }))) {
712  signalPassFailure();
713  return;
714  }
715 
716  const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
717 
719  dataLayoutAnalysis.getAtOrAbove(m));
720  options.useBarePtrCallConv = useBarePtrCallConv;
721  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
722  options.overrideIndexBitwidth(indexBitwidth);
723  options.dataLayout = llvm::DataLayout(dataLayout);
724 
725  LLVMTypeConverter typeConverter(&getContext(), options,
726  &dataLayoutAnalysis);
727 
728  std::optional<SymbolTable> optSymbolTable = std::nullopt;
729  const SymbolTable *symbolTable = nullptr;
730  if (!options.useBarePtrCallConv) {
731  optSymbolTable.emplace(m);
732  symbolTable = &optSymbolTable.value();
733  }
734 
735  RewritePatternSet patterns(&getContext());
736  populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable);
737 
738  // TODO(https://github.com/llvm/llvm-project/issues/70982): Remove these in
739  // favor of their dedicated conversion passes.
740  arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
741  cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
742 
744  if (failed(applyPartialConversion(m, target, std::move(patterns))))
745  signalPassFailure();
746  }
747 };
748 
749 struct SetLLVMModuleDataLayoutPass
750  : public impl::SetLLVMModuleDataLayoutPassBase<
751  SetLLVMModuleDataLayoutPass> {
752  using Base::Base;
753 
754  /// Run the dialect converter on the module.
755  void runOnOperation() override {
756  if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
757  this->dataLayout, [this](const Twine &message) {
758  getOperation().emitError() << message.str();
759  }))) {
760  signalPassFailure();
761  return;
762  }
763  ModuleOp m = getOperation();
764  m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
765  StringAttr::get(m.getContext(), this->dataLayout));
766  }
767 };
768 } // namespace
769 
770 //===----------------------------------------------------------------------===//
771 // ConvertToLLVMPatternInterface implementation
772 //===----------------------------------------------------------------------===//
773 
774 namespace {
775 /// Implement the interface to convert Func to LLVM.
776 struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
778  /// Hook for derived dialect interface to provide conversion patterns
779  /// and mark dialect legal for the conversion target.
780  void populateConvertToLLVMConversionPatterns(
781  ConversionTarget &target, LLVMTypeConverter &typeConverter,
782  RewritePatternSet &patterns) const final {
783  populateFuncToLLVMConversionPatterns(typeConverter, patterns);
784  }
785 };
786 } // namespace
787 
789  registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
790  dialect->addInterfaces<FuncToLLVMDialectInterface>();
791  });
792 }
static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType, FunctionOpInterface funcOp, LLVM::LLVMFuncOp wrapperFuncOp)
Propagate argument/results attributes.
Definition: FuncToLLVM.cpp:87
static constexpr StringRef barePtrAttrName
Definition: FuncToLLVM.cpp:64
static constexpr StringRef varargsAttrName
Definition: FuncToLLVM.cpp:62
static constexpr StringRef linkageAttrName
Definition: FuncToLLVM.cpp:63
static void filterFuncAttributes(FunctionOpInterface func, SmallVectorImpl< NamedAttribute > &result)
Only retain those attributes that are not constructed by LLVMFuncOp::build.
Definition: FuncToLLVM.cpp:75
static bool shouldUseBarePtrCallConv(Operation *op, const LLVMTypeConverter *typeConverter)
Return true if the op should use bare pointer calling convention.
Definition: FuncToLLVM.cpp:67
static void wrapExternalFunction(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, FunctionOpInterface funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
Definition: FuncToLLVM.cpp:178
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, const LLVMTypeConverter &typeConverter, FunctionOpInterface funcOp, LLVM::LLVMFuncOp newFuncOp)
Creates an auxiliary function with pointer-to-memref-descriptor-struct arguments instead of unpacked ...
Definition: FuncToLLVM.cpp:118
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:242
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:277
IndexType getIndexType()
Definition: Builders.cpp:75
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:124
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:114
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
const LowerToLLVMOptions & getOptions() const
Definition: TypeConverter.h:94
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
static unsigned getNumUnpackedValues(MemRefType type)
Returns the number of non-aggregate values that would be produced by unpack.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:434
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
operand_type_range getOperandTypes()
Definition: Operation.h:392
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides all of the information necessary to convert a type signature.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
static unsigned getNumUnpackedValues()
Returns the number of non-aggregate values that would be produced by unpack.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:856
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the patterns to convert from the ControlFlow dialect to LLVM.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, const SymbolTable *symbolTable=nullptr)
Collect the patterns to convert from the Func dialect to LLVM.
Definition: FuncToLLVM.cpp:683
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
void registerConvertFuncToLLVMInterface(DialectRegistry &registry)
Definition: FuncToLLVM.cpp:788
FailureOr< LLVM::LLVMFuncOp > convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter)
Convert input FunctionOpInterface operation to LLVMFuncOp by using the provided LLVMTypeConverter.
Definition: FuncToLLVM.cpp:271
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateFuncToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect the default pattern to convert a FuncOp to the LLVM dialect.
Definition: FuncToLLVM.cpp:678
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.