MLIR  16.0.0git
LLVMToLLVMIRTranslation.cpp
Go to the documentation of this file.
1 //===- LLVMToLLVMIRTranslation.cpp - Translate LLVM dialect to LLVM IR ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR LLVM dialect and LLVM IR.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Support/LLVM.h"
18 
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InlineAsm.h"
21 #include "llvm/IR/MDBuilder.h"
22 #include "llvm/IR/MatrixBuilder.h"
23 #include "llvm/IR/Operator.h"
24 
25 using namespace mlir;
26 using namespace mlir::LLVM;
28 
29 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
30 
31 /// Convert MLIR integer comparison predicate to LLVM IR comparison predicate.
32 static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p) {
33  switch (p) {
34  case LLVM::ICmpPredicate::eq:
35  return llvm::CmpInst::Predicate::ICMP_EQ;
36  case LLVM::ICmpPredicate::ne:
37  return llvm::CmpInst::Predicate::ICMP_NE;
38  case LLVM::ICmpPredicate::slt:
39  return llvm::CmpInst::Predicate::ICMP_SLT;
40  case LLVM::ICmpPredicate::sle:
41  return llvm::CmpInst::Predicate::ICMP_SLE;
42  case LLVM::ICmpPredicate::sgt:
43  return llvm::CmpInst::Predicate::ICMP_SGT;
44  case LLVM::ICmpPredicate::sge:
45  return llvm::CmpInst::Predicate::ICMP_SGE;
46  case LLVM::ICmpPredicate::ult:
47  return llvm::CmpInst::Predicate::ICMP_ULT;
48  case LLVM::ICmpPredicate::ule:
49  return llvm::CmpInst::Predicate::ICMP_ULE;
50  case LLVM::ICmpPredicate::ugt:
51  return llvm::CmpInst::Predicate::ICMP_UGT;
52  case LLVM::ICmpPredicate::uge:
53  return llvm::CmpInst::Predicate::ICMP_UGE;
54  }
55  llvm_unreachable("incorrect comparison predicate");
56 }
57 
58 static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) {
59  switch (p) {
60  case LLVM::FCmpPredicate::_false:
61  return llvm::CmpInst::Predicate::FCMP_FALSE;
62  case LLVM::FCmpPredicate::oeq:
63  return llvm::CmpInst::Predicate::FCMP_OEQ;
64  case LLVM::FCmpPredicate::ogt:
65  return llvm::CmpInst::Predicate::FCMP_OGT;
66  case LLVM::FCmpPredicate::oge:
67  return llvm::CmpInst::Predicate::FCMP_OGE;
68  case LLVM::FCmpPredicate::olt:
69  return llvm::CmpInst::Predicate::FCMP_OLT;
70  case LLVM::FCmpPredicate::ole:
71  return llvm::CmpInst::Predicate::FCMP_OLE;
72  case LLVM::FCmpPredicate::one:
73  return llvm::CmpInst::Predicate::FCMP_ONE;
74  case LLVM::FCmpPredicate::ord:
75  return llvm::CmpInst::Predicate::FCMP_ORD;
76  case LLVM::FCmpPredicate::ueq:
77  return llvm::CmpInst::Predicate::FCMP_UEQ;
78  case LLVM::FCmpPredicate::ugt:
79  return llvm::CmpInst::Predicate::FCMP_UGT;
80  case LLVM::FCmpPredicate::uge:
81  return llvm::CmpInst::Predicate::FCMP_UGE;
82  case LLVM::FCmpPredicate::ult:
83  return llvm::CmpInst::Predicate::FCMP_ULT;
84  case LLVM::FCmpPredicate::ule:
85  return llvm::CmpInst::Predicate::FCMP_ULE;
86  case LLVM::FCmpPredicate::une:
87  return llvm::CmpInst::Predicate::FCMP_UNE;
88  case LLVM::FCmpPredicate::uno:
89  return llvm::CmpInst::Predicate::FCMP_UNO;
90  case LLVM::FCmpPredicate::_true:
91  return llvm::CmpInst::Predicate::FCMP_TRUE;
92  }
93  llvm_unreachable("incorrect comparison predicate");
94 }
95 
96 static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op) {
97  switch (op) {
98  case LLVM::AtomicBinOp::xchg:
99  return llvm::AtomicRMWInst::BinOp::Xchg;
100  case LLVM::AtomicBinOp::add:
101  return llvm::AtomicRMWInst::BinOp::Add;
102  case LLVM::AtomicBinOp::sub:
103  return llvm::AtomicRMWInst::BinOp::Sub;
104  case LLVM::AtomicBinOp::_and:
105  return llvm::AtomicRMWInst::BinOp::And;
106  case LLVM::AtomicBinOp::nand:
107  return llvm::AtomicRMWInst::BinOp::Nand;
108  case LLVM::AtomicBinOp::_or:
109  return llvm::AtomicRMWInst::BinOp::Or;
110  case LLVM::AtomicBinOp::_xor:
111  return llvm::AtomicRMWInst::BinOp::Xor;
113  return llvm::AtomicRMWInst::BinOp::Max;
115  return llvm::AtomicRMWInst::BinOp::Min;
116  case LLVM::AtomicBinOp::umax:
117  return llvm::AtomicRMWInst::BinOp::UMax;
118  case LLVM::AtomicBinOp::umin:
119  return llvm::AtomicRMWInst::BinOp::UMin;
120  case LLVM::AtomicBinOp::fadd:
121  return llvm::AtomicRMWInst::BinOp::FAdd;
122  case LLVM::AtomicBinOp::fsub:
123  return llvm::AtomicRMWInst::BinOp::FSub;
124  }
125  llvm_unreachable("incorrect atomic binary operator");
126 }
127 
128 static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering) {
129  switch (ordering) {
130  case LLVM::AtomicOrdering::not_atomic:
131  return llvm::AtomicOrdering::NotAtomic;
132  case LLVM::AtomicOrdering::unordered:
133  return llvm::AtomicOrdering::Unordered;
134  case LLVM::AtomicOrdering::monotonic:
135  return llvm::AtomicOrdering::Monotonic;
136  case LLVM::AtomicOrdering::acquire:
137  return llvm::AtomicOrdering::Acquire;
138  case LLVM::AtomicOrdering::release:
139  return llvm::AtomicOrdering::Release;
140  case LLVM::AtomicOrdering::acq_rel:
141  return llvm::AtomicOrdering::AcquireRelease;
142  case LLVM::AtomicOrdering::seq_cst:
143  return llvm::AtomicOrdering::SequentiallyConsistent;
144  }
145  llvm_unreachable("incorrect atomic ordering");
146 }
147 
148 static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
149  using llvmFMF = llvm::FastMathFlags;
150  using FuncT = void (llvmFMF::*)(bool);
151  const std::pair<FastmathFlags, FuncT> handlers[] = {
152  // clang-format off
153  {FastmathFlags::nnan, &llvmFMF::setNoNaNs},
154  {FastmathFlags::ninf, &llvmFMF::setNoInfs},
155  {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros},
156  {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal},
157  {FastmathFlags::contract, &llvmFMF::setAllowContract},
158  {FastmathFlags::afn, &llvmFMF::setApproxFunc},
159  {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc},
160  // clang-format on
161  };
162  llvm::FastMathFlags ret;
163  ::mlir::LLVM::FastmathFlags fmfMlir = op.getFastmathAttr().getValue();
164  for (auto it : handlers)
165  if (bitEnumContainsAll(fmfMlir, it.first))
166  (ret.*(it.second))(true);
167  return ret;
168 }
169 
170 /// Returns an LLVM metadata node corresponding to a loop option. This metadata
171 /// is attached to an llvm.loop node.
172 static llvm::MDNode *getLoopOptionMetadata(llvm::LLVMContext &ctx,
173  LoopOptionCase option,
174  int64_t value) {
175  StringRef name;
176  llvm::Constant *cstValue = nullptr;
177  switch (option) {
178  case LoopOptionCase::disable_licm:
179  name = "llvm.licm.disable";
180  cstValue = llvm::ConstantInt::getBool(ctx, value);
181  break;
182  case LoopOptionCase::disable_unroll:
183  name = "llvm.loop.unroll.disable";
184  cstValue = llvm::ConstantInt::getBool(ctx, value);
185  break;
186  case LoopOptionCase::interleave_count:
187  name = "llvm.loop.interleave.count";
188  cstValue = llvm::ConstantInt::get(
189  llvm::IntegerType::get(ctx, /*NumBits=*/32), value);
190  break;
191  case LoopOptionCase::disable_pipeline:
192  name = "llvm.loop.pipeline.disable";
193  cstValue = llvm::ConstantInt::getBool(ctx, value);
194  break;
195  case LoopOptionCase::pipeline_initiation_interval:
196  name = "llvm.loop.pipeline.initiationinterval";
197  cstValue = llvm::ConstantInt::get(
198  llvm::IntegerType::get(ctx, /*NumBits=*/32), value);
199  break;
200  }
201  return llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name),
202  llvm::ConstantAsMetadata::get(cstValue)});
203 }
204 
205 static void setLoopMetadata(Operation &opInst, llvm::Instruction &llvmInst,
206  llvm::IRBuilderBase &builder,
207  LLVM::ModuleTranslation &moduleTranslation) {
208  if (Attribute attr = opInst.getAttr(LLVMDialect::getLoopAttrName())) {
209  llvm::Module *module = builder.GetInsertBlock()->getModule();
210  llvm::MDNode *loopMD = moduleTranslation.lookupLoopOptionsMetadata(attr);
211  if (!loopMD) {
212  llvm::LLVMContext &ctx = module->getContext();
213 
214  SmallVector<llvm::Metadata *> loopOptions;
215  // Reserve operand 0 for loop id self reference.
216  auto dummy = llvm::MDNode::getTemporary(ctx, std::nullopt);
217  loopOptions.push_back(dummy.get());
218 
219  auto loopAttr = attr.cast<DictionaryAttr>();
220  auto parallelAccessGroup =
221  loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName());
222  if (parallelAccessGroup) {
223  SmallVector<llvm::Metadata *> parallelAccess;
224  parallelAccess.push_back(
225  llvm::MDString::get(ctx, "llvm.loop.parallel_accesses"));
226  for (SymbolRefAttr accessGroupRef : parallelAccessGroup->getValue()
227  .cast<ArrayAttr>()
228  .getAsRange<SymbolRefAttr>())
229  parallelAccess.push_back(
230  moduleTranslation.getAccessGroup(opInst, accessGroupRef));
231  loopOptions.push_back(llvm::MDNode::get(ctx, parallelAccess));
232  }
233 
234  if (auto loopOptionsAttr = loopAttr.getAs<LoopOptionsAttr>(
235  LLVMDialect::getLoopOptionsAttrName())) {
236  for (auto option : loopOptionsAttr.getOptions())
237  loopOptions.push_back(
238  getLoopOptionMetadata(ctx, option.first, option.second));
239  }
240 
241  // Create loop options and set the first operand to itself.
242  loopMD = llvm::MDNode::get(ctx, loopOptions);
243  loopMD->replaceOperandWith(0, loopMD);
244 
245  // Store a map from this Attribute to the LLVM metadata in case we
246  // encounter it again.
247  moduleTranslation.mapLoopOptionsMetadata(attr, loopMD);
248  }
249 
250  llvmInst.setMetadata(module->getMDKindID("llvm.loop"), loopMD);
251  }
252 }
253 
254 /// Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
256  SmallVector<unsigned> position;
257  llvm::append_range(position, indices);
258  return position;
259 }
260 
261 /// Get the declaration of an overloaded llvm intrinsic. First we get the
262 /// overloaded argument types and/or result type from the CallIntrinsicOp, and
263 /// then use those to get the correct declaration of the overloaded intrinsic.
266  llvm::Module *module,
267  LLVM::ModuleTranslation &moduleTranslation) {
269  for (Type type : op->getOperandTypes())
270  allArgTys.push_back(moduleTranslation.convertType(type));
271 
272  llvm::Type *resTy;
273  if (op.getNumResults() == 0)
274  resTy = llvm::Type::getVoidTy(module->getContext());
275  else
276  resTy = moduleTranslation.convertType(op.getResult(0).getType());
277 
278  // ATM we do not support variadic intrinsics.
279  llvm::FunctionType *ft = llvm::FunctionType::get(resTy, allArgTys, false);
280 
282  getIntrinsicInfoTableEntries(id, table);
284 
285  SmallVector<llvm::Type *, 8> overloadedArgTys;
286  if (llvm::Intrinsic::matchIntrinsicSignature(ft, tableRef,
287  overloadedArgTys) !=
288  llvm::Intrinsic::MatchIntrinsicTypesResult::MatchIntrinsicTypes_Match) {
289  return op.emitOpError("intrinsic type is not a match");
290  }
291 
292  ArrayRef<llvm::Type *> overloadedArgTysRef = overloadedArgTys;
293  return llvm::Intrinsic::getDeclaration(module, id, overloadedArgTysRef);
294 }
295 
296 /// Builder for LLVM_CallIntrinsicOp
297 static LogicalResult
298 convertCallLLVMIntrinsicOp(CallIntrinsicOp &op, llvm::IRBuilderBase &builder,
299  LLVM::ModuleTranslation &moduleTranslation) {
300  llvm::Module *module = builder.GetInsertBlock()->getModule();
302  llvm::Function::lookupIntrinsicID(op.getIntrinAttr());
303  if (!id)
304  return op.emitOpError()
305  << "couldn't find intrinsic: " << op.getIntrinAttr();
306 
307  llvm::Function *fn = nullptr;
308  if (llvm::Intrinsic::isOverloaded(id)) {
309  auto fnOrFailure =
310  getOverloadedDeclaration(op, id, module, moduleTranslation);
311  if (failed(fnOrFailure))
312  return failure();
313  fn = fnOrFailure.value();
314  } else {
315  fn = llvm::Intrinsic::getDeclaration(module, id, {});
316  }
317 
318  auto *inst =
319  builder.CreateCall(fn, moduleTranslation.lookupValues(op.getOperands()));
320  if (op.getNumResults() == 1)
321  moduleTranslation.mapValue(op->getResults().front()) = inst;
322  return success();
323 }
324 
325 static LogicalResult
326 convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
327  LLVM::ModuleTranslation &moduleTranslation) {
328 
329  llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder);
330  if (auto fmf = dyn_cast<FastmathFlagsInterface>(opInst))
331  builder.setFastMathFlags(getFastmathFlags(fmf));
332 
333 #include "mlir/Dialect/LLVMIR/LLVMConversions.inc"
334 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicConversions.inc"
335 
336  // Emit function calls. If the "callee" attribute is present, this is a
337  // direct function call and we also need to look up the remapped function
338  // itself. Otherwise, this is an indirect call and the callee is the first
339  // operand, look it up as a normal value. Return the llvm::Value
340  // representing the function result, which may be of llvm::VoidTy type.
341  auto convertCall = [&](Operation &op) -> llvm::Value * {
342  auto operands = moduleTranslation.lookupValues(op.getOperands());
343  ArrayRef<llvm::Value *> operandsRef(operands);
344  if (auto attr = op.getAttrOfType<FlatSymbolRefAttr>("callee"))
345  return builder.CreateCall(
346  moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
347  auto calleeType =
348  op.getOperands().front().getType().cast<LLVMPointerType>();
349  auto *calleeFunctionType = cast<llvm::FunctionType>(
350  moduleTranslation.convertType(calleeType.getElementType()));
351  return builder.CreateCall(calleeFunctionType, operandsRef.front(),
352  operandsRef.drop_front());
353  };
354 
355  // Emit calls. If the called function has a result, remap the corresponding
356  // value. Note that LLVM IR dialect CallOp has either 0 or 1 result.
357  if (isa<LLVM::CallOp>(opInst)) {
358  llvm::Value *result = convertCall(opInst);
359  if (opInst.getNumResults() != 0) {
360  moduleTranslation.mapValue(opInst.getResult(0), result);
361  return success();
362  }
363  // Check that LLVM call returns void for 0-result functions.
364  return success(result->getType()->isVoidTy());
365  }
366 
367  if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
368  // TODO: refactor function type creation which usually occurs in std-LLVM
369  // conversion.
370  SmallVector<Type, 8> operandTypes;
371  llvm::append_range(operandTypes, inlineAsmOp.getOperands().getTypes());
372 
373  Type resultType;
374  if (inlineAsmOp.getNumResults() == 0) {
375  resultType = LLVM::LLVMVoidType::get(&moduleTranslation.getContext());
376  } else {
377  assert(inlineAsmOp.getNumResults() == 1);
378  resultType = inlineAsmOp.getResultTypes()[0];
379  }
380  auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes);
381  llvm::InlineAsm *inlineAsmInst =
382  inlineAsmOp.getAsmDialect()
383  ? llvm::InlineAsm::get(
384  static_cast<llvm::FunctionType *>(
385  moduleTranslation.convertType(ft)),
386  inlineAsmOp.getAsmString(), inlineAsmOp.getConstraints(),
387  inlineAsmOp.getHasSideEffects(),
388  inlineAsmOp.getIsAlignStack(),
389  convertAsmDialectToLLVM(*inlineAsmOp.getAsmDialect()))
390  : llvm::InlineAsm::get(static_cast<llvm::FunctionType *>(
391  moduleTranslation.convertType(ft)),
392  inlineAsmOp.getAsmString(),
393  inlineAsmOp.getConstraints(),
394  inlineAsmOp.getHasSideEffects(),
395  inlineAsmOp.getIsAlignStack());
396  llvm::CallInst *inst = builder.CreateCall(
397  inlineAsmInst,
398  moduleTranslation.lookupValues(inlineAsmOp.getOperands()));
399  if (auto maybeOperandAttrs = inlineAsmOp.getOperandAttrs()) {
400  llvm::AttributeList attrList;
401  for (const auto &it : llvm::enumerate(*maybeOperandAttrs)) {
402  Attribute attr = it.value();
403  if (!attr)
404  continue;
405  DictionaryAttr dAttr = attr.cast<DictionaryAttr>();
406  TypeAttr tAttr =
407  dAttr.get(InlineAsmOp::getElementTypeAttrName()).cast<TypeAttr>();
408  llvm::AttrBuilder b(moduleTranslation.getLLVMContext());
409  llvm::Type *ty = moduleTranslation.convertType(tAttr.getValue());
410  b.addTypeAttr(llvm::Attribute::ElementType, ty);
411  // shift to account for the returned value (this is always 1 aggregate
412  // value in LLVM).
413  int shift = (opInst.getNumResults() > 0) ? 1 : 0;
414  attrList = attrList.addAttributesAtIndex(
415  moduleTranslation.getLLVMContext(), it.index() + shift, b);
416  }
417  inst->setAttributes(attrList);
418  }
419 
420  if (opInst.getNumResults() != 0)
421  moduleTranslation.mapValue(opInst.getResult(0), inst);
422  return success();
423  }
424 
425  if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
426  auto operands = moduleTranslation.lookupValues(invOp.getCalleeOperands());
427  ArrayRef<llvm::Value *> operandsRef(operands);
428  llvm::Instruction *result;
429  if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
430  result = builder.CreateInvoke(
431  moduleTranslation.lookupFunction(attr.getValue()),
432  moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
433  moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef);
434  } else {
435  auto calleeType =
436  invOp.getCalleeOperands().front().getType().cast<LLVMPointerType>();
437  auto *calleeFunctionType = cast<llvm::FunctionType>(
438  moduleTranslation.convertType(calleeType.getElementType()));
439  result = builder.CreateInvoke(
440  calleeFunctionType, operandsRef.front(),
441  moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
442  moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
443  operandsRef.drop_front());
444  }
445  moduleTranslation.mapBranch(invOp, result);
446  // InvokeOp can only have 0 or 1 result
447  if (invOp->getNumResults() != 0) {
448  moduleTranslation.mapValue(opInst.getResult(0), result);
449  return success();
450  }
451  return success(result->getType()->isVoidTy());
452  }
453 
454  if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
455  llvm::Type *ty = moduleTranslation.convertType(lpOp.getType());
456  llvm::LandingPadInst *lpi =
457  builder.CreateLandingPad(ty, lpOp.getNumOperands());
458  lpi->setCleanup(lpOp.getCleanup());
459 
460  // Add clauses
461  for (llvm::Value *operand :
462  moduleTranslation.lookupValues(lpOp.getOperands())) {
463  // All operands should be constant - checked by verifier
464  if (auto *constOperand = dyn_cast<llvm::Constant>(operand))
465  lpi->addClause(constOperand);
466  }
467  moduleTranslation.mapValue(lpOp.getResult(), lpi);
468  return success();
469  }
470 
471  // Emit branches. We need to look up the remapped blocks and ignore the
472  // block arguments that were transformed into PHI nodes.
473  if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
474  llvm::BranchInst *branch =
475  builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor()));
476  moduleTranslation.mapBranch(&opInst, branch);
477  setLoopMetadata(opInst, *branch, builder, moduleTranslation);
478  return success();
479  }
480  if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
481  llvm::MDNode *branchWeights = nullptr;
482  if (auto weights = condbrOp.getBranchWeights()) {
483  // Map weight attributes to LLVM metadata.
484  auto weightValues = weights->getValues<APInt>();
485  auto trueWeight = weightValues[0].getSExtValue();
486  auto falseWeight = weightValues[1].getSExtValue();
487  branchWeights =
488  llvm::MDBuilder(moduleTranslation.getLLVMContext())
489  .createBranchWeights(static_cast<uint32_t>(trueWeight),
490  static_cast<uint32_t>(falseWeight));
491  }
492  llvm::BranchInst *branch = builder.CreateCondBr(
493  moduleTranslation.lookupValue(condbrOp.getOperand(0)),
494  moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)),
495  moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)), branchWeights);
496  moduleTranslation.mapBranch(&opInst, branch);
497  setLoopMetadata(opInst, *branch, builder, moduleTranslation);
498  return success();
499  }
500  if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
501  llvm::MDNode *branchWeights = nullptr;
502  if (auto weights = switchOp.getBranchWeights()) {
503  llvm::SmallVector<uint32_t> weightValues;
504  weightValues.reserve(weights->size());
505  for (llvm::APInt weight : weights->cast<DenseIntElementsAttr>())
506  weightValues.push_back(weight.getLimitedValue());
507  branchWeights = llvm::MDBuilder(moduleTranslation.getLLVMContext())
508  .createBranchWeights(weightValues);
509  }
510 
511  llvm::SwitchInst *switchInst = builder.CreateSwitch(
512  moduleTranslation.lookupValue(switchOp.getValue()),
513  moduleTranslation.lookupBlock(switchOp.getDefaultDestination()),
514  switchOp.getCaseDestinations().size(), branchWeights);
515 
516  auto *ty = llvm::cast<llvm::IntegerType>(
517  moduleTranslation.convertType(switchOp.getValue().getType()));
518  for (auto i :
519  llvm::zip(switchOp.getCaseValues()->cast<DenseIntElementsAttr>(),
520  switchOp.getCaseDestinations()))
521  switchInst->addCase(
522  llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
523  moduleTranslation.lookupBlock(std::get<1>(i)));
524 
525  moduleTranslation.mapBranch(&opInst, switchInst);
526  return success();
527  }
528 
529  // Emit addressof. We need to look up the global value referenced by the
530  // operation and store it in the MLIR-to-LLVM value mapping. This does not
531  // emit any LLVM instruction.
532  if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
533  LLVM::GlobalOp global =
534  addressOfOp.getGlobal(moduleTranslation.symbolTable());
535  LLVM::LLVMFuncOp function =
536  addressOfOp.getFunction(moduleTranslation.symbolTable());
537 
538  // The verifier should not have allowed this.
539  assert((global || function) &&
540  "referencing an undefined global or function");
541 
542  moduleTranslation.mapValue(
543  addressOfOp.getResult(),
544  global ? moduleTranslation.lookupGlobal(global)
545  : moduleTranslation.lookupFunction(function.getName()));
546  return success();
547  }
548 
549  return failure();
550 }
551 
552 namespace {
553 /// Implementation of the dialect interface that converts operations belonging
554 /// to the LLVM dialect to LLVM IR.
555 class LLVMDialectLLVMIRTranslationInterface
557 public:
559 
560  /// Translates the given operation to LLVM IR using the provided IR builder
561  /// and saving the state in `moduleTranslation`.
563  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
564  LLVM::ModuleTranslation &moduleTranslation) const final {
565  return convertOperationImpl(*op, builder, moduleTranslation);
566  }
567 };
568 } // namespace
569 
571  registry.insert<LLVM::LLVMDialect>();
572  registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
573  dialect->addInterfaces<LLVMDialectLLVMIRTranslationInterface>();
574  });
575 }
576 
578  DialectRegistry registry;
580  context.appendDialectRegistry(registry);
581 }
static constexpr const bool value
static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p)
Convert MLIR integer comparison predicate to LLVM IR comparison predicate.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static FailureOr< llvm::Function * > getOverloadedDeclaration(CallIntrinsicOp &op, llvm::Intrinsic::ID id, llvm::Module *module, LLVM::ModuleTranslation &moduleTranslation)
Get the declaration of an overloaded llvm intrinsic.
static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering)
static llvm::MDNode * getLoopOptionMetadata(llvm::LLVMContext &ctx, LoopOptionCase option, int64_t value)
Returns an LLVM metadata node corresponding to a loop option.
static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op)
static LogicalResult convertCallLLVMIntrinsicOp(CallIntrinsicOp &op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Builder for LLVM_CallIntrinsicOp.
static LogicalResult convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void setLoopMetadata(Operation &opInst, llvm::Instruction &llvmInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U cast() const
Definition: Attributes.h:137
An attribute that represents a reference to a dense integer vector or tensor object.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
A symbol reference with a reference path containing a single element.
Base class for dialect interfaces providing translation to LLVM IR.
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
void mapBranch(Operation *mlir, llvm::Instruction *llvm)
Stores the mapping between an MLIR operation with successors and a corresponding LLVM IR instruction.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
SymbolTableCollection & symbolTable()
llvm::MDNode * getAccessGroup(Operation &opInst, SymbolRefAttr accessGroupRef) const
Returns the LLVM metadata corresponding to a reference to an mlir LLVM dialect access group operation...
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
void mapLoopOptionsMetadata(Attribute options, llvm::MDNode *metadata)
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
llvm::MDNode * lookupLoopOptionsMetadata(Attribute options) const
Returns the LLVM metadata corresponding to a llvm loop's codegen options attribute.
MLIRContext & getContext()
Returns the MLIR context of the module being translated.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:375
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:371
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
llvm::Constant * getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc, const ModuleTranslation &moduleTranslation)
Create an LLVM IR constant of llvmType from the MLIR attribute attr.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void registerLLVMDialectTranslation(DialectRegistry &registry)
Register the LLVM dialect and the translation from it to the LLVM IR in the given registry;.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26