MLIR  19.0.0git
ModuleTranslation.cpp
Go to the documentation of this file.
1 //===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the translation between an MLIR LLVM dialect module and
10 // the corresponding LLVMIR module. It only handles core LLVM IR operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
16 #include "AttrKindDetail.h"
17 #include "DebugTranslation.h"
19 #include "mlir/Dialect/DLTI/DLTI.h"
27 #include "mlir/IR/Attributes.h"
28 #include "mlir/IR/BuiltinOps.h"
29 #include "mlir/IR/BuiltinTypes.h"
32 #include "mlir/Support/LLVM.h"
37 
38 #include "llvm/ADT/PostOrderIterator.h"
39 #include "llvm/ADT/SetVector.h"
40 #include "llvm/ADT/StringExtras.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
43 #include "llvm/IR/BasicBlock.h"
44 #include "llvm/IR/CFG.h"
45 #include "llvm/IR/Constants.h"
46 #include "llvm/IR/DerivedTypes.h"
47 #include "llvm/IR/IRBuilder.h"
48 #include "llvm/IR/InlineAsm.h"
49 #include "llvm/IR/IntrinsicsNVPTX.h"
50 #include "llvm/IR/LLVMContext.h"
51 #include "llvm/IR/MDBuilder.h"
52 #include "llvm/IR/Module.h"
53 #include "llvm/IR/Verifier.h"
54 #include "llvm/Support/Debug.h"
55 #include "llvm/Support/raw_ostream.h"
56 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
57 #include "llvm/Transforms/Utils/Cloning.h"
58 #include "llvm/Transforms/Utils/ModuleUtils.h"
59 #include <optional>
60 
61 #define DEBUG_TYPE "llvm-dialect-to-llvm-ir"
62 
63 using namespace mlir;
64 using namespace mlir::LLVM;
65 using namespace mlir::LLVM::detail;
66 
67 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
68 
69 namespace {
70 /// A customized inserter for LLVM's IRBuilder that captures all LLVM IR
71 /// instructions that are created for future reference.
72 ///
73 /// This is intended to be used with the `CollectionScope` RAII object:
74 ///
75 /// llvm::IRBuilder<..., InstructionCapturingInserter> builder;
76 /// {
77 /// InstructionCapturingInserter::CollectionScope scope(builder);
78 /// // Call IRBuilder methods as usual.
79 ///
80 /// // This will return a list of all instructions created by the builder,
81 /// // in order of creation.
82 /// builder.getInserter().getCapturedInstructions();
83 /// }
84 /// // This will return an empty list.
85 /// builder.getInserter().getCapturedInstructions();
86 ///
87 /// The capturing functionality is _disabled_ by default for performance
88 /// consideration. It needs to be explicitly enabled, which is achieved by
89 /// creating a `CollectionScope`.
90 class InstructionCapturingInserter : public llvm::IRBuilderCallbackInserter {
91 public:
92  /// Constructs the inserter.
93  InstructionCapturingInserter()
94  : llvm::IRBuilderCallbackInserter([this](llvm::Instruction *instruction) {
95  if (LLVM_LIKELY(enabled))
96  capturedInstructions.push_back(instruction);
97  }) {}
98 
99  /// Returns the list of LLVM IR instructions captured since the last cleanup.
100  ArrayRef<llvm::Instruction *> getCapturedInstructions() const {
101  return capturedInstructions;
102  }
103 
104  /// Clears the list of captured LLVM IR instructions.
105  void clearCapturedInstructions() { capturedInstructions.clear(); }
106 
107  /// RAII object enabling the capture of created LLVM IR instructions.
108  class CollectionScope {
109  public:
110  /// Creates the scope for the given inserter.
111  CollectionScope(llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing);
112 
113  /// Ends the scope.
114  ~CollectionScope();
115 
116  ArrayRef<llvm::Instruction *> getCapturedInstructions() {
117  if (!inserter)
118  return {};
119  return inserter->getCapturedInstructions();
120  }
121 
122  private:
123  /// Back reference to the inserter.
124  InstructionCapturingInserter *inserter = nullptr;
125 
126  /// List of instructions in the inserter prior to this scope.
127  SmallVector<llvm::Instruction *> previouslyCollectedInstructions;
128 
129  /// Whether the inserter was enabled prior to this scope.
130  bool wasEnabled;
131  };
132 
133  /// Enable or disable the capturing mechanism.
134  void setEnabled(bool enabled = true) { this->enabled = enabled; }
135 
136 private:
137  /// List of captured instructions.
138  SmallVector<llvm::Instruction *> capturedInstructions;
139 
140  /// Whether the collection is enabled.
141  bool enabled = false;
142 };
143 
144 using CapturingIRBuilder =
145  llvm::IRBuilder<llvm::ConstantFolder, InstructionCapturingInserter>;
146 } // namespace
147 
148 InstructionCapturingInserter::CollectionScope::CollectionScope(
149  llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing) {
150 
151  if (!isBuilderCapturing)
152  return;
153 
154  auto &capturingIRBuilder = static_cast<CapturingIRBuilder &>(irBuilder);
155  inserter = &capturingIRBuilder.getInserter();
156  wasEnabled = inserter->enabled;
157  if (wasEnabled)
158  previouslyCollectedInstructions.swap(inserter->capturedInstructions);
159  inserter->setEnabled(true);
160 }
161 
162 InstructionCapturingInserter::CollectionScope::~CollectionScope() {
163  if (!inserter)
164  return;
165 
166  previouslyCollectedInstructions.swap(inserter->capturedInstructions);
167  // If collection was enabled (likely in another, surrounding scope), keep
168  // the instructions collected in this scope.
169  if (wasEnabled) {
170  llvm::append_range(inserter->capturedInstructions,
171  previouslyCollectedInstructions);
172  }
173  inserter->setEnabled(wasEnabled);
174 }
175 
176 /// Translates the given data layout spec attribute to the LLVM IR data layout.
177 /// Only integer, float, pointer and endianness entries are currently supported.
179 translateDataLayout(DataLayoutSpecInterface attribute,
180  const DataLayout &dataLayout,
181  std::optional<Location> loc = std::nullopt) {
182  if (!loc)
183  loc = UnknownLoc::get(attribute.getContext());
184 
185  // Translate the endianness attribute.
186  std::string llvmDataLayout;
187  llvm::raw_string_ostream layoutStream(llvmDataLayout);
188  for (DataLayoutEntryInterface entry : attribute.getEntries()) {
189  auto key = llvm::dyn_cast_if_present<StringAttr>(entry.getKey());
190  if (!key)
191  continue;
192  if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) {
193  auto value = cast<StringAttr>(entry.getValue());
194  bool isLittleEndian =
195  value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle;
196  layoutStream << "-" << (isLittleEndian ? "e" : "E");
197  layoutStream.flush();
198  continue;
199  }
200  if (key.getValue() == DLTIDialect::kDataLayoutProgramMemorySpaceKey) {
201  auto value = cast<IntegerAttr>(entry.getValue());
202  uint64_t space = value.getValue().getZExtValue();
203  // Skip the default address space.
204  if (space == 0)
205  continue;
206  layoutStream << "-P" << space;
207  layoutStream.flush();
208  continue;
209  }
210  if (key.getValue() == DLTIDialect::kDataLayoutGlobalMemorySpaceKey) {
211  auto value = cast<IntegerAttr>(entry.getValue());
212  uint64_t space = value.getValue().getZExtValue();
213  // Skip the default address space.
214  if (space == 0)
215  continue;
216  layoutStream << "-G" << space;
217  layoutStream.flush();
218  continue;
219  }
220  if (key.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey) {
221  auto value = cast<IntegerAttr>(entry.getValue());
222  uint64_t space = value.getValue().getZExtValue();
223  // Skip the default address space.
224  if (space == 0)
225  continue;
226  layoutStream << "-A" << space;
227  layoutStream.flush();
228  continue;
229  }
230  if (key.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey) {
231  auto value = cast<IntegerAttr>(entry.getValue());
232  uint64_t alignment = value.getValue().getZExtValue();
233  // Skip the default stack alignment.
234  if (alignment == 0)
235  continue;
236  layoutStream << "-S" << alignment;
237  layoutStream.flush();
238  continue;
239  }
240  emitError(*loc) << "unsupported data layout key " << key;
241  return failure();
242  }
243 
244  // Go through the list of entries to check which types are explicitly
245  // specified in entries. Where possible, data layout queries are used instead
246  // of directly inspecting the entries.
247  for (DataLayoutEntryInterface entry : attribute.getEntries()) {
248  auto type = llvm::dyn_cast_if_present<Type>(entry.getKey());
249  if (!type)
250  continue;
251  // Data layout for the index type is irrelevant at this point.
252  if (isa<IndexType>(type))
253  continue;
254  layoutStream << "-";
255  LogicalResult result =
257  .Case<IntegerType, Float16Type, Float32Type, Float64Type,
258  Float80Type, Float128Type>([&](Type type) -> LogicalResult {
259  if (auto intType = dyn_cast<IntegerType>(type)) {
260  if (intType.getSignedness() != IntegerType::Signless)
261  return emitError(*loc)
262  << "unsupported data layout for non-signless integer "
263  << intType;
264  layoutStream << "i";
265  } else {
266  layoutStream << "f";
267  }
268  uint64_t size = dataLayout.getTypeSizeInBits(type);
269  uint64_t abi = dataLayout.getTypeABIAlignment(type) * 8u;
270  uint64_t preferred =
271  dataLayout.getTypePreferredAlignment(type) * 8u;
272  layoutStream << size << ":" << abi;
273  if (abi != preferred)
274  layoutStream << ":" << preferred;
275  return success();
276  })
277  .Case([&](LLVMPointerType type) {
278  layoutStream << "p" << type.getAddressSpace() << ":";
279  uint64_t size = dataLayout.getTypeSizeInBits(type);
280  uint64_t abi = dataLayout.getTypeABIAlignment(type) * 8u;
281  uint64_t preferred =
282  dataLayout.getTypePreferredAlignment(type) * 8u;
283  uint64_t index = *dataLayout.getTypeIndexBitwidth(type);
284  layoutStream << size << ":" << abi << ":" << preferred << ":"
285  << index;
286  return success();
287  })
288  .Default([loc](Type type) {
289  return emitError(*loc)
290  << "unsupported type in data layout: " << type;
291  });
292  if (failed(result))
293  return failure();
294  }
295  layoutStream.flush();
296  StringRef layoutSpec(llvmDataLayout);
297  if (layoutSpec.starts_with("-"))
298  layoutSpec = layoutSpec.drop_front();
299 
300  return llvm::DataLayout(layoutSpec);
301 }
302 
303 /// Builds a constant of a sequential LLVM type `type`, potentially containing
304 /// other sequential types recursively, from the individual constant values
305 /// provided in `constants`. `shape` contains the number of elements in nested
306 /// sequential types. Reports errors at `loc` and returns nullptr on error.
307 static llvm::Constant *
309  ArrayRef<int64_t> shape, llvm::Type *type,
310  Location loc) {
311  if (shape.empty()) {
312  llvm::Constant *result = constants.front();
313  constants = constants.drop_front();
314  return result;
315  }
316 
317  llvm::Type *elementType;
318  if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
319  elementType = arrayTy->getElementType();
320  } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
321  elementType = vectorTy->getElementType();
322  } else {
323  emitError(loc) << "expected sequential LLVM types wrapping a scalar";
324  return nullptr;
325  }
326 
328  nested.reserve(shape.front());
329  for (int64_t i = 0; i < shape.front(); ++i) {
330  nested.push_back(buildSequentialConstant(constants, shape.drop_front(),
331  elementType, loc));
332  if (!nested.back())
333  return nullptr;
334  }
335 
336  if (shape.size() == 1 && type->isVectorTy())
337  return llvm::ConstantVector::get(nested);
339  llvm::ArrayType::get(elementType, shape.front()), nested);
340 }
341 
342 /// Returns the first non-sequential type nested in sequential types.
343 static llvm::Type *getInnermostElementType(llvm::Type *type) {
344  do {
345  if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
346  type = arrayTy->getElementType();
347  } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
348  type = vectorTy->getElementType();
349  } else {
350  return type;
351  }
352  } while (true);
353 }
354 
355 /// Convert a dense elements attribute to an LLVM IR constant using its raw data
356 /// storage if possible. This supports elements attributes of tensor or vector
357 /// type and avoids constructing separate objects for individual values of the
358 /// innermost dimension. Constants for other dimensions are still constructed
359 /// recursively. Returns null if constructing from raw data is not supported for
360 /// this type, e.g., element type is not a power-of-two-sized primitive. Reports
361 /// other errors at `loc`.
362 static llvm::Constant *
364  llvm::Type *llvmType,
365  const ModuleTranslation &moduleTranslation) {
366  if (!denseElementsAttr)
367  return nullptr;
368 
369  llvm::Type *innermostLLVMType = getInnermostElementType(llvmType);
370  if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType))
371  return nullptr;
372 
373  ShapedType type = denseElementsAttr.getType();
374  if (type.getNumElements() == 0)
375  return nullptr;
376 
377  // Check that the raw data size matches what is expected for the scalar size.
378  // TODO: in theory, we could repack the data here to keep constructing from
379  // raw data.
380  // TODO: we may also need to consider endianness when cross-compiling to an
381  // architecture where it is different.
382  int64_t elementByteSize = denseElementsAttr.getRawData().size() /
383  denseElementsAttr.getNumElements();
384  if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits())
385  return nullptr;
386 
387  // Compute the shape of all dimensions but the innermost. Note that the
388  // innermost dimension may be that of the vector element type.
389  bool hasVectorElementType = isa<VectorType>(type.getElementType());
390  int64_t numAggregates =
391  denseElementsAttr.getNumElements() /
392  (hasVectorElementType ? 1
393  : denseElementsAttr.getType().getShape().back());
394  ArrayRef<int64_t> outerShape = type.getShape();
395  if (!hasVectorElementType)
396  outerShape = outerShape.drop_back();
397 
398  // Handle the case of vector splat, LLVM has special support for it.
399  if (denseElementsAttr.isSplat() &&
400  (isa<VectorType>(type) || hasVectorElementType)) {
401  llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
402  innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
403  moduleTranslation);
404  llvm::Constant *splatVector =
405  llvm::ConstantDataVector::getSplat(0, splatValue);
406  SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
407  ArrayRef<llvm::Constant *> constantsRef = constants;
408  return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
409  }
410  if (denseElementsAttr.isSplat())
411  return nullptr;
412 
413  // In case of non-splat, create a constructor for the innermost constant from
414  // a piece of raw data.
415  std::function<llvm::Constant *(StringRef)> buildCstData;
416  if (isa<TensorType>(type)) {
417  auto vectorElementType = dyn_cast<VectorType>(type.getElementType());
418  if (vectorElementType && vectorElementType.getRank() == 1) {
419  buildCstData = [&](StringRef data) {
420  return llvm::ConstantDataVector::getRaw(
421  data, vectorElementType.getShape().back(), innermostLLVMType);
422  };
423  } else if (!vectorElementType) {
424  buildCstData = [&](StringRef data) {
425  return llvm::ConstantDataArray::getRaw(data, type.getShape().back(),
426  innermostLLVMType);
427  };
428  }
429  } else if (isa<VectorType>(type)) {
430  buildCstData = [&](StringRef data) {
431  return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
432  innermostLLVMType);
433  };
434  }
435  if (!buildCstData)
436  return nullptr;
437 
438  // Create innermost constants and defer to the default constant creation
439  // mechanism for other dimensions.
441  int64_t aggregateSize = denseElementsAttr.getType().getShape().back() *
442  (innermostLLVMType->getScalarSizeInBits() / 8);
443  constants.reserve(numAggregates);
444  for (unsigned i = 0; i < numAggregates; ++i) {
445  StringRef data(denseElementsAttr.getRawData().data() + i * aggregateSize,
446  aggregateSize);
447  constants.push_back(buildCstData(data));
448  }
449 
450  ArrayRef<llvm::Constant *> constantsRef = constants;
451  return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
452 }
453 
454 /// Convert a dense resource elements attribute to an LLVM IR constant using its
455 /// raw data storage if possible. This supports elements attributes of tensor or
456 /// vector type and avoids constructing separate objects for individual values
457 /// of the innermost dimension. Constants for other dimensions are still
458 /// constructed recursively. Returns nullptr on failure and emits errors at
459 /// `loc`.
460 static llvm::Constant *convertDenseResourceElementsAttr(
461  Location loc, DenseResourceElementsAttr denseResourceAttr,
462  llvm::Type *llvmType, const ModuleTranslation &moduleTranslation) {
463  assert(denseResourceAttr && "expected non-null attribute");
464 
465  llvm::Type *innermostLLVMType = getInnermostElementType(llvmType);
466  if (!llvm::ConstantDataSequential::isElementTypeCompatible(
467  innermostLLVMType)) {
468  emitError(loc, "no known conversion for innermost element type");
469  return nullptr;
470  }
471 
472  ShapedType type = denseResourceAttr.getType();
473  assert(type.getNumElements() > 0 && "Expected non-empty elements attribute");
474 
475  AsmResourceBlob *blob = denseResourceAttr.getRawHandle().getBlob();
476  if (!blob) {
477  emitError(loc, "resource does not exist");
478  return nullptr;
479  }
480 
481  ArrayRef<char> rawData = blob->getData();
482 
483  // Check that the raw data size matches what is expected for the scalar size.
484  // TODO: in theory, we could repack the data here to keep constructing from
485  // raw data.
486  // TODO: we may also need to consider endianness when cross-compiling to an
487  // architecture where it is different.
488  int64_t numElements = denseResourceAttr.getType().getNumElements();
489  int64_t elementByteSize = rawData.size() / numElements;
490  if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits()) {
491  emitError(loc, "raw data size does not match element type size");
492  return nullptr;
493  }
494 
495  // Compute the shape of all dimensions but the innermost. Note that the
496  // innermost dimension may be that of the vector element type.
497  bool hasVectorElementType = isa<VectorType>(type.getElementType());
498  int64_t numAggregates =
499  numElements / (hasVectorElementType
500  ? 1
501  : denseResourceAttr.getType().getShape().back());
502  ArrayRef<int64_t> outerShape = type.getShape();
503  if (!hasVectorElementType)
504  outerShape = outerShape.drop_back();
505 
506  // Create a constructor for the innermost constant from a piece of raw data.
507  std::function<llvm::Constant *(StringRef)> buildCstData;
508  if (isa<TensorType>(type)) {
509  auto vectorElementType = dyn_cast<VectorType>(type.getElementType());
510  if (vectorElementType && vectorElementType.getRank() == 1) {
511  buildCstData = [&](StringRef data) {
512  return llvm::ConstantDataVector::getRaw(
513  data, vectorElementType.getShape().back(), innermostLLVMType);
514  };
515  } else if (!vectorElementType) {
516  buildCstData = [&](StringRef data) {
517  return llvm::ConstantDataArray::getRaw(data, type.getShape().back(),
518  innermostLLVMType);
519  };
520  }
521  } else if (isa<VectorType>(type)) {
522  buildCstData = [&](StringRef data) {
523  return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
524  innermostLLVMType);
525  };
526  }
527  if (!buildCstData) {
528  emitError(loc, "unsupported dense_resource type");
529  return nullptr;
530  }
531 
532  // Create innermost constants and defer to the default constant creation
533  // mechanism for other dimensions.
535  int64_t aggregateSize = denseResourceAttr.getType().getShape().back() *
536  (innermostLLVMType->getScalarSizeInBits() / 8);
537  constants.reserve(numAggregates);
538  for (unsigned i = 0; i < numAggregates; ++i) {
539  StringRef data(rawData.data() + i * aggregateSize, aggregateSize);
540  constants.push_back(buildCstData(data));
541  }
542 
543  ArrayRef<llvm::Constant *> constantsRef = constants;
544  return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
545 }
546 
547 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
548 /// This currently supports integer, floating point, splat and dense element
549 /// attributes and combinations thereof. Also, an array attribute with two
550 /// elements is supported to represent a complex constant. In case of error,
551 /// report it to `loc` and return nullptr.
553  llvm::Type *llvmType, Attribute attr, Location loc,
554  const ModuleTranslation &moduleTranslation) {
555  if (!attr)
556  return llvm::UndefValue::get(llvmType);
557  if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
558  auto arrayAttr = dyn_cast<ArrayAttr>(attr);
559  if (!arrayAttr || arrayAttr.size() != 2) {
560  emitError(loc, "expected struct type to be a complex number");
561  return nullptr;
562  }
563  llvm::Type *elementType = structType->getElementType(0);
564  llvm::Constant *real =
565  getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
566  if (!real)
567  return nullptr;
568  llvm::Constant *imag =
569  getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
570  if (!imag)
571  return nullptr;
572  return llvm::ConstantStruct::get(structType, {real, imag});
573  }
574  // For integer types, we allow a mismatch in sizes as the index type in
575  // MLIR might have a different size than the index type in the LLVM module.
576  if (auto intAttr = dyn_cast<IntegerAttr>(attr))
577  return llvm::ConstantInt::get(
578  llvmType,
579  intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));
580  if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
581  const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
582  // Special case for 8-bit floats, which are represented by integers due to
583  // the lack of native fp8 types in LLVM at the moment. Additionally, handle
584  // targets (like AMDGPU) that don't implement bfloat and convert all bfloats
585  // to i16.
586  unsigned floatWidth = APFloat::getSizeInBits(sem);
587  if (llvmType->isIntegerTy(floatWidth))
588  return llvm::ConstantInt::get(llvmType,
589  floatAttr.getValue().bitcastToAPInt());
590  if (llvmType !=
591  llvm::Type::getFloatingPointTy(llvmType->getContext(),
592  floatAttr.getValue().getSemantics())) {
593  emitError(loc, "FloatAttr does not match expected type of the constant");
594  return nullptr;
595  }
596  return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
597  }
598  if (auto funcAttr = dyn_cast<FlatSymbolRefAttr>(attr))
599  return llvm::ConstantExpr::getBitCast(
600  moduleTranslation.lookupFunction(funcAttr.getValue()), llvmType);
601  if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr)) {
602  llvm::Type *elementType;
603  uint64_t numElements;
604  bool isScalable = false;
605  if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
606  elementType = arrayTy->getElementType();
607  numElements = arrayTy->getNumElements();
608  } else if (auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) {
609  elementType = fVectorTy->getElementType();
610  numElements = fVectorTy->getNumElements();
611  } else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
612  elementType = sVectorTy->getElementType();
613  numElements = sVectorTy->getMinNumElements();
614  isScalable = true;
615  } else {
616  llvm_unreachable("unrecognized constant vector type");
617  }
618  // Splat value is a scalar. Extract it only if the element type is not
619  // another sequence type. The recursion terminates because each step removes
620  // one outer sequential type.
621  bool elementTypeSequential =
622  isa<llvm::ArrayType, llvm::VectorType>(elementType);
623  llvm::Constant *child = getLLVMConstant(
624  elementType,
625  elementTypeSequential ? splatAttr
626  : splatAttr.getSplatValue<Attribute>(),
627  loc, moduleTranslation);
628  if (!child)
629  return nullptr;
630  if (llvmType->isVectorTy())
631  return llvm::ConstantVector::getSplat(
632  llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child);
633  if (llvmType->isArrayTy()) {
634  auto *arrayType = llvm::ArrayType::get(elementType, numElements);
635  SmallVector<llvm::Constant *, 8> constants(numElements, child);
636  return llvm::ConstantArray::get(arrayType, constants);
637  }
638  }
639 
640  // Try using raw elements data if possible.
641  if (llvm::Constant *result =
642  convertDenseElementsAttr(loc, dyn_cast<DenseElementsAttr>(attr),
643  llvmType, moduleTranslation)) {
644  return result;
645  }
646 
647  if (auto denseResourceAttr = dyn_cast<DenseResourceElementsAttr>(attr)) {
648  return convertDenseResourceElementsAttr(loc, denseResourceAttr, llvmType,
649  moduleTranslation);
650  }
651 
652  // Fall back to element-by-element construction otherwise.
653  if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) {
654  assert(elementsAttr.getShapedType().hasStaticShape());
655  assert(!elementsAttr.getShapedType().getShape().empty() &&
656  "unexpected empty elements attribute shape");
657 
659  constants.reserve(elementsAttr.getNumElements());
660  llvm::Type *innermostType = getInnermostElementType(llvmType);
661  for (auto n : elementsAttr.getValues<Attribute>()) {
662  constants.push_back(
663  getLLVMConstant(innermostType, n, loc, moduleTranslation));
664  if (!constants.back())
665  return nullptr;
666  }
667  ArrayRef<llvm::Constant *> constantsRef = constants;
668  llvm::Constant *result = buildSequentialConstant(
669  constantsRef, elementsAttr.getShapedType().getShape(), llvmType, loc);
670  assert(constantsRef.empty() && "did not consume all elemental constants");
671  return result;
672  }
673 
674  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
676  moduleTranslation.getLLVMContext(),
677  ArrayRef<char>{stringAttr.getValue().data(),
678  stringAttr.getValue().size()});
679  }
680  emitError(loc, "unsupported constant value");
681  return nullptr;
682 }
683 
684 ModuleTranslation::ModuleTranslation(Operation *module,
685  std::unique_ptr<llvm::Module> llvmModule)
686  : mlirModule(module), llvmModule(std::move(llvmModule)),
687  debugTranslation(
688  std::make_unique<DebugTranslation>(module, *this->llvmModule)),
689  loopAnnotationTranslation(std::make_unique<LoopAnnotationTranslation>(
690  *this, *this->llvmModule)),
691  typeTranslator(this->llvmModule->getContext()),
692  iface(module->getContext()) {
693  assert(satisfiesLLVMModule(mlirModule) &&
694  "mlirModule should honor LLVM's module semantics.");
695 }
696 
697 ModuleTranslation::~ModuleTranslation() {
698  if (ompBuilder)
699  ompBuilder->finalize();
700 }
701 
703  SmallVector<Region *> toProcess;
704  toProcess.push_back(&region);
705  while (!toProcess.empty()) {
706  Region *current = toProcess.pop_back_val();
707  for (Block &block : *current) {
708  blockMapping.erase(&block);
709  for (Value arg : block.getArguments())
710  valueMapping.erase(arg);
711  for (Operation &op : block) {
712  for (Value value : op.getResults())
713  valueMapping.erase(value);
714  if (op.hasSuccessors())
715  branchMapping.erase(&op);
716  if (isa<LLVM::GlobalOp>(op))
717  globalsMapping.erase(&op);
718  if (isa<LLVM::CallOp>(op))
719  callMapping.erase(&op);
720  llvm::append_range(
721  toProcess,
722  llvm::map_range(op.getRegions(), [](Region &r) { return &r; }));
723  }
724  }
725  }
726 }
727 
728 /// Get the SSA value passed to the current block from the terminator operation
729 /// of its predecessor.
730 static Value getPHISourceValue(Block *current, Block *pred,
731  unsigned numArguments, unsigned index) {
732  Operation &terminator = *pred->getTerminator();
733  if (isa<LLVM::BrOp>(terminator))
734  return terminator.getOperand(index);
735 
736 #ifndef NDEBUG
737  llvm::SmallPtrSet<Block *, 4> seenSuccessors;
738  for (unsigned i = 0, e = terminator.getNumSuccessors(); i < e; ++i) {
739  Block *successor = terminator.getSuccessor(i);
740  auto branch = cast<BranchOpInterface>(terminator);
741  SuccessorOperands successorOperands = branch.getSuccessorOperands(i);
742  assert(
743  (!seenSuccessors.contains(successor) || successorOperands.empty()) &&
744  "successors with arguments in LLVM branches must be different blocks");
745  seenSuccessors.insert(successor);
746  }
747 #endif
748 
749  // For instructions that branch based on a condition value, we need to take
750  // the operands for the branch that was taken.
751  if (auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator)) {
752  // For conditional branches, we take the operands from either the "true" or
753  // the "false" branch.
754  return condBranchOp.getSuccessor(0) == current
755  ? condBranchOp.getTrueDestOperands()[index]
756  : condBranchOp.getFalseDestOperands()[index];
757  }
758 
759  if (auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) {
760  // For switches, we take the operands from either the default case, or from
761  // the case branch that was taken.
762  if (switchOp.getDefaultDestination() == current)
763  return switchOp.getDefaultOperands()[index];
764  for (const auto &i : llvm::enumerate(switchOp.getCaseDestinations()))
765  if (i.value() == current)
766  return switchOp.getCaseOperands(i.index())[index];
767  }
768 
769  if (auto invokeOp = dyn_cast<LLVM::InvokeOp>(terminator)) {
770  return invokeOp.getNormalDest() == current
771  ? invokeOp.getNormalDestOperands()[index]
772  : invokeOp.getUnwindDestOperands()[index];
773  }
774 
775  llvm_unreachable(
776  "only branch, switch or invoke operations can be terminators "
777  "of a block that has successors");
778 }
779 
780 /// Connect the PHI nodes to the results of preceding blocks.
782  const ModuleTranslation &state) {
783  // Skip the first block, it cannot be branched to and its arguments correspond
784  // to the arguments of the LLVM function.
785  for (Block &bb : llvm::drop_begin(region)) {
786  llvm::BasicBlock *llvmBB = state.lookupBlock(&bb);
787  auto phis = llvmBB->phis();
788  auto numArguments = bb.getNumArguments();
789  assert(numArguments == std::distance(phis.begin(), phis.end()));
790  for (auto [index, phiNode] : llvm::enumerate(phis)) {
791  for (auto *pred : bb.getPredecessors()) {
792  // Find the LLVM IR block that contains the converted terminator
793  // instruction and use it in the PHI node. Note that this block is not
794  // necessarily the same as state.lookupBlock(pred), some operations
795  // (in particular, OpenMP operations using OpenMPIRBuilder) may have
796  // split the blocks.
797  llvm::Instruction *terminator =
798  state.lookupBranch(pred->getTerminator());
799  assert(terminator && "missing the mapping for a terminator");
800  phiNode.addIncoming(state.lookupValue(getPHISourceValue(
801  &bb, pred, numArguments, index)),
802  terminator->getParent());
803  }
804  }
805  }
806 }
807 
809  llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic,
811  llvm::Module *module = builder.GetInsertBlock()->getModule();
812  llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic, tys);
813  return builder.CreateCall(fn, args);
814 }
815 
817  llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
818  Operation *intrOp, llvm::Intrinsic::ID intrinsic, unsigned numResults,
819  ArrayRef<unsigned> overloadedResults, ArrayRef<unsigned> overloadedOperands,
820  ArrayRef<unsigned> immArgPositions,
821  ArrayRef<StringLiteral> immArgAttrNames) {
822  assert(immArgPositions.size() == immArgAttrNames.size() &&
823  "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
824  "length");
825 
826  // Map operands and attributes to LLVM values.
827  auto operands = moduleTranslation.lookupValues(intrOp->getOperands());
828  SmallVector<llvm::Value *> args(immArgPositions.size() + operands.size());
829  for (auto [immArgPos, immArgName] :
830  llvm::zip(immArgPositions, immArgAttrNames)) {
831  auto attr = llvm::cast<TypedAttr>(intrOp->getAttr(immArgName));
832  assert(attr.getType().isIntOrFloat() && "expected int or float immarg");
833  auto *type = moduleTranslation.convertType(attr.getType());
834  args[immArgPos] = LLVM::detail::getLLVMConstant(
835  type, attr, intrOp->getLoc(), moduleTranslation);
836  }
837  unsigned opArg = 0;
838  for (auto &arg : args) {
839  if (!arg)
840  arg = operands[opArg++];
841  }
842 
843  // Resolve overloaded intrinsic declaration.
844  SmallVector<llvm::Type *> overloadedTypes;
845  for (unsigned overloadedResultIdx : overloadedResults) {
846  if (numResults > 1) {
847  // More than one result is mapped to an LLVM struct.
848  overloadedTypes.push_back(moduleTranslation.convertType(
849  llvm::cast<LLVM::LLVMStructType>(intrOp->getResult(0).getType())
850  .getBody()[overloadedResultIdx]));
851  } else {
852  overloadedTypes.push_back(
853  moduleTranslation.convertType(intrOp->getResult(0).getType()));
854  }
855  }
856  for (unsigned overloadedOperandIdx : overloadedOperands)
857  overloadedTypes.push_back(args[overloadedOperandIdx]->getType());
858  llvm::Module *module = builder.GetInsertBlock()->getModule();
859  llvm::Function *llvmIntr =
860  llvm::Intrinsic::getDeclaration(module, intrinsic, overloadedTypes);
861 
862  return builder.CreateCall(llvmIntr, args);
863 }
864 
865 /// Given a single MLIR operation, create the corresponding LLVM IR operation
866 /// using the `builder`.
867 LogicalResult ModuleTranslation::convertOperation(Operation &op,
868  llvm::IRBuilderBase &builder,
869  bool recordInsertions) {
870  const LLVMTranslationDialectInterface *opIface = iface.getInterfaceFor(&op);
871  if (!opIface)
872  return op.emitError("cannot be converted to LLVM IR: missing "
873  "`LLVMTranslationDialectInterface` registration for "
874  "dialect for op: ")
875  << op.getName();
876 
877  InstructionCapturingInserter::CollectionScope scope(builder,
878  recordInsertions);
879  if (failed(opIface->convertOperation(&op, builder, *this)))
880  return op.emitError("LLVM Translation failed for operation: ")
881  << op.getName();
882 
883  return convertDialectAttributes(&op, scope.getCapturedInstructions());
884 }
885 
886 /// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes
887 /// to define values corresponding to the MLIR block arguments. These nodes
888 /// are not connected to the source basic blocks, which may not exist yet. Uses
889 /// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have
890 /// been created for `bb` and included in the block mapping. Inserts new
891 /// instructions at the end of the block and leaves `builder` in a state
892 /// suitable for further insertion into the end of the block.
893 LogicalResult ModuleTranslation::convertBlockImpl(Block &bb,
894  bool ignoreArguments,
895  llvm::IRBuilderBase &builder,
896  bool recordInsertions) {
897  builder.SetInsertPoint(lookupBlock(&bb));
898  auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram();
899 
900  // Before traversing operations, make block arguments available through
901  // value remapping and PHI nodes, but do not add incoming edges for the PHI
902  // nodes just yet: those values may be defined by this or following blocks.
903  // This step is omitted if "ignoreArguments" is set. The arguments of the
904  // first block have been already made available through the remapping of
905  // LLVM function arguments.
906  if (!ignoreArguments) {
907  auto predecessors = bb.getPredecessors();
908  unsigned numPredecessors =
909  std::distance(predecessors.begin(), predecessors.end());
910  for (auto arg : bb.getArguments()) {
911  auto wrappedType = arg.getType();
912  if (!isCompatibleType(wrappedType))
913  return emitError(bb.front().getLoc(),
914  "block argument does not have an LLVM type");
915  llvm::Type *type = convertType(wrappedType);
916  llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
917  mapValue(arg, phi);
918  }
919  }
920 
921  // Traverse operations.
922  for (auto &op : bb) {
923  // Set the current debug location within the builder.
924  builder.SetCurrentDebugLocation(
925  debugTranslation->translateLoc(op.getLoc(), subprogram));
926 
927  if (failed(convertOperation(op, builder, recordInsertions)))
928  return failure();
929 
930  // Set the branch weight metadata on the translated instruction.
931  if (auto iface = dyn_cast<BranchWeightOpInterface>(op))
933  }
934 
935  return success();
936 }
937 
938 /// A helper method to get the single Block in an operation honoring LLVM's
939 /// module requirements.
940 static Block &getModuleBody(Operation *module) {
941  return module->getRegion(0).front();
942 }
943 
944 /// A helper method to decide if a constant must not be set as a global variable
945 /// initializer. For an external linkage variable, the variable with an
946 /// initializer is considered externally visible and defined in this module, the
947 /// variable without an initializer is externally available and is defined
948 /// elsewhere.
949 static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage,
950  llvm::Constant *cst) {
951  return (linkage == llvm::GlobalVariable::ExternalLinkage && !cst) ||
952  linkage == llvm::GlobalVariable::ExternalWeakLinkage;
953 }
954 
955 /// Sets the runtime preemption specifier of `gv` to dso_local if
956 /// `dsoLocalRequested` is true, otherwise it is left unchanged.
957 static void addRuntimePreemptionSpecifier(bool dsoLocalRequested,
958  llvm::GlobalValue *gv) {
959  if (dsoLocalRequested)
960  gv->setDSOLocal(true);
961 }
962 
963 /// Create named global variables that correspond to llvm.mlir.global
964 /// definitions. Convert llvm.global_ctors and global_dtors ops.
965 LogicalResult ModuleTranslation::convertGlobals() {
966  // Mapping from compile unit to its respective set of global variables.
968 
969  for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
970  llvm::Type *type = convertType(op.getType());
971  llvm::Constant *cst = nullptr;
972  if (op.getValueOrNull()) {
973  // String attributes are treated separately because they cannot appear as
974  // in-function constants and are thus not supported by getLLVMConstant.
975  if (auto strAttr = dyn_cast_or_null<StringAttr>(op.getValueOrNull())) {
976  cst = llvm::ConstantDataArray::getString(
977  llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false);
978  type = cst->getType();
979  } else if (!(cst = getLLVMConstant(type, op.getValueOrNull(), op.getLoc(),
980  *this))) {
981  return failure();
982  }
983  }
984 
985  auto linkage = convertLinkageToLLVM(op.getLinkage());
986  auto addrSpace = op.getAddrSpace();
987 
988  // LLVM IR requires constant with linkage other than external or weak
989  // external to have initializers. If MLIR does not provide an initializer,
990  // default to undef.
991  bool dropInitializer = shouldDropGlobalInitializer(linkage, cst);
992  if (!dropInitializer && !cst)
993  cst = llvm::UndefValue::get(type);
994  else if (dropInitializer && cst)
995  cst = nullptr;
996 
997  auto *var = new llvm::GlobalVariable(
998  *llvmModule, type, op.getConstant(), linkage, cst, op.getSymName(),
999  /*InsertBefore=*/nullptr,
1000  op.getThreadLocal_() ? llvm::GlobalValue::GeneralDynamicTLSModel
1001  : llvm::GlobalValue::NotThreadLocal,
1002  addrSpace);
1003 
1004  if (std::optional<mlir::SymbolRefAttr> comdat = op.getComdat()) {
1005  auto selectorOp = cast<ComdatSelectorOp>(
1007  var->setComdat(comdatMapping.lookup(selectorOp));
1008  }
1009 
1010  if (op.getUnnamedAddr().has_value())
1011  var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr()));
1012 
1013  if (op.getSection().has_value())
1014  var->setSection(*op.getSection());
1015 
1016  addRuntimePreemptionSpecifier(op.getDsoLocal(), var);
1017 
1018  std::optional<uint64_t> alignment = op.getAlignment();
1019  if (alignment.has_value())
1020  var->setAlignment(llvm::MaybeAlign(alignment.value()));
1021 
1022  var->setVisibility(convertVisibilityToLLVM(op.getVisibility_()));
1023 
1024  globalsMapping.try_emplace(op, var);
1025 
1026  // Add debug information if present.
1027  if (op.getDbgExpr()) {
1028  llvm::DIGlobalVariableExpression *diGlobalExpr =
1029  debugTranslation->translateGlobalVariableExpression(op.getDbgExpr());
1030  llvm::DIGlobalVariable *diGlobalVar = diGlobalExpr->getVariable();
1031  var->addDebugInfo(diGlobalExpr);
1032 
1033  // Get the compile unit (scope) of the the global variable.
1034  if (llvm::DICompileUnit *compileUnit =
1035  dyn_cast_if_present<llvm::DICompileUnit>(
1036  diGlobalVar->getScope())) {
1037  // Update the compile unit with this incoming global variable expression
1038  // during the finalizing step later.
1039  allGVars[compileUnit].push_back(diGlobalExpr);
1040  }
1041  }
1042  }
1043 
1044  // Convert global variable bodies. This is done after all global variables
1045  // have been created in LLVM IR because a global body may refer to another
1046  // global or itself. So all global variables need to be mapped first.
1047  for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
1048  if (Block *initializer = op.getInitializerBlock()) {
1049  llvm::IRBuilder<> builder(llvmModule->getContext());
1050 
1051  [[maybe_unused]] int numConstantsHit = 0;
1052  [[maybe_unused]] int numConstantsErased = 0;
1053  DenseMap<llvm::ConstantAggregate *, int> constantAggregateUseMap;
1054 
1055  for (auto &op : initializer->without_terminator()) {
1056  if (failed(convertOperation(op, builder)))
1057  return emitError(op.getLoc(), "fail to convert global initializer");
1058  auto *cst = dyn_cast<llvm::Constant>(lookupValue(op.getResult(0)));
1059  if (!cst)
1060  return emitError(op.getLoc(), "unemittable constant value");
1061 
1062  // When emitting an LLVM constant, a new constant is created and the old
1063  // constant may become dangling and take space. We should remove the
1064  // dangling constants to avoid memory explosion especially for constant
1065  // arrays whose number of elements is large.
1066  // Because multiple operations may refer to the same constant, we need
1067  // to count the number of uses of each constant array and remove it only
1068  // when the count becomes zero.
1069  if (auto *agg = dyn_cast<llvm::ConstantAggregate>(cst)) {
1070  numConstantsHit++;
1071  Value result = op.getResult(0);
1072  int numUsers = std::distance(result.use_begin(), result.use_end());
1073  auto [iterator, inserted] =
1074  constantAggregateUseMap.try_emplace(agg, numUsers);
1075  if (!inserted) {
1076  // Key already exists, update the value
1077  iterator->second += numUsers;
1078  }
1079  }
1080  // Scan the operands of the operation to decrement the use count of
1081  // constants. Erase the constant if the use count becomes zero.
1082  for (Value v : op.getOperands()) {
1083  auto cst = dyn_cast<llvm::ConstantAggregate>(lookupValue(v));
1084  if (!cst)
1085  continue;
1086  auto iter = constantAggregateUseMap.find(cst);
1087  assert(iter != constantAggregateUseMap.end() && "constant not found");
1088  iter->second--;
1089  if (iter->second == 0) {
1090  // NOTE: cannot call removeDeadConstantUsers() here because it
1091  // may remove the constant which has uses not be converted yet.
1092  if (cst->user_empty()) {
1093  cst->destroyConstant();
1094  numConstantsErased++;
1095  }
1096  constantAggregateUseMap.erase(iter);
1097  }
1098  }
1099  }
1100 
1101  ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
1102  llvm::Constant *cst =
1103  cast<llvm::Constant>(lookupValue(ret.getOperand(0)));
1104  auto *global = cast<llvm::GlobalVariable>(lookupGlobal(op));
1105  if (!shouldDropGlobalInitializer(global->getLinkage(), cst))
1106  global->setInitializer(cst);
1107 
1108  // Try to remove the dangling constants again after all operations are
1109  // converted.
1110  for (auto it : constantAggregateUseMap) {
1111  auto cst = it.first;
1112  cst->removeDeadConstantUsers();
1113  if (cst->user_empty()) {
1114  cst->destroyConstant();
1115  numConstantsErased++;
1116  }
1117  }
1118 
1119  LLVM_DEBUG(llvm::dbgs()
1120  << "Convert initializer for " << op.getName() << "\n";
1121  llvm::dbgs() << numConstantsHit << " new constants hit\n";
1122  llvm::dbgs()
1123  << numConstantsErased << " dangling constants erased\n";);
1124  }
1125  }
1126 
1127  // Convert llvm.mlir.global_ctors and dtors.
1128  for (Operation &op : getModuleBody(mlirModule)) {
1129  auto ctorOp = dyn_cast<GlobalCtorsOp>(op);
1130  auto dtorOp = dyn_cast<GlobalDtorsOp>(op);
1131  if (!ctorOp && !dtorOp)
1132  continue;
1133  auto range = ctorOp ? llvm::zip(ctorOp.getCtors(), ctorOp.getPriorities())
1134  : llvm::zip(dtorOp.getDtors(), dtorOp.getPriorities());
1135  auto appendGlobalFn =
1136  ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors;
1137  for (auto symbolAndPriority : range) {
1138  llvm::Function *f = lookupFunction(
1139  cast<FlatSymbolRefAttr>(std::get<0>(symbolAndPriority)).getValue());
1140  appendGlobalFn(*llvmModule, f,
1141  cast<IntegerAttr>(std::get<1>(symbolAndPriority)).getInt(),
1142  /*Data=*/nullptr);
1143  }
1144  }
1145 
1146  for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>())
1147  if (failed(convertDialectAttributes(op, {})))
1148  return failure();
1149 
1150  // Finally, update the compile units their respective sets of global variables
1151  // created earlier.
1152  for (const auto &[compileUnit, globals] : allGVars) {
1153  compileUnit->replaceGlobalVariables(
1154  llvm::MDTuple::get(getLLVMContext(), globals));
1155  }
1156 
1157  return success();
1158 }
1159 
1160 /// Attempts to add an attribute identified by `key`, optionally with the given
1161 /// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
1162 /// attribute has a kind known to LLVM IR, create the attribute of this kind,
1163 /// otherwise keep it as a string attribute. Performs additional checks for
1164 /// attributes known to have or not have a value in order to avoid assertions
1165 /// inside LLVM upon construction.
1167  llvm::Function *llvmFunc,
1168  StringRef key,
1169  StringRef value = StringRef()) {
1170  auto kind = llvm::Attribute::getAttrKindFromName(key);
1171  if (kind == llvm::Attribute::None) {
1172  llvmFunc->addFnAttr(key, value);
1173  return success();
1174  }
1175 
1176  if (llvm::Attribute::isIntAttrKind(kind)) {
1177  if (value.empty())
1178  return emitError(loc) << "LLVM attribute '" << key << "' expects a value";
1179 
1180  int64_t result;
1181  if (!value.getAsInteger(/*Radix=*/0, result))
1182  llvmFunc->addFnAttr(
1183  llvm::Attribute::get(llvmFunc->getContext(), kind, result));
1184  else
1185  llvmFunc->addFnAttr(key, value);
1186  return success();
1187  }
1188 
1189  if (!value.empty())
1190  return emitError(loc) << "LLVM attribute '" << key
1191  << "' does not expect a value, found '" << value
1192  << "'";
1193 
1194  llvmFunc->addFnAttr(kind);
1195  return success();
1196 }
1197 
1198 /// Attaches the attributes listed in the given array attribute to `llvmFunc`.
1199 /// Reports error to `loc` if any and returns immediately. Expects `attributes`
1200 /// to be an array attribute containing either string attributes, treated as
1201 /// value-less LLVM attributes, or array attributes containing two string
1202 /// attributes, with the first string being the name of the corresponding LLVM
1203 /// attribute and the second string beings its value. Note that even integer
1204 /// attributes are expected to have their values expressed as strings.
1205 static LogicalResult
1206 forwardPassthroughAttributes(Location loc, std::optional<ArrayAttr> attributes,
1207  llvm::Function *llvmFunc) {
1208  if (!attributes)
1209  return success();
1210 
1211  for (Attribute attr : *attributes) {
1212  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
1213  if (failed(
1214  checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue())))
1215  return failure();
1216  continue;
1217  }
1218 
1219  auto arrayAttr = dyn_cast<ArrayAttr>(attr);
1220  if (!arrayAttr || arrayAttr.size() != 2)
1221  return emitError(loc)
1222  << "expected 'passthrough' to contain string or array attributes";
1223 
1224  auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]);
1225  auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]);
1226  if (!keyAttr || !valueAttr)
1227  return emitError(loc)
1228  << "expected arrays within 'passthrough' to contain two strings";
1229 
1230  if (failed(checkedAddLLVMFnAttribute(loc, llvmFunc, keyAttr.getValue(),
1231  valueAttr.getValue())))
1232  return failure();
1233  }
1234  return success();
1235 }
1236 
1237 LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
1238  // Clear the block, branch value mappings, they are only relevant within one
1239  // function.
1240  blockMapping.clear();
1241  valueMapping.clear();
1242  branchMapping.clear();
1243  llvm::Function *llvmFunc = lookupFunction(func.getName());
1244 
1245  // Add function arguments to the value remapping table.
1246  for (auto [mlirArg, llvmArg] :
1247  llvm::zip(func.getArguments(), llvmFunc->args()))
1248  mapValue(mlirArg, &llvmArg);
1249 
1250  // Check the personality and set it.
1251  if (func.getPersonality()) {
1252  llvm::Type *ty = llvm::PointerType::getUnqual(llvmFunc->getContext());
1253  if (llvm::Constant *pfunc = getLLVMConstant(ty, func.getPersonalityAttr(),
1254  func.getLoc(), *this))
1255  llvmFunc->setPersonalityFn(pfunc);
1256  }
1257 
1258  if (std::optional<StringRef> section = func.getSection())
1259  llvmFunc->setSection(*section);
1260 
1261  if (func.getArmStreaming())
1262  llvmFunc->addFnAttr("aarch64_pstate_sm_enabled");
1263  else if (func.getArmLocallyStreaming())
1264  llvmFunc->addFnAttr("aarch64_pstate_sm_body");
1265  else if (func.getArmStreamingCompatible())
1266  llvmFunc->addFnAttr("aarch64_pstate_sm_compatible");
1267 
1268  if (func.getArmNewZa())
1269  llvmFunc->addFnAttr("aarch64_new_za");
1270  else if (func.getArmInZa())
1271  llvmFunc->addFnAttr("aarch64_in_za");
1272  else if (func.getArmOutZa())
1273  llvmFunc->addFnAttr("aarch64_out_za");
1274  else if (func.getArmInoutZa())
1275  llvmFunc->addFnAttr("aarch64_inout_za");
1276  else if (func.getArmPreservesZa())
1277  llvmFunc->addFnAttr("aarch64_preserves_za");
1278 
1279  if (auto targetCpu = func.getTargetCpu())
1280  llvmFunc->addFnAttr("target-cpu", *targetCpu);
1281 
1282  if (auto targetFeatures = func.getTargetFeatures())
1283  llvmFunc->addFnAttr("target-features", targetFeatures->getFeaturesString());
1284 
1285  if (auto attr = func.getVscaleRange())
1286  llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
1287  getLLVMContext(), attr->getMinRange().getInt(),
1288  attr->getMaxRange().getInt()));
1289 
1290  if (auto unsafeFpMath = func.getUnsafeFpMath())
1291  llvmFunc->addFnAttr("unsafe-fp-math", llvm::toStringRef(*unsafeFpMath));
1292 
1293  if (auto noInfsFpMath = func.getNoInfsFpMath())
1294  llvmFunc->addFnAttr("no-infs-fp-math", llvm::toStringRef(*noInfsFpMath));
1295 
1296  if (auto noNansFpMath = func.getNoNansFpMath())
1297  llvmFunc->addFnAttr("no-nans-fp-math", llvm::toStringRef(*noNansFpMath));
1298 
1299  if (auto approxFuncFpMath = func.getApproxFuncFpMath())
1300  llvmFunc->addFnAttr("approx-func-fp-math",
1301  llvm::toStringRef(*approxFuncFpMath));
1302 
1303  if (auto noSignedZerosFpMath = func.getNoSignedZerosFpMath())
1304  llvmFunc->addFnAttr("no-signed-zeros-fp-math",
1305  llvm::toStringRef(*noSignedZerosFpMath));
1306 
1307  // Add function attribute frame-pointer, if found.
1308  if (FramePointerKindAttr attr = func.getFramePointerAttr())
1309  llvmFunc->addFnAttr("frame-pointer",
1310  LLVM::framePointerKind::stringifyFramePointerKind(
1311  (attr.getFramePointerKind())));
1312 
1313  // First, create all blocks so we can jump to them.
1314  llvm::LLVMContext &llvmContext = llvmFunc->getContext();
1315  for (auto &bb : func) {
1316  auto *llvmBB = llvm::BasicBlock::Create(llvmContext);
1317  llvmBB->insertInto(llvmFunc);
1318  mapBlock(&bb, llvmBB);
1319  }
1320 
1321  // Then, convert blocks one by one in topological order to ensure defs are
1322  // converted before uses.
1323  auto blocks = getTopologicallySortedBlocks(func.getBody());
1324  for (Block *bb : blocks) {
1325  CapturingIRBuilder builder(llvmContext);
1326  if (failed(convertBlockImpl(*bb, bb->isEntryBlock(), builder,
1327  /*recordInsertions=*/true)))
1328  return failure();
1329  }
1330 
1331  // After all blocks have been traversed and values mapped, connect the PHI
1332  // nodes to the results of preceding blocks.
1333  detail::connectPHINodes(func.getBody(), *this);
1334 
1335  // Finally, convert dialect attributes attached to the function.
1336  return convertDialectAttributes(func, {});
1337 }
1338 
1339 LogicalResult ModuleTranslation::convertDialectAttributes(
1340  Operation *op, ArrayRef<llvm::Instruction *> instructions) {
1341  for (NamedAttribute attribute : op->getDialectAttrs())
1342  if (failed(iface.amendOperation(op, instructions, attribute, *this)))
1343  return failure();
1344  return success();
1345 }
1346 
1347 /// Converts the function attributes from LLVMFuncOp and attaches them to the
1348 /// llvm::Function.
1349 static void convertFunctionAttributes(LLVMFuncOp func,
1350  llvm::Function *llvmFunc) {
1351  if (!func.getMemory())
1352  return;
1353 
1354  MemoryEffectsAttr memEffects = func.getMemoryAttr();
1355 
1356  // Add memory effects incrementally.
1357  llvm::MemoryEffects newMemEffects =
1358  llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
1359  convertModRefInfoToLLVM(memEffects.getArgMem()));
1360  newMemEffects |= llvm::MemoryEffects(
1361  llvm::MemoryEffects::Location::InaccessibleMem,
1362  convertModRefInfoToLLVM(memEffects.getInaccessibleMem()));
1363  newMemEffects |=
1364  llvm::MemoryEffects(llvm::MemoryEffects::Location::Other,
1365  convertModRefInfoToLLVM(memEffects.getOther()));
1366  llvmFunc->setMemoryEffects(newMemEffects);
1367 }
1368 
1370 ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
1371  DictionaryAttr paramAttrs) {
1372  llvm::AttrBuilder attrBuilder(llvmModule->getContext());
1373  auto attrNameToKindMapping = getAttrNameToKindMapping();
1374 
1375  for (auto namedAttr : paramAttrs) {
1376  auto it = attrNameToKindMapping.find(namedAttr.getName());
1377  if (it != attrNameToKindMapping.end()) {
1378  llvm::Attribute::AttrKind llvmKind = it->second;
1379 
1380  llvm::TypeSwitch<Attribute>(namedAttr.getValue())
1381  .Case<TypeAttr>([&](auto typeAttr) {
1382  attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue()));
1383  })
1384  .Case<IntegerAttr>([&](auto intAttr) {
1385  attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
1386  })
1387  .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); });
1388  } else if (namedAttr.getNameDialect()) {
1389  if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this)))
1390  return failure();
1391  }
1392  }
1393 
1394  return attrBuilder;
1395 }
1396 
1397 LogicalResult ModuleTranslation::convertFunctionSignatures() {
1398  // Declare all functions first because there may be function calls that form a
1399  // call graph with cycles, or global initializers that reference functions.
1400  for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
1401  llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
1402  function.getName(),
1403  cast<llvm::FunctionType>(convertType(function.getFunctionType())));
1404  llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee());
1405  llvmFunc->setLinkage(convertLinkageToLLVM(function.getLinkage()));
1406  llvmFunc->setCallingConv(convertCConvToLLVM(function.getCConv()));
1407  mapFunction(function.getName(), llvmFunc);
1408  addRuntimePreemptionSpecifier(function.getDsoLocal(), llvmFunc);
1409 
1410  // Convert function attributes.
1411  convertFunctionAttributes(function, llvmFunc);
1412 
1413  // Convert function_entry_count attribute to metadata.
1414  if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount())
1415  llvmFunc->setEntryCount(entryCount.value());
1416 
1417  // Convert result attributes.
1418  if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
1419  DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]);
1420  FailureOr<llvm::AttrBuilder> attrBuilder =
1421  convertParameterAttrs(function, -1, resultAttrs);
1422  if (failed(attrBuilder))
1423  return failure();
1424  llvmFunc->addRetAttrs(*attrBuilder);
1425  }
1426 
1427  // Convert argument attributes.
1428  for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) {
1429  if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) {
1430  FailureOr<llvm::AttrBuilder> attrBuilder =
1431  convertParameterAttrs(function, argIdx, argAttrs);
1432  if (failed(attrBuilder))
1433  return failure();
1434  llvmArg.addAttrs(*attrBuilder);
1435  }
1436  }
1437 
1438  // Forward the pass-through attributes to LLVM.
1440  function.getLoc(), function.getPassthrough(), llvmFunc)))
1441  return failure();
1442 
1443  // Convert visibility attribute.
1444  llvmFunc->setVisibility(convertVisibilityToLLVM(function.getVisibility_()));
1445 
1446  // Convert the comdat attribute.
1447  if (std::optional<mlir::SymbolRefAttr> comdat = function.getComdat()) {
1448  auto selectorOp = cast<ComdatSelectorOp>(
1449  SymbolTable::lookupNearestSymbolFrom(function, *comdat));
1450  llvmFunc->setComdat(comdatMapping.lookup(selectorOp));
1451  }
1452 
1453  if (auto gc = function.getGarbageCollector())
1454  llvmFunc->setGC(gc->str());
1455 
1456  if (auto unnamedAddr = function.getUnnamedAddr())
1457  llvmFunc->setUnnamedAddr(convertUnnamedAddrToLLVM(*unnamedAddr));
1458 
1459  if (auto alignment = function.getAlignment())
1460  llvmFunc->setAlignment(llvm::MaybeAlign(*alignment));
1461 
1462  // Translate the debug information for this function.
1463  debugTranslation->translate(function, *llvmFunc);
1464  }
1465 
1466  return success();
1467 }
1468 
1469 LogicalResult ModuleTranslation::convertFunctions() {
1470  // Convert functions.
1471  for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
1472  // Do not convert external functions, but do process dialect attributes
1473  // attached to them.
1474  if (function.isExternal()) {
1475  if (failed(convertDialectAttributes(function, {})))
1476  return failure();
1477  continue;
1478  }
1479 
1480  if (failed(convertOneFunction(function)))
1481  return failure();
1482  }
1483 
1484  return success();
1485 }
1486 
1487 LogicalResult ModuleTranslation::convertComdats() {
1488  for (auto comdatOp : getModuleBody(mlirModule).getOps<ComdatOp>()) {
1489  for (auto selectorOp : comdatOp.getOps<ComdatSelectorOp>()) {
1490  llvm::Module *module = getLLVMModule();
1491  if (module->getComdatSymbolTable().contains(selectorOp.getSymName()))
1492  return emitError(selectorOp.getLoc())
1493  << "comdat selection symbols must be unique even in different "
1494  "comdat regions";
1495  llvm::Comdat *comdat = module->getOrInsertComdat(selectorOp.getSymName());
1496  comdat->setSelectionKind(convertComdatToLLVM(selectorOp.getComdat()));
1497  comdatMapping.try_emplace(selectorOp, comdat);
1498  }
1499  }
1500  return success();
1501 }
1502 
1503 void ModuleTranslation::setAccessGroupsMetadata(AccessGroupOpInterface op,
1504  llvm::Instruction *inst) {
1505  if (llvm::MDNode *node = loopAnnotationTranslation->getAccessGroups(op))
1506  inst->setMetadata(llvm::LLVMContext::MD_access_group, node);
1507 }
1508 
1509 llvm::MDNode *
1510 ModuleTranslation::getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr) {
1511  auto [scopeIt, scopeInserted] =
1512  aliasScopeMetadataMapping.try_emplace(aliasScopeAttr, nullptr);
1513  if (!scopeInserted)
1514  return scopeIt->second;
1515  llvm::LLVMContext &ctx = llvmModule->getContext();
1516  auto dummy = llvm::MDNode::getTemporary(ctx, std::nullopt);
1517  // Convert the domain metadata node if necessary.
1518  auto [domainIt, insertedDomain] = aliasDomainMetadataMapping.try_emplace(
1519  aliasScopeAttr.getDomain(), nullptr);
1520  if (insertedDomain) {
1522  // Placeholder for self-reference.
1523  operands.push_back(dummy.get());
1524  if (StringAttr description = aliasScopeAttr.getDomain().getDescription())
1525  operands.push_back(llvm::MDString::get(ctx, description));
1526  domainIt->second = llvm::MDNode::get(ctx, operands);
1527  // Self-reference for uniqueness.
1528  domainIt->second->replaceOperandWith(0, domainIt->second);
1529  }
1530  // Convert the scope metadata node.
1531  assert(domainIt->second && "Scope's domain should already be valid");
1533  // Placeholder for self-reference.
1534  operands.push_back(dummy.get());
1535  operands.push_back(domainIt->second);
1536  if (StringAttr description = aliasScopeAttr.getDescription())
1537  operands.push_back(llvm::MDString::get(ctx, description));
1538  scopeIt->second = llvm::MDNode::get(ctx, operands);
1539  // Self-reference for uniqueness.
1540  scopeIt->second->replaceOperandWith(0, scopeIt->second);
1541  return scopeIt->second;
1542 }
1543 
1545  ArrayRef<AliasScopeAttr> aliasScopeAttrs) {
1547  nodes.reserve(aliasScopeAttrs.size());
1548  for (AliasScopeAttr aliasScopeAttr : aliasScopeAttrs)
1549  nodes.push_back(getOrCreateAliasScope(aliasScopeAttr));
1550  return llvm::MDNode::get(getLLVMContext(), nodes);
1551 }
1552 
1553 void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op,
1554  llvm::Instruction *inst) {
1555  auto populateScopeMetadata = [&](ArrayAttr aliasScopeAttrs, unsigned kind) {
1556  if (!aliasScopeAttrs || aliasScopeAttrs.empty())
1557  return;
1558  llvm::MDNode *node = getOrCreateAliasScopes(
1559  llvm::to_vector(aliasScopeAttrs.getAsRange<AliasScopeAttr>()));
1560  inst->setMetadata(kind, node);
1561  };
1562 
1563  populateScopeMetadata(op.getAliasScopesOrNull(),
1564  llvm::LLVMContext::MD_alias_scope);
1565  populateScopeMetadata(op.getNoAliasScopesOrNull(),
1566  llvm::LLVMContext::MD_noalias);
1567 }
1568 
1569 llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const {
1570  return tbaaMetadataMapping.lookup(tbaaAttr);
1571 }
1572 
1573 void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
1574  llvm::Instruction *inst) {
1575  ArrayAttr tagRefs = op.getTBAATagsOrNull();
1576  if (!tagRefs || tagRefs.empty())
1577  return;
1578 
1579  // LLVM IR currently does not support attaching more than one TBAA access tag
1580  // to a memory accessing instruction. It may be useful to support this in
1581  // future, but for the time being just ignore the metadata if MLIR operation
1582  // has multiple access tags.
1583  if (tagRefs.size() > 1) {
1584  op.emitWarning() << "TBAA access tags were not translated, because LLVM "
1585  "IR only supports a single tag per instruction";
1586  return;
1587  }
1588 
1589  llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0]));
1590  inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
1591 }
1592 
1593 void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
1594  DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
1595  if (!weightsAttr)
1596  return;
1597 
1598  llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op);
1599  assert(inst && "expected the operation to have a mapping to an instruction");
1600  SmallVector<uint32_t> weights(weightsAttr.asArrayRef());
1601  inst->setMetadata(
1602  llvm::LLVMContext::MD_prof,
1603  llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights));
1604 }
1605 
1606 LogicalResult ModuleTranslation::createTBAAMetadata() {
1607  llvm::LLVMContext &ctx = llvmModule->getContext();
1608  llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64);
1609 
1610  // Walk the entire module and create all metadata nodes for the TBAA
1611  // attributes. The code below relies on two invariants of the
1612  // `AttrTypeWalker`:
1613  // 1. Attributes are visited in post-order: Since the attributes create a DAG,
1614  // this ensures that any lookups into `tbaaMetadataMapping` for child
1615  // attributes succeed.
1616  // 2. Attributes are only ever visited once: This way we don't leak any
1617  // LLVM metadata instances.
1618  AttrTypeWalker walker;
1619  walker.addWalk([&](TBAARootAttr root) {
1620  tbaaMetadataMapping.insert(
1621  {root, llvm::MDNode::get(ctx, llvm::MDString::get(ctx, root.getId()))});
1622  });
1623 
1624  walker.addWalk([&](TBAATypeDescriptorAttr descriptor) {
1626  operands.push_back(llvm::MDString::get(ctx, descriptor.getId()));
1627  for (TBAAMemberAttr member : descriptor.getMembers()) {
1628  operands.push_back(tbaaMetadataMapping.lookup(member.getTypeDesc()));
1629  operands.push_back(llvm::ConstantAsMetadata::get(
1630  llvm::ConstantInt::get(offsetTy, member.getOffset())));
1631  }
1632 
1633  tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)});
1634  });
1635 
1636  walker.addWalk([&](TBAATagAttr tag) {
1638 
1639  operands.push_back(tbaaMetadataMapping.lookup(tag.getBaseType()));
1640  operands.push_back(tbaaMetadataMapping.lookup(tag.getAccessType()));
1641 
1642  operands.push_back(llvm::ConstantAsMetadata::get(
1643  llvm::ConstantInt::get(offsetTy, tag.getOffset())));
1644  if (tag.getConstant())
1645  operands.push_back(
1647 
1648  tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)});
1649  });
1650 
1651  mlirModule->walk([&](AliasAnalysisOpInterface analysisOpInterface) {
1652  if (auto attr = analysisOpInterface.getTBAATagsOrNull())
1653  walker.walk(attr);
1654  });
1655 
1656  return success();
1657 }
1658 
1660  llvm::Instruction *inst) {
1661  LoopAnnotationAttr attr =
1663  .Case<LLVM::BrOp, LLVM::CondBrOp>(
1664  [](auto branchOp) { return branchOp.getLoopAnnotationAttr(); });
1665  if (!attr)
1666  return;
1667  llvm::MDNode *loopMD =
1668  loopAnnotationTranslation->translateLoopAnnotation(attr, op);
1669  inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
1670 }
1671 
1673  return typeTranslator.translateType(type);
1674 }
1675 
1676 /// A helper to look up remapped operands in the value remapping table.
1678  SmallVector<llvm::Value *> remapped;
1679  remapped.reserve(values.size());
1680  for (Value v : values)
1681  remapped.push_back(lookupValue(v));
1682  return remapped;
1683 }
1684 
1685 llvm::OpenMPIRBuilder *ModuleTranslation::getOpenMPBuilder() {
1686  if (!ompBuilder) {
1687  ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(*llvmModule);
1688  ompBuilder->initialize();
1689 
1690  // Flags represented as top-level OpenMP dialect attributes are set in
1691  // `OpenMPDialectLLVMIRTranslationInterface::amendOperation()`. Here we set
1692  // the default configuration.
1693  ompBuilder->setConfig(llvm::OpenMPIRBuilderConfig(
1694  /* IsTargetDevice = */ false, /* IsGPU = */ false,
1695  /* OpenMPOffloadMandatory = */ false,
1696  /* HasRequiresReverseOffload = */ false,
1697  /* HasRequiresUnifiedAddress = */ false,
1698  /* HasRequiresUnifiedSharedMemory = */ false,
1699  /* HasRequiresDynamicAllocators = */ false));
1700  }
1701  return ompBuilder.get();
1702 }
1703 
1705  llvm::DILocalScope *scope) {
1706  return debugTranslation->translateLoc(loc, scope);
1707 }
1708 
1709 llvm::DIExpression *
1710 ModuleTranslation::translateExpression(LLVM::DIExpressionAttr attr) {
1711  return debugTranslation->translateExpression(attr);
1712 }
1713 
1714 llvm::DIGlobalVariableExpression *
1716  LLVM::DIGlobalVariableExpressionAttr attr) {
1717  return debugTranslation->translateGlobalVariableExpression(attr);
1718 }
1719 
1721  return debugTranslation->translate(attr);
1722 }
1723 
1724 llvm::RoundingMode
1725 ModuleTranslation::translateRoundingMode(LLVM::RoundingMode rounding) {
1726  return convertRoundingModeToLLVM(rounding);
1727 }
1728 
1730  LLVM::FPExceptionBehavior exceptionBehavior) {
1731  return convertFPExceptionBehaviorToLLVM(exceptionBehavior);
1732 }
1733 
1734 llvm::NamedMDNode *
1736  return llvmModule->getOrInsertNamedMetadata(name);
1737 }
1738 
1739 void ModuleTranslation::StackFrame::anchor() {}
1740 
1741 static std::unique_ptr<llvm::Module>
1742 prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
1743  StringRef name) {
1744  m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>();
1745  auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext);
1746  if (auto dataLayoutAttr =
1747  m->getDiscardableAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) {
1748  llvmModule->setDataLayout(cast<StringAttr>(dataLayoutAttr).getValue());
1749  } else {
1750  FailureOr<llvm::DataLayout> llvmDataLayout(llvm::DataLayout(""));
1751  if (auto iface = dyn_cast<DataLayoutOpInterface>(m)) {
1752  if (DataLayoutSpecInterface spec = iface.getDataLayoutSpec()) {
1753  llvmDataLayout =
1754  translateDataLayout(spec, DataLayout(iface), m->getLoc());
1755  }
1756  } else if (auto mod = dyn_cast<ModuleOp>(m)) {
1757  if (DataLayoutSpecInterface spec = mod.getDataLayoutSpec()) {
1758  llvmDataLayout =
1759  translateDataLayout(spec, DataLayout(mod), m->getLoc());
1760  }
1761  }
1762  if (failed(llvmDataLayout))
1763  return nullptr;
1764  llvmModule->setDataLayout(*llvmDataLayout);
1765  }
1766  if (auto targetTripleAttr =
1767  m->getDiscardableAttr(LLVM::LLVMDialect::getTargetTripleAttrName()))
1768  llvmModule->setTargetTriple(cast<StringAttr>(targetTripleAttr).getValue());
1769 
1770  return llvmModule;
1771 }
1772 
1773 std::unique_ptr<llvm::Module>
1774 mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
1775  StringRef name) {
1776  if (!satisfiesLLVMModule(module)) {
1777  module->emitOpError("can not be translated to an LLVMIR module");
1778  return nullptr;
1779  }
1780 
1781  std::unique_ptr<llvm::Module> llvmModule =
1782  prepareLLVMModule(module, llvmContext, name);
1783  if (!llvmModule)
1784  return nullptr;
1785 
1788 
1789  ModuleTranslation translator(module, std::move(llvmModule));
1790  llvm::IRBuilder<> llvmBuilder(llvmContext);
1791 
1792  // Convert module before functions and operations inside, so dialect
1793  // attributes can be used to change dialect-specific global configurations via
1794  // `amendOperation()`. These configurations can then influence the translation
1795  // of operations afterwards.
1796  if (failed(translator.convertOperation(*module, llvmBuilder)))
1797  return nullptr;
1798 
1799  if (failed(translator.convertComdats()))
1800  return nullptr;
1801  if (failed(translator.convertFunctionSignatures()))
1802  return nullptr;
1803  if (failed(translator.convertGlobals()))
1804  return nullptr;
1805  if (failed(translator.createTBAAMetadata()))
1806  return nullptr;
1807 
1808  // Convert other top-level operations if possible.
1809  for (Operation &o : getModuleBody(module).getOperations()) {
1810  if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp, LLVM::GlobalCtorsOp,
1811  LLVM::GlobalDtorsOp, LLVM::ComdatOp>(&o) &&
1812  !o.hasTrait<OpTrait::IsTerminator>() &&
1813  failed(translator.convertOperation(o, llvmBuilder))) {
1814  return nullptr;
1815  }
1816  }
1817 
1818  // Operations in function bodies with symbolic references must be converted
1819  // after the top-level operations they refer to are declared, so we do it
1820  // last.
1821  if (failed(translator.convertFunctions()))
1822  return nullptr;
1823 
1824  if (llvm::verifyModule(*translator.llvmModule, &llvm::errs()))
1825  return nullptr;
1826 
1827  return std::move(translator.llvmModule);
1828 }
static MLIRContext * getContext(OpFoldResult val)
@ None
static Value getPHISourceValue(Block *current, Block *pred, unsigned numArguments, unsigned index)
Get the SSA value passed to the current block from the terminator operation of its predecessor.
static llvm::Constant * convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, llvm::Type *llvmType, const ModuleTranslation &moduleTranslation)
Convert a dense elements attribute to an LLVM IR constant using its raw data storage if possible.
static Block & getModuleBody(Operation *module)
A helper method to get the single Block in an operation honoring LLVM's module requirements.
static void addRuntimePreemptionSpecifier(bool dsoLocalRequested, llvm::GlobalValue *gv)
Sets the runtime preemption specifier of gv to dso_local if dsoLocalRequested is true,...
static LogicalResult checkedAddLLVMFnAttribute(Location loc, llvm::Function *llvmFunc, StringRef key, StringRef value=StringRef())
Attempts to add an attribute identified by key, optionally with the given value to LLVM function llvm...
static void convertFunctionAttributes(LLVMFuncOp func, llvm::Function *llvmFunc)
Converts the function attributes from LLVMFuncOp and attaches them to the llvm::Function.
static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage, llvm::Constant *cst)
A helper method to decide if a constant must not be set as a global variable initializer.
static llvm::Type * getInnermostElementType(llvm::Type *type)
Returns the first non-sequential type nested in sequential types.
static std::unique_ptr< llvm::Module > prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, StringRef name)
static llvm::Constant * convertDenseResourceElementsAttr(Location loc, DenseResourceElementsAttr denseResourceAttr, llvm::Type *llvmType, const ModuleTranslation &moduleTranslation)
Convert a dense resource elements attribute to an LLVM IR constant using its raw data storage if poss...
static LogicalResult forwardPassthroughAttributes(Location loc, std::optional< ArrayAttr > attributes, llvm::Function *llvmFunc)
Attaches the attributes listed in the given array attribute to llvmFunc.
static llvm::Constant * buildSequentialConstant(ArrayRef< llvm::Constant * > &constants, ArrayRef< int64_t > shape, llvm::Type *type, Location loc)
Builds a constant of a sequential LLVM type type, potentially containing other sequential types recur...
The following classes enable support for parsing and printing resources within MLIR assembly formats.
Definition: AsmState.h:88
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition: AsmState.h:142
void addWalk(WalkFn< Attribute > &&fn)
Register a walk function for a given attribute or type.
WalkResult walk(T element)
Walk the given attribute/type, and recursively walk any sub elements.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:234
BlockArgListType getArguments()
Definition: Block.h:84
Operation & front()
Definition: Block.h:150
The main mechanism for performing data layout queries.
std::optional< uint64_t > getTypeIndexBitwidth(Type t) const
Returns the bitwidth that should be used when performing index computations for the given pointer-lik...
uint64_t getTypePreferredAlignment(Type t) const
Returns the preferred of the given type in the current scope.
uint64_t getTypeABIAlignment(Type t) const
Returns the required alignment of the given type in the current scope.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t getNumElements() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
const InterfaceType * getInterfaceFor(Object *obj) const
Get the interface for a given object, or null if one is not registered.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
Base class for dialect interfaces providing translation to LLVM IR.
virtual LogicalResult convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const
Hook for derived dialect interface to provide translation of the operations to LLVM IR.
virtual LogicalResult convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const
Acts on the given function operation using the interface implemented by the dialect of one of the fun...
virtual LogicalResult amendOperation(Operation *op, ArrayRef< llvm::Instruction * > instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const
Acts on the given operation using the interface implemented by the dialect of one of the operation's ...
This class represents the base attribute for all debug info attributes.
Definition: LLVMAttrs.h:27
Implementation class for module translation.
llvm::fp::ExceptionBehavior translateFPExceptionBehavior(LLVM::FPExceptionBehavior exceptionBehavior)
Translates the given LLVM FP exception behavior metadata.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
llvm::DIGlobalVariableExpression * translateGlobalVariableExpression(LLVM::DIGlobalVariableExpressionAttr attr)
Translates the given LLVM global variable expression metadata.
llvm::NamedMDNode * getOrInsertNamedModuleMetadata(StringRef name)
Gets the named metadata in the LLVM IR module being constructed, creating it if it does not exist.
llvm::Instruction * lookupBranch(Operation *op) const
Finds an LLVM IR instruction that corresponds to the given MLIR operation with successors.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
void mapFunction(StringRef name, llvm::Function *func)
Stores the mapping between a function name and its LLVM IR representation.
llvm::DILocation * translateLoc(Location loc, llvm::DILocalScope *scope)
Translates the given location.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
void setBranchWeightsMetadata(BranchWeightOpInterface op)
Sets LLVM profiling metadata for operations that have branch weights.
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
llvm::RoundingMode translateRoundingMode(LLVM::RoundingMode rounding)
Translates the given LLVM rounding mode metadata.
void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst)
Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
llvm::DIExpression * translateExpression(LLVM::DIExpressionAttr attr)
Translates the given LLVM DWARF expression metadata.
llvm::OpenMPIRBuilder * getOpenMPBuilder()
Returns the OpenMP IR builder associated with the LLVM IR module being constructed.
llvm::CallInst * lookupCall(Operation *op) const
Finds an LLVM call instruction that corresponds to the given MLIR call operation.
llvm::Metadata * translateDebugInfo(LLVM::DINodeAttr attr)
Translates the given LLVM debug info metadata.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
llvm::MDNode * getOrCreateAliasScopes(ArrayRef< AliasScopeAttr > aliasScopeAttrs)
Returns the LLVM metadata corresponding to an array of mlir LLVM dialect alias scope attributes.
void mapBlock(Block *mlir, llvm::BasicBlock *llvm)
Stores the mapping between an MLIR block and LLVM IR basic block.
llvm::MDNode * getOrCreateAliasScope(AliasScopeAttr aliasScopeAttr)
Returns the LLVM metadata corresponding to a mlir LLVM dialect alias scope attribute.
void forgetMapping(Region &region)
Removes the mapping for blocks contained in the region and values defined in these blocks.
void setAliasScopeMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst)
void setAccessGroupsMetadata(AccessGroupOpInterface op, llvm::Instruction *inst)
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
void setLoopMetadata(Operation *op, llvm::Instruction *inst)
Sets LLVM loop metadata for branch operations that have a loop annotation attribute.
llvm::Type * translateType(Type type)
Translates the given MLIR LLVM dialect type to LLVM IR.
Definition: TypeToLLVM.cpp:192
A helper class that converts LoopAnnotationAttrs and AccessGroupAttrs into corresponding llvm::MDNode...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
Definition: MLIRContext.h:97
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Attribute getDiscardableAttr(StringRef name)
Access a discardable attribute by name, returns an null Attribute if the discardable attribute does n...
Definition: Operation.h:448
Value getOperand(unsigned idx)
Definition: Operation.h:345
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
Block * getSuccessor(unsigned index)
Definition: Operation.h:704
unsigned getNumSuccessors()
Definition: Operation.h:702
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Definition: Operation.cpp:280
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
dialect_attr_range getDialectAttrs()
Return a range corresponding to the dialect attributes for this operation.
Definition: Operation.h:632
bool hasSuccessors()
Definition: Operation.h:701
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & front()
Definition: Region.h:65
This class models how operands are forwarded to block arguments in control flow.
bool empty() const
Returns true if there are no successor operands.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_iterator use_end() const
Definition: Value.h:209
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_iterator use_begin() const
Definition: Value.h:208
Include the generated interface declarations.
Definition: CallGraph.h:229
void connectPHINodes(Region &region, const ModuleTranslation &state)
For all blocks in the region that were converted to LLVM IR using the given ModuleTranslation,...
llvm::CallInst * createIntrinsicCall(llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic, ArrayRef< llvm::Value * > args={}, ArrayRef< llvm::Type * > tys={})
Creates a call to an LLVM IR intrinsic function with the given arguments.
static llvm::DenseMap< llvm::StringRef, llvm::Attribute::AttrKind > getAttrNameToKindMapping()
Returns a dense map from LLVM attribute name to their kind in LLVM IR dialect.
llvm::Constant * getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc, const ModuleTranslation &moduleTranslation)
Create an LLVM IR constant of llvmType from the MLIR attribute attr.
bool satisfiesLLVMModule(Operation *op)
LLVM requires some operations to be inside of a Module operation.
void legalizeDIExpressionsRecursively(Operation *op)
Register all known legalization patterns declared here and apply them to all ops in op.
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:858
void ensureDistinctSuccessors(Operation *op)
Make argument-taking successors of each block distinct.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index)
Returns the dictionary attribute corresponding to the argument at 'index'.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SetVector< Block * > getTopologicallySortedBlocks(Region &region)
Get a topologically sorted list of blocks of the given region.
DataLayoutSpecInterface translateDataLayout(const llvm::DataLayout &dataLayout, MLIRContext *context)
Translate the given LLVM data layout into an MLIR equivalent using the DLTI dialect.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< llvm::Module > translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, llvm::StringRef name="LLVMDialectModule")
Translate operation that satisfies LLVM dialect module requirements into an LLVM IR module living in ...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
Definition: MathExtras.h:45
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26