MLIR  21.0.0git
LLVMToLLVMIRTranslation.cpp
Go to the documentation of this file.
1 //===- LLVMToLLVMIRTranslation.cpp - Translate LLVM dialect to LLVM IR ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR LLVM dialect and LLVM IR.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Support/LLVM.h"
18 
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/InlineAsm.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/MDBuilder.h"
24 #include "llvm/IR/MatrixBuilder.h"
25 #include "llvm/IR/Operator.h"
26 
27 using namespace mlir;
28 using namespace mlir::LLVM;
30 
31 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
32 
33 static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
34  using llvmFMF = llvm::FastMathFlags;
35  using FuncT = void (llvmFMF::*)(bool);
36  const std::pair<FastmathFlags, FuncT> handlers[] = {
37  // clang-format off
38  {FastmathFlags::nnan, &llvmFMF::setNoNaNs},
39  {FastmathFlags::ninf, &llvmFMF::setNoInfs},
40  {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros},
41  {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal},
42  {FastmathFlags::contract, &llvmFMF::setAllowContract},
43  {FastmathFlags::afn, &llvmFMF::setApproxFunc},
44  {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc},
45  // clang-format on
46  };
47  llvm::FastMathFlags ret;
48  ::mlir::LLVM::FastmathFlags fmfMlir = op.getFastmathAttr().getValue();
49  for (auto it : handlers)
50  if (bitEnumContainsAll(fmfMlir, it.first))
51  (ret.*(it.second))(true);
52  return ret;
53 }
54 
55 /// Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
57  SmallVector<unsigned> position;
58  llvm::append_range(position, indices);
59  return position;
60 }
61 
62 /// Convert an LLVM type to a string for printing in diagnostics.
63 static std::string diagStr(const llvm::Type *type) {
64  std::string str;
65  llvm::raw_string_ostream os(str);
66  type->print(os);
67  return str;
68 }
69 
70 /// Get the declaration of an overloaded llvm intrinsic. First we get the
71 /// overloaded argument types and/or result type from the CallIntrinsicOp, and
72 /// then use those to get the correct declaration of the overloaded intrinsic.
73 static FailureOr<llvm::Function *>
75  llvm::Module *module,
76  LLVM::ModuleTranslation &moduleTranslation) {
78  for (Type type : op->getOperandTypes())
79  allArgTys.push_back(moduleTranslation.convertType(type));
80 
81  llvm::Type *resTy;
82  if (op.getNumResults() == 0)
83  resTy = llvm::Type::getVoidTy(module->getContext());
84  else
85  resTy = moduleTranslation.convertType(op.getResult(0).getType());
86 
87  // ATM we do not support variadic intrinsics.
88  llvm::FunctionType *ft = llvm::FunctionType::get(resTy, allArgTys, false);
89 
91  getIntrinsicInfoTableEntries(id, table);
93 
94  SmallVector<llvm::Type *, 8> overloadedArgTys;
95  if (llvm::Intrinsic::matchIntrinsicSignature(ft, tableRef,
96  overloadedArgTys) !=
97  llvm::Intrinsic::MatchIntrinsicTypesResult::MatchIntrinsicTypes_Match) {
98  return mlir::emitError(op.getLoc(), "call intrinsic signature ")
99  << diagStr(ft) << " to overloaded intrinsic " << op.getIntrinAttr()
100  << " does not match any of the overloads";
101  }
102 
103  ArrayRef<llvm::Type *> overloadedArgTysRef = overloadedArgTys;
104  return llvm::Intrinsic::getOrInsertDeclaration(module, id,
105  overloadedArgTysRef);
106 }
107 
108 static llvm::OperandBundleDef
109 convertOperandBundle(OperandRange bundleOperands, StringRef bundleTag,
110  LLVM::ModuleTranslation &moduleTranslation) {
111  std::vector<llvm::Value *> operands;
112  operands.reserve(bundleOperands.size());
113  for (Value bundleArg : bundleOperands)
114  operands.push_back(moduleTranslation.lookupValue(bundleArg));
115  return llvm::OperandBundleDef(bundleTag.str(), std::move(operands));
116 }
117 
119 convertOperandBundles(OperandRangeRange bundleOperands, ArrayAttr bundleTags,
120  LLVM::ModuleTranslation &moduleTranslation) {
122  bundles.reserve(bundleOperands.size());
123 
124  for (auto [operands, tagAttr] : llvm::zip_equal(bundleOperands, bundleTags)) {
125  StringRef tag = cast<StringAttr>(tagAttr).getValue();
126  bundles.push_back(convertOperandBundle(operands, tag, moduleTranslation));
127  }
128  return bundles;
129 }
130 
133  std::optional<ArrayAttr> bundleTags,
134  LLVM::ModuleTranslation &moduleTranslation) {
135  if (!bundleTags)
136  return {};
137  return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
138 }
139 
140 static LogicalResult
141 convertParameterAndResultAttrs(mlir::Location loc, ArrayAttr argAttrsArray,
142  ArrayAttr resAttrsArray, llvm::CallBase *call,
143  LLVM::ModuleTranslation &moduleTranslation) {
144  if (argAttrsArray) {
145  for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
146  if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr);
147  !argAttrs.empty()) {
148  FailureOr<llvm::AttrBuilder> attrBuilder =
149  moduleTranslation.convertParameterAttrs(loc, argAttrs);
150  if (failed(attrBuilder))
151  return failure();
152  call->addParamAttrs(argIdx, *attrBuilder);
153  }
154  }
155  }
156 
157  if (resAttrsArray && resAttrsArray.size() > 0) {
158  if (resAttrsArray.size() != 1)
159  return mlir::emitError(loc, "llvm.func cannot have multiple results");
160  if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
161  !resAttrs.empty()) {
162  FailureOr<llvm::AttrBuilder> attrBuilder =
163  moduleTranslation.convertParameterAttrs(loc, resAttrs);
164  if (failed(attrBuilder))
165  return failure();
166  call->addRetAttrs(*attrBuilder);
167  }
168  }
169  return success();
170 }
171 
172 static LogicalResult
173 convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call,
174  LLVM::ModuleTranslation &moduleTranslation) {
176  callOp.getLoc(), callOp.getArgAttrsAttr(), callOp.getResAttrsAttr(), call,
177  moduleTranslation);
178 }
179 
180 /// Builder for LLVM_CallIntrinsicOp
181 static LogicalResult
182 convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
183  LLVM::ModuleTranslation &moduleTranslation) {
184  llvm::Module *module = builder.GetInsertBlock()->getModule();
186  llvm::Intrinsic::lookupIntrinsicID(op.getIntrinAttr());
187  if (!id)
188  return mlir::emitError(op.getLoc(), "could not find LLVM intrinsic: ")
189  << op.getIntrinAttr();
190 
191  llvm::Function *fn = nullptr;
192  if (llvm::Intrinsic::isOverloaded(id)) {
193  auto fnOrFailure =
194  getOverloadedDeclaration(op, id, module, moduleTranslation);
195  if (failed(fnOrFailure))
196  return failure();
197  fn = *fnOrFailure;
198  } else {
199  fn = llvm::Intrinsic::getOrInsertDeclaration(module, id, {});
200  }
201 
202  // Check the result type of the call.
203  const llvm::Type *intrinType =
204  op.getNumResults() == 0
205  ? llvm::Type::getVoidTy(module->getContext())
206  : moduleTranslation.convertType(op.getResultTypes().front());
207  if (intrinType != fn->getReturnType()) {
208  return mlir::emitError(op.getLoc(), "intrinsic call returns ")
209  << diagStr(intrinType) << " but " << op.getIntrinAttr()
210  << " actually returns " << diagStr(fn->getReturnType());
211  }
212 
213  // Check the argument types of the call. If the function is variadic, check
214  // the subrange of required arguments.
215  if (!fn->getFunctionType()->isVarArg() &&
216  op.getArgs().size() != fn->arg_size()) {
217  return mlir::emitError(op.getLoc(), "intrinsic call has ")
218  << op.getArgs().size() << " operands but " << op.getIntrinAttr()
219  << " expects " << fn->arg_size();
220  }
221  if (fn->getFunctionType()->isVarArg() &&
222  op.getArgs().size() < fn->arg_size()) {
223  return mlir::emitError(op.getLoc(), "intrinsic call has ")
224  << op.getArgs().size() << " operands but variadic "
225  << op.getIntrinAttr() << " expects at least " << fn->arg_size();
226  }
227  // Check the arguments up to the number the function requires.
228  for (unsigned i = 0, e = fn->arg_size(); i != e; ++i) {
229  const llvm::Type *expected = fn->getArg(i)->getType();
230  const llvm::Type *actual =
231  moduleTranslation.convertType(op.getOperandTypes()[i]);
232  if (actual != expected) {
233  return mlir::emitError(op.getLoc(), "intrinsic call operand #")
234  << i << " has type " << diagStr(actual) << " but "
235  << op.getIntrinAttr() << " expects " << diagStr(expected);
236  }
237  }
238 
239  FastmathFlagsInterface itf = op;
240  builder.setFastMathFlags(getFastmathFlags(itf));
241 
242  auto *inst = builder.CreateCall(
243  fn, moduleTranslation.lookupValues(op.getArgs()),
244  convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(),
245  moduleTranslation));
246 
247  if (failed(convertParameterAndResultAttrs(op.getLoc(), op.getArgAttrsAttr(),
248  op.getResAttrsAttr(), inst,
249  moduleTranslation)))
250  return failure();
251 
252  if (op.getNumResults() == 1)
253  moduleTranslation.mapValue(op->getResults().front()) = inst;
254  return success();
255 }
256 
257 static void convertLinkerOptionsOp(ArrayAttr options,
258  llvm::IRBuilderBase &builder,
259  LLVM::ModuleTranslation &moduleTranslation) {
260  llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
261  llvm::LLVMContext &context = llvmModule->getContext();
262  llvm::NamedMDNode *linkerMDNode =
263  llvmModule->getOrInsertNamedMetadata("llvm.linker.options");
265  MDNodes.reserve(options.size());
266  for (auto s : options.getAsRange<StringAttr>()) {
267  auto *MDNode = llvm::MDString::get(context, s.getValue());
268  MDNodes.push_back(MDNode);
269  }
270 
271  auto *listMDNode = llvm::MDTuple::get(context, MDNodes);
272  linkerMDNode->addOperand(listMDNode);
273 }
274 
275 static llvm::Metadata *
276 convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr,
277  llvm::IRBuilderBase &builder,
278  LLVM::ModuleTranslation &moduleTranslation) {
279  llvm::LLVMContext &context = builder.getContext();
280  llvm::MDBuilder mdb(context);
282 
283  if (key == LLVMDialect::getModuleFlagKeyCGProfileName()) {
284  for (auto entry : arrayAttr.getAsRange<ModuleFlagCGProfileEntryAttr>()) {
285  llvm::Metadata *fromMetadata =
286  entry.getFrom()
287  ? llvm::ValueAsMetadata::get(moduleTranslation.lookupFunction(
288  entry.getFrom().getValue()))
289  : nullptr;
290  llvm::Metadata *toMetadata =
291  entry.getTo()
293  moduleTranslation.lookupFunction(entry.getTo().getValue()))
294  : nullptr;
295 
296  llvm::Metadata *vals[] = {
297  fromMetadata, toMetadata,
298  mdb.createConstant(llvm::ConstantInt::get(
299  llvm::Type::getInt64Ty(context), entry.getCount()))};
300  nodes.push_back(llvm::MDNode::get(context, vals));
301  }
302  return llvm::MDTuple::getDistinct(context, nodes);
303  }
304  return nullptr;
305 }
306 
307 static llvm::Metadata *convertModuleFlagProfileSummaryAttr(
308  StringRef key, ModuleFlagProfileSummaryAttr summaryAttr,
309  llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) {
310  llvm::LLVMContext &context = builder.getContext();
311  llvm::MDBuilder mdb(context);
312 
313  auto getIntTuple = [&](StringRef key, uint64_t val) -> llvm::MDTuple * {
315  mdb.createString(key), mdb.createConstant(llvm::ConstantInt::get(
316  llvm::Type::getInt64Ty(context), val))};
317  return llvm::MDTuple::get(context, tupleNodes);
318  };
319 
321  mdb.createString("ProfileFormat"),
322  mdb.createString(
323  stringifyProfileSummaryFormatKind(summaryAttr.getFormat()))};
324 
326  llvm::MDTuple::get(context, fmtNode),
327  getIntTuple("TotalCount", summaryAttr.getTotalCount()),
328  getIntTuple("MaxCount", summaryAttr.getMaxCount()),
329  getIntTuple("MaxInternalCount", summaryAttr.getMaxInternalCount()),
330  getIntTuple("MaxFunctionCount", summaryAttr.getMaxFunctionCount()),
331  getIntTuple("NumCounts", summaryAttr.getNumCounts()),
332  getIntTuple("NumFunctions", summaryAttr.getNumFunctions()),
333  };
334 
335  if (summaryAttr.getIsPartialProfile())
336  vals.push_back(
337  getIntTuple("IsPartialProfile", *summaryAttr.getIsPartialProfile()));
338 
339  if (summaryAttr.getPartialProfileRatio()) {
341  mdb.createString("PartialProfileRatio"),
342  mdb.createConstant(llvm::ConstantFP::get(
343  llvm::Type::getDoubleTy(context),
344  summaryAttr.getPartialProfileRatio().getValue()))};
345  vals.push_back(llvm::MDTuple::get(context, tupleNodes));
346  }
347 
348  SmallVector<llvm::Metadata *> detailedEntries;
349  llvm::Type *llvmInt64Type = llvm::Type::getInt64Ty(context);
350  for (ModuleFlagProfileSummaryDetailedAttr detailedEntry :
351  summaryAttr.getDetailedSummary()) {
353  mdb.createConstant(
354  llvm::ConstantInt::get(llvmInt64Type, detailedEntry.getCutOff())),
355  mdb.createConstant(
356  llvm::ConstantInt::get(llvmInt64Type, detailedEntry.getMinCount())),
357  mdb.createConstant(llvm::ConstantInt::get(
358  llvmInt64Type, detailedEntry.getNumCounts()))};
359  detailedEntries.push_back(llvm::MDTuple::get(context, tupleNodes));
360  }
361  SmallVector<llvm::Metadata *> detailedSummary{
362  mdb.createString("DetailedSummary"),
363  llvm::MDTuple::get(context, detailedEntries)};
364  vals.push_back(llvm::MDTuple::get(context, detailedSummary));
365 
366  return llvm::MDNode::get(context, vals);
367 }
368 
369 static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
370  LLVM::ModuleTranslation &moduleTranslation) {
371  llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
372  for (auto flagAttr : flags.getAsRange<ModuleFlagAttr>()) {
373  llvm::Metadata *valueMetadata =
375  .Case<StringAttr>([&](auto strAttr) {
376  return llvm::MDString::get(builder.getContext(),
377  strAttr.getValue());
378  })
379  .Case<IntegerAttr>([&](auto intAttr) {
381  llvm::Type::getInt32Ty(builder.getContext()),
382  intAttr.getInt()));
383  })
384  .Case<ArrayAttr>([&](auto arrayAttr) {
385  return convertModuleFlagValue(flagAttr.getKey().getValue(),
386  arrayAttr, builder,
387  moduleTranslation);
388  })
389  .Case([&](ModuleFlagProfileSummaryAttr summaryAttr) {
391  flagAttr.getKey().getValue(), summaryAttr, builder,
392  moduleTranslation);
393  })
394  .Default([](auto) { return nullptr; });
395 
396  assert(valueMetadata && "expected valid metadata");
397  llvmModule->addModuleFlag(
398  convertModFlagBehaviorToLLVM(flagAttr.getBehavior()),
399  flagAttr.getKey().getValue(), valueMetadata);
400  }
401 }
402 
403 static LogicalResult
404 convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
405  LLVM::ModuleTranslation &moduleTranslation) {
406 
407  llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder);
408  if (auto fmf = dyn_cast<FastmathFlagsInterface>(opInst))
409  builder.setFastMathFlags(getFastmathFlags(fmf));
410 
411 #include "mlir/Dialect/LLVMIR/LLVMConversions.inc"
412 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicConversions.inc"
413 
414  // Emit function calls. If the "callee" attribute is present, this is a
415  // direct function call and we also need to look up the remapped function
416  // itself. Otherwise, this is an indirect call and the callee is the first
417  // operand, look it up as a normal value.
418  if (auto callOp = dyn_cast<LLVM::CallOp>(opInst)) {
419  auto operands = moduleTranslation.lookupValues(callOp.getCalleeOperands());
421  convertOperandBundles(callOp.getOpBundleOperands(),
422  callOp.getOpBundleTags(), moduleTranslation);
423  ArrayRef<llvm::Value *> operandsRef(operands);
424  llvm::CallInst *call;
425  if (auto attr = callOp.getCalleeAttr()) {
426  call =
427  builder.CreateCall(moduleTranslation.lookupFunction(attr.getValue()),
428  operandsRef, opBundles);
429  } else {
430  llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
431  moduleTranslation.convertType(callOp.getCalleeFunctionType()));
432  call = builder.CreateCall(calleeType, operandsRef.front(),
433  operandsRef.drop_front(), opBundles);
434  }
435  call->setCallingConv(convertCConvToLLVM(callOp.getCConv()));
436  call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind()));
437  if (callOp.getConvergentAttr())
438  call->addFnAttr(llvm::Attribute::Convergent);
439  if (callOp.getNoUnwindAttr())
440  call->addFnAttr(llvm::Attribute::NoUnwind);
441  if (callOp.getWillReturnAttr())
442  call->addFnAttr(llvm::Attribute::WillReturn);
443  if (callOp.getNoInlineAttr())
444  call->addFnAttr(llvm::Attribute::NoInline);
445  if (callOp.getAlwaysInlineAttr())
446  call->addFnAttr(llvm::Attribute::AlwaysInline);
447  if (callOp.getInlineHintAttr())
448  call->addFnAttr(llvm::Attribute::InlineHint);
449 
450  if (failed(convertParameterAndResultAttrs(callOp, call, moduleTranslation)))
451  return failure();
452 
453  if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
454  llvm::MemoryEffects memEffects =
455  llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
456  convertModRefInfoToLLVM(memAttr.getArgMem())) |
457  llvm::MemoryEffects(
458  llvm::MemoryEffects::Location::InaccessibleMem,
459  convertModRefInfoToLLVM(memAttr.getInaccessibleMem())) |
460  llvm::MemoryEffects(llvm::MemoryEffects::Location::Other,
461  convertModRefInfoToLLVM(memAttr.getOther()));
462  call->setMemoryEffects(memEffects);
463  }
464 
465  moduleTranslation.setAccessGroupsMetadata(callOp, call);
466  moduleTranslation.setAliasScopeMetadata(callOp, call);
467  moduleTranslation.setTBAAMetadata(callOp, call);
468  // If the called function has a result, remap the corresponding value. Note
469  // that LLVM IR dialect CallOp has either 0 or 1 result.
470  if (opInst.getNumResults() != 0)
471  moduleTranslation.mapValue(opInst.getResult(0), call);
472  // Check that LLVM call returns void for 0-result functions.
473  else if (!call->getType()->isVoidTy())
474  return failure();
475  moduleTranslation.mapCall(callOp, call);
476  return success();
477  }
478 
479  if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
480  // TODO: refactor function type creation which usually occurs in std-LLVM
481  // conversion.
482  SmallVector<Type, 8> operandTypes;
483  llvm::append_range(operandTypes, inlineAsmOp.getOperands().getTypes());
484 
485  Type resultType;
486  if (inlineAsmOp.getNumResults() == 0) {
487  resultType = LLVM::LLVMVoidType::get(&moduleTranslation.getContext());
488  } else {
489  assert(inlineAsmOp.getNumResults() == 1);
490  resultType = inlineAsmOp.getResultTypes()[0];
491  }
492  auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes);
493  llvm::InlineAsm *inlineAsmInst =
494  inlineAsmOp.getAsmDialect()
496  static_cast<llvm::FunctionType *>(
497  moduleTranslation.convertType(ft)),
498  inlineAsmOp.getAsmString(), inlineAsmOp.getConstraints(),
499  inlineAsmOp.getHasSideEffects(),
500  inlineAsmOp.getIsAlignStack(),
501  convertAsmDialectToLLVM(*inlineAsmOp.getAsmDialect()))
502  : llvm::InlineAsm::get(static_cast<llvm::FunctionType *>(
503  moduleTranslation.convertType(ft)),
504  inlineAsmOp.getAsmString(),
505  inlineAsmOp.getConstraints(),
506  inlineAsmOp.getHasSideEffects(),
507  inlineAsmOp.getIsAlignStack());
508  llvm::CallInst *inst = builder.CreateCall(
509  inlineAsmInst,
510  moduleTranslation.lookupValues(inlineAsmOp.getOperands()));
511  inst->setTailCallKind(convertTailCallKindToLLVM(
512  inlineAsmOp.getTailCallKindAttr().getTailCallKind()));
513  if (auto maybeOperandAttrs = inlineAsmOp.getOperandAttrs()) {
514  llvm::AttributeList attrList;
515  for (const auto &it : llvm::enumerate(*maybeOperandAttrs)) {
516  Attribute attr = it.value();
517  if (!attr)
518  continue;
519  DictionaryAttr dAttr = cast<DictionaryAttr>(attr);
520  if (dAttr.empty())
521  continue;
522  TypeAttr tAttr =
523  cast<TypeAttr>(dAttr.get(InlineAsmOp::getElementTypeAttrName()));
524  llvm::AttrBuilder b(moduleTranslation.getLLVMContext());
525  llvm::Type *ty = moduleTranslation.convertType(tAttr.getValue());
526  b.addTypeAttr(llvm::Attribute::ElementType, ty);
527  // shift to account for the returned value (this is always 1 aggregate
528  // value in LLVM).
529  int shift = (opInst.getNumResults() > 0) ? 1 : 0;
530  attrList = attrList.addAttributesAtIndex(
531  moduleTranslation.getLLVMContext(), it.index() + shift, b);
532  }
533  inst->setAttributes(attrList);
534  }
535 
536  if (opInst.getNumResults() != 0)
537  moduleTranslation.mapValue(opInst.getResult(0), inst);
538  return success();
539  }
540 
541  if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
542  auto operands = moduleTranslation.lookupValues(invOp.getCalleeOperands());
544  convertOperandBundles(invOp.getOpBundleOperands(),
545  invOp.getOpBundleTags(), moduleTranslation);
546  ArrayRef<llvm::Value *> operandsRef(operands);
547  llvm::InvokeInst *result;
548  if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
549  result = builder.CreateInvoke(
550  moduleTranslation.lookupFunction(attr.getValue()),
551  moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
552  moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef,
553  opBundles);
554  } else {
555  llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
556  moduleTranslation.convertType(invOp.getCalleeFunctionType()));
557  result = builder.CreateInvoke(
558  calleeType, operandsRef.front(),
559  moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
560  moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
561  operandsRef.drop_front(), opBundles);
562  }
563  result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
564  if (failed(
565  convertParameterAndResultAttrs(invOp, result, moduleTranslation)))
566  return failure();
567  moduleTranslation.mapBranch(invOp, result);
568  // InvokeOp can only have 0 or 1 result
569  if (invOp->getNumResults() != 0) {
570  moduleTranslation.mapValue(opInst.getResult(0), result);
571  return success();
572  }
573  return success(result->getType()->isVoidTy());
574  }
575 
576  if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
577  llvm::Type *ty = moduleTranslation.convertType(lpOp.getType());
578  llvm::LandingPadInst *lpi =
579  builder.CreateLandingPad(ty, lpOp.getNumOperands());
580  lpi->setCleanup(lpOp.getCleanup());
581 
582  // Add clauses
583  for (llvm::Value *operand :
584  moduleTranslation.lookupValues(lpOp.getOperands())) {
585  // All operands should be constant - checked by verifier
586  if (auto *constOperand = dyn_cast<llvm::Constant>(operand))
587  lpi->addClause(constOperand);
588  }
589  moduleTranslation.mapValue(lpOp.getResult(), lpi);
590  return success();
591  }
592 
593  // Emit branches. We need to look up the remapped blocks and ignore the
594  // block arguments that were transformed into PHI nodes.
595  if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
596  llvm::BranchInst *branch =
597  builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor()));
598  moduleTranslation.mapBranch(&opInst, branch);
599  moduleTranslation.setLoopMetadata(&opInst, branch);
600  return success();
601  }
602  if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
603  llvm::BranchInst *branch = builder.CreateCondBr(
604  moduleTranslation.lookupValue(condbrOp.getOperand(0)),
605  moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)),
606  moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)));
607  moduleTranslation.mapBranch(&opInst, branch);
608  moduleTranslation.setLoopMetadata(&opInst, branch);
609  return success();
610  }
611  if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
612  llvm::SwitchInst *switchInst = builder.CreateSwitch(
613  moduleTranslation.lookupValue(switchOp.getValue()),
614  moduleTranslation.lookupBlock(switchOp.getDefaultDestination()),
615  switchOp.getCaseDestinations().size());
616 
617  // Handle switch with zero cases.
618  if (!switchOp.getCaseValues())
619  return success();
620 
621  auto *ty = llvm::cast<llvm::IntegerType>(
622  moduleTranslation.convertType(switchOp.getValue().getType()));
623  for (auto i :
624  llvm::zip(llvm::cast<DenseIntElementsAttr>(*switchOp.getCaseValues()),
625  switchOp.getCaseDestinations()))
626  switchInst->addCase(
627  llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
628  moduleTranslation.lookupBlock(std::get<1>(i)));
629 
630  moduleTranslation.mapBranch(&opInst, switchInst);
631  return success();
632  }
633  if (auto indBrOp = dyn_cast<LLVM::IndirectBrOp>(opInst)) {
634  llvm::IndirectBrInst *indBr = builder.CreateIndirectBr(
635  moduleTranslation.lookupValue(indBrOp.getAddr()),
636  indBrOp->getNumSuccessors());
637  for (auto *succ : indBrOp.getSuccessors())
638  indBr->addDestination(moduleTranslation.lookupBlock(succ));
639  moduleTranslation.mapBranch(&opInst, indBr);
640  return success();
641  }
642 
643  // Emit addressof. We need to look up the global value referenced by the
644  // operation and store it in the MLIR-to-LLVM value mapping. This does not
645  // emit any LLVM instruction.
646  if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
647  LLVM::GlobalOp global =
648  addressOfOp.getGlobal(moduleTranslation.symbolTable());
649  LLVM::LLVMFuncOp function =
650  addressOfOp.getFunction(moduleTranslation.symbolTable());
651  LLVM::AliasOp alias = addressOfOp.getAlias(moduleTranslation.symbolTable());
652 
653  // The verifier should not have allowed this.
654  assert((global || function || alias) &&
655  "referencing an undefined global, function, or alias");
656 
657  llvm::Value *llvmValue = nullptr;
658  if (global)
659  llvmValue = moduleTranslation.lookupGlobal(global);
660  else if (alias)
661  llvmValue = moduleTranslation.lookupAlias(alias);
662  else
663  llvmValue = moduleTranslation.lookupFunction(function.getName());
664 
665  moduleTranslation.mapValue(addressOfOp.getResult(), llvmValue);
666  return success();
667  }
668 
669  // Emit dso_local_equivalent. We need to look up the global value referenced
670  // by the operation and store it in the MLIR-to-LLVM value mapping.
671  if (auto dsoLocalEquivalentOp =
672  dyn_cast<LLVM::DSOLocalEquivalentOp>(opInst)) {
673  LLVM::LLVMFuncOp function =
674  dsoLocalEquivalentOp.getFunction(moduleTranslation.symbolTable());
675  LLVM::AliasOp alias =
676  dsoLocalEquivalentOp.getAlias(moduleTranslation.symbolTable());
677 
678  // The verifier should not have allowed this.
679  assert((function || alias) &&
680  "referencing an undefined function, or alias");
681 
682  llvm::Value *llvmValue = nullptr;
683  if (alias)
684  llvmValue = moduleTranslation.lookupAlias(alias);
685  else
686  llvmValue = moduleTranslation.lookupFunction(function.getName());
687 
688  moduleTranslation.mapValue(
689  dsoLocalEquivalentOp.getResult(),
690  llvm::DSOLocalEquivalent::get(cast<llvm::GlobalValue>(llvmValue)));
691  return success();
692  }
693 
694  // Emit blockaddress. We first need to find the LLVM block referenced by this
695  // operation and then create a LLVM block address for it.
696  if (auto blockAddressOp = dyn_cast<LLVM::BlockAddressOp>(opInst)) {
697  BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr();
698  llvm::BasicBlock *llvmBlock =
699  moduleTranslation.lookupBlockAddress(blockAddressAttr);
700 
701  llvm::Value *llvmValue = nullptr;
702  StringRef fnName = blockAddressAttr.getFunction().getValue();
703  if (llvmBlock) {
704  llvm::Function *llvmFn = moduleTranslation.lookupFunction(fnName);
705  llvmValue = llvm::BlockAddress::get(llvmFn, llvmBlock);
706  } else {
707  // The matching LLVM block is not yet emitted, a placeholder is created
708  // in its place. When the LLVM block is emitted later in translation,
709  // the llvmValue is replaced with the actual llvm::BlockAddress.
710  // A GlobalVariable is chosen as placeholder because in general LLVM
711  // constants are uniqued and are not proper for RAUW, since that could
712  // harm unrelated uses of the constant.
713  llvmValue = new llvm::GlobalVariable(
714  *moduleTranslation.getLLVMModule(),
715  llvm::PointerType::getUnqual(moduleTranslation.getLLVMContext()),
716  /*isConstant=*/true, llvm::GlobalValue::LinkageTypes::ExternalLinkage,
717  /*Initializer=*/nullptr,
718  Twine("__mlir_block_address_")
719  .concat(Twine(fnName))
720  .concat(Twine((uint64_t)blockAddressOp.getOperation())));
721  moduleTranslation.mapUnresolvedBlockAddress(blockAddressOp, llvmValue);
722  }
723 
724  moduleTranslation.mapValue(blockAddressOp.getResult(), llvmValue);
725  return success();
726  }
727 
728  // Emit block label. If this label is seen before BlockAddressOp is
729  // translated, go ahead and already map it.
730  if (auto blockTagOp = dyn_cast<LLVM::BlockTagOp>(opInst)) {
731  auto funcOp = blockTagOp->getParentOfType<LLVMFuncOp>();
732  BlockAddressAttr blockAddressAttr = BlockAddressAttr::get(
733  &moduleTranslation.getContext(),
734  FlatSymbolRefAttr::get(&moduleTranslation.getContext(),
735  funcOp.getName()),
736  blockTagOp.getTag());
737  moduleTranslation.mapBlockAddress(blockAddressAttr,
738  builder.GetInsertBlock());
739  return success();
740  }
741 
742  return failure();
743 }
744 
745 namespace {
746 /// Implementation of the dialect interface that converts operations belonging
747 /// to the LLVM dialect to LLVM IR.
748 class LLVMDialectLLVMIRTranslationInterface
750 public:
752 
753  /// Translates the given operation to LLVM IR using the provided IR builder
754  /// and saving the state in `moduleTranslation`.
755  LogicalResult
756  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
757  LLVM::ModuleTranslation &moduleTranslation) const final {
758  return convertOperationImpl(*op, builder, moduleTranslation);
759  }
760 };
761 } // namespace
762 
764  registry.insert<LLVM::LLVMDialect>();
765  registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
766  dialect->addInterfaces<LLVMDialectLLVMIRTranslationInterface>();
767  });
768 }
769 
771  DialectRegistry registry;
773  context.appendDialectRegistry(registry);
774 }
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static std::string diagStr(const llvm::Type *type)
Convert an LLVM type to a string for printing in diagnostics.
static LogicalResult convertParameterAndResultAttrs(mlir::Location loc, ArrayAttr argAttrsArray, ArrayAttr resAttrsArray, llvm::CallBase *call, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Metadata * convertModuleFlagProfileSummaryAttr(StringRef key, ModuleFlagProfileSummaryAttr summaryAttr, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static FailureOr< llvm::Function * > getOverloadedDeclaration(CallIntrinsicOp op, llvm::Intrinsic::ID id, llvm::Module *module, LLVM::ModuleTranslation &moduleTranslation)
Get the declaration of an overloaded llvm intrinsic.
static LogicalResult convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static SmallVector< llvm::OperandBundleDef > convertOperandBundles(OperandRangeRange bundleOperands, ArrayAttr bundleTags, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Builder for LLVM_CallIntrinsicOp.
static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op)
static llvm::OperandBundleDef convertOperandBundle(OperandRange bundleOperands, StringRef bundleTag, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Metadata * convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void convertLinkerOptionsOp(ArrayAttr options, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::ManagedStatic< PassManagerOptions > options
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
const float * table
Attributes are known-constant values of operations.
Definition: Attributes.h:25
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
Base class for dialect interfaces providing translation to LLVM IR.
Implementation class for module translation.
void mapUnresolvedBlockAddress(BlockAddressOp op, llvm::Value *cst)
Maps a blockaddress operation to its corresponding placeholder LLVM value.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapCall(Operation *mlir, llvm::CallInst *llvm)
Stores a mapping between an MLIR call operation and a corresponding LLVM call instruction.
FailureOr< llvm::AttrBuilder > convertParameterAttrs(mlir::Location loc, DictionaryAttr paramAttrs)
Translates parameter attributes of a call and adds them to the returned AttrBuilder.
void mapBranch(Operation *mlir, llvm::Instruction *llvm)
Stores the mapping between an MLIR operation with successors and a corresponding LLVM IR instruction.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::GlobalValue * lookupAlias(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global alias va...
void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst)
Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
llvm::BasicBlock * lookupBlockAddress(BlockAddressAttr attr) const
Finds the LLVM basic block that corresponds to the given BlockAddressAttr.
void setAliasScopeMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst)
void setAccessGroupsMetadata(AccessGroupOpInterface op, llvm::Instruction *inst)
MLIRContext & getContext()
Returns the MLIR context of the module being translated.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
void mapBlockAddress(BlockAddressAttr attr, llvm::BasicBlock *block)
Maps a BlockAddressAttr to its corresponding LLVM basic block.
void setLoopMetadata(Operation *op, llvm::Instruction *inst)
Sets LLVM loop metadata for branch operations that have a loop annotation attribute.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
This class represents a contiguous range of operand ranges, e.g.
Definition: ValueRange.h:84
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
llvm::Constant * getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc, const ModuleTranslation &moduleTranslation)
Create an LLVM IR constant of llvmType from the MLIR attribute attr.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2362
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerLLVMDialectTranslation(DialectRegistry &registry)
Register the LLVM dialect and the translation from it to the LLVM IR in the given registry;.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...