MLIR  19.0.0git
ModuleTranslation.cpp
Go to the documentation of this file.
1 //===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the translation between an MLIR LLVM dialect module and
10 // the corresponding LLVMIR module. It only handles core LLVM IR operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
16 #include "AttrKindDetail.h"
17 #include "DebugTranslation.h"
19 #include "mlir/Dialect/DLTI/DLTI.h"
27 #include "mlir/IR/Attributes.h"
28 #include "mlir/IR/BuiltinOps.h"
29 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/Support/LLVM.h"
37 
38 #include "llvm/ADT/PostOrderIterator.h"
39 #include "llvm/ADT/SetVector.h"
40 #include "llvm/ADT/StringExtras.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
43 #include "llvm/IR/BasicBlock.h"
44 #include "llvm/IR/CFG.h"
45 #include "llvm/IR/Constants.h"
46 #include "llvm/IR/DerivedTypes.h"
47 #include "llvm/IR/IRBuilder.h"
48 #include "llvm/IR/InlineAsm.h"
49 #include "llvm/IR/IntrinsicsNVPTX.h"
50 #include "llvm/IR/LLVMContext.h"
51 #include "llvm/IR/MDBuilder.h"
52 #include "llvm/IR/Module.h"
53 #include "llvm/IR/Verifier.h"
54 #include "llvm/Support/Debug.h"
55 #include "llvm/Support/raw_ostream.h"
56 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
57 #include "llvm/Transforms/Utils/Cloning.h"
58 #include "llvm/Transforms/Utils/ModuleUtils.h"
59 #include <optional>
60 
61 #define DEBUG_TYPE "llvm-dialect-to-llvm-ir"
62 
63 using namespace mlir;
64 using namespace mlir::LLVM;
65 using namespace mlir::LLVM::detail;
66 
67 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
68 
69 namespace {
70 /// A customized inserter for LLVM's IRBuilder that captures all LLVM IR
71 /// instructions that are created for future reference.
72 ///
73 /// This is intended to be used with the `CollectionScope` RAII object:
74 ///
75 /// llvm::IRBuilder<..., InstructionCapturingInserter> builder;
76 /// {
77 /// InstructionCapturingInserter::CollectionScope scope(builder);
78 /// // Call IRBuilder methods as usual.
79 ///
80 /// // This will return a list of all instructions created by the builder,
81 /// // in order of creation.
82 /// builder.getInserter().getCapturedInstructions();
83 /// }
84 /// // This will return an empty list.
85 /// builder.getInserter().getCapturedInstructions();
86 ///
87 /// The capturing functionality is _disabled_ by default for performance
88 /// consideration. It needs to be explicitly enabled, which is achieved by
89 /// creating a `CollectionScope`.
90 class InstructionCapturingInserter : public llvm::IRBuilderCallbackInserter {
91 public:
92  /// Constructs the inserter.
93  InstructionCapturingInserter()
94  : llvm::IRBuilderCallbackInserter([this](llvm::Instruction *instruction) {
95  if (LLVM_LIKELY(enabled))
96  capturedInstructions.push_back(instruction);
97  }) {}
98 
99  /// Returns the list of LLVM IR instructions captured since the last cleanup.
100  ArrayRef<llvm::Instruction *> getCapturedInstructions() const {
101  return capturedInstructions;
102  }
103 
104  /// Clears the list of captured LLVM IR instructions.
105  void clearCapturedInstructions() { capturedInstructions.clear(); }
106 
107  /// RAII object enabling the capture of created LLVM IR instructions.
108  class CollectionScope {
109  public:
110  /// Creates the scope for the given inserter.
111  CollectionScope(llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing);
112 
113  /// Ends the scope.
114  ~CollectionScope();
115 
116  ArrayRef<llvm::Instruction *> getCapturedInstructions() {
117  if (!inserter)
118  return {};
119  return inserter->getCapturedInstructions();
120  }
121 
122  private:
123  /// Back reference to the inserter.
124  InstructionCapturingInserter *inserter = nullptr;
125 
126  /// List of instructions in the inserter prior to this scope.
127  SmallVector<llvm::Instruction *> previouslyCollectedInstructions;
128 
129  /// Whether the inserter was enabled prior to this scope.
130  bool wasEnabled;
131  };
132 
133  /// Enable or disable the capturing mechanism.
134  void setEnabled(bool enabled = true) { this->enabled = enabled; }
135 
136 private:
137  /// List of captured instructions.
138  SmallVector<llvm::Instruction *> capturedInstructions;
139 
140  /// Whether the collection is enabled.
141  bool enabled = false;
142 };
143 
144 using CapturingIRBuilder =
145  llvm::IRBuilder<llvm::ConstantFolder, InstructionCapturingInserter>;
146 } // namespace
147 
148 InstructionCapturingInserter::CollectionScope::CollectionScope(
149  llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing) {
150 
151  if (!isBuilderCapturing)
152  return;
153 
154  auto &capturingIRBuilder = static_cast<CapturingIRBuilder &>(irBuilder);
155  inserter = &capturingIRBuilder.getInserter();
156  wasEnabled = inserter->enabled;
157  if (wasEnabled)
158  previouslyCollectedInstructions.swap(inserter->capturedInstructions);
159  inserter->setEnabled(true);
160 }
161 
162 InstructionCapturingInserter::CollectionScope::~CollectionScope() {
163  if (!inserter)
164  return;
165 
166  previouslyCollectedInstructions.swap(inserter->capturedInstructions);
167  // If collection was enabled (likely in another, surrounding scope), keep
168  // the instructions collected in this scope.
169  if (wasEnabled) {
170  llvm::append_range(inserter->capturedInstructions,
171  previouslyCollectedInstructions);
172  }
173  inserter->setEnabled(wasEnabled);
174 }
175 
176 /// Translates the given data layout spec attribute to the LLVM IR data layout.
177 /// Only integer, float, pointer and endianness entries are currently supported.
179 translateDataLayout(DataLayoutSpecInterface attribute,
180  const DataLayout &dataLayout,
181  std::optional<Location> loc = std::nullopt) {
182  if (!loc)
183  loc = UnknownLoc::get(attribute.getContext());
184 
185  // Translate the endianness attribute.
186  std::string llvmDataLayout;
187  llvm::raw_string_ostream layoutStream(llvmDataLayout);
188  for (DataLayoutEntryInterface entry : attribute.getEntries()) {
189  auto key = llvm::dyn_cast_if_present<StringAttr>(entry.getKey());
190  if (!key)
191  continue;
192  if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) {
193  auto value = cast<StringAttr>(entry.getValue());
194  bool isLittleEndian =
195  value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle;
196  layoutStream << "-" << (isLittleEndian ? "e" : "E");
197  layoutStream.flush();
198  continue;
199  }
200  if (key.getValue() == DLTIDialect::kDataLayoutProgramMemorySpaceKey) {
201  auto value = cast<IntegerAttr>(entry.getValue());
202  uint64_t space = value.getValue().getZExtValue();
203  // Skip the default address space.
204  if (space == 0)
205  continue;
206  layoutStream << "-P" << space;
207  layoutStream.flush();
208  continue;
209  }
210  if (key.getValue() == DLTIDialect::kDataLayoutGlobalMemorySpaceKey) {
211  auto value = cast<IntegerAttr>(entry.getValue());
212  uint64_t space = value.getValue().getZExtValue();
213  // Skip the default address space.
214  if (space == 0)
215  continue;
216  layoutStream << "-G" << space;
217  layoutStream.flush();
218  continue;
219  }
220  if (key.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey) {
221  auto value = cast<IntegerAttr>(entry.getValue());
222  uint64_t space = value.getValue().getZExtValue();
223  // Skip the default address space.
224  if (space == 0)
225  continue;
226  layoutStream << "-A" << space;
227  layoutStream.flush();
228  continue;
229  }
230  if (key.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey) {
231  auto value = cast<IntegerAttr>(entry.getValue());
232  uint64_t alignment = value.getValue().getZExtValue();
233  // Skip the default stack alignment.
234  if (alignment == 0)
235  continue;
236  layoutStream << "-S" << alignment;
237  layoutStream.flush();
238  continue;
239  }
240  emitError(*loc) << "unsupported data layout key " << key;
241  return failure();
242  }
243 
244  // Go through the list of entries to check which types are explicitly
245  // specified in entries. Where possible, data layout queries are used instead
246  // of directly inspecting the entries.
247  for (DataLayoutEntryInterface entry : attribute.getEntries()) {
248  auto type = llvm::dyn_cast_if_present<Type>(entry.getKey());
249  if (!type)
250  continue;
251  // Data layout for the index type is irrelevant at this point.
252  if (isa<IndexType>(type))
253  continue;
254  layoutStream << "-";
255  LogicalResult result =
257  .Case<IntegerType, Float16Type, Float32Type, Float64Type,
258  Float80Type, Float128Type>([&](Type type) -> LogicalResult {
259  if (auto intType = dyn_cast<IntegerType>(type)) {
260  if (intType.getSignedness() != IntegerType::Signless)
261  return emitError(*loc)
262  << "unsupported data layout for non-signless integer "
263  << intType;
264  layoutStream << "i";
265  } else {
266  layoutStream << "f";
267  }
268  uint64_t size = dataLayout.getTypeSizeInBits(type);
269  uint64_t abi = dataLayout.getTypeABIAlignment(type) * 8u;
270  uint64_t preferred =
271  dataLayout.getTypePreferredAlignment(type) * 8u;
272  layoutStream << size << ":" << abi;
273  if (abi != preferred)
274  layoutStream << ":" << preferred;
275  return success();
276  })
277  .Case([&](LLVMPointerType ptrType) {
278  layoutStream << "p" << ptrType.getAddressSpace() << ":";
279  uint64_t size = dataLayout.getTypeSizeInBits(type);
280  uint64_t abi = dataLayout.getTypeABIAlignment(type) * 8u;
281  uint64_t preferred =
282  dataLayout.getTypePreferredAlignment(type) * 8u;
283  layoutStream << size << ":" << abi << ":" << preferred;
284  if (std::optional<uint64_t> index = extractPointerSpecValue(
285  entry.getValue(), PtrDLEntryPos::Index))
286  layoutStream << ":" << *index;
287  return success();
288  })
289  .Default([loc](Type type) {
290  return emitError(*loc)
291  << "unsupported type in data layout: " << type;
292  });
293  if (failed(result))
294  return failure();
295  }
296  layoutStream.flush();
297  StringRef layoutSpec(llvmDataLayout);
298  if (layoutSpec.starts_with("-"))
299  layoutSpec = layoutSpec.drop_front();
300 
301  return llvm::DataLayout(layoutSpec);
302 }
303 
304 /// Builds a constant of a sequential LLVM type `type`, potentially containing
305 /// other sequential types recursively, from the individual constant values
306 /// provided in `constants`. `shape` contains the number of elements in nested
307 /// sequential types. Reports errors at `loc` and returns nullptr on error.
308 static llvm::Constant *
310  ArrayRef<int64_t> shape, llvm::Type *type,
311  Location loc) {
312  if (shape.empty()) {
313  llvm::Constant *result = constants.front();
314  constants = constants.drop_front();
315  return result;
316  }
317 
318  llvm::Type *elementType;
319  if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
320  elementType = arrayTy->getElementType();
321  } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
322  elementType = vectorTy->getElementType();
323  } else {
324  emitError(loc) << "expected sequential LLVM types wrapping a scalar";
325  return nullptr;
326  }
327 
329  nested.reserve(shape.front());
330  for (int64_t i = 0; i < shape.front(); ++i) {
331  nested.push_back(buildSequentialConstant(constants, shape.drop_front(),
332  elementType, loc));
333  if (!nested.back())
334  return nullptr;
335  }
336 
337  if (shape.size() == 1 && type->isVectorTy())
338  return llvm::ConstantVector::get(nested);
340  llvm::ArrayType::get(elementType, shape.front()), nested);
341 }
342 
343 /// Returns the first non-sequential type nested in sequential types.
344 static llvm::Type *getInnermostElementType(llvm::Type *type) {
345  do {
346  if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
347  type = arrayTy->getElementType();
348  } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
349  type = vectorTy->getElementType();
350  } else {
351  return type;
352  }
353  } while (true);
354 }
355 
356 /// Convert a dense elements attribute to an LLVM IR constant using its raw data
357 /// storage if possible. This supports elements attributes of tensor or vector
358 /// type and avoids constructing separate objects for individual values of the
359 /// innermost dimension. Constants for other dimensions are still constructed
360 /// recursively. Returns null if constructing from raw data is not supported for
361 /// this type, e.g., element type is not a power-of-two-sized primitive. Reports
362 /// other errors at `loc`.
363 static llvm::Constant *
365  llvm::Type *llvmType,
366  const ModuleTranslation &moduleTranslation) {
367  if (!denseElementsAttr)
368  return nullptr;
369 
370  llvm::Type *innermostLLVMType = getInnermostElementType(llvmType);
371  if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType))
372  return nullptr;
373 
374  ShapedType type = denseElementsAttr.getType();
375  if (type.getNumElements() == 0)
376  return nullptr;
377 
378  // Check that the raw data size matches what is expected for the scalar size.
379  // TODO: in theory, we could repack the data here to keep constructing from
380  // raw data.
381  // TODO: we may also need to consider endianness when cross-compiling to an
382  // architecture where it is different.
383  int64_t elementByteSize = denseElementsAttr.getRawData().size() /
384  denseElementsAttr.getNumElements();
385  if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits())
386  return nullptr;
387 
388  // Compute the shape of all dimensions but the innermost. Note that the
389  // innermost dimension may be that of the vector element type.
390  bool hasVectorElementType = isa<VectorType>(type.getElementType());
391  int64_t numAggregates =
392  denseElementsAttr.getNumElements() /
393  (hasVectorElementType ? 1
394  : denseElementsAttr.getType().getShape().back());
395  ArrayRef<int64_t> outerShape = type.getShape();
396  if (!hasVectorElementType)
397  outerShape = outerShape.drop_back();
398 
399  // Handle the case of vector splat, LLVM has special support for it.
400  if (denseElementsAttr.isSplat() &&
401  (isa<VectorType>(type) || hasVectorElementType)) {
402  llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
403  innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
404  moduleTranslation);
405  llvm::Constant *splatVector =
406  llvm::ConstantDataVector::getSplat(0, splatValue);
407  SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
408  ArrayRef<llvm::Constant *> constantsRef = constants;
409  return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
410  }
411  if (denseElementsAttr.isSplat())
412  return nullptr;
413 
414  // In case of non-splat, create a constructor for the innermost constant from
415  // a piece of raw data.
416  std::function<llvm::Constant *(StringRef)> buildCstData;
417  if (isa<TensorType>(type)) {
418  auto vectorElementType = dyn_cast<VectorType>(type.getElementType());
419  if (vectorElementType && vectorElementType.getRank() == 1) {
420  buildCstData = [&](StringRef data) {
421  return llvm::ConstantDataVector::getRaw(
422  data, vectorElementType.getShape().back(), innermostLLVMType);
423  };
424  } else if (!vectorElementType) {
425  buildCstData = [&](StringRef data) {
426  return llvm::ConstantDataArray::getRaw(data, type.getShape().back(),
427  innermostLLVMType);
428  };
429  }
430  } else if (isa<VectorType>(type)) {
431  buildCstData = [&](StringRef data) {
432  return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
433  innermostLLVMType);
434  };
435  }
436  if (!buildCstData)
437  return nullptr;
438 
439  // Create innermost constants and defer to the default constant creation
440  // mechanism for other dimensions.
442  int64_t aggregateSize = denseElementsAttr.getType().getShape().back() *
443  (innermostLLVMType->getScalarSizeInBits() / 8);
444  constants.reserve(numAggregates);
445  for (unsigned i = 0; i < numAggregates; ++i) {
446  StringRef data(denseElementsAttr.getRawData().data() + i * aggregateSize,
447  aggregateSize);
448  constants.push_back(buildCstData(data));
449  }
450 
451  ArrayRef<llvm::Constant *> constantsRef = constants;
452  return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
453 }
454 
455 /// Convert a dense resource elements attribute to an LLVM IR constant using its
456 /// raw data storage if possible. This supports elements attributes of tensor or
457 /// vector type and avoids constructing separate objects for individual values
458 /// of the innermost dimension. Constants for other dimensions are still
459 /// constructed recursively. Returns nullptr on failure and emits errors at
460 /// `loc`.
461 static llvm::Constant *convertDenseResourceElementsAttr(
462  Location loc, DenseResourceElementsAttr denseResourceAttr,
463  llvm::Type *llvmType, const ModuleTranslation &moduleTranslation) {
464  assert(denseResourceAttr && "expected non-null attribute");
465 
466  llvm::Type *innermostLLVMType = getInnermostElementType(llvmType);
467  if (!llvm::ConstantDataSequential::isElementTypeCompatible(
468  innermostLLVMType)) {
469  emitError(loc, "no known conversion for innermost element type");
470  return nullptr;
471  }
472 
473  ShapedType type = denseResourceAttr.getType();
474  assert(type.getNumElements() > 0 && "Expected non-empty elements attribute");
475 
476  AsmResourceBlob *blob = denseResourceAttr.getRawHandle().getBlob();
477  if (!blob) {
478  emitError(loc, "resource does not exist");
479  return nullptr;
480  }
481 
482  ArrayRef<char> rawData = blob->getData();
483 
484  // Check that the raw data size matches what is expected for the scalar size.
485  // TODO: in theory, we could repack the data here to keep constructing from
486  // raw data.
487  // TODO: we may also need to consider endianness when cross-compiling to an
488  // architecture where it is different.
489  int64_t numElements = denseResourceAttr.getType().getNumElements();
490  int64_t elementByteSize = rawData.size() / numElements;
491  if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits()) {
492  emitError(loc, "raw data size does not match element type size");
493  return nullptr;
494  }
495 
496  // Compute the shape of all dimensions but the innermost. Note that the
497  // innermost dimension may be that of the vector element type.
498  bool hasVectorElementType = isa<VectorType>(type.getElementType());
499  int64_t numAggregates =
500  numElements / (hasVectorElementType
501  ? 1
502  : denseResourceAttr.getType().getShape().back());
503  ArrayRef<int64_t> outerShape = type.getShape();
504  if (!hasVectorElementType)
505  outerShape = outerShape.drop_back();
506 
507  // Create a constructor for the innermost constant from a piece of raw data.
508  std::function<llvm::Constant *(StringRef)> buildCstData;
509  if (isa<TensorType>(type)) {
510  auto vectorElementType = dyn_cast<VectorType>(type.getElementType());
511  if (vectorElementType && vectorElementType.getRank() == 1) {
512  buildCstData = [&](StringRef data) {
513  return llvm::ConstantDataVector::getRaw(
514  data, vectorElementType.getShape().back(), innermostLLVMType);
515  };
516  } else if (!vectorElementType) {
517  buildCstData = [&](StringRef data) {
518  return llvm::ConstantDataArray::getRaw(data, type.getShape().back(),
519  innermostLLVMType);
520  };
521  }
522  } else if (isa<VectorType>(type)) {
523  buildCstData = [&](StringRef data) {
524  return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
525  innermostLLVMType);
526  };
527  }
528  if (!buildCstData) {
529  emitError(loc, "unsupported dense_resource type");
530  return nullptr;
531  }
532 
533  // Create innermost constants and defer to the default constant creation
534  // mechanism for other dimensions.
536  int64_t aggregateSize = denseResourceAttr.getType().getShape().back() *
537  (innermostLLVMType->getScalarSizeInBits() / 8);
538  constants.reserve(numAggregates);
539  for (unsigned i = 0; i < numAggregates; ++i) {
540  StringRef data(rawData.data() + i * aggregateSize, aggregateSize);
541  constants.push_back(buildCstData(data));
542  }
543 
544  ArrayRef<llvm::Constant *> constantsRef = constants;
545  return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
546 }
547 
548 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
549 /// This currently supports integer, floating point, splat and dense element
550 /// attributes and combinations thereof. Also, an array attribute with two
551 /// elements is supported to represent a complex constant. In case of error,
552 /// report it to `loc` and return nullptr.
554  llvm::Type *llvmType, Attribute attr, Location loc,
555  const ModuleTranslation &moduleTranslation) {
556  if (!attr)
557  return llvm::UndefValue::get(llvmType);
558  if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
559  auto arrayAttr = dyn_cast<ArrayAttr>(attr);
560  if (!arrayAttr || arrayAttr.size() != 2) {
561  emitError(loc, "expected struct type to be a complex number");
562  return nullptr;
563  }
564  llvm::Type *elementType = structType->getElementType(0);
565  llvm::Constant *real =
566  getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
567  if (!real)
568  return nullptr;
569  llvm::Constant *imag =
570  getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
571  if (!imag)
572  return nullptr;
573  return llvm::ConstantStruct::get(structType, {real, imag});
574  }
575  // For integer types, we allow a mismatch in sizes as the index type in
576  // MLIR might have a different size than the index type in the LLVM module.
577  if (auto intAttr = dyn_cast<IntegerAttr>(attr))
578  return llvm::ConstantInt::get(
579  llvmType,
580  intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));
581  if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
582  const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
583  // Special case for 8-bit floats, which are represented by integers due to
584  // the lack of native fp8 types in LLVM at the moment. Additionally, handle
585  // targets (like AMDGPU) that don't implement bfloat and convert all bfloats
586  // to i16.
587  unsigned floatWidth = APFloat::getSizeInBits(sem);
588  if (llvmType->isIntegerTy(floatWidth))
589  return llvm::ConstantInt::get(llvmType,
590  floatAttr.getValue().bitcastToAPInt());
591  if (llvmType !=
592  llvm::Type::getFloatingPointTy(llvmType->getContext(),
593  floatAttr.getValue().getSemantics())) {
594  emitError(loc, "FloatAttr does not match expected type of the constant");
595  return nullptr;
596  }
597  return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
598  }
599  if (auto funcAttr = dyn_cast<FlatSymbolRefAttr>(attr))
600  return llvm::ConstantExpr::getBitCast(
601  moduleTranslation.lookupFunction(funcAttr.getValue()), llvmType);
602  if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr)) {
603  llvm::Type *elementType;
604  uint64_t numElements;
605  bool isScalable = false;
606  if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
607  elementType = arrayTy->getElementType();
608  numElements = arrayTy->getNumElements();
609  } else if (auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) {
610  elementType = fVectorTy->getElementType();
611  numElements = fVectorTy->getNumElements();
612  } else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
613  elementType = sVectorTy->getElementType();
614  numElements = sVectorTy->getMinNumElements();
615  isScalable = true;
616  } else {
617  llvm_unreachable("unrecognized constant vector type");
618  }
619  // Splat value is a scalar. Extract it only if the element type is not
620  // another sequence type. The recursion terminates because each step removes
621  // one outer sequential type.
622  bool elementTypeSequential =
623  isa<llvm::ArrayType, llvm::VectorType>(elementType);
624  llvm::Constant *child = getLLVMConstant(
625  elementType,
626  elementTypeSequential ? splatAttr
627  : splatAttr.getSplatValue<Attribute>(),
628  loc, moduleTranslation);
629  if (!child)
630  return nullptr;
631  if (llvmType->isVectorTy())
632  return llvm::ConstantVector::getSplat(
633  llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child);
634  if (llvmType->isArrayTy()) {
635  auto *arrayType = llvm::ArrayType::get(elementType, numElements);
636  SmallVector<llvm::Constant *, 8> constants(numElements, child);
637  return llvm::ConstantArray::get(arrayType, constants);
638  }
639  }
640 
641  // Try using raw elements data if possible.
642  if (llvm::Constant *result =
643  convertDenseElementsAttr(loc, dyn_cast<DenseElementsAttr>(attr),
644  llvmType, moduleTranslation)) {
645  return result;
646  }
647 
648  if (auto denseResourceAttr = dyn_cast<DenseResourceElementsAttr>(attr)) {
649  return convertDenseResourceElementsAttr(loc, denseResourceAttr, llvmType,
650  moduleTranslation);
651  }
652 
653  // Fall back to element-by-element construction otherwise.
654  if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) {
655  assert(elementsAttr.getShapedType().hasStaticShape());
656  assert(!elementsAttr.getShapedType().getShape().empty() &&
657  "unexpected empty elements attribute shape");
658 
660  constants.reserve(elementsAttr.getNumElements());
661  llvm::Type *innermostType = getInnermostElementType(llvmType);
662  for (auto n : elementsAttr.getValues<Attribute>()) {
663  constants.push_back(
664  getLLVMConstant(innermostType, n, loc, moduleTranslation));
665  if (!constants.back())
666  return nullptr;
667  }
668  ArrayRef<llvm::Constant *> constantsRef = constants;
669  llvm::Constant *result = buildSequentialConstant(
670  constantsRef, elementsAttr.getShapedType().getShape(), llvmType, loc);
671  assert(constantsRef.empty() && "did not consume all elemental constants");
672  return result;
673  }
674 
675  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
677  moduleTranslation.getLLVMContext(),
678  ArrayRef<char>{stringAttr.getValue().data(),
679  stringAttr.getValue().size()});
680  }
681  emitError(loc, "unsupported constant value");
682  return nullptr;
683 }
684 
685 ModuleTranslation::ModuleTranslation(Operation *module,
686  std::unique_ptr<llvm::Module> llvmModule)
687  : mlirModule(module), llvmModule(std::move(llvmModule)),
688  debugTranslation(
689  std::make_unique<DebugTranslation>(module, *this->llvmModule)),
690  loopAnnotationTranslation(std::make_unique<LoopAnnotationTranslation>(
691  *this, *this->llvmModule)),
692  typeTranslator(this->llvmModule->getContext()),
693  iface(module->getContext()) {
694  assert(satisfiesLLVMModule(mlirModule) &&
695  "mlirModule should honor LLVM's module semantics.");
696 }
697 
698 ModuleTranslation::~ModuleTranslation() {
699  if (ompBuilder)
700  ompBuilder->finalize();
701 }
702 
704  SmallVector<Region *> toProcess;
705  toProcess.push_back(&region);
706  while (!toProcess.empty()) {
707  Region *current = toProcess.pop_back_val();
708  for (Block &block : *current) {
709  blockMapping.erase(&block);
710  for (Value arg : block.getArguments())
711  valueMapping.erase(arg);
712  for (Operation &op : block) {
713  for (Value value : op.getResults())
714  valueMapping.erase(value);
715  if (op.hasSuccessors())
716  branchMapping.erase(&op);
717  if (isa<LLVM::GlobalOp>(op))
718  globalsMapping.erase(&op);
719  llvm::append_range(
720  toProcess,
721  llvm::map_range(op.getRegions(), [](Region &r) { return &r; }));
722  }
723  }
724  }
725 }
726 
727 /// Get the SSA value passed to the current block from the terminator operation
728 /// of its predecessor.
729 static Value getPHISourceValue(Block *current, Block *pred,
730  unsigned numArguments, unsigned index) {
731  Operation &terminator = *pred->getTerminator();
732  if (isa<LLVM::BrOp>(terminator))
733  return terminator.getOperand(index);
734 
735 #ifndef NDEBUG
736  llvm::SmallPtrSet<Block *, 4> seenSuccessors;
737  for (unsigned i = 0, e = terminator.getNumSuccessors(); i < e; ++i) {
738  Block *successor = terminator.getSuccessor(i);
739  auto branch = cast<BranchOpInterface>(terminator);
740  SuccessorOperands successorOperands = branch.getSuccessorOperands(i);
741  assert(
742  (!seenSuccessors.contains(successor) || successorOperands.empty()) &&
743  "successors with arguments in LLVM branches must be different blocks");
744  seenSuccessors.insert(successor);
745  }
746 #endif
747 
748  // For instructions that branch based on a condition value, we need to take
749  // the operands for the branch that was taken.
750  if (auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator)) {
751  // For conditional branches, we take the operands from either the "true" or
752  // the "false" branch.
753  return condBranchOp.getSuccessor(0) == current
754  ? condBranchOp.getTrueDestOperands()[index]
755  : condBranchOp.getFalseDestOperands()[index];
756  }
757 
758  if (auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) {
759  // For switches, we take the operands from either the default case, or from
760  // the case branch that was taken.
761  if (switchOp.getDefaultDestination() == current)
762  return switchOp.getDefaultOperands()[index];
763  for (const auto &i : llvm::enumerate(switchOp.getCaseDestinations()))
764  if (i.value() == current)
765  return switchOp.getCaseOperands(i.index())[index];
766  }
767 
768  if (auto invokeOp = dyn_cast<LLVM::InvokeOp>(terminator)) {
769  return invokeOp.getNormalDest() == current
770  ? invokeOp.getNormalDestOperands()[index]
771  : invokeOp.getUnwindDestOperands()[index];
772  }
773 
774  llvm_unreachable(
775  "only branch, switch or invoke operations can be terminators "
776  "of a block that has successors");
777 }
778 
779 /// Connect the PHI nodes to the results of preceding blocks.
781  const ModuleTranslation &state) {
782  // Skip the first block, it cannot be branched to and its arguments correspond
783  // to the arguments of the LLVM function.
784  for (Block &bb : llvm::drop_begin(region)) {
785  llvm::BasicBlock *llvmBB = state.lookupBlock(&bb);
786  auto phis = llvmBB->phis();
787  auto numArguments = bb.getNumArguments();
788  assert(numArguments == std::distance(phis.begin(), phis.end()));
789  for (auto [index, phiNode] : llvm::enumerate(phis)) {
790  for (auto *pred : bb.getPredecessors()) {
791  // Find the LLVM IR block that contains the converted terminator
792  // instruction and use it in the PHI node. Note that this block is not
793  // necessarily the same as state.lookupBlock(pred), some operations
794  // (in particular, OpenMP operations using OpenMPIRBuilder) may have
795  // split the blocks.
796  llvm::Instruction *terminator =
797  state.lookupBranch(pred->getTerminator());
798  assert(terminator && "missing the mapping for a terminator");
799  phiNode.addIncoming(state.lookupValue(getPHISourceValue(
800  &bb, pred, numArguments, index)),
801  terminator->getParent());
802  }
803  }
804  }
805 }
806 
808  llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic,
810  llvm::Module *module = builder.GetInsertBlock()->getModule();
811  llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic, tys);
812  return builder.CreateCall(fn, args);
813 }
814 
816  llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
817  Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
818  ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
819  ArrayRef<unsigned> immArgPositions,
820  ArrayRef<StringLiteral> immArgAttrNames) {
821  assert(immArgPositions.size() == immArgAttrNames.size() &&
822  "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
823  "length");
824 
825  // Map operands and attributes to LLVM values.
826  auto operands = moduleTranslation.lookupValues(intrOp->getOperands());
827  SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size());
828  for (auto [immArgPos, immArgName] :
829  llvm::zip(immArgPositions, immArgAttrNames)) {
830  auto attr = llvm::cast<TypedAttr>(intrOp->getAttr(immArgName));
831  assert(attr.getType().isIntOrFloat() && "expected int or float immarg");
832  auto *type = moduleTranslation.convertType(attr.getType());
833  args[immArgPos] = LLVM::detail::getLLVMConstant(
834  type, attr, intrOp->getLoc(), moduleTranslation);
835  }
836  unsigned opArg = 0;
837  for (auto &arg : args) {
838  if (!arg)
839  arg = operands[opArg++];
840  }
841 
842  // Resolve overloaded intrinsic declaration.
843  SmallVector<llvm::Type *> overloadedTypes;
844  for (unsigned overloadedResultIdx : overloadedResults) {
845  if (numResults > 1) {
846  // More than one result is mapped to an LLVM struct.
847  overloadedTypes.push_back(moduleTranslation.convertType(
848  llvm::cast<LLVM::LLVMStructType>(intrOp->getResult(0).getType())
849  .getBody()[overloadedResultIdx]));
850  } else {
851  overloadedTypes.push_back(
852  moduleTranslation.convertType(intrOp->getResult(0).getType()));
853  }
854  }
855  for (unsigned overloadedOperandIdx : overloadedOperands)
856  overloadedTypes.push_back(args[overloadedOperandIdx]->getType());
857  llvm::Module *module = builder.GetInsertBlock()->getModule();
858  llvm::Function *llvmIntr =
859  llvm::Intrinsic::getDeclaration(module, intrinsic, overloadedTypes);
860 
861  return builder.CreateCall(llvmIntr, args);
862 }
863 
864 /// Given a single MLIR operation, create the corresponding LLVM IR operation
865 /// using the `builder`.
866 LogicalResult ModuleTranslation::convertOperation(Operation &op,
867  llvm::IRBuilderBase &builder,
868  bool recordInsertions) {
869  const LLVMTranslationDialectInterface *opIface = iface.getInterfaceFor(&op);
870  if (!opIface)
871  return op.emitError("cannot be converted to LLVM IR: missing "
872  "`LLVMTranslationDialectInterface` registration for "
873  "dialect for op: ")
874  << op.getName();
875 
876  InstructionCapturingInserter::CollectionScope scope(builder,
877  recordInsertions);
878  if (failed(opIface->convertOperation(&op, builder, *this)))
879  return op.emitError("LLVM Translation failed for operation: ")
880  << op.getName();
881 
882  return convertDialectAttributes(&op, scope.getCapturedInstructions());
883 }
884 
885 /// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes
886 /// to define values corresponding to the MLIR block arguments. These nodes
887 /// are not connected to the source basic blocks, which may not exist yet. Uses
888 /// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have
889 /// been created for `bb` and included in the block mapping. Inserts new
890 /// instructions at the end of the block and leaves `builder` in a state
891 /// suitable for further insertion into the end of the block.
892 LogicalResult ModuleTranslation::convertBlockImpl(Block &bb,
893  bool ignoreArguments,
894  llvm::IRBuilderBase &builder,
895  bool recordInsertions) {
896  builder.SetInsertPoint(lookupBlock(&bb));
897  auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram();
898 
899  // Before traversing operations, make block arguments available through
900  // value remapping and PHI nodes, but do not add incoming edges for the PHI
901  // nodes just yet: those values may be defined by this or following blocks.
902  // This step is omitted if "ignoreArguments" is set. The arguments of the
903  // first block have been already made available through the remapping of
904  // LLVM function arguments.
905  if (!ignoreArguments) {
906  auto predecessors = bb.getPredecessors();
907  unsigned numPredecessors =
908  std::distance(predecessors.begin(), predecessors.end());
909  for (auto arg : bb.getArguments()) {
910  auto wrappedType = arg.getType();
911  if (!isCompatibleType(wrappedType))
912  return emitError(bb.front().getLoc(),
913  "block argument does not have an LLVM type");
914  llvm::Type *type = convertType(wrappedType);
915  llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
916  mapValue(arg, phi);
917  }
918  }
919 
920  // Traverse operations.
921  for (auto &op : bb) {
922  // Set the current debug location within the builder.
923  builder.SetCurrentDebugLocation(
924  debugTranslation->translateLoc(op.getLoc(), subprogram));
925 
926  if (failed(convertOperation(op, builder, recordInsertions)))
927  return failure();
928 
929  // Set the branch weight metadata on the translated instruction.
930  if (auto iface = dyn_cast<BranchWeightOpInterface>(op))
932  }
933 
934  return success();
935 }
936 
937 /// A helper method to get the single Block in an operation honoring LLVM's
938 /// module requirements.
939 static Block &getModuleBody(Operation *module) {
940  return module->getRegion(0).front();
941 }
942 
943 /// A helper method to decide if a constant must not be set as a global variable
944 /// initializer. For an external linkage variable, the variable with an
945 /// initializer is considered externally visible and defined in this module, the
946 /// variable without an initializer is externally available and is defined
947 /// elsewhere.
948 static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage,
949  llvm::Constant *cst) {
950  return (linkage == llvm::GlobalVariable::ExternalLinkage && !cst) ||
951  linkage == llvm::GlobalVariable::ExternalWeakLinkage;
952 }
953 
954 /// Sets the runtime preemption specifier of `gv` to dso_local if
955 /// `dsoLocalRequested` is true, otherwise it is left unchanged.
956 static void addRuntimePreemptionSpecifier(bool dsoLocalRequested,
957  llvm::GlobalValue *gv) {
958  if (dsoLocalRequested)
959  gv->setDSOLocal(true);
960 }
961 
962 /// Create named global variables that correspond to llvm.mlir.global
963 /// definitions. Convert llvm.global_ctors and global_dtors ops.
964 LogicalResult ModuleTranslation::convertGlobals() {
965  // Mapping from compile unit to its respective set of global variables.
967 
968  for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
969  llvm::Type *type = convertType(op.getType());
970  llvm::Constant *cst = nullptr;
971  if (op.getValueOrNull()) {
972  // String attributes are treated separately because they cannot appear as
973  // in-function constants and are thus not supported by getLLVMConstant.
974  if (auto strAttr = dyn_cast_or_null<StringAttr>(op.getValueOrNull())) {
975  cst = llvm::ConstantDataArray::getString(
976  llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false);
977  type = cst->getType();
978  } else if (!(cst = getLLVMConstant(type, op.getValueOrNull(), op.getLoc(),
979  *this))) {
980  return failure();
981  }
982  }
983 
984  auto linkage = convertLinkageToLLVM(op.getLinkage());
985  auto addrSpace = op.getAddrSpace();
986 
987  // LLVM IR requires constant with linkage other than external or weak
988  // external to have initializers. If MLIR does not provide an initializer,
989  // default to undef.
990  bool dropInitializer = shouldDropGlobalInitializer(linkage, cst);
991  if (!dropInitializer && !cst)
992  cst = llvm::UndefValue::get(type);
993  else if (dropInitializer && cst)
994  cst = nullptr;
995 
996  auto *var = new llvm::GlobalVariable(
997  *llvmModule, type, op.getConstant(), linkage, cst, op.getSymName(),
998  /*InsertBefore=*/nullptr,
999  op.getThreadLocal_() ? llvm::GlobalValue::GeneralDynamicTLSModel
1000  : llvm::GlobalValue::NotThreadLocal,
1001  addrSpace);
1002 
1003  if (std::optional<mlir::SymbolRefAttr> comdat = op.getComdat()) {
1004  auto selectorOp = cast<ComdatSelectorOp>(
1006  var->setComdat(comdatMapping.lookup(selectorOp));
1007  }
1008 
1009  if (op.getUnnamedAddr().has_value())
1010  var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr()));
1011 
1012  if (op.getSection().has_value())
1013  var->setSection(*op.getSection());
1014 
1015  addRuntimePreemptionSpecifier(op.getDsoLocal(), var);
1016 
1017  std::optional<uint64_t> alignment = op.getAlignment();
1018  if (alignment.has_value())
1019  var->setAlignment(llvm::MaybeAlign(alignment.value()));
1020 
1021  var->setVisibility(convertVisibilityToLLVM(op.getVisibility_()));
1022 
1023  globalsMapping.try_emplace(op, var);
1024 
1025  // Add debug information if present.
1026  if (op.getDbgExpr()) {
1027  llvm::DIGlobalVariableExpression *diGlobalExpr =
1028  debugTranslation->translateGlobalVariableExpression(op.getDbgExpr());
1029  llvm::DIGlobalVariable *diGlobalVar = diGlobalExpr->getVariable();
1030  var->addDebugInfo(diGlobalExpr);
1031 
1032  // Get the compile unit (scope) of the the global variable.
1033  if (llvm::DICompileUnit *compileUnit =
1034  dyn_cast_if_present<llvm::DICompileUnit>(
1035  diGlobalVar->getScope())) {
1036  // Update the compile unit with this incoming global variable expression
1037  // during the finalizing step later.
1038  allGVars[compileUnit].push_back(diGlobalExpr);
1039  }
1040  }
1041  }
1042 
1043  // Convert global variable bodies. This is done after all global variables
1044  // have been created in LLVM IR because a global body may refer to another
1045  // global or itself. So all global variables need to be mapped first.
1046  for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
1047  if (Block *initializer = op.getInitializerBlock()) {
1048  llvm::IRBuilder<> builder(llvmModule->getContext());
1049 
1050  [[maybe_unused]] int numConstantsHit = 0;
1051  [[maybe_unused]] int numConstantsErased = 0;
1052  DenseMap<llvm::ConstantAggregate *, int> constantAggregateUseMap;
1053 
1054  for (auto &op : initializer->without_terminator()) {
1055  if (failed(convertOperation(op, builder)))
1056  return emitError(op.getLoc(), "fail to convert global initializer");
1057  auto *cst = dyn_cast<llvm::Constant>(lookupValue(op.getResult(0)));
1058  if (!cst)
1059  return emitError(op.getLoc(), "unemittable constant value");
1060 
1061  // When emitting an LLVM constant, a new constant is created and the old
1062  // constant may become dangling and take space. We should remove the
1063  // dangling constants to avoid memory explosion especially for constant
1064  // arrays whose number of elements is large.
1065  // Because multiple operations may refer to the same constant, we need
1066  // to count the number of uses of each constant array and remove it only
1067  // when the count becomes zero.
1068  if (auto *agg = dyn_cast<llvm::ConstantAggregate>(cst)) {
1069  numConstantsHit++;
1070  Value result = op.getResult(0);
1071  int numUsers = std::distance(result.use_begin(), result.use_end());
1072  auto [iterator, inserted] =
1073  constantAggregateUseMap.try_emplace(agg, numUsers);
1074  if (!inserted) {
1075  // Key already exists, update the value
1076  iterator->second += numUsers;
1077  }
1078  }
1079  // Scan the operands of the operation to decrement the use count of
1080  // constants. Erase the constant if the use count becomes zero.
1081  for (Value v : op.getOperands()) {
1082  auto cst = dyn_cast<llvm::ConstantAggregate>(lookupValue(v));
1083  if (!cst)
1084  continue;
1085  auto iter = constantAggregateUseMap.find(cst);
1086  assert(iter != constantAggregateUseMap.end() && "constant not found");
1087  iter->second--;
1088  if (iter->second == 0) {
1089  // NOTE: cannot call removeDeadConstantUsers() here because it
1090  // may remove the constant which has uses not be converted yet.
1091  if (cst->user_empty()) {
1092  cst->destroyConstant();
1093  numConstantsErased++;
1094  }
1095  constantAggregateUseMap.erase(iter);
1096  }
1097  }
1098  }
1099 
1100  ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
1101  llvm::Constant *cst =
1102  cast<llvm::Constant>(lookupValue(ret.getOperand(0)));
1103  auto *global = cast<llvm::GlobalVariable>(lookupGlobal(op));
1104  if (!shouldDropGlobalInitializer(global->getLinkage(), cst))
1105  global->setInitializer(cst);
1106 
1107  // Try to remove the dangling constants again after all operations are
1108  // converted.
1109  for (auto it : constantAggregateUseMap) {
1110  auto cst = it.first;
1111  cst->removeDeadConstantUsers();
1112  if (cst->user_empty()) {
1113  cst->destroyConstant();
1114  numConstantsErased++;
1115  }
1116  }
1117 
1118  LLVM_DEBUG(llvm::dbgs()
1119  << "Convert initializer for " << op.getName() << "\n";
1120  llvm::dbgs() << numConstantsHit << " new constants hit\n";
1121  llvm::dbgs()
1122  << numConstantsErased << " dangling constants erased\n";);
1123  }
1124  }
1125 
1126  // Convert llvm.mlir.global_ctors and dtors.
1127  for (Operation &op : getModuleBody(mlirModule)) {
1128  auto ctorOp = dyn_cast<GlobalCtorsOp>(op);
1129  auto dtorOp = dyn_cast<GlobalDtorsOp>(op);
1130  if (!ctorOp && !dtorOp)
1131  continue;
1132  auto range = ctorOp ? llvm::zip(ctorOp.getCtors(), ctorOp.getPriorities())
1133  : llvm::zip(dtorOp.getDtors(), dtorOp.getPriorities());
1134  auto appendGlobalFn =
1135  ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors;
1136  for (auto symbolAndPriority : range) {
1137  llvm::Function *f = lookupFunction(
1138  cast<FlatSymbolRefAttr>(std::get<0>(symbolAndPriority)).getValue());
1139  appendGlobalFn(*llvmModule, f,
1140  cast<IntegerAttr>(std::get<1>(symbolAndPriority)).getInt(),
1141  /*Data=*/nullptr);
1142  }
1143  }
1144 
1145  for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>())
1146  if (failed(convertDialectAttributes(op, {})))
1147  return failure();
1148 
1149  // Finally, update the compile units their respective sets of global variables
1150  // created earlier.
1151  for (const auto &[compileUnit, globals] : allGVars) {
1152  compileUnit->replaceGlobalVariables(
1153  llvm::MDTuple::get(getLLVMContext(), globals));
1154  }
1155 
1156  return success();
1157 }
1158 
1159 /// Attempts to add an attribute identified by `key`, optionally with the given
1160 /// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
1161 /// attribute has a kind known to LLVM IR, create the attribute of this kind,
1162 /// otherwise keep it as a string attribute. Performs additional checks for
1163 /// attributes known to have or not have a value in order to avoid assertions
1164 /// inside LLVM upon construction.
1166  llvm::Function *llvmFunc,
1167  StringRef key,
1168  StringRef value = StringRef()) {
1169  auto kind = llvm::Attribute::getAttrKindFromName(key);
1170  if (kind == llvm::Attribute::None) {
1171  llvmFunc->addFnAttr(key, value);
1172  return success();
1173  }
1174 
1175  if (llvm::Attribute::isIntAttrKind(kind)) {
1176  if (value.empty())
1177  return emitError(loc) << "LLVM attribute '" << key << "' expects a value";
1178 
1179  int64_t result;
1180  if (!value.getAsInteger(/*Radix=*/0, result))
1181  llvmFunc->addFnAttr(
1182  llvm::Attribute::get(llvmFunc->getContext(), kind, result));
1183  else
1184  llvmFunc->addFnAttr(key, value);
1185  return success();
1186  }
1187 
1188  if (!value.empty())
1189  return emitError(loc) << "LLVM attribute '" << key
1190  << "' does not expect a value, found '" << value
1191  << "'";
1192 
1193  llvmFunc->addFnAttr(kind);
1194  return success();
1195 }
1196 
1197 /// Attaches the attributes listed in the given array attribute to `llvmFunc`.
1198 /// Reports error to `loc` if any and returns immediately. Expects `attributes`
1199 /// to be an array attribute containing either string attributes, treated as
1200 /// value-less LLVM attributes, or array attributes containing two string
1201 /// attributes, with the first string being the name of the corresponding LLVM
1202 /// attribute and the second string beings its value. Note that even integer
1203 /// attributes are expected to have their values expressed as strings.
1204 static LogicalResult
1205 forwardPassthroughAttributes(Location loc, std::optional<ArrayAttr> attributes,
1206  llvm::Function *llvmFunc) {
1207  if (!attributes)
1208  return success();
1209 
1210  for (Attribute attr : *attributes) {
1211  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
1212  if (failed(
1213  checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue())))
1214  return failure();
1215  continue;
1216  }
1217 
1218  auto arrayAttr = dyn_cast<ArrayAttr>(attr);
1219  if (!arrayAttr || arrayAttr.size() != 2)
1220  return emitError(loc)
1221  << "expected 'passthrough' to contain string or array attributes";
1222 
1223  auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]);
1224  auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]);
1225  if (!keyAttr || !valueAttr)
1226  return emitError(loc)
1227  << "expected arrays within 'passthrough' to contain two strings";
1228 
1229  if (failed(checkedAddLLVMFnAttribute(loc, llvmFunc, keyAttr.getValue(),
1230  valueAttr.getValue())))
1231  return failure();
1232  }
1233  return success();
1234 }
1235 
1236 LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
1237  // Clear the block, branch value mappings, they are only relevant within one
1238  // function.
1239  blockMapping.clear();
1240  valueMapping.clear();
1241  branchMapping.clear();
1242  llvm::Function *llvmFunc = lookupFunction(func.getName());
1243 
1244  // Add function arguments to the value remapping table.
1245  for (auto [mlirArg, llvmArg] :
1246  llvm::zip(func.getArguments(), llvmFunc->args()))
1247  mapValue(mlirArg, &llvmArg);
1248 
1249  // Check the personality and set it.
1250  if (func.getPersonality()) {
1251  llvm::Type *ty = llvm::PointerType::getUnqual(llvmFunc->getContext());
1252  if (llvm::Constant *pfunc = getLLVMConstant(ty, func.getPersonalityAttr(),
1253  func.getLoc(), *this))
1254  llvmFunc->setPersonalityFn(pfunc);
1255  }
1256 
1257  if (std::optional<StringRef> section = func.getSection())
1258  llvmFunc->setSection(*section);
1259 
1260  if (func.getArmStreaming())
1261  llvmFunc->addFnAttr("aarch64_pstate_sm_enabled");
1262  else if (func.getArmLocallyStreaming())
1263  llvmFunc->addFnAttr("aarch64_pstate_sm_body");
1264  else if (func.getArmStreamingCompatible())
1265  llvmFunc->addFnAttr("aarch64_pstate_sm_compatible");
1266 
1267  if (func.getArmNewZa())
1268  llvmFunc->addFnAttr("aarch64_new_za");
1269  else if (func.getArmInZa())
1270  llvmFunc->addFnAttr("aarch64_in_za");
1271  else if (func.getArmOutZa())
1272  llvmFunc->addFnAttr("aarch64_out_za");
1273  else if (func.getArmInoutZa())
1274  llvmFunc->addFnAttr("aarch64_inout_za");
1275  else if (func.getArmPreservesZa())
1276  llvmFunc->addFnAttr("aarch64_preserves_za");
1277 
1278  if (auto targetCpu = func.getTargetCpu())
1279  llvmFunc->addFnAttr("target-cpu", *targetCpu);
1280 
1281  if (auto targetFeatures = func.getTargetFeatures())
1282  llvmFunc->addFnAttr("target-features", targetFeatures->getFeaturesString());
1283 
1284  if (auto attr = func.getVscaleRange())
1285  llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
1286  getLLVMContext(), attr->getMinRange().getInt(),
1287  attr->getMaxRange().getInt()));
1288 
1289  if (auto unsafeFpMath = func.getUnsafeFpMath())
1290  llvmFunc->addFnAttr("unsafe-fp-math", llvm::toStringRef(*unsafeFpMath));
1291 
1292  if (auto noInfsFpMath = func.getNoInfsFpMath())
1293  llvmFunc->addFnAttr("no-infs-fp-math", llvm::toStringRef(*noInfsFpMath));
1294 
1295  if (auto noNansFpMath = func.getNoNansFpMath())
1296  llvmFunc->addFnAttr("no-nans-fp-math", llvm::toStringRef(*noNansFpMath));
1297 
1298  if (auto approxFuncFpMath = func.getApproxFuncFpMath())
1299  llvmFunc->addFnAttr("approx-func-fp-math",
1300  llvm::toStringRef(*approxFuncFpMath));
1301 
1302  if (auto noSignedZerosFpMath = func.getNoSignedZerosFpMath())
1303  llvmFunc->addFnAttr("no-signed-zeros-fp-math",
1304  llvm::toStringRef(*noSignedZerosFpMath));
1305 
1306  // Add function attribute frame-pointer, if found.
1307  if (FramePointerKindAttr attr = func.getFramePointerAttr())
1308  llvmFunc->addFnAttr("frame-pointer",
1309  LLVM::framePointerKind::stringifyFramePointerKind(
1310  (attr.getFramePointerKind())));
1311 
1312  // First, create all blocks so we can jump to them.
1313  llvm::LLVMContext &llvmContext = llvmFunc->getContext();
1314  for (auto &bb : func) {
1315  auto *llvmBB = llvm::BasicBlock::Create(llvmContext);
1316  llvmBB->insertInto(llvmFunc);
1317  mapBlock(&bb, llvmBB);
1318  }
1319 
1320  // Then, convert blocks one by one in topological order to ensure defs are
1321  // converted before uses.
1322  auto blocks = getTopologicallySortedBlocks(func.getBody());
1323  for (Block *bb : blocks) {
1324  CapturingIRBuilder builder(llvmContext);
1325  if (failed(convertBlockImpl(*bb, bb->isEntryBlock(), builder,
1326  /*recordInsertions=*/true)))
1327  return failure();
1328  }
1329 
1330  // After all blocks have been traversed and values mapped, connect the PHI
1331  // nodes to the results of preceding blocks.
1332  detail::connectPHINodes(func.getBody(), *this);
1333 
1334  // Finally, convert dialect attributes attached to the function.
1335  return convertDialectAttributes(func, {});
1336 }
1337 
1338 LogicalResult ModuleTranslation::convertDialectAttributes(
1339  Operation *op, ArrayRef<llvm::Instruction *> instructions) {
1340  for (NamedAttribute attribute : op->getDialectAttrs())
1341  if (failed(iface.amendOperation(op, instructions, attribute, *this)))
1342  return failure();
1343  return success();
1344 }
1345 
1346 /// Converts the function attributes from LLVMFuncOp and attaches them to the
1347 /// llvm::Function.
1348 static void convertFunctionAttributes(LLVMFuncOp func,
1349  llvm::Function *llvmFunc) {
1350  if (!func.getMemory())
1351  return;
1352 
1353  MemoryEffectsAttr memEffects = func.getMemoryAttr();
1354 
1355  // Add memory effects incrementally.
1356  llvm::MemoryEffects newMemEffects =
1357  llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
1358  convertModRefInfoToLLVM(memEffects.getArgMem()));
1359  newMemEffects |= llvm::MemoryEffects(
1360  llvm::MemoryEffects::Location::InaccessibleMem,
1361  convertModRefInfoToLLVM(memEffects.getInaccessibleMem()));
1362  newMemEffects |=
1363  llvm::MemoryEffects(llvm::MemoryEffects::Location::Other,
1364  convertModRefInfoToLLVM(memEffects.getOther()));
1365  llvmFunc->setMemoryEffects(newMemEffects);
1366 }
1367 
1369 ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
1370  DictionaryAttr paramAttrs) {
1371  llvm::AttrBuilder attrBuilder(llvmModule->getContext());
1372  auto attrNameToKindMapping = getAttrNameToKindMapping();
1373 
1374  for (auto namedAttr : paramAttrs) {
1375  auto it = attrNameToKindMapping.find(namedAttr.getName());
1376  if (it != attrNameToKindMapping.end()) {
1377  llvm::Attribute::AttrKind llvmKind = it->second;
1378 
1379  llvm::TypeSwitch<Attribute>(namedAttr.getValue())
1380  .Case<TypeAttr>([&](auto typeAttr) {
1381  attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue()));
1382  })
1383  .Case<IntegerAttr>([&](auto intAttr) {
1384  attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
1385  })
1386  .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); });
1387  } else if (namedAttr.getNameDialect()) {
1388  if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this)))
1389  return failure();
1390  }
1391  }
1392 
1393  return attrBuilder;
1394 }
1395 
1396 LogicalResult ModuleTranslation::convertFunctionSignatures() {
1397  // Declare all functions first because there may be function calls that form a
1398  // call graph with cycles, or global initializers that reference functions.
1399  for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
1400  llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
1401  function.getName(),
1402  cast<llvm::FunctionType>(convertType(function.getFunctionType())));
1403  llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee());
1404  llvmFunc->setLinkage(convertLinkageToLLVM(function.getLinkage()));
1405  llvmFunc->setCallingConv(convertCConvToLLVM(function.getCConv()));
1406  mapFunction(function.getName(), llvmFunc);
1407  addRuntimePreemptionSpecifier(function.getDsoLocal(), llvmFunc);
1408 
1409  // Convert function attributes.
1410  convertFunctionAttributes(function, llvmFunc);
1411 
1412  // Convert function_entry_count attribute to metadata.
1413  if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount())
1414  llvmFunc->setEntryCount(entryCount.value());
1415 
1416  // Convert result attributes.
1417  if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
1418  DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]);
1419  FailureOr<llvm::AttrBuilder> attrBuilder =
1420  convertParameterAttrs(function, -1, resultAttrs);
1421  if (failed(attrBuilder))
1422  return failure();
1423  llvmFunc->addRetAttrs(*attrBuilder);
1424  }
1425 
1426  // Convert argument attributes.
1427  for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) {
1428  if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) {
1429  FailureOr<llvm::AttrBuilder> attrBuilder =
1430  convertParameterAttrs(function, argIdx, argAttrs);
1431  if (failed(attrBuilder))
1432  return failure();
1433  llvmArg.addAttrs(*attrBuilder);
1434  }
1435  }
1436 
1437  // Forward the pass-through attributes to LLVM.
1439  function.getLoc(), function.getPassthrough(), llvmFunc)))
1440  return failure();
1441 
1442  // Convert visibility attribute.
1443  llvmFunc->setVisibility(convertVisibilityToLLVM(function.getVisibility_()));
1444 
1445  // Convert the comdat attribute.
1446  if (std::optional<mlir::SymbolRefAttr> comdat = function.getComdat()) {
1447  auto selectorOp = cast<ComdatSelectorOp>(
1448  SymbolTable::lookupNearestSymbolFrom(function, *comdat));
1449  llvmFunc->setComdat(comdatMapping.lookup(selectorOp));
1450  }
1451 
1452  if (auto gc = function.getGarbageCollector())
1453  llvmFunc->setGC(gc->str());
1454 
1455  if (auto unnamedAddr = function.getUnnamedAddr())
1456  llvmFunc->setUnnamedAddr(convertUnnamedAddrToLLVM(*unnamedAddr));
1457 
1458  if (auto alignment = function.getAlignment())
1459  llvmFunc->setAlignment(llvm::MaybeAlign(*alignment));
1460 
1461  // Translate the debug information for this function.
1462  debugTranslation->translate(function, *llvmFunc);
1463  }
1464 
1465  return success();
1466 }
1467 
1468 LogicalResult ModuleTranslation::convertFunctions() {
1469  // Convert functions.
1470  for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
1471  // Do not convert external functions, but do process dialect attributes
1472  // attached to them.
1473  if (function.isExternal()) {
1474  if (failed(convertDialectAttributes(function, {})))
1475  return failure();
1476  continue;
1477  }
1478 
1479  if (failed(convertOneFunction(function)))
1480  return failure();
1481  }
1482 
1483  return success();
1484 }
1485 
1486 LogicalResult ModuleTranslation::convertComdats() {
1487  for (auto comdatOp : getModuleBody(mlirModule).getOps<ComdatOp>()) {
1488  for (auto selectorOp : comdatOp.getOps<ComdatSelectorOp>()) {
1489  llvm::Module *module = getLLVMModule();
1490  if (module->getComdatSymbolTable().contains(selectorOp.getSymName()))
1491  return emitError(selectorOp.getLoc())
1492  << "comdat selection symbols must be unique even in different "
1493  "comdat regions";
1494  llvm::Comdat *comdat = module->getOrInsertComdat(selectorOp.getSymName());
1495  comdat->setSelectionKind(convertComdatToLLVM(selectorOp.getComdat()));
1496  comdatMapping.try_emplace(selectorOp, comdat);
1497  }
1498  }
1499  return success();
1500 }
1501 
1502 void ModuleTranslation::setAccessGroupsMetadata(AccessGroupOpInterface op,
1503  llvm::Instruction *inst) {
1504  if (llvm::MDNode *node = loopAnnotationTranslation->getAccessGroups(op))
1505  inst->setMetadata(llvm::LLVMContext::MD_access_group, node);
1506 }
1507 
1508 llvm::MDNode *
1509 ModuleTranslation::getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr) {
1510  auto [scopeIt, scopeInserted] =
1511  aliasScopeMetadataMapping.try_emplace(aliasScopeAttr, nullptr);
1512  if (!scopeInserted)
1513  return scopeIt->second;
1514  llvm::LLVMContext &ctx = llvmModule->getContext();
1515  // Convert the domain metadata node if necessary.
1516  auto [domainIt, insertedDomain] = aliasDomainMetadataMapping.try_emplace(
1517  aliasScopeAttr.getDomain(), nullptr);
1518  if (insertedDomain) {
1520  // Placeholder for self-reference.
1521  operands.push_back({});
1522  if (StringAttr description = aliasScopeAttr.getDomain().getDescription())
1523  operands.push_back(llvm::MDString::get(ctx, description));
1524  domainIt->second = llvm::MDNode::get(ctx, operands);
1525  // Self-reference for uniqueness.
1526  domainIt->second->replaceOperandWith(0, domainIt->second);
1527  }
1528  // Convert the scope metadata node.
1529  assert(domainIt->second && "Scope's domain should already be valid");
1531  // Placeholder for self-reference.
1532  operands.push_back({});
1533  operands.push_back(domainIt->second);
1534  if (StringAttr description = aliasScopeAttr.getDescription())
1535  operands.push_back(llvm::MDString::get(ctx, description));
1536  scopeIt->second = llvm::MDNode::get(ctx, operands);
1537  // Self-reference for uniqueness.
1538  scopeIt->second->replaceOperandWith(0, scopeIt->second);
1539  return scopeIt->second;
1540 }
1541 
1543  ArrayRef<AliasScopeAttr> aliasScopeAttrs) {
1545  nodes.reserve(aliasScopeAttrs.size());
1546  for (AliasScopeAttr aliasScopeAttr : aliasScopeAttrs)
1547  nodes.push_back(getOrCreateAliasScope(aliasScopeAttr));
1548  return llvm::MDNode::get(getLLVMContext(), nodes);
1549 }
1550 
1551 void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op,
1552  llvm::Instruction *inst) {
1553  auto populateScopeMetadata = [&](ArrayAttr aliasScopeAttrs, unsigned kind) {
1554  if (!aliasScopeAttrs || aliasScopeAttrs.empty())
1555  return;
1556  llvm::MDNode *node = getOrCreateAliasScopes(
1557  llvm::to_vector(aliasScopeAttrs.getAsRange<AliasScopeAttr>()));
1558  inst->setMetadata(kind, node);
1559  };
1560 
1561  populateScopeMetadata(op.getAliasScopesOrNull(),
1562  llvm::LLVMContext::MD_alias_scope);
1563  populateScopeMetadata(op.getNoAliasScopesOrNull(),
1564  llvm::LLVMContext::MD_noalias);
1565 }
1566 
1567 llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const {
1568  return tbaaMetadataMapping.lookup(tbaaAttr);
1569 }
1570 
1571 void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
1572  llvm::Instruction *inst) {
1573  ArrayAttr tagRefs = op.getTBAATagsOrNull();
1574  if (!tagRefs || tagRefs.empty())
1575  return;
1576 
1577  // LLVM IR currently does not support attaching more than one TBAA access tag
1578  // to a memory accessing instruction. It may be useful to support this in
1579  // future, but for the time being just ignore the metadata if MLIR operation
1580  // has multiple access tags.
1581  if (tagRefs.size() > 1) {
1582  op.emitWarning() << "TBAA access tags were not translated, because LLVM "
1583  "IR only supports a single tag per instruction";
1584  return;
1585  }
1586 
1587  llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0]));
1588  inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
1589 }
1590 
1591 void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
1592  DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
1593  if (!weightsAttr)
1594  return;
1595 
1596  llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op);
1597  assert(inst && "expected the operation to have a mapping to an instruction");
1598  SmallVector<uint32_t> weights(weightsAttr.asArrayRef());
1599  inst->setMetadata(
1600  llvm::LLVMContext::MD_prof,
1601  llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights));
1602 }
1603 
1604 LogicalResult ModuleTranslation::createTBAAMetadata() {
1605  llvm::LLVMContext &ctx = llvmModule->getContext();
1606  llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64);
1607 
1608  // Walk the entire module and create all metadata nodes for the TBAA
1609  // attributes. The code below relies on two invariants of the
1610  // `AttrTypeWalker`:
1611  // 1. Attributes are visited in post-order: Since the attributes create a DAG,
1612  // this ensures that any lookups into `tbaaMetadataMapping` for child
1613  // attributes succeed.
1614  // 2. Attributes are only ever visited once: This way we don't leak any
1615  // LLVM metadata instances.
1616  AttrTypeWalker walker;
1617  walker.addWalk([&](TBAARootAttr root) {
1618  tbaaMetadataMapping.insert(
1619  {root, llvm::MDNode::get(ctx, llvm::MDString::get(ctx, root.getId()))});
1620  });
1621 
1622  walker.addWalk([&](TBAATypeDescriptorAttr descriptor) {
1624  operands.push_back(llvm::MDString::get(ctx, descriptor.getId()));
1625  for (TBAAMemberAttr member : descriptor.getMembers()) {
1626  operands.push_back(tbaaMetadataMapping.lookup(member.getTypeDesc()));
1627  operands.push_back(llvm::ConstantAsMetadata::get(
1628  llvm::ConstantInt::get(offsetTy, member.getOffset())));
1629  }
1630 
1631  tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)});
1632  });
1633 
1634  walker.addWalk([&](TBAATagAttr tag) {
1636 
1637  operands.push_back(tbaaMetadataMapping.lookup(tag.getBaseType()));
1638  operands.push_back(tbaaMetadataMapping.lookup(tag.getAccessType()));
1639 
1640  operands.push_back(llvm::ConstantAsMetadata::get(
1641  llvm::ConstantInt::get(offsetTy, tag.getOffset())));
1642  if (tag.getConstant())
1643  operands.push_back(
1645 
1646  tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)});
1647  });
1648 
1649  mlirModule->walk([&](AliasAnalysisOpInterface analysisOpInterface) {
1650  if (auto attr = analysisOpInterface.getTBAATagsOrNull())
1651  walker.walk(attr);
1652  });
1653 
1654  return success();
1655 }
1656 
1658  llvm::Instruction *inst) {
1659  LoopAnnotationAttr attr =
1661  .Case<LLVM::BrOp, LLVM::CondBrOp>(
1662  [](auto branchOp) { return branchOp.getLoopAnnotationAttr(); });
1663  if (!attr)
1664  return;
1665  llvm::MDNode *loopMD =
1666  loopAnnotationTranslation->translateLoopAnnotation(attr, op);
1667  inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
1668 }
1669 
1671  return typeTranslator.translateType(type);
1672 }
1673 
1674 /// A helper to look up remapped operands in the value remapping table.
1676  SmallVector<llvm::Value *> remapped;
1677  remapped.reserve(values.size());
1678  for (Value v : values)
1679  remapped.push_back(lookupValue(v));
1680  return remapped;
1681 }
1682 
1683 llvm::OpenMPIRBuilder *ModuleTranslation::getOpenMPBuilder() {
1684  if (!ompBuilder) {
1685  ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(*llvmModule);
1686  ompBuilder->initialize();
1687 
1688  // Flags represented as top-level OpenMP dialect attributes are set in
1689  // `OpenMPDialectLLVMIRTranslationInterface::amendOperation()`. Here we set
1690  // the default configuration.
1691  ompBuilder->setConfig(llvm::OpenMPIRBuilderConfig(
1692  /* IsTargetDevice = */ false, /* IsGPU = */ false,
1693  /* OpenMPOffloadMandatory = */ false,
1694  /* HasRequiresReverseOffload = */ false,
1695  /* HasRequiresUnifiedAddress = */ false,
1696  /* HasRequiresUnifiedSharedMemory = */ false,
1697  /* HasRequiresDynamicAllocators = */ false));
1698  }
1699  return ompBuilder.get();
1700 }
1701 
1703  llvm::DILocalScope *scope) {
1704  return debugTranslation->translateLoc(loc, scope);
1705 }
1706 
1707 llvm::DIExpression *
1708 ModuleTranslation::translateExpression(LLVM::DIExpressionAttr attr) {
1709  return debugTranslation->translateExpression(attr);
1710 }
1711 
1712 llvm::DIGlobalVariableExpression *
1714  LLVM::DIGlobalVariableExpressionAttr attr) {
1715  return debugTranslation->translateGlobalVariableExpression(attr);
1716 }
1717 
1719  return debugTranslation->translate(attr);
1720 }
1721 
1722 llvm::NamedMDNode *
1724  return llvmModule->getOrInsertNamedMetadata(name);
1725 }
1726 
1727 void ModuleTranslation::StackFrame::anchor() {}
1728 
1729 static std::unique_ptr<llvm::Module>
1730 prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
1731  StringRef name) {
1732  m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>();
1733  auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext);
1734  if (auto dataLayoutAttr =
1735  m->getDiscardableAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) {
1736  llvmModule->setDataLayout(cast<StringAttr>(dataLayoutAttr).getValue());
1737  } else {
1738  FailureOr<llvm::DataLayout> llvmDataLayout(llvm::DataLayout(""));
1739  if (auto iface = dyn_cast<DataLayoutOpInterface>(m)) {
1740  if (DataLayoutSpecInterface spec = iface.getDataLayoutSpec()) {
1741  llvmDataLayout =
1742  translateDataLayout(spec, DataLayout(iface), m->getLoc());
1743  }
1744  } else if (auto mod = dyn_cast<ModuleOp>(m)) {
1745  if (DataLayoutSpecInterface spec = mod.getDataLayoutSpec()) {
1746  llvmDataLayout =
1747  translateDataLayout(spec, DataLayout(mod), m->getLoc());
1748  }
1749  }
1750  if (failed(llvmDataLayout))
1751  return nullptr;
1752  llvmModule->setDataLayout(*llvmDataLayout);
1753  }
1754  if (auto targetTripleAttr =
1755  m->getDiscardableAttr(LLVM::LLVMDialect::getTargetTripleAttrName()))
1756  llvmModule->setTargetTriple(cast<StringAttr>(targetTripleAttr).getValue());
1757 
1758  return llvmModule;
1759 }
1760 
1761 std::unique_ptr<llvm::Module>
1762 mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
1763  StringRef name) {
1764  if (!satisfiesLLVMModule(module)) {
1765  module->emitOpError("can not be translated to an LLVMIR module");
1766  return nullptr;
1767  }
1768 
1769  std::unique_ptr<llvm::Module> llvmModule =
1770  prepareLLVMModule(module, llvmContext, name);
1771  if (!llvmModule)
1772  return nullptr;
1773 
1776 
1777  ModuleTranslation translator(module, std::move(llvmModule));
1778  llvm::IRBuilder<> llvmBuilder(llvmContext);
1779 
1780  // Convert module before functions and operations inside, so dialect
1781  // attributes can be used to change dialect-specific global configurations via
1782  // `amendOperation()`. These configurations can then influence the translation
1783  // of operations afterwards.
1784  if (failed(translator.convertOperation(*module, llvmBuilder)))
1785  return nullptr;
1786 
1787  if (failed(translator.convertComdats()))
1788  return nullptr;
1789  if (failed(translator.convertFunctionSignatures()))
1790  return nullptr;
1791  if (failed(translator.convertGlobals()))
1792  return nullptr;
1793  if (failed(translator.createTBAAMetadata()))
1794  return nullptr;
1795 
1796  // Convert other top-level operations if possible.
1797  for (Operation &o : getModuleBody(module).getOperations()) {
1798  if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp, LLVM::GlobalCtorsOp,
1799  LLVM::GlobalDtorsOp, LLVM::ComdatOp>(&o) &&
1800  !o.hasTrait<OpTrait::IsTerminator>() &&
1801  failed(translator.convertOperation(o, llvmBuilder))) {
1802  return nullptr;
1803  }
1804  }
1805 
1806  // Operations in function bodies with symbolic references must be converted
1807  // after the top-level operations they refer to are declared, so we do it
1808  // last.
1809  if (failed(translator.convertFunctions()))
1810  return nullptr;
1811 
1812  if (llvm::verifyModule(*translator.llvmModule, &llvm::errs()))
1813  return nullptr;
1814 
1815  return std::move(translator.llvmModule);
1816 }
static MLIRContext * getContext(OpFoldResult val)
@ None
static Value getPHISourceValue(Block *current, Block *pred, unsigned numArguments, unsigned index)
Get the SSA value passed to the current block from the terminator operation of its predecessor.
static llvm::Constant * convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, llvm::Type *llvmType, const ModuleTranslation &moduleTranslation)
Convert a dense elements attribute to an LLVM IR constant using its raw data storage if possible.
static Block & getModuleBody(Operation *module)
A helper method to get the single Block in an operation honoring LLVM's module requirements.
static void addRuntimePreemptionSpecifier(bool dsoLocalRequested, llvm::GlobalValue *gv)
Sets the runtime preemption specifier of gv to dso_local if dsoLocalRequested is true,...
static LogicalResult checkedAddLLVMFnAttribute(Location loc, llvm::Function *llvmFunc, StringRef key, StringRef value=StringRef())
Attempts to add an attribute identified by key, optionally with the given value to LLVM function llvm...
static void convertFunctionAttributes(LLVMFuncOp func, llvm::Function *llvmFunc)
Converts the function attributes from LLVMFuncOp and attaches them to the llvm::Function.
static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage, llvm::Constant *cst)
A helper method to decide if a constant must not be set as a global variable initializer.
static llvm::Type * getInnermostElementType(llvm::Type *type)
Returns the first non-sequential type nested in sequential types.
static std::unique_ptr< llvm::Module > prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, StringRef name)
static llvm::Constant * convertDenseResourceElementsAttr(Location loc, DenseResourceElementsAttr denseResourceAttr, llvm::Type *llvmType, const ModuleTranslation &moduleTranslation)
Convert a dense resource elements attribute to an LLVM IR constant using its raw data storage if poss...
static LogicalResult forwardPassthroughAttributes(Location loc, std::optional< ArrayAttr > attributes, llvm::Function *llvmFunc)
Attaches the attributes listed in the given array attribute to llvmFunc.
static llvm::Constant * buildSequentialConstant(ArrayRef< llvm::Constant * > &constants, ArrayRef< int64_t > shape, llvm::Type *type, Location loc)
Builds a constant of a sequential LLVM type type, potentially containing other sequential types recur...
The following classes enable support for parsing and printing resources within MLIR assembly formats.
Definition: AsmState.h:88
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition: AsmState.h:142
void addWalk(WalkFn< Attribute > &&fn)
Register a walk function for a given attribute or type.
WalkResult walk(T element)
Walk the given attribute/type, and recursively walk any sub elements.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:234
BlockArgListType getArguments()
Definition: Block.h:84
Operation & front()
Definition: Block.h:150
The main mechanism for performing data layout queries.
uint64_t getTypePreferredAlignment(Type t) const
Returns the preferred of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t getNumElements() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
const InterfaceType * getInterfaceFor(Object *obj) const
Get the interface for a given object, or null if one is not registered.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
Base class for dialect interfaces providing translation to LLVM IR.
virtual LogicalResult convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to provide translation of the operations to LLVM IR.
virtual LogicalResult convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const
Acts on the given function operation using the interface implemented by the dialect of one of the fun...
virtual LogicalResult amendOperation(Operation *op, ArrayRef< llvm::Instruction * > instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const
Acts on the given operation using the interface implemented by the dialect of one of the operation's ...
This class represents the base attribute for all debug info attributes.
Definition: LLVMAttrs.h:27
Implementation class for module translation.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
llvm::DIGlobalVariableExpression * translateGlobalVariableExpression(LLVM::DIGlobalVariableExpressionAttr attr)
Translates the given LLVM global variable expression metadata.
llvm::NamedMDNode * getOrInsertNamedModuleMetadata(StringRef name)
Gets the named metadata in the LLVM IR module being constructed, creating it if it does not exist.
llvm::Instruction * lookupBranch(Operation *op) const
Finds an LLVM IR instruction that corresponds to the given MLIR operation with successors.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::DILocation * translateLoc(Location loc, llvm::DILocalScope *scope)
Translates the given location.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
void setBranchWeightsMetadata(BranchWeightOpInterface op)
Sets LLVM profiling metadata for operations that have branch weights.
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst)
Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
llvm::DIExpression * translateExpression(LLVM::DIExpressionAttr attr)
Translates the given LLVM DWARF expression metadata.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::CallInst * lookupCall(Operation *op) const
Finds an LLVM call instruction that corresponds to the given MLIR call operation.
llvm::Metadata * translateDebugInfo(LLVM::DINodeAttr attr)
Translates the given LLVM debug info metadata.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
llvm::MDNode * getOrCreateAliasScopes(ArrayRef< AliasScopeAttr > aliasScopeAttrs)
Returns the LLVM metadata corresponding to an array of mlir LLVM dialect alias scope attributes.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::MDNode * getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr)
Returns the LLVM metadata corresponding to a mlir LLVM dialect alias scope attribute.
void forgetMapping(Region &region)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void setAliasScopeMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst)
void setAccessGroupsMetadata(AccessGroupOpInterface op, llvm::Instruction *inst)
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
void setLoopMetadata(Operation *op, llvm::Instruction *inst)
Sets LLVM loop metadata for branch operations that have a loop annotation attribute.
llvm::Type * translateType(Type type)
Translates the given MLIR LLVM dialect type to LLVM IR.
Definition: TypeToLLVM.cpp:192
A helper class that converts LoopAnnotationAttrs and AccessGroupAttrs into corresponding llvm::MDNode...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:97
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:762
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Attribute getDiscardableAttr(StringRef name)
Access a discardable attribute by name, returns an null Attribute if the discardable attribute does n...
Definition: Operation.h:448
Value getOperand(unsigned idx)
Definition: Operation.h:345
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
Block * getSuccessor(unsigned index)
Definition: Operation.h:704
unsigned getNumSuccessors()
Definition: Operation.h:702
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Definition: Operation.cpp:280
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
dialect_attr_range getDialectAttrs()
Return a range corresponding to the dialect attributes for this operation.
Definition: Operation.h:632
bool hasSuccessors()
Definition: Operation.h:701
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
This class models how operands are forwarded to block arguments in control flow.
bool empty() const
Returns true if there are no successor operands.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_iterator use_end() const
Definition: Value.h:205
Type getType() const
Return the type of this value.
Definition: Value.h:125
use_iterator use_begin() const
Definition: Value.h:204
Include the generated interface declarations.
Definition: CallGraph.h:229
void connectPHINodes(Region &region, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
static llvm::DenseMap< llvm::StringRef, llvm::Attribute::AttrKind > getAttrNameToKindMapping()
Returns a dense map from LLVM attribute name to their kind in LLVM IR dialect.
llvm::Constant * getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc, const ModuleTranslation &moduleTranslation)
Create an LLVM IR constant of llvmType from the MLIR attribute attr.
bool satisfiesLLVMModule(Operation *op)
LLVM requires some operations to be inside of a Module operation.
void legalizeDIExpressionsRecursively(Operation *op)
Register all known legalization patterns declared here and apply them to all ops in op.
std::optional< uint64_t > extractPointerSpecValue(Attribute attr, PtrDLEntryPos pos)
Returns the value that corresponds to named position pos from the data layout entry attr assuming it'...
Definition: LLVMTypes.cpp:262
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:845
void ensureDistinctSuccessors(Operation *op)
Make argument-taking successors of each block distinct.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index)
Returns the dictionary attribute corresponding to the argument at 'index'.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SetVector< Block * > getTopologicallySortedBlocks(Region &region)
Get a topologically sorted list of blocks of the given region.
DataLayoutSpecInterface translateDataLayout(const llvm::DataLayout &dataLayout, MLIRContext *context)
Translate the given LLVM data layout into an MLIR equivalent using the DLTI dialect.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< llvm::Module > translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, llvm::StringRef name="LLVMDialectModule")
Translate operation that satisfies LLVM dialect module requirements into an LLVM IR module living in ...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
Definition: MathExtras.h:45
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26