MLIR  18.0.0git
TypeConverter.cpp
Go to the documentation of this file.
1 //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "MemRefDescriptor.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Threading.h"
16 #include <memory>
17 #include <mutex>
18 #include <optional>
19 
20 using namespace mlir;
21 
23  {
24  // Most of the time, the entry already exists in the map.
25  std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
26  std::defer_lock);
27  if (getContext().isMultithreadingEnabled())
28  lock.lock();
29  auto recursiveStack = conversionCallStack.find(llvm::get_threadid());
30  if (recursiveStack != conversionCallStack.end())
31  return *recursiveStack->second;
32  }
33 
34  // First time this thread gets here, we have to get an exclusive access to
35  // inset in the map
36  std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
37  auto recursiveStackInserted = conversionCallStack.insert(std::make_pair(
38  llvm::get_threadid(), std::make_unique<SmallVector<Type>>()));
39  return *recursiveStackInserted.first->second;
40 }
41 
42 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
44  const DataLayoutAnalysis *analysis)
45  : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
46 
47 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
50  const DataLayoutAnalysis *analysis)
51  : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
52  dataLayoutAnalysis(analysis) {
53  assert(llvmDialect && "LLVM IR dialect is not registered");
54 
55  // Register conversions for the builtin types.
56  addConversion([&](ComplexType type) { return convertComplexType(type); });
57  addConversion([&](FloatType type) { return convertFloatType(type); });
58  addConversion([&](FunctionType type) { return convertFunctionType(type); });
59  addConversion([&](IndexType type) { return convertIndexType(type); });
60  addConversion([&](IntegerType type) { return convertIntegerType(type); });
61  addConversion([&](MemRefType type) { return convertMemRefType(type); });
63  [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
64  addConversion([&](VectorType type) -> std::optional<Type> {
65  FailureOr<Type> llvmType = convertVectorType(type);
66  if (failed(llvmType))
67  return std::nullopt;
68  return llvmType;
69  });
70 
71  // LLVM-compatible types are legal, so add a pass-through conversion. Do this
72  // before the conversions below since conversions are attempted in reverse
73  // order and those should take priority.
74  addConversion([](Type type) {
75  return LLVM::isCompatibleType(type) ? std::optional<Type>(type)
76  : std::nullopt;
77  });
78 
79  // LLVM container types may (recursively) contain other types that must be
80  // converted even when the outer type is compatible.
81  addConversion([&](LLVM::LLVMPointerType type) -> std::optional<Type> {
82  if (type.isOpaque())
83  return type;
84  if (auto pointee = convertType(type.getElementType()))
85  return LLVM::LLVMPointerType::get(pointee, type.getAddressSpace());
86  return std::nullopt;
87  });
88 
90  -> std::optional<LogicalResult> {
91  // Fastpath for types that won't be converted by this callback anyway.
92  if (LLVM::isCompatibleType(type)) {
93  results.push_back(type);
94  return success();
95  }
96 
97  if (type.isIdentified()) {
98  auto convertedType = LLVM::LLVMStructType::getIdentified(
99  type.getContext(), ("_Converted_" + type.getName()).str());
100  unsigned counter = 1;
101  while (convertedType.isInitialized()) {
102  assert(counter != UINT_MAX &&
103  "about to overflow struct renaming counter in conversion");
104  convertedType = LLVM::LLVMStructType::getIdentified(
105  type.getContext(),
106  ("_Converted_" + std::to_string(counter) + type.getName()).str());
107  }
108 
110  if (llvm::count(recursiveStack, type)) {
111  results.push_back(convertedType);
112  return success();
113  }
114  recursiveStack.push_back(type);
115  auto popConversionCallStack = llvm::make_scope_exit(
116  [&recursiveStack]() { recursiveStack.pop_back(); });
117 
118  SmallVector<Type> convertedElemTypes;
119  convertedElemTypes.reserve(type.getBody().size());
120  if (failed(convertTypes(type.getBody(), convertedElemTypes)))
121  return std::nullopt;
122 
123  if (failed(convertedType.setBody(convertedElemTypes, type.isPacked())))
124  return failure();
125  results.push_back(convertedType);
126  return success();
127  }
128 
129  SmallVector<Type> convertedSubtypes;
130  convertedSubtypes.reserve(type.getBody().size());
131  if (failed(convertTypes(type.getBody(), convertedSubtypes)))
132  return std::nullopt;
133 
134  results.push_back(LLVM::LLVMStructType::getLiteral(
135  type.getContext(), convertedSubtypes, type.isPacked()));
136  return success();
137  });
138  addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> {
139  if (auto element = convertType(type.getElementType()))
140  return LLVM::LLVMArrayType::get(element, type.getNumElements());
141  return std::nullopt;
142  });
143  addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
144  Type convertedResType = convertType(type.getReturnType());
145  if (!convertedResType)
146  return std::nullopt;
147 
148  SmallVector<Type> convertedArgTypes;
149  convertedArgTypes.reserve(type.getNumParams());
150  if (failed(convertTypes(type.getParams(), convertedArgTypes)))
151  return std::nullopt;
152 
153  return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
154  type.isVarArg());
155  });
156 
157  // Materialization for memrefs creates descriptor structs from individual
158  // values constituting them, when descriptors are used, i.e. more than one
159  // value represents a memref.
161  [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
162  Location loc) -> std::optional<Value> {
163  if (inputs.size() == 1)
164  return std::nullopt;
165  return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
166  inputs);
167  });
168  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
169  ValueRange inputs,
170  Location loc) -> std::optional<Value> {
171  // TODO: bare ptr conversion could be handled here but we would need a way
172  // to distinguish between FuncOp and other regions.
173  if (inputs.size() == 1)
174  return std::nullopt;
175  return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
176  });
177  // Add generic source and target materializations to handle cases where
178  // non-LLVM types persist after an LLVM conversion.
179  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
180  ValueRange inputs,
181  Location loc) -> std::optional<Value> {
182  if (inputs.size() != 1)
183  return std::nullopt;
184 
185  return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
186  .getResult(0);
187  });
188  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
189  ValueRange inputs,
190  Location loc) -> std::optional<Value> {
191  if (inputs.size() != 1)
192  return std::nullopt;
193 
194  return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
195  .getResult(0);
196  });
197 
198  // Integer memory spaces map to themselves.
200  [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
201 }
202 
203 /// Returns the MLIR context.
205  return *getDialect()->getContext();
206 }
207 
210 }
211 
212 LLVM::LLVMPointerType
214  unsigned int addressSpace) const {
215  if (useOpaquePointers())
216  return LLVM::LLVMPointerType::get(&getContext(), addressSpace);
217  return LLVM::LLVMPointerType::get(elementType, addressSpace);
218 }
219 
220 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const {
221  return options.dataLayout.getPointerSizeInBits(addressSpace);
222 }
223 
224 Type LLVMTypeConverter::convertIndexType(IndexType type) const {
225  return getIndexType();
226 }
227 
228 Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
229  return IntegerType::get(&getContext(), type.getWidth());
230 }
231 
232 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
233  if (type.isFloat8E5M2() || type.isFloat8E4M3FN() || type.isFloat8E5M2FNUZ() ||
234  type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ())
235  return IntegerType::get(&getContext(), type.getWidth());
236  return type;
237 }
238 
239 // Convert a `ComplexType` to an LLVM type. The result is a complex number
240 // struct with entries for the
241 // 1. real part and for the
242 // 2. imaginary part.
243 Type LLVMTypeConverter::convertComplexType(ComplexType type) const {
244  auto elementType = convertType(type.getElementType());
246  {elementType, elementType});
247 }
248 
249 // Except for signatures, MLIR function types are converted into LLVM
250 // pointer-to-function types.
251 Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
252  SignatureConversion conversion(type.getNumInputs());
253  Type converted = convertFunctionSignature(
254  type, /*isVariadic=*/false, options.useBarePtrCallConv, conversion);
255  if (!converted)
256  return {};
257  return getPointerType(converted);
258 }
259 
260 // Function types are converted to LLVM Function types by recursively converting
261 // argument and result types. If MLIR Function has zero results, the LLVM
262 // Function has one VoidType result. If MLIR Function has more than one result,
263 // they are into an LLVM StructType in their order of appearance.
265  FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
267  // Select the argument converter depending on the calling convention.
268  useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
269  auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
271  // Convert argument types one by one and check for errors.
272  for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
273  SmallVector<Type, 8> converted;
274  if (failed(funcArgConverter(*this, type, converted)))
275  return {};
276  result.addInputs(idx, converted);
277  }
278 
279  // If function does not return anything, create the void result type,
280  // if it returns on element, convert it, otherwise pack the result types into
281  // a struct.
282  Type resultType =
283  funcTy.getNumResults() == 0
285  : packFunctionResults(funcTy.getResults(), useBarePtrCallConv);
286  if (!resultType)
287  return {};
288  return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
289  isVariadic);
290 }
291 
292 /// Converts the function type to a C-compatible format, in particular using
293 /// pointers to memref descriptors for arguments.
294 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
296  SmallVector<Type, 4> inputs;
297 
298  Type resultType = type.getNumResults() == 0
300  : packFunctionResults(type.getResults());
301  if (!resultType)
302  return {};
303 
304  auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
305  if (structType) {
306  // Struct types cannot be safely returned via C interface. Make this a
307  // pointer argument, instead.
308  inputs.push_back(getPointerType(structType));
309  resultType = LLVM::LLVMVoidType::get(&getContext());
310  }
311 
312  for (Type t : type.getInputs()) {
313  auto converted = convertType(t);
314  if (!converted || !LLVM::isCompatibleType(converted))
315  return {};
316  if (isa<MemRefType, UnrankedMemRefType>(t))
317  converted = getPointerType(converted);
318  inputs.push_back(converted);
319  }
320 
321  return {LLVM::LLVMFunctionType::get(resultType, inputs), structType};
322 }
323 
324 /// Convert a memref type into a list of LLVM IR types that will form the
325 /// memref descriptor. The result contains the following types:
326 /// 1. The pointer to the allocated data buffer, followed by
327 /// 2. The pointer to the aligned data buffer, followed by
328 /// 3. A lowered `index`-type integer containing the distance between the
329 /// beginning of the buffer and the first element to be accessed through the
330 /// view, followed by
331 /// 4. An array containing as many `index`-type integers as the rank of the
332 /// MemRef: the array represents the size, in number of elements, of the memref
333 /// along the given dimension. For constant MemRef dimensions, the
334 /// corresponding size entry is a constant whose runtime value must match the
335 /// static value, followed by
336 /// 5. A second array containing as many `index`-type integers as the rank of
337 /// the MemRef: the second array represents the "stride" (in tensor abstraction
338 /// sense), i.e. the number of consecutive elements of the underlying buffer.
339 /// TODO: add assertions for the static cases.
340 ///
341 /// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
342 /// are expanded into individual index-type elements.
343 ///
344 /// template <typename Elem, typename Index, size_t Rank>
345 /// struct {
346 /// Elem *allocatedPtr;
347 /// Elem *alignedPtr;
348 /// Index offset;
349 /// Index sizes[Rank]; // omitted when rank == 0
350 /// Index strides[Rank]; // omitted when rank == 0
351 /// };
353 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
354  bool unpackAggregates) const {
355  if (!isStrided(type)) {
356  emitError(
357  UnknownLoc::get(type.getContext()),
358  "conversion to strided form failed either due to non-strided layout "
359  "maps (which should have been normalized away) or other reasons");
360  return {};
361  }
362 
363  Type elementType = convertType(type.getElementType());
364  if (!elementType)
365  return {};
366 
367  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
368  if (failed(addressSpace)) {
369  emitError(UnknownLoc::get(type.getContext()),
370  "conversion of memref memory space ")
371  << type.getMemorySpace()
372  << " to integer address space "
373  "failed. Consider adding memory space conversions.";
374  return {};
375  }
376  auto ptrTy = getPointerType(elementType, *addressSpace);
377 
378  auto indexTy = getIndexType();
379 
380  SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
381  auto rank = type.getRank();
382  if (rank == 0)
383  return results;
384 
385  if (unpackAggregates)
386  results.insert(results.end(), 2 * rank, indexTy);
387  else
388  results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
389  return results;
390 }
391 
392 unsigned
394  const DataLayout &layout) const {
395  // Compute the descriptor size given that of its components indicated above.
396  unsigned space = *getMemRefAddressSpace(type);
397  return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
398  (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
399 }
400 
401 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
402 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
403 Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
404  // When converting a MemRefType to a struct with descriptor fields, do not
405  // unpack the `sizes` and `strides` arrays.
406  SmallVector<Type, 5> types =
407  getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
408  if (types.empty())
409  return {};
411 }
412 
413 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
414 /// that will form the unranked memref descriptor. In particular, the fields
415 /// for an unranked memref descriptor are:
416 /// 1. index-typed rank, the dynamic rank of this MemRef
417 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
418 /// stack allocated (alloca) copy of a MemRef descriptor that got casted to
419 /// be unranked.
421 LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const {
423 }
424 
426  UnrankedMemRefType type, const DataLayout &layout) const {
427  // Compute the descriptor size given that of its components indicated above.
428  unsigned space = *getMemRefAddressSpace(type);
429  return layout.getTypeSize(getIndexType()) +
430  llvm::divideCeil(getPointerBitwidth(space), 8);
431 }
432 
433 Type LLVMTypeConverter::convertUnrankedMemRefType(
434  UnrankedMemRefType type) const {
435  if (!convertType(type.getElementType()))
436  return {};
438  getUnrankedMemRefDescriptorFields());
439 }
440 
443  if (!type.getMemorySpace()) // Default memory space -> 0.
444  return 0;
445  std::optional<Attribute> converted =
446  convertTypeAttribute(type, type.getMemorySpace());
447  if (!converted)
448  return failure();
449  if (!(*converted)) // Conversion to default is 0.
450  return 0;
451  if (auto explicitSpace = llvm::dyn_cast_if_present<IntegerAttr>(*converted))
452  return explicitSpace.getInt();
453  return failure();
454 }
455 
456 // Check if a memref type can be converted to a bare pointer.
458  if (isa<UnrankedMemRefType>(type))
459  // Unranked memref is not supported in the bare pointer calling convention.
460  return false;
461 
462  // Check that the memref has static shape, strides and offset. Otherwise, it
463  // cannot be lowered to a bare pointer.
464  auto memrefTy = cast<MemRefType>(type);
465  if (!memrefTy.hasStaticShape())
466  return false;
467 
468  int64_t offset = 0;
469  SmallVector<int64_t, 4> strides;
470  if (failed(getStridesAndOffset(memrefTy, strides, offset)))
471  return false;
472 
473  for (int64_t stride : strides)
474  if (ShapedType::isDynamic(stride))
475  return false;
476 
477  return !ShapedType::isDynamic(offset);
478 }
479 
480 /// Convert a memref type to a bare pointer to the memref element type.
481 Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
482  if (!canConvertToBarePtr(type))
483  return {};
484  Type elementType = convertType(type.getElementType());
485  if (!elementType)
486  return {};
487  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
488  if (failed(addressSpace))
489  return {};
490  return getPointerType(elementType, *addressSpace);
491 }
492 
493 /// Convert an n-D vector type to an LLVM vector type:
494 /// * 0-D `vector<T>` are converted to vector<1xT>
495 /// * 1-D `vector<axT>` remains as is while,
496 /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
497 /// `!llvm.array<ax...array<jxvector<kxT>>>`.
498 /// Returns failure for n-D scalable vector types as LLVM does not support
499 /// arrays of scalable vectors.
500 FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
501  auto elementType = convertType(type.getElementType());
502  if (!elementType)
503  return {};
504  if (type.getShape().empty())
505  return VectorType::get({1}, elementType);
506  Type vectorType = VectorType::get(type.getShape().back(), elementType,
507  type.getScalableDims().back());
508  assert(LLVM::isCompatibleVectorType(vectorType) &&
509  "expected vector type compatible with the LLVM dialect");
510  // Only the trailing dimension can be scalable.
511  if (llvm::is_contained(type.getScalableDims().drop_back(), true))
512  return failure();
513  auto shape = type.getShape();
514  for (int i = shape.size() - 2; i >= 0; --i)
515  vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
516  return vectorType;
517 }
518 
519 /// Convert a type in the context of the default or bare pointer calling
520 /// convention. Calling convention sensitive types, such as MemRefType and
521 /// UnrankedMemRefType, are converted following the specific rules for the
522 /// calling convention. Calling convention independent types are converted
523 /// following the default LLVM type conversions.
525  Type type, bool useBarePtrCallConv) const {
526  if (useBarePtrCallConv)
527  if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
528  return convertMemRefToBarePtr(memrefTy);
529 
530  return convertType(type);
531 }
532 
533 /// Promote the bare pointers in 'values' that resulted from memrefs to
534 /// descriptors. 'stdTypes' holds they types of 'values' before the conversion
535 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
537  ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
538  SmallVectorImpl<Value> &values) const {
539  assert(stdTypes.size() == values.size() &&
540  "The number of types and values doesn't match");
541  for (unsigned i = 0, end = values.size(); i < end; ++i)
542  if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
543  values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
544  memrefTy, values[i]);
545 }
546 
547 /// Convert a non-empty list of types of values produced by an operation into an
548 /// LLVM-compatible type. In particular, if more than one value is
549 /// produced, create a literal structure with elements that correspond to each
550 /// of the types converted with `convertType`.
552  assert(!types.empty() && "expected non-empty list of type");
553  if (types.size() == 1)
554  return convertType(types[0]);
555 
556  SmallVector<Type> resultTypes;
557  resultTypes.reserve(types.size());
558  for (Type type : types) {
559  Type converted = convertType(type);
560  if (!converted || !LLVM::isCompatibleType(converted))
561  return {};
562  resultTypes.push_back(converted);
563  }
564 
565  return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
566 }
567 
568 /// Convert a non-empty list of types to be returned from a function into an
569 /// LLVM-compatible type. In particular, if more than one value is returned,
570 /// create an LLVM dialect structure type with elements that correspond to each
571 /// of the types converted with `convertCallingConventionType`.
573  bool useBarePtrCallConv) const {
574  assert(!types.empty() && "expected non-empty list of type");
575 
576  useBarePtrCallConv |= options.useBarePtrCallConv;
577  if (types.size() == 1)
578  return convertCallingConventionType(types.front(), useBarePtrCallConv);
579 
580  SmallVector<Type> resultTypes;
581  resultTypes.reserve(types.size());
582  for (auto t : types) {
583  auto converted = convertCallingConventionType(t, useBarePtrCallConv);
584  if (!converted || !LLVM::isCompatibleType(converted))
585  return {};
586  resultTypes.push_back(converted);
587  }
588 
589  return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
590 }
591 
593  OpBuilder &builder) const {
594  // Alloca with proper alignment. We do not expect optimizations of this
595  // alloca op and so we omit allocating at the entry block.
596  auto ptrType = getPointerType(operand.getType());
597  Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
598  builder.getIndexAttr(1));
599  Value allocated =
600  builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one);
601  // Store into the alloca'ed descriptor.
602  builder.create<LLVM::StoreOp>(loc, operand, allocated);
603  return allocated;
604 }
605 
608  ValueRange operands, OpBuilder &builder,
609  bool useBarePtrCallConv) const {
610  SmallVector<Value, 4> promotedOperands;
611  promotedOperands.reserve(operands.size());
612  useBarePtrCallConv |= options.useBarePtrCallConv;
613  for (auto it : llvm::zip(opOperands, operands)) {
614  auto operand = std::get<0>(it);
615  auto llvmOperand = std::get<1>(it);
616 
617  if (useBarePtrCallConv) {
618  // For the bare-ptr calling convention, we only have to extract the
619  // aligned pointer of a memref.
620  if (dyn_cast<MemRefType>(operand.getType())) {
621  MemRefDescriptor desc(llvmOperand);
622  llvmOperand = desc.alignedPtr(builder, loc);
623  } else if (isa<UnrankedMemRefType>(operand.getType())) {
624  llvm_unreachable("Unranked memrefs are not supported");
625  }
626  } else {
627  if (isa<UnrankedMemRefType>(operand.getType())) {
628  UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
629  promotedOperands);
630  continue;
631  }
632  if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
633  MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
634  promotedOperands);
635  continue;
636  }
637  }
638 
639  promotedOperands.push_back(llvmOperand);
640  }
641  return promotedOperands;
642 }
643 
644 /// Callback to convert function argument types. It converts a MemRef function
645 /// argument to a list of non-aggregate types containing descriptor
646 /// information, and an UnrankedmemRef function argument to a list containing
647 /// the rank and a pointer to a descriptor struct.
650  SmallVectorImpl<Type> &result) {
651  if (auto memref = dyn_cast<MemRefType>(type)) {
652  // In signatures, Memref descriptors are expanded into lists of
653  // non-aggregate values.
654  auto converted =
655  converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
656  if (converted.empty())
657  return failure();
658  result.append(converted.begin(), converted.end());
659  return success();
660  }
661  if (isa<UnrankedMemRefType>(type)) {
662  auto converted = converter.getUnrankedMemRefDescriptorFields();
663  if (converted.empty())
664  return failure();
665  result.append(converted.begin(), converted.end());
666  return success();
667  }
668  auto converted = converter.convertType(type);
669  if (!converted)
670  return failure();
671  result.push_back(converted);
672  return success();
673 }
674 
675 /// Callback to convert function argument types. It converts MemRef function
676 /// arguments to bare pointers to the MemRef element type.
679  SmallVectorImpl<Type> &result) {
680  auto llvmTy = converter.convertCallingConventionType(
681  type, /*useBarePointerCallConv=*/true);
682  if (!llvmTy)
683  return failure();
684 
685  result.push_back(llvmTy);
686  return success();
687 }
static llvm::ManagedStatic< PassManagerOptions > options
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:137
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
IntegerType getI64Type()
Definition: Builders.cpp:85
This class implements a pattern rewriter for use with ConversionPatterns.
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
unsigned getTypeSize(Type t) const
Returns the size of the given type in the current scope.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
unsigned getWidth()
Return the bitwidth of this float type.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
LLVM::LLVMDialect * llvmDialect
Pointer to the LLVM dialect.
llvm::sys::SmartRWMutex< true > callStackMutex
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Definition: TypeConverter.h:91
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, Location loc, ArrayRef< Type > stdTypes, SmallVectorImpl< Value > &values) const
Promote the bare pointers in 'values' that resulted from memrefs to descriptors.
DenseMap< uint64_t, std::unique_ptr< SmallVector< Type > > > conversionCallStack
SmallVector< Type > & getCurrentThreadRecursiveStack()
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
Type convertCallingConventionType(Type type, bool useBarePointerCallConv=false) const
Convert a type in the context of the default or bare pointer calling convention.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
SmallVector< Value, 4 > promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder, bool useBarePtrCallConv=false) const
Promote the LLVM representation of all operands including promoting MemRef descriptors to stack and u...
LLVMTypeConverter(MLIRContext *ctx, const DataLayoutAnalysis *analysis=nullptr)
Create an LLVMTypeConverter using the default LowerToLLVMOptions.
unsigned getPointerBitwidth(unsigned addressSpace=0) const
Gets the pointer bitwidth.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LLVM::LLVMPointerType getPointerType(Type elementType, unsigned addressSpace=0) const
Creates an LLVM pointer type with the given element type and address space.
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
Type getIndexType() const
Gets the LLVM representation of the index type.
bool useOpaquePointers() const
Returns true if using opaque pointers was enabled in the lowering options.
friend LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Give structFuncArgTypeConverter access to memref-specific functions.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:108
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:502
static LLVMStructType getIdentified(MLIRContext *context, StringRef name)
Gets or creates an identified struct with the given name in the provided context.
Definition: LLVMTypes.cpp:474
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
llvm::DataLayout dataLayout
The data layout of the module to produce.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
void addArgumentMaterialization(FnT &&callback)
Register a materialization function, which must be convertible to the following form: std::optional<V...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isFloat8E4M3FN() const
Definition: Types.cpp:38
bool isFloat8E4M3FNUZ() const
Definition: Types.cpp:42
bool isFloat8E4M3B11FNUZ() const
Definition: Types.cpp:45
bool isFloat8E5M2() const
Definition: Types.cpp:37
bool isFloat8E5M2FNUZ() const
Definition: Types.cpp:39
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:931
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:913
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26