MLIR  16.0.0git
ModuleTranslation.cpp
Go to the documentation of this file.
1 //===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the translation between an MLIR LLVM dialect module and
10 // the corresponding LLVMIR module. It only handles core LLVM IR operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
16 #include "DebugTranslation.h"
17 #include "mlir/Dialect/DLTI/DLTI.h"
21 #include "mlir/IR/Attributes.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/Support/LLVM.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 
30 #include "llvm/ADT/PostOrderIterator.h"
31 #include "llvm/ADT/SetVector.h"
32 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
33 #include "llvm/IR/BasicBlock.h"
34 #include "llvm/IR/CFG.h"
35 #include "llvm/IR/Constants.h"
36 #include "llvm/IR/DerivedTypes.h"
37 #include "llvm/IR/IRBuilder.h"
38 #include "llvm/IR/InlineAsm.h"
39 #include "llvm/IR/IntrinsicsNVPTX.h"
40 #include "llvm/IR/LLVMContext.h"
41 #include "llvm/IR/MDBuilder.h"
42 #include "llvm/IR/Module.h"
43 #include "llvm/IR/Verifier.h"
44 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
45 #include "llvm/Transforms/Utils/Cloning.h"
46 #include "llvm/Transforms/Utils/ModuleUtils.h"
47 
48 using namespace mlir;
49 using namespace mlir::LLVM;
50 using namespace mlir::LLVM::detail;
51 
52 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
53 
54 /// Translates the given data layout spec attribute to the LLVM IR data layout.
55 /// Only integer, float, pointer and endianness entries are currently supported.
57 translateDataLayout(DataLayoutSpecInterface attribute,
58  const DataLayout &dataLayout,
59  Optional<Location> loc = std::nullopt) {
60  if (!loc)
61  loc = UnknownLoc::get(attribute.getContext());
62 
63  // Translate the endianness attribute.
64  std::string llvmDataLayout;
65  llvm::raw_string_ostream layoutStream(llvmDataLayout);
66  for (DataLayoutEntryInterface entry : attribute.getEntries()) {
67  auto key = entry.getKey().dyn_cast<StringAttr>();
68  if (!key)
69  continue;
70  if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) {
71  auto value = entry.getValue().cast<StringAttr>();
72  bool isLittleEndian =
73  value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle;
74  layoutStream << (isLittleEndian ? "e" : "E");
75  layoutStream.flush();
76  continue;
77  }
78  emitError(*loc) << "unsupported data layout key " << key;
79  return failure();
80  }
81 
82  // Go through the list of entries to check which types are explicitly
83  // specified in entries. Where possible, data layout queries are used instead
84  // of directly inspecting the entries.
85  for (DataLayoutEntryInterface entry : attribute.getEntries()) {
86  auto type = entry.getKey().dyn_cast<Type>();
87  if (!type)
88  continue;
89  // Data layout for the index type is irrelevant at this point.
90  if (type.isa<IndexType>())
91  continue;
92  layoutStream << "-";
93  LogicalResult result =
95  .Case<IntegerType, Float16Type, Float32Type, Float64Type,
96  Float80Type, Float128Type>([&](Type type) -> LogicalResult {
97  if (auto intType = type.dyn_cast<IntegerType>()) {
98  if (intType.getSignedness() != IntegerType::Signless)
99  return emitError(*loc)
100  << "unsupported data layout for non-signless integer "
101  << intType;
102  layoutStream << "i";
103  } else {
104  layoutStream << "f";
105  }
106  unsigned size = dataLayout.getTypeSizeInBits(type);
107  unsigned abi = dataLayout.getTypeABIAlignment(type) * 8u;
108  unsigned preferred =
109  dataLayout.getTypePreferredAlignment(type) * 8u;
110  layoutStream << size << ":" << abi;
111  if (abi != preferred)
112  layoutStream << ":" << preferred;
113  return success();
114  })
115  .Case([&](LLVMPointerType ptrType) {
116  layoutStream << "p" << ptrType.getAddressSpace() << ":";
117  unsigned size = dataLayout.getTypeSizeInBits(type);
118  unsigned abi = dataLayout.getTypeABIAlignment(type) * 8u;
119  unsigned preferred =
120  dataLayout.getTypePreferredAlignment(type) * 8u;
121  layoutStream << size << ":" << abi << ":" << preferred;
123  entry.getValue(), PtrDLEntryPos::Index))
124  layoutStream << ":" << *index;
125  return success();
126  })
127  .Default([loc](Type type) {
128  return emitError(*loc)
129  << "unsupported type in data layout: " << type;
130  });
131  if (failed(result))
132  return failure();
133  }
134  layoutStream.flush();
135  StringRef layoutSpec(llvmDataLayout);
136  if (layoutSpec.startswith("-"))
137  layoutSpec = layoutSpec.drop_front();
138 
139  return llvm::DataLayout(layoutSpec);
140 }
141 
142 /// Builds a constant of a sequential LLVM type `type`, potentially containing
143 /// other sequential types recursively, from the individual constant values
144 /// provided in `constants`. `shape` contains the number of elements in nested
145 /// sequential types. Reports errors at `loc` and returns nullptr on error.
146 static llvm::Constant *
148  ArrayRef<int64_t> shape, llvm::Type *type,
149  Location loc) {
150  if (shape.empty()) {
151  llvm::Constant *result = constants.front();
152  constants = constants.drop_front();
153  return result;
154  }
155 
156  llvm::Type *elementType;
157  if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
158  elementType = arrayTy->getElementType();
159  } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
160  elementType = vectorTy->getElementType();
161  } else {
162  emitError(loc) << "expected sequential LLVM types wrapping a scalar";
163  return nullptr;
164  }
165 
167  nested.reserve(shape.front());
168  for (int64_t i = 0; i < shape.front(); ++i) {
169  nested.push_back(buildSequentialConstant(constants, shape.drop_front(),
170  elementType, loc));
171  if (!nested.back())
172  return nullptr;
173  }
174 
175  if (shape.size() == 1 && type->isVectorTy())
176  return llvm::ConstantVector::get(nested);
177  return llvm::ConstantArray::get(
178  llvm::ArrayType::get(elementType, shape.front()), nested);
179 }
180 
181 /// Returns the first non-sequential type nested in sequential types.
182 static llvm::Type *getInnermostElementType(llvm::Type *type) {
183  do {
184  if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
185  type = arrayTy->getElementType();
186  } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
187  type = vectorTy->getElementType();
188  } else {
189  return type;
190  }
191  } while (true);
192 }
193 
194 /// Convert a dense elements attribute to an LLVM IR constant using its raw data
195 /// storage if possible. This supports elements attributes of tensor or vector
196 /// type and avoids constructing separate objects for individual values of the
197 /// innermost dimension. Constants for other dimensions are still constructed
198 /// recursively. Returns null if constructing from raw data is not supported for
199 /// this type, e.g., element type is not a power-of-two-sized primitive. Reports
200 /// other errors at `loc`.
201 static llvm::Constant *
203  llvm::Type *llvmType,
204  const ModuleTranslation &moduleTranslation) {
205  if (!denseElementsAttr)
206  return nullptr;
207 
208  llvm::Type *innermostLLVMType = getInnermostElementType(llvmType);
209  if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType))
210  return nullptr;
211 
212  ShapedType type = denseElementsAttr.getType();
213  if (type.getNumElements() == 0)
214  return nullptr;
215 
216  // Compute the shape of all dimensions but the innermost. Note that the
217  // innermost dimension may be that of the vector element type.
218  bool hasVectorElementType = type.getElementType().isa<VectorType>();
219  unsigned numAggregates =
220  denseElementsAttr.getNumElements() /
221  (hasVectorElementType ? 1
222  : denseElementsAttr.getType().getShape().back());
223  ArrayRef<int64_t> outerShape = type.getShape();
224  if (!hasVectorElementType)
225  outerShape = outerShape.drop_back();
226 
227  // Handle the case of vector splat, LLVM has special support for it.
228  if (denseElementsAttr.isSplat() &&
229  (type.isa<VectorType>() || hasVectorElementType)) {
230  llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
231  innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
232  moduleTranslation);
233  llvm::Constant *splatVector =
234  llvm::ConstantDataVector::getSplat(0, splatValue);
235  SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
236  ArrayRef<llvm::Constant *> constantsRef = constants;
237  return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
238  }
239  if (denseElementsAttr.isSplat())
240  return nullptr;
241 
242  // In case of non-splat, create a constructor for the innermost constant from
243  // a piece of raw data.
244  std::function<llvm::Constant *(StringRef)> buildCstData;
245  if (type.isa<TensorType>()) {
246  auto vectorElementType = type.getElementType().dyn_cast<VectorType>();
247  if (vectorElementType && vectorElementType.getRank() == 1) {
248  buildCstData = [&](StringRef data) {
249  return llvm::ConstantDataVector::getRaw(
250  data, vectorElementType.getShape().back(), innermostLLVMType);
251  };
252  } else if (!vectorElementType) {
253  buildCstData = [&](StringRef data) {
254  return llvm::ConstantDataArray::getRaw(data, type.getShape().back(),
255  innermostLLVMType);
256  };
257  }
258  } else if (type.isa<VectorType>()) {
259  buildCstData = [&](StringRef data) {
260  return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
261  innermostLLVMType);
262  };
263  }
264  if (!buildCstData)
265  return nullptr;
266 
267  // Create innermost constants and defer to the default constant creation
268  // mechanism for other dimensions.
270  unsigned aggregateSize = denseElementsAttr.getType().getShape().back() *
271  (innermostLLVMType->getScalarSizeInBits() / 8);
272  constants.reserve(numAggregates);
273  for (unsigned i = 0; i < numAggregates; ++i) {
274  StringRef data(denseElementsAttr.getRawData().data() + i * aggregateSize,
275  aggregateSize);
276  constants.push_back(buildCstData(data));
277  }
278 
279  ArrayRef<llvm::Constant *> constantsRef = constants;
280  return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
281 }
282 
283 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
284 /// This currently supports integer, floating point, splat and dense element
285 /// attributes and combinations thereof. Also, an array attribute with two
286 /// elements is supported to represent a complex constant. In case of error,
287 /// report it to `loc` and return nullptr.
289  llvm::Type *llvmType, Attribute attr, Location loc,
290  const ModuleTranslation &moduleTranslation) {
291  if (!attr)
292  return llvm::UndefValue::get(llvmType);
293  if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
294  auto arrayAttr = attr.dyn_cast<ArrayAttr>();
295  if (!arrayAttr || arrayAttr.size() != 2) {
296  emitError(loc, "expected struct type to be a complex number");
297  return nullptr;
298  }
299  llvm::Type *elementType = structType->getElementType(0);
300  llvm::Constant *real =
301  getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
302  if (!real)
303  return nullptr;
304  llvm::Constant *imag =
305  getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
306  if (!imag)
307  return nullptr;
308  return llvm::ConstantStruct::get(structType, {real, imag});
309  }
310  // For integer types, we allow a mismatch in sizes as the index type in
311  // MLIR might have a different size than the index type in the LLVM module.
312  if (auto intAttr = attr.dyn_cast<IntegerAttr>())
313  return llvm::ConstantInt::get(
314  llvmType,
315  intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));
316  if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
317  if (llvmType !=
318  llvm::Type::getFloatingPointTy(llvmType->getContext(),
319  floatAttr.getValue().getSemantics())) {
320  emitError(loc, "FloatAttr does not match expected type of the constant");
321  return nullptr;
322  }
323  return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
324  }
325  if (auto funcAttr = attr.dyn_cast<FlatSymbolRefAttr>())
326  return llvm::ConstantExpr::getBitCast(
327  moduleTranslation.lookupFunction(funcAttr.getValue()), llvmType);
328  if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
329  llvm::Type *elementType;
330  uint64_t numElements;
331  bool isScalable = false;
332  if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
333  elementType = arrayTy->getElementType();
334  numElements = arrayTy->getNumElements();
335  } else if (auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) {
336  elementType = fVectorTy->getElementType();
337  numElements = fVectorTy->getNumElements();
338  } else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
339  elementType = sVectorTy->getElementType();
340  numElements = sVectorTy->getMinNumElements();
341  isScalable = true;
342  } else {
343  llvm_unreachable("unrecognized constant vector type");
344  }
345  // Splat value is a scalar. Extract it only if the element type is not
346  // another sequence type. The recursion terminates because each step removes
347  // one outer sequential type.
348  bool elementTypeSequential =
349  isa<llvm::ArrayType, llvm::VectorType>(elementType);
350  llvm::Constant *child = getLLVMConstant(
351  elementType,
352  elementTypeSequential ? splatAttr
353  : splatAttr.getSplatValue<Attribute>(),
354  loc, moduleTranslation);
355  if (!child)
356  return nullptr;
357  if (llvmType->isVectorTy())
358  return llvm::ConstantVector::getSplat(
359  llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child);
360  if (llvmType->isArrayTy()) {
361  auto *arrayType = llvm::ArrayType::get(elementType, numElements);
362  SmallVector<llvm::Constant *, 8> constants(numElements, child);
363  return llvm::ConstantArray::get(arrayType, constants);
364  }
365  }
366 
367  // Try using raw elements data if possible.
368  if (llvm::Constant *result =
370  llvmType, moduleTranslation)) {
371  return result;
372  }
373 
374  // Fall back to element-by-element construction otherwise.
375  if (auto elementsAttr = attr.dyn_cast<ElementsAttr>()) {
376  assert(elementsAttr.getType().hasStaticShape());
377  assert(!elementsAttr.getType().getShape().empty() &&
378  "unexpected empty elements attribute shape");
379 
381  constants.reserve(elementsAttr.getNumElements());
382  llvm::Type *innermostType = getInnermostElementType(llvmType);
383  for (auto n : elementsAttr.getValues<Attribute>()) {
384  constants.push_back(
385  getLLVMConstant(innermostType, n, loc, moduleTranslation));
386  if (!constants.back())
387  return nullptr;
388  }
389  ArrayRef<llvm::Constant *> constantsRef = constants;
390  llvm::Constant *result = buildSequentialConstant(
391  constantsRef, elementsAttr.getType().getShape(), llvmType, loc);
392  assert(constantsRef.empty() && "did not consume all elemental constants");
393  return result;
394  }
395 
396  if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
397  return llvm::ConstantDataArray::get(
398  moduleTranslation.getLLVMContext(),
399  ArrayRef<char>{stringAttr.getValue().data(),
400  stringAttr.getValue().size()});
401  }
402  emitError(loc, "unsupported constant value");
403  return nullptr;
404 }
405 
406 ModuleTranslation::ModuleTranslation(Operation *module,
407  std::unique_ptr<llvm::Module> llvmModule)
408  : mlirModule(module), llvmModule(std::move(llvmModule)),
409  debugTranslation(
410  std::make_unique<DebugTranslation>(module, *this->llvmModule)),
411  typeTranslator(this->llvmModule->getContext()),
412  iface(module->getContext()) {
413  assert(satisfiesLLVMModule(mlirModule) &&
414  "mlirModule should honor LLVM's module semantics.");
415 }
416 ModuleTranslation::~ModuleTranslation() {
417  if (ompBuilder)
418  ompBuilder->finalize();
419 }
420 
422  SmallVector<Region *> toProcess;
423  toProcess.push_back(&region);
424  while (!toProcess.empty()) {
425  Region *current = toProcess.pop_back_val();
426  for (Block &block : *current) {
427  blockMapping.erase(&block);
428  for (Value arg : block.getArguments())
429  valueMapping.erase(arg);
430  for (Operation &op : block) {
431  for (Value value : op.getResults())
432  valueMapping.erase(value);
433  if (op.hasSuccessors())
434  branchMapping.erase(&op);
435  if (isa<LLVM::GlobalOp>(op))
436  globalsMapping.erase(&op);
437  accessGroupMetadataMapping.erase(&op);
438  llvm::append_range(
439  toProcess,
440  llvm::map_range(op.getRegions(), [](Region &r) { return &r; }));
441  }
442  }
443  }
444 }
445 
446 /// Get the SSA value passed to the current block from the terminator operation
447 /// of its predecessor.
448 static Value getPHISourceValue(Block *current, Block *pred,
449  unsigned numArguments, unsigned index) {
450  Operation &terminator = *pred->getTerminator();
451  if (isa<LLVM::BrOp>(terminator))
452  return terminator.getOperand(index);
453 
454 #ifndef NDEBUG
455  llvm::SmallPtrSet<Block *, 4> seenSuccessors;
456  for (unsigned i = 0, e = terminator.getNumSuccessors(); i < e; ++i) {
457  Block *successor = terminator.getSuccessor(i);
458  auto branch = cast<BranchOpInterface>(terminator);
459  SuccessorOperands successorOperands = branch.getSuccessorOperands(i);
460  assert(
461  (!seenSuccessors.contains(successor) || successorOperands.empty()) &&
462  "successors with arguments in LLVM branches must be different blocks");
463  seenSuccessors.insert(successor);
464  }
465 #endif
466 
467  // For instructions that branch based on a condition value, we need to take
468  // the operands for the branch that was taken.
469  if (auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator)) {
470  // For conditional branches, we take the operands from either the "true" or
471  // the "false" branch.
472  return condBranchOp.getSuccessor(0) == current
473  ? condBranchOp.getTrueDestOperands()[index]
474  : condBranchOp.getFalseDestOperands()[index];
475  }
476 
477  if (auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) {
478  // For switches, we take the operands from either the default case, or from
479  // the case branch that was taken.
480  if (switchOp.getDefaultDestination() == current)
481  return switchOp.getDefaultOperands()[index];
482  for (const auto &i : llvm::enumerate(switchOp.getCaseDestinations()))
483  if (i.value() == current)
484  return switchOp.getCaseOperands(i.index())[index];
485  }
486 
487  if (auto invokeOp = dyn_cast<LLVM::InvokeOp>(terminator)) {
488  return invokeOp.getNormalDest() == current
489  ? invokeOp.getNormalDestOperands()[index]
490  : invokeOp.getUnwindDestOperands()[index];
491  }
492 
493  llvm_unreachable(
494  "only branch, switch or invoke operations can be terminators "
495  "of a block that has successors");
496 }
497 
498 /// Connect the PHI nodes to the results of preceding blocks.
500  const ModuleTranslation &state) {
501  // Skip the first block, it cannot be branched to and its arguments correspond
502  // to the arguments of the LLVM function.
503  for (Block &bb : llvm::drop_begin(region)) {
504  llvm::BasicBlock *llvmBB = state.lookupBlock(&bb);
505  auto phis = llvmBB->phis();
506  auto numArguments = bb.getNumArguments();
507  assert(numArguments == std::distance(phis.begin(), phis.end()));
508  for (auto &numberedPhiNode : llvm::enumerate(phis)) {
509  auto &phiNode = numberedPhiNode.value();
510  unsigned index = numberedPhiNode.index();
511  for (auto *pred : bb.getPredecessors()) {
512  // Find the LLVM IR block that contains the converted terminator
513  // instruction and use it in the PHI node. Note that this block is not
514  // necessarily the same as state.lookupBlock(pred), some operations
515  // (in particular, OpenMP operations using OpenMPIRBuilder) may have
516  // split the blocks.
517  llvm::Instruction *terminator =
518  state.lookupBranch(pred->getTerminator());
519  assert(terminator && "missing the mapping for a terminator");
520  phiNode.addIncoming(state.lookupValue(getPHISourceValue(
521  &bb, pred, numArguments, index)),
522  terminator->getParent());
523  }
524  }
525  }
526 }
527 
528 /// Sort function blocks topologically.
531  // For each block that has not been visited yet (i.e. that has no
532  // predecessors), add it to the list as well as its successors.
533  SetVector<Block *> blocks;
534  for (Block &b : region) {
535  if (blocks.count(&b) == 0) {
536  llvm::ReversePostOrderTraversal<Block *> traversal(&b);
537  blocks.insert(traversal.begin(), traversal.end());
538  }
539  }
540  assert(blocks.size() == region.getBlocks().size() &&
541  "some blocks are not sorted");
542 
543  return blocks;
544 }
545 
547  llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic,
549  llvm::Module *module = builder.GetInsertBlock()->getModule();
550  llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic, tys);
551  return builder.CreateCall(fn, args);
552 }
553 
554 /// Given a single MLIR operation, create the corresponding LLVM IR operation
555 /// using the `builder`.
557 ModuleTranslation::convertOperation(Operation &op,
558  llvm::IRBuilderBase &builder) {
559  const LLVMTranslationDialectInterface *opIface = iface.getInterfaceFor(&op);
560  if (!opIface)
561  return op.emitError("cannot be converted to LLVM IR: missing "
562  "`LLVMTranslationDialectInterface` registration for "
563  "dialect for op: ")
564  << op.getName();
565 
566  if (failed(opIface->convertOperation(&op, builder, *this)))
567  return op.emitError("LLVM Translation failed for operation: ")
568  << op.getName();
569 
570  return convertDialectAttributes(&op);
571 }
572 
573 /// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes
574 /// to define values corresponding to the MLIR block arguments. These nodes
575 /// are not connected to the source basic blocks, which may not exist yet. Uses
576 /// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have
577 /// been created for `bb` and included in the block mapping. Inserts new
578 /// instructions at the end of the block and leaves `builder` in a state
579 /// suitable for further insertion into the end of the block.
581  llvm::IRBuilderBase &builder) {
582  builder.SetInsertPoint(lookupBlock(&bb));
583  auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram();
584 
585  // Before traversing operations, make block arguments available through
586  // value remapping and PHI nodes, but do not add incoming edges for the PHI
587  // nodes just yet: those values may be defined by this or following blocks.
588  // This step is omitted if "ignoreArguments" is set. The arguments of the
589  // first block have been already made available through the remapping of
590  // LLVM function arguments.
591  if (!ignoreArguments) {
592  auto predecessors = bb.getPredecessors();
593  unsigned numPredecessors =
594  std::distance(predecessors.begin(), predecessors.end());
595  for (auto arg : bb.getArguments()) {
596  auto wrappedType = arg.getType();
597  if (!isCompatibleType(wrappedType))
598  return emitError(bb.front().getLoc(),
599  "block argument does not have an LLVM type");
600  llvm::Type *type = convertType(wrappedType);
601  llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
602  mapValue(arg, phi);
603  }
604  }
605 
606  // Traverse operations.
607  for (auto &op : bb) {
608  // Set the current debug location within the builder.
609  builder.SetCurrentDebugLocation(
610  debugTranslation->translateLoc(op.getLoc(), subprogram));
611 
612  if (failed(convertOperation(op, builder)))
613  return failure();
614  }
615 
616  return success();
617 }
618 
619 /// A helper method to get the single Block in an operation honoring LLVM's
620 /// module requirements.
621 static Block &getModuleBody(Operation *module) {
622  return module->getRegion(0).front();
623 }
624 
625 /// A helper method to decide if a constant must not be set as a global variable
626 /// initializer. For an external linkage variable, the variable with an
627 /// initializer is considered externally visible and defined in this module, the
628 /// variable without an initializer is externally available and is defined
629 /// elsewhere.
630 static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage,
631  llvm::Constant *cst) {
632  return (linkage == llvm::GlobalVariable::ExternalLinkage && !cst) ||
633  linkage == llvm::GlobalVariable::ExternalWeakLinkage;
634 }
635 
636 /// Sets the runtime preemption specifier of `gv` to dso_local if
637 /// `dsoLocalRequested` is true, otherwise it is left unchanged.
638 static void addRuntimePreemptionSpecifier(bool dsoLocalRequested,
639  llvm::GlobalValue *gv) {
640  if (dsoLocalRequested)
641  gv->setDSOLocal(true);
642 }
643 
644 /// Create named global variables that correspond to llvm.mlir.global
645 /// definitions. Convert llvm.global_ctors and global_dtors ops.
646 LogicalResult ModuleTranslation::convertGlobals() {
647  for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
648  llvm::Type *type = convertType(op.getType());
649  llvm::Constant *cst = nullptr;
650  if (op.getValueOrNull()) {
651  // String attributes are treated separately because they cannot appear as
652  // in-function constants and are thus not supported by getLLVMConstant.
653  if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
654  cst = llvm::ConstantDataArray::getString(
655  llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false);
656  type = cst->getType();
657  } else if (!(cst = getLLVMConstant(type, op.getValueOrNull(), op.getLoc(),
658  *this))) {
659  return failure();
660  }
661  }
662 
663  auto linkage = convertLinkageToLLVM(op.getLinkage());
664  auto addrSpace = op.getAddrSpace();
665 
666  // LLVM IR requires constant with linkage other than external or weak
667  // external to have initializers. If MLIR does not provide an initializer,
668  // default to undef.
669  bool dropInitializer = shouldDropGlobalInitializer(linkage, cst);
670  if (!dropInitializer && !cst)
671  cst = llvm::UndefValue::get(type);
672  else if (dropInitializer && cst)
673  cst = nullptr;
674 
675  auto *var = new llvm::GlobalVariable(
676  *llvmModule, type, op.getConstant(), linkage, cst, op.getSymName(),
677  /*InsertBefore=*/nullptr,
678  op.getThreadLocal_() ? llvm::GlobalValue::GeneralDynamicTLSModel
679  : llvm::GlobalValue::NotThreadLocal,
680  addrSpace);
681 
682  if (op.getUnnamedAddr().has_value())
683  var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr()));
684 
685  if (op.getSection().has_value())
686  var->setSection(*op.getSection());
687 
688  addRuntimePreemptionSpecifier(op.getDsoLocal(), var);
689 
690  Optional<uint64_t> alignment = op.getAlignment();
691  if (alignment.has_value())
692  var->setAlignment(llvm::MaybeAlign(alignment.value()));
693 
694  globalsMapping.try_emplace(op, var);
695  }
696 
697  // Convert global variable bodies. This is done after all global variables
698  // have been created in LLVM IR because a global body may refer to another
699  // global or itself. So all global variables need to be mapped first.
700  for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
701  if (Block *initializer = op.getInitializerBlock()) {
702  llvm::IRBuilder<> builder(llvmModule->getContext());
703  for (auto &op : initializer->without_terminator()) {
704  if (failed(convertOperation(op, builder)) ||
705  !isa<llvm::Constant>(lookupValue(op.getResult(0))))
706  return emitError(op.getLoc(), "unemittable constant value");
707  }
708  ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
709  llvm::Constant *cst =
710  cast<llvm::Constant>(lookupValue(ret.getOperand(0)));
711  auto *global = cast<llvm::GlobalVariable>(lookupGlobal(op));
712  if (!shouldDropGlobalInitializer(global->getLinkage(), cst))
713  global->setInitializer(cst);
714  }
715  }
716 
717  // Convert llvm.mlir.global_ctors and dtors.
718  for (Operation &op : getModuleBody(mlirModule)) {
719  auto ctorOp = dyn_cast<GlobalCtorsOp>(op);
720  auto dtorOp = dyn_cast<GlobalDtorsOp>(op);
721  if (!ctorOp && !dtorOp)
722  continue;
723  auto range = ctorOp ? llvm::zip(ctorOp.getCtors(), ctorOp.getPriorities())
724  : llvm::zip(dtorOp.getDtors(), dtorOp.getPriorities());
725  auto appendGlobalFn =
726  ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors;
727  for (auto symbolAndPriority : range) {
728  llvm::Function *f = lookupFunction(
729  std::get<0>(symbolAndPriority).cast<FlatSymbolRefAttr>().getValue());
730  appendGlobalFn(
731  *llvmModule, f,
732  std::get<1>(symbolAndPriority).cast<IntegerAttr>().getInt(),
733  /*Data=*/nullptr);
734  }
735  }
736 
737  return success();
738 }
739 
740 /// Attempts to add an attribute identified by `key`, optionally with the given
741 /// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
742 /// attribute has a kind known to LLVM IR, create the attribute of this kind,
743 /// otherwise keep it as a string attribute. Performs additional checks for
744 /// attributes known to have or not have a value in order to avoid assertions
745 /// inside LLVM upon construction.
747  llvm::Function *llvmFunc,
748  StringRef key,
749  StringRef value = StringRef()) {
750  auto kind = llvm::Attribute::getAttrKindFromName(key);
751  if (kind == llvm::Attribute::None) {
752  llvmFunc->addFnAttr(key, value);
753  return success();
754  }
755 
756  if (llvm::Attribute::isIntAttrKind(kind)) {
757  if (value.empty())
758  return emitError(loc) << "LLVM attribute '" << key << "' expects a value";
759 
760  int result;
761  if (!value.getAsInteger(/*Radix=*/0, result))
762  llvmFunc->addFnAttr(
763  llvm::Attribute::get(llvmFunc->getContext(), kind, result));
764  else
765  llvmFunc->addFnAttr(key, value);
766  return success();
767  }
768 
769  if (!value.empty())
770  return emitError(loc) << "LLVM attribute '" << key
771  << "' does not expect a value, found '" << value
772  << "'";
773 
774  llvmFunc->addFnAttr(kind);
775  return success();
776 }
777 
778 /// Attaches the attributes listed in the given array attribute to `llvmFunc`.
779 /// Reports error to `loc` if any and returns immediately. Expects `attributes`
780 /// to be an array attribute containing either string attributes, treated as
781 /// value-less LLVM attributes, or array attributes containing two string
782 /// attributes, with the first string being the name of the corresponding LLVM
783 /// attribute and the second string beings its value. Note that even integer
784 /// attributes are expected to have their values expressed as strings.
785 static LogicalResult
787  llvm::Function *llvmFunc) {
788  if (!attributes)
789  return success();
790 
791  for (Attribute attr : *attributes) {
792  if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
793  if (failed(
794  checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue())))
795  return failure();
796  continue;
797  }
798 
799  auto arrayAttr = attr.dyn_cast<ArrayAttr>();
800  if (!arrayAttr || arrayAttr.size() != 2)
801  return emitError(loc)
802  << "expected 'passthrough' to contain string or array attributes";
803 
804  auto keyAttr = arrayAttr[0].dyn_cast<StringAttr>();
805  auto valueAttr = arrayAttr[1].dyn_cast<StringAttr>();
806  if (!keyAttr || !valueAttr)
807  return emitError(loc)
808  << "expected arrays within 'passthrough' to contain two strings";
809 
810  if (failed(checkedAddLLVMFnAttribute(loc, llvmFunc, keyAttr.getValue(),
811  valueAttr.getValue())))
812  return failure();
813  }
814  return success();
815 }
816 
817 LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
818  // Clear the block, branch value mappings, they are only relevant within one
819  // function.
820  blockMapping.clear();
821  valueMapping.clear();
822  branchMapping.clear();
823  llvm::Function *llvmFunc = lookupFunction(func.getName());
824 
825  // Translate the debug information for this function.
826  debugTranslation->translate(func, *llvmFunc);
827 
828  // Add function arguments to the value remapping table.
829  for (auto [mlirArg, llvmArg] :
830  llvm::zip(func.getArguments(), llvmFunc->args()))
831  mapValue(mlirArg, &llvmArg);
832 
833  // Check the personality and set it.
834  if (func.getPersonality()) {
835  llvm::Type *ty = llvm::Type::getInt8PtrTy(llvmFunc->getContext());
836  if (llvm::Constant *pfunc = getLLVMConstant(ty, func.getPersonalityAttr(),
837  func.getLoc(), *this))
838  llvmFunc->setPersonalityFn(pfunc);
839  }
840 
841  if (auto gc = func.getGarbageCollector())
842  llvmFunc->setGC(gc->str());
843 
844  // First, create all blocks so we can jump to them.
845  llvm::LLVMContext &llvmContext = llvmFunc->getContext();
846  for (auto &bb : func) {
847  auto *llvmBB = llvm::BasicBlock::Create(llvmContext);
848  llvmBB->insertInto(llvmFunc);
849  mapBlock(&bb, llvmBB);
850  }
851 
852  // Then, convert blocks one by one in topological order to ensure defs are
853  // converted before uses.
854  auto blocks = detail::getTopologicallySortedBlocks(func.getBody());
855  for (Block *bb : blocks) {
856  llvm::IRBuilder<> builder(llvmContext);
857  if (failed(convertBlock(*bb, bb->isEntryBlock(), builder)))
858  return failure();
859  }
860 
861  // After all blocks have been traversed and values mapped, connect the PHI
862  // nodes to the results of preceding blocks.
863  detail::connectPHINodes(func.getBody(), *this);
864 
865  // Finally, convert dialect attributes attached to the function.
866  return convertDialectAttributes(func);
867 }
868 
869 LogicalResult ModuleTranslation::convertDialectAttributes(Operation *op) {
870  for (NamedAttribute attribute : op->getDialectAttrs())
871  if (failed(iface.amendOperation(op, attribute, *this)))
872  return failure();
873  return success();
874 }
875 
876 LogicalResult ModuleTranslation::convertFunctionSignatures() {
877  // Declare all functions first because there may be function calls that form a
878  // call graph with cycles, or global initializers that reference functions.
879  for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
880  llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
881  function.getName(),
882  cast<llvm::FunctionType>(convertType(function.getFunctionType())));
883  llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee());
884  llvmFunc->setLinkage(convertLinkageToLLVM(function.getLinkage()));
885  mapFunction(function.getName(), llvmFunc);
886  addRuntimePreemptionSpecifier(function.getDsoLocal(), llvmFunc);
887 
888  // Convert function attributes.
889  if (function->getAttrOfType<UnitAttr>(LLVMDialect::getReadnoneAttrName()))
890  llvmFunc->setDoesNotAccessMemory();
891 
892  // Convert result attributes.
893  if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
894  llvm::AttrBuilder retAttrs(llvmFunc->getContext());
895  DictionaryAttr resultAttrs = allResultAttrs[0].cast<DictionaryAttr>();
896  for (const NamedAttribute &attr : resultAttrs) {
897  StringAttr name = attr.getName();
898  if (name == LLVMDialect::getAlignAttrName()) {
899  auto alignAmount = attr.getValue().cast<IntegerAttr>();
900  retAttrs.addAlignmentAttr(llvm::Align(alignAmount.getInt()));
901  } else if (name == LLVMDialect::getNoAliasAttrName()) {
902  retAttrs.addAttribute(llvm::Attribute::NoAlias);
903  } else if (name == LLVMDialect::getNoUndefAttrName()) {
904  retAttrs.addAttribute(llvm::Attribute::NoUndef);
905  } else if (name == LLVMDialect::getSExtAttrName()) {
906  retAttrs.addAttribute(llvm::Attribute::SExt);
907  } else if (name == LLVMDialect::getZExtAttrName()) {
908  retAttrs.addAttribute(llvm::Attribute::ZExt);
909  }
910  }
911  llvmFunc->addRetAttrs(retAttrs);
912  }
913 
914  // Convert argument attributes.
915  unsigned int argIdx = 0;
916  for (auto [mlirArgTy, llvmArg] :
917  llvm::zip(function.getArgumentTypes(), llvmFunc->args())) {
918  if (auto attr = function.getArgAttrOfType<UnitAttr>(
919  argIdx, LLVMDialect::getNoAliasAttrName())) {
920  // NB: Attribute already verified to be boolean, so check if we can
921  // indeed attach the attribute to this argument, based on its type.
922  if (!mlirArgTy.isa<LLVM::LLVMPointerType>())
923  return function.emitError(
924  "llvm.noalias attribute attached to LLVM non-pointer argument");
925  llvmArg.addAttr(llvm::Attribute::AttrKind::NoAlias);
926  }
927 
928  if (auto attr = function.getArgAttrOfType<IntegerAttr>(
929  argIdx, LLVMDialect::getAlignAttrName())) {
930  // NB: Attribute already verified to be int, so check if we can indeed
931  // attach the attribute to this argument, based on its type.
932  if (!mlirArgTy.isa<LLVM::LLVMPointerType>())
933  return function.emitError(
934  "llvm.align attribute attached to LLVM non-pointer argument");
935  llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
936  .addAlignmentAttr(llvm::Align(attr.getInt())));
937  }
938 
939  if (auto attr = function.getArgAttrOfType<TypeAttr>(
940  argIdx, LLVMDialect::getStructRetAttrName())) {
941  auto argTy = mlirArgTy.dyn_cast<LLVM::LLVMPointerType>();
942  if (!argTy)
943  return function.emitError(
944  "llvm.sret attribute attached to LLVM non-pointer argument");
945  if (!argTy.isOpaque() && argTy.getElementType() != attr.getValue())
946  return function.emitError(
947  "llvm.sret attribute attached to LLVM pointer "
948  "argument of a different type");
949  llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
950  .addStructRetAttr(convertType(attr.getValue())));
951  }
952 
953  if (auto attr = function.getArgAttrOfType<TypeAttr>(
954  argIdx, LLVMDialect::getByValAttrName())) {
955  auto argTy = mlirArgTy.dyn_cast<LLVM::LLVMPointerType>();
956  if (!argTy)
957  return function.emitError(
958  "llvm.byval attribute attached to LLVM non-pointer argument");
959  if (!argTy.isOpaque() && argTy.getElementType() != attr.getValue())
960  return function.emitError(
961  "llvm.byval attribute attached to LLVM pointer "
962  "argument of a different type");
963  llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
964  .addByValAttr(convertType(attr.getValue())));
965  }
966 
967  if (auto attr = function.getArgAttrOfType<TypeAttr>(
968  argIdx, LLVMDialect::getByRefAttrName())) {
969  auto argTy = mlirArgTy.dyn_cast<LLVM::LLVMPointerType>();
970  if (!argTy)
971  return function.emitError(
972  "llvm.byref attribute attached to LLVM non-pointer argument");
973  if (!argTy.isOpaque() && argTy.getElementType() != attr.getValue())
974  return function.emitError(
975  "llvm.byref attribute attached to LLVM pointer "
976  "argument of a different type");
977  llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
978  .addByRefAttr(convertType(attr.getValue())));
979  }
980 
981  if (auto attr = function.getArgAttrOfType<TypeAttr>(
982  argIdx, LLVMDialect::getInAllocaAttrName())) {
983  auto argTy = mlirArgTy.dyn_cast<LLVM::LLVMPointerType>();
984  if (!argTy)
985  return function.emitError(
986  "llvm.inalloca attribute attached to LLVM non-pointer argument");
987  if (!argTy.isOpaque() && argTy.getElementType() != attr.getValue())
988  return function.emitError(
989  "llvm.inalloca attribute attached to LLVM pointer "
990  "argument of a different type");
991  llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
992  .addInAllocaAttr(convertType(attr.getValue())));
993  }
994 
995  if (auto attr =
996  function.getArgAttrOfType<UnitAttr>(argIdx, "llvm.nest")) {
997  if (!mlirArgTy.isa<LLVM::LLVMPointerType>())
998  return function.emitError(
999  "llvm.nest attribute attached to LLVM non-pointer argument");
1000  llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
1001  .addAttribute(llvm::Attribute::Nest));
1002  }
1003 
1004  if (auto attr = function.getArgAttrOfType<UnitAttr>(
1005  argIdx, LLVMDialect::getNoUndefAttrName())) {
1006  // llvm.noundef can be added to any argument type.
1007  llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
1008  .addAttribute(llvm::Attribute::NoUndef));
1009  }
1010  if (auto attr = function.getArgAttrOfType<UnitAttr>(
1011  argIdx, LLVMDialect::getSExtAttrName())) {
1012  // llvm.signext can be added to any integer argument type.
1013  if (!mlirArgTy.isa<mlir::IntegerType>())
1014  return function.emitError(
1015  "llvm.signext attribute attached to LLVM non-integer argument");
1016  llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
1017  .addAttribute(llvm::Attribute::SExt));
1018  }
1019  if (auto attr = function.getArgAttrOfType<UnitAttr>(
1020  argIdx, LLVMDialect::getZExtAttrName())) {
1021  // llvm.zeroext can be added to any integer argument type.
1022  if (!mlirArgTy.isa<mlir::IntegerType>())
1023  return function.emitError(
1024  "llvm.zeroext attribute attached to LLVM non-integer argument");
1025  llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
1026  .addAttribute(llvm::Attribute::ZExt));
1027  }
1028 
1029  ++argIdx;
1030  }
1031 
1032  // Forward the pass-through attributes to LLVM.
1034  function.getLoc(), function.getPassthrough(), llvmFunc)))
1035  return failure();
1036  }
1037 
1038  return success();
1039 }
1040 
1041 LogicalResult ModuleTranslation::convertFunctions() {
1042  // Convert functions.
1043  for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
1044  // Ignore external functions.
1045  if (function.isExternal())
1046  continue;
1047 
1048  if (failed(convertOneFunction(function)))
1049  return failure();
1050  }
1051 
1052  return success();
1053 }
1054 
1055 llvm::MDNode *
1057  SymbolRefAttr accessGroupRef) const {
1058  auto metadataName = accessGroupRef.getRootReference();
1059  auto accessGroupName = accessGroupRef.getLeafReference();
1060  auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
1061  opInst.getParentOp(), metadataName);
1062  auto *accessGroupOp =
1063  SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
1064  return accessGroupMetadataMapping.lookup(accessGroupOp);
1065 }
1066 
1067 LogicalResult ModuleTranslation::createAccessGroupMetadata() {
1068  mlirModule->walk([&](LLVM::MetadataOp metadatas) {
1069  metadatas.walk([&](LLVM::AccessGroupMetadataOp op) {
1070  llvm::LLVMContext &ctx = llvmModule->getContext();
1071  llvm::MDNode *accessGroup = llvm::MDNode::getDistinct(ctx, {});
1072  accessGroupMetadataMapping.insert({op, accessGroup});
1073  });
1074  });
1075  return success();
1076 }
1077 
1079  llvm::Instruction *inst) {
1080  auto accessGroups =
1081  op->getAttrOfType<ArrayAttr>(LLVMDialect::getAccessGroupsAttrName());
1082  if (accessGroups && !accessGroups.empty()) {
1083  llvm::Module *module = inst->getModule();
1085  for (SymbolRefAttr accessGroupRef :
1086  accessGroups.getAsRange<SymbolRefAttr>())
1087  metadatas.push_back(getAccessGroup(*op, accessGroupRef));
1088 
1089  llvm::MDNode *unionMD = nullptr;
1090  if (metadatas.size() == 1)
1091  unionMD = llvm::cast<llvm::MDNode>(metadatas.front());
1092  else if (metadatas.size() >= 2)
1093  unionMD = llvm::MDNode::get(module->getContext(), metadatas);
1094 
1095  inst->setMetadata(module->getMDKindID("llvm.access.group"), unionMD);
1096  }
1097 }
1098 
1099 LogicalResult ModuleTranslation::createAliasScopeMetadata() {
1100  mlirModule->walk([&](LLVM::MetadataOp metadatas) {
1101  // Create the domains first, so they can be reference below in the scopes.
1102  DenseMap<Operation *, llvm::MDNode *> aliasScopeDomainMetadataMapping;
1103  metadatas.walk([&](LLVM::AliasScopeDomainMetadataOp op) {
1104  llvm::LLVMContext &ctx = llvmModule->getContext();
1106  operands.push_back({}); // Placeholder for self-reference
1107  if (Optional<StringRef> description = op.getDescription())
1108  operands.push_back(llvm::MDString::get(ctx, *description));
1109  llvm::MDNode *domain = llvm::MDNode::get(ctx, operands);
1110  domain->replaceOperandWith(0, domain); // Self-reference for uniqueness
1111  aliasScopeDomainMetadataMapping.insert({op, domain});
1112  });
1113 
1114  // Now create the scopes, referencing the domains created above.
1115  metadatas.walk([&](LLVM::AliasScopeMetadataOp op) {
1116  llvm::LLVMContext &ctx = llvmModule->getContext();
1117  assert(isa<LLVM::MetadataOp>(op->getParentOp()));
1118  auto metadataOp = dyn_cast<LLVM::MetadataOp>(op->getParentOp());
1119  Operation *domainOp =
1120  SymbolTable::lookupNearestSymbolFrom(metadataOp, op.getDomainAttr());
1121  llvm::MDNode *domain = aliasScopeDomainMetadataMapping.lookup(domainOp);
1122  assert(domain && "Scope's domain should already be valid");
1124  operands.push_back({}); // Placeholder for self-reference
1125  operands.push_back(domain);
1126  if (Optional<StringRef> description = op.getDescription())
1127  operands.push_back(llvm::MDString::get(ctx, *description));
1128  llvm::MDNode *scope = llvm::MDNode::get(ctx, operands);
1129  scope->replaceOperandWith(0, scope); // Self-reference for uniqueness
1130  aliasScopeMetadataMapping.insert({op, scope});
1131  });
1132  });
1133  return success();
1134 }
1135 
1136 llvm::MDNode *
1138  SymbolRefAttr aliasScopeRef) const {
1139  StringAttr metadataName = aliasScopeRef.getRootReference();
1140  StringAttr scopeName = aliasScopeRef.getLeafReference();
1141  auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
1142  opInst.getParentOp(), metadataName);
1143  Operation *aliasScopeOp =
1144  SymbolTable::lookupNearestSymbolFrom(metadataOp, scopeName);
1145  return aliasScopeMetadataMapping.lookup(aliasScopeOp);
1146 }
1147 
1149  llvm::Instruction *inst) {
1150  auto populateScopeMetadata = [this, op, inst](StringRef attrName,
1151  StringRef llvmMetadataName) {
1152  auto scopes = op->getAttrOfType<ArrayAttr>(attrName);
1153  if (!scopes || scopes.empty())
1154  return;
1155  llvm::Module *module = inst->getModule();
1157  for (SymbolRefAttr scopeRef : scopes.getAsRange<SymbolRefAttr>())
1158  scopeMDs.push_back(getAliasScope(*op, scopeRef));
1159  llvm::MDNode *unionMD = llvm::MDNode::get(module->getContext(), scopeMDs);
1160  inst->setMetadata(module->getMDKindID(llvmMetadataName), unionMD);
1161  };
1162 
1163  populateScopeMetadata(LLVMDialect::getAliasScopesAttrName(), "alias.scope");
1164  populateScopeMetadata(LLVMDialect::getNoAliasScopesAttrName(), "noalias");
1165 }
1166 
1168  return typeTranslator.translateType(type);
1169 }
1170 
1171 /// A helper to look up remapped operands in the value remapping table.
1173  SmallVector<llvm::Value *> remapped;
1174  remapped.reserve(values.size());
1175  for (Value v : values)
1176  remapped.push_back(lookupValue(v));
1177  return remapped;
1178 }
1179 
1180 const llvm::DILocation *
1181 ModuleTranslation::translateLoc(Location loc, llvm::DILocalScope *scope) {
1182  return debugTranslation->translateLoc(loc, scope);
1183 }
1184 
1186  return debugTranslation->translate(attr);
1187 }
1188 
1189 llvm::NamedMDNode *
1191  return llvmModule->getOrInsertNamedMetadata(name);
1192 }
1193 
1194 void ModuleTranslation::StackFrame::anchor() {}
1195 
1196 static std::unique_ptr<llvm::Module>
1197 prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
1198  StringRef name) {
1199  m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>();
1200  auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext);
1201  if (auto dataLayoutAttr =
1202  m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) {
1203  llvmModule->setDataLayout(dataLayoutAttr.cast<StringAttr>().getValue());
1204  } else {
1205  FailureOr<llvm::DataLayout> llvmDataLayout(llvm::DataLayout(""));
1206  if (auto iface = dyn_cast<DataLayoutOpInterface>(m)) {
1207  if (DataLayoutSpecInterface spec = iface.getDataLayoutSpec()) {
1208  llvmDataLayout =
1209  translateDataLayout(spec, DataLayout(iface), m->getLoc());
1210  }
1211  } else if (auto mod = dyn_cast<ModuleOp>(m)) {
1212  if (DataLayoutSpecInterface spec = mod.getDataLayoutSpec()) {
1213  llvmDataLayout =
1214  translateDataLayout(spec, DataLayout(mod), m->getLoc());
1215  }
1216  }
1217  if (failed(llvmDataLayout))
1218  return nullptr;
1219  llvmModule->setDataLayout(*llvmDataLayout);
1220  }
1221  if (auto targetTripleAttr =
1222  m->getAttr(LLVM::LLVMDialect::getTargetTripleAttrName()))
1223  llvmModule->setTargetTriple(targetTripleAttr.cast<StringAttr>().getValue());
1224 
1225  // Inject declarations for `malloc` and `free` functions that can be used in
1226  // memref allocation/deallocation coming from standard ops lowering.
1227  llvm::IRBuilder<> builder(llvmContext);
1228  llvmModule->getOrInsertFunction("malloc", builder.getInt8PtrTy(),
1229  builder.getInt64Ty());
1230  llvmModule->getOrInsertFunction("free", builder.getVoidTy(),
1231  builder.getInt8PtrTy());
1232 
1233  return llvmModule;
1234 }
1235 
1236 std::unique_ptr<llvm::Module>
1237 mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
1238  StringRef name) {
1239  if (!satisfiesLLVMModule(module)) {
1240  module->emitOpError("can not be translated to an LLVMIR module");
1241  return nullptr;
1242  }
1243 
1244  std::unique_ptr<llvm::Module> llvmModule =
1245  prepareLLVMModule(module, llvmContext, name);
1246  if (!llvmModule)
1247  return nullptr;
1248 
1250 
1251  ModuleTranslation translator(module, std::move(llvmModule));
1252  if (failed(translator.convertFunctionSignatures()))
1253  return nullptr;
1254  if (failed(translator.convertGlobals()))
1255  return nullptr;
1256  if (failed(translator.createAccessGroupMetadata()))
1257  return nullptr;
1258  if (failed(translator.createAliasScopeMetadata()))
1259  return nullptr;
1260  if (failed(translator.convertFunctions()))
1261  return nullptr;
1262 
1263  // Convert other top-level operations if possible.
1264  llvm::IRBuilder<> llvmBuilder(llvmContext);
1265  for (Operation &o : getModuleBody(module).getOperations()) {
1266  if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp, LLVM::GlobalCtorsOp,
1267  LLVM::GlobalDtorsOp, LLVM::MetadataOp>(&o) &&
1268  !o.hasTrait<OpTrait::IsTerminator>() &&
1269  failed(translator.convertOperation(o, llvmBuilder))) {
1270  return nullptr;
1271  }
1272  }
1273 
1274  if (llvm::verifyModule(*translator.llvmModule, &llvm::errs()))
1275  return nullptr;
1276 
1277  return std::move(translator.llvmModule);
1278 }
static constexpr const bool value
@ None
static Value getPHISourceValue(Block *current, Block *pred, unsigned numArguments, unsigned index)
Get the SSA value passed to the current block from the terminator operation of its predecessor.
static llvm::Constant * convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, llvm::Type *llvmType, const ModuleTranslation &moduleTranslation)
Convert a dense elements attribute to an LLVM IR constant using its raw data storage if possible.
static Block & getModuleBody(Operation *module)
A helper method to get the single Block in an operation honoring LLVM's module requirements.
static void addRuntimePreemptionSpecifier(bool dsoLocalRequested, llvm::GlobalValue *gv)
Sets the runtime preemption specifier of gv to dso_local if dsoLocalRequested is true,...
static LogicalResult checkedAddLLVMFnAttribute(Location loc, llvm::Function *llvmFunc, StringRef key, StringRef value=StringRef())
Attempts to add an attribute identified by key, optionally with the given value to LLVM function llvm...
static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage, llvm::Constant *cst)
A helper method to decide if a constant must not be set as a global variable initializer.
static llvm::Type * getInnermostElementType(llvm::Type *type)
Returns the first non-sequential type nested in sequential types.
static std::unique_ptr< llvm::Module > prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, StringRef name)
static llvm::Constant * buildSequentialConstant(ArrayRef< llvm::Constant * > &constants, ArrayRef< int64_t > shape, llvm::Type *type, Location loc)
Builds a constant of a sequential LLVM type type, potentially containing other sequential types recur...
static LogicalResult forwardPassthroughAttributes(Location loc, Optional< ArrayAttr > attributes, llvm::Function *llvmFunc)
Attaches the attributes listed in the given array attribute to llvmFunc.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:127
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:223
BlockArgListType getArguments()
Definition: Block.h:76
Operation & front()
Definition: Block.h:142
The main mechanism for performing data layout queries.
unsigned getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
unsigned getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
unsigned getTypePreferredAlignment(Type t) const
Returns the preferred of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t getNumElements() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
const InterfaceType * getInterfaceFor(Object *obj) const
Get the interface for a given object, or null if one is not registered.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
A symbol reference with a reference path containing a single element.
Base class for dialect interfaces providing translation to LLVM IR.
virtual LogicalResult convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to provide translation of the operations to LLVM IR.
virtual LogicalResult amendOperation(Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const
Acts on the given operation using the interface implemented by the dialect of one of the operation's ...
This class represents the base attribute for all debug info attributes.
Definition: LLVMAttrs.h:27
Implementation class for module translation.
LogicalResult convertBlock(Block &bb, bool ignoreArguments, llvm::IRBuilderBase &builder)
Translates the contents of the given block to LLVM IR using this translator.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
llvm::NamedMDNode * getOrInsertNamedModuleMetadata(StringRef name)
Gets the named metadata in the LLVM IR module being constructed, creating it if it does not exist.
llvm::Instruction * lookupBranch(Operation *op) const
Finds an LLVM IR instruction that corresponds to the given MLIR operation with successors.
void setAccessGroupsMetadata(Operation *op, llvm::Instruction *inst)
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
const llvm::DILocation * translateLoc(Location loc, llvm::DILocalScope *scope)
Translates the given location.
llvm::MDNode * getAccessGroup(Operation &opInst, SymbolRefAttr accessGroupRef) const
Returns the LLVM metadata corresponding to a reference to an mlir LLVM dialect access group operation...
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
void setAliasScopeMetadata(Operation *op, llvm::Instruction *inst)
llvm::Metadata * translateDebugInfo(LLVM::DINodeAttr attr)
Translates the given LLVM debug info metadata.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
void forgetMapping(Region &region)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::MDNode * getAliasScope(Operation &opInst, SymbolRefAttr aliasScopeRef) const
Returns the LLVM metadata corresponding to a reference to an mlir LLVM dialect alias scope operation.
llvm::Type * translateType(Type type)
Translates the given MLIR LLVM dialect type to LLVM IR.
Definition: TypeToLLVM.cpp:187
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:93
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:150
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:701
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Value getOperand(unsigned idx)
Definition: Operation.h:267
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:375
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:371
Block * getSuccessor(unsigned index)
Definition: Operation.h:508
unsigned getNumSuccessors()
Definition: Operation.h:506
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:147
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:165
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:225
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:486
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
dialect_attr_range getDialectAttrs()
Return a range corresponding to the dialect attributes for this operation.
Definition: Operation.h:438
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:574
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:512
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
This class models how operands are forwarded to block arguments in control flow.
bool empty() const
Returns true if there are no successor operands.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:78
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Include the generated interface declarations.
Definition: CallGraph.h:229
void connectPHINodes(Region &region, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::Value * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
SetVector< Block * > getTopologicallySortedBlocks(Region &region)
Get a topologically sorted list of blocks of the given region.
llvm::Constant * getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc, const ModuleTranslation &moduleTranslation)
Create an LLVM IR constant of llvmType from the MLIR attribute attr.
Optional< unsigned > extractPointerSpecValue(Attribute attr, PtrDLEntryPos pos)
Returns the value that corresponds to named position pos from the data layout entry attr assuming it'...
Definition: LLVMTypes.cpp:269
bool satisfiesLLVMModule(Operation *op)
LLVM requires some operations to be inside of a Module operation.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:839
void ensureDistinctSuccessors(Operation *op)
Make argument-taking successors of each block distinct.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
Type getFunctionType(Builder &builder, ArrayRef< OpAsmParser::Argument > argAttrs, ArrayRef< Type > resultTypes)
Get a function type corresponding to an array of arguments (which have types) and a set of result typ...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
DataLayoutSpecInterface translateDataLayout(const llvm::DataLayout &dataLayout, MLIRContext *context)
Translate the given LLVM data layout into an MLIR equivalent using the DLTI dialect.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::unique_ptr< llvm::Module > translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, llvm::StringRef name="LLVMDialectModule")
Translate operation that satisfies LLVM dialect module requirements into an LLVM IR module living in ...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
Definition: MathExtras.h:45
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26