MLIR  19.0.0git
TypeConverter.cpp
Go to the documentation of this file.
1 //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "MemRefDescriptor.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Threading.h"
16 #include <memory>
17 #include <mutex>
18 #include <optional>
19 
20 using namespace mlir;
21 
23  {
24  // Most of the time, the entry already exists in the map.
25  std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
26  std::defer_lock);
27  if (getContext().isMultithreadingEnabled())
28  lock.lock();
29  auto recursiveStack = conversionCallStack.find(llvm::get_threadid());
30  if (recursiveStack != conversionCallStack.end())
31  return *recursiveStack->second;
32  }
33 
34  // First time this thread gets here, we have to get an exclusive access to
35  // inset in the map
36  std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
37  auto recursiveStackInserted = conversionCallStack.insert(std::make_pair(
38  llvm::get_threadid(), std::make_unique<SmallVector<Type>>()));
39  return *recursiveStackInserted.first->second;
40 }
41 
42 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
44  const DataLayoutAnalysis *analysis)
45  : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
46 
47 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
50  const DataLayoutAnalysis *analysis)
51  : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
52  dataLayoutAnalysis(analysis) {
53  assert(llvmDialect && "LLVM IR dialect is not registered");
54 
55  // Register conversions for the builtin types.
56  addConversion([&](ComplexType type) { return convertComplexType(type); });
57  addConversion([&](FloatType type) { return convertFloatType(type); });
58  addConversion([&](FunctionType type) { return convertFunctionType(type); });
59  addConversion([&](IndexType type) { return convertIndexType(type); });
60  addConversion([&](IntegerType type) { return convertIntegerType(type); });
61  addConversion([&](MemRefType type) { return convertMemRefType(type); });
63  [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
64  addConversion([&](VectorType type) -> std::optional<Type> {
65  FailureOr<Type> llvmType = convertVectorType(type);
66  if (failed(llvmType))
67  return std::nullopt;
68  return llvmType;
69  });
70 
71  // LLVM-compatible types are legal, so add a pass-through conversion. Do this
72  // before the conversions below since conversions are attempted in reverse
73  // order and those should take priority.
74  addConversion([](Type type) {
75  return LLVM::isCompatibleType(type) ? std::optional<Type>(type)
76  : std::nullopt;
77  });
78 
80  -> std::optional<LogicalResult> {
81  // Fastpath for types that won't be converted by this callback anyway.
82  if (LLVM::isCompatibleType(type)) {
83  results.push_back(type);
84  return success();
85  }
86 
87  if (type.isIdentified()) {
88  auto convertedType = LLVM::LLVMStructType::getIdentified(
89  type.getContext(), ("_Converted." + type.getName()).str());
90 
92  if (llvm::count(recursiveStack, type)) {
93  results.push_back(convertedType);
94  return success();
95  }
96  recursiveStack.push_back(type);
97  auto popConversionCallStack = llvm::make_scope_exit(
98  [&recursiveStack]() { recursiveStack.pop_back(); });
99 
100  SmallVector<Type> convertedElemTypes;
101  convertedElemTypes.reserve(type.getBody().size());
102  if (failed(convertTypes(type.getBody(), convertedElemTypes)))
103  return std::nullopt;
104 
105  // If the converted type has not been initialized yet, just set its body
106  // to be the converted arguments and return.
107  if (!convertedType.isInitialized()) {
108  if (failed(
109  convertedType.setBody(convertedElemTypes, type.isPacked()))) {
110  return failure();
111  }
112  results.push_back(convertedType);
113  return success();
114  }
115 
116  // If it has been initialized, has the same body and packed bit, just use
117  // it. This ensures that recursive structs keep being recursive rather
118  // than including a non-updated name.
119  if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) &&
120  convertedType.isPacked() == type.isPacked()) {
121  results.push_back(convertedType);
122  return success();
123  }
124 
125  return failure();
126  }
127 
128  SmallVector<Type> convertedSubtypes;
129  convertedSubtypes.reserve(type.getBody().size());
130  if (failed(convertTypes(type.getBody(), convertedSubtypes)))
131  return std::nullopt;
132 
133  results.push_back(LLVM::LLVMStructType::getLiteral(
134  type.getContext(), convertedSubtypes, type.isPacked()));
135  return success();
136  });
137  addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> {
138  if (auto element = convertType(type.getElementType()))
139  return LLVM::LLVMArrayType::get(element, type.getNumElements());
140  return std::nullopt;
141  });
142  addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
143  Type convertedResType = convertType(type.getReturnType());
144  if (!convertedResType)
145  return std::nullopt;
146 
147  SmallVector<Type> convertedArgTypes;
148  convertedArgTypes.reserve(type.getNumParams());
149  if (failed(convertTypes(type.getParams(), convertedArgTypes)))
150  return std::nullopt;
151 
152  return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
153  type.isVarArg());
154  });
155 
156  // Materialization for memrefs creates descriptor structs from individual
157  // values constituting them, when descriptors are used, i.e. more than one
158  // value represents a memref.
160  [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
161  Location loc) -> std::optional<Value> {
162  if (inputs.size() == 1)
163  return std::nullopt;
164  return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
165  inputs);
166  });
167  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
168  ValueRange inputs,
169  Location loc) -> std::optional<Value> {
170  // TODO: bare ptr conversion could be handled here but we would need a way
171  // to distinguish between FuncOp and other regions.
172  if (inputs.size() == 1)
173  return std::nullopt;
174  return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
175  });
176  // Add generic source and target materializations to handle cases where
177  // non-LLVM types persist after an LLVM conversion.
178  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
179  ValueRange inputs,
180  Location loc) -> std::optional<Value> {
181  if (inputs.size() != 1)
182  return std::nullopt;
183 
184  return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
185  .getResult(0);
186  });
187  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
188  ValueRange inputs,
189  Location loc) -> std::optional<Value> {
190  if (inputs.size() != 1)
191  return std::nullopt;
192 
193  return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
194  .getResult(0);
195  });
196 
197  // Integer memory spaces map to themselves.
199  [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
200 }
201 
202 /// Returns the MLIR context.
204  return *getDialect()->getContext();
205 }
206 
209 }
210 
211 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const {
212  return options.dataLayout.getPointerSizeInBits(addressSpace);
213 }
214 
215 Type LLVMTypeConverter::convertIndexType(IndexType type) const {
216  return getIndexType();
217 }
218 
219 Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
220  return IntegerType::get(&getContext(), type.getWidth());
221 }
222 
223 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
224  if (type.isFloat8E5M2() || type.isFloat8E4M3FN() || type.isFloat8E5M2FNUZ() ||
225  type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ())
226  return IntegerType::get(&getContext(), type.getWidth());
227  return type;
228 }
229 
230 // Convert a `ComplexType` to an LLVM type. The result is a complex number
231 // struct with entries for the
232 // 1. real part and for the
233 // 2. imaginary part.
234 Type LLVMTypeConverter::convertComplexType(ComplexType type) const {
235  auto elementType = convertType(type.getElementType());
237  {elementType, elementType});
238 }
239 
240 // Except for signatures, MLIR function types are converted into LLVM
241 // pointer-to-function types.
242 Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
243  return LLVM::LLVMPointerType::get(type.getContext());
244 }
245 
246 // Function types are converted to LLVM Function types by recursively converting
247 // argument and result types. If MLIR Function has zero results, the LLVM
248 // Function has one VoidType result. If MLIR Function has more than one result,
249 // they are into an LLVM StructType in their order of appearance.
251  FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
253  // Select the argument converter depending on the calling convention.
254  useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
255  auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
257  // Convert argument types one by one and check for errors.
258  for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
259  SmallVector<Type, 8> converted;
260  if (failed(funcArgConverter(*this, type, converted)))
261  return {};
262  result.addInputs(idx, converted);
263  }
264 
265  // If function does not return anything, create the void result type,
266  // if it returns on element, convert it, otherwise pack the result types into
267  // a struct.
268  Type resultType =
269  funcTy.getNumResults() == 0
271  : packFunctionResults(funcTy.getResults(), useBarePtrCallConv);
272  if (!resultType)
273  return {};
274  return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
275  isVariadic);
276 }
277 
278 /// Converts the function type to a C-compatible format, in particular using
279 /// pointers to memref descriptors for arguments.
280 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
282  SmallVector<Type, 4> inputs;
283 
284  Type resultType = type.getNumResults() == 0
286  : packFunctionResults(type.getResults());
287  if (!resultType)
288  return {};
289 
290  auto ptrType = LLVM::LLVMPointerType::get(type.getContext());
291  auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
292  if (structType) {
293  // Struct types cannot be safely returned via C interface. Make this a
294  // pointer argument, instead.
295  inputs.push_back(ptrType);
296  resultType = LLVM::LLVMVoidType::get(&getContext());
297  }
298 
299  for (Type t : type.getInputs()) {
300  auto converted = convertType(t);
301  if (!converted || !LLVM::isCompatibleType(converted))
302  return {};
303  if (isa<MemRefType, UnrankedMemRefType>(t))
304  converted = ptrType;
305  inputs.push_back(converted);
306  }
307 
308  return {LLVM::LLVMFunctionType::get(resultType, inputs), structType};
309 }
310 
311 /// Convert a memref type into a list of LLVM IR types that will form the
312 /// memref descriptor. The result contains the following types:
313 /// 1. The pointer to the allocated data buffer, followed by
314 /// 2. The pointer to the aligned data buffer, followed by
315 /// 3. A lowered `index`-type integer containing the distance between the
316 /// beginning of the buffer and the first element to be accessed through the
317 /// view, followed by
318 /// 4. An array containing as many `index`-type integers as the rank of the
319 /// MemRef: the array represents the size, in number of elements, of the memref
320 /// along the given dimension. For constant MemRef dimensions, the
321 /// corresponding size entry is a constant whose runtime value must match the
322 /// static value, followed by
323 /// 5. A second array containing as many `index`-type integers as the rank of
324 /// the MemRef: the second array represents the "stride" (in tensor abstraction
325 /// sense), i.e. the number of consecutive elements of the underlying buffer.
326 /// TODO: add assertions for the static cases.
327 ///
328 /// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
329 /// are expanded into individual index-type elements.
330 ///
331 /// template <typename Elem, typename Index, size_t Rank>
332 /// struct {
333 /// Elem *allocatedPtr;
334 /// Elem *alignedPtr;
335 /// Index offset;
336 /// Index sizes[Rank]; // omitted when rank == 0
337 /// Index strides[Rank]; // omitted when rank == 0
338 /// };
340 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
341  bool unpackAggregates) const {
342  if (!isStrided(type)) {
343  emitError(
344  UnknownLoc::get(type.getContext()),
345  "conversion to strided form failed either due to non-strided layout "
346  "maps (which should have been normalized away) or other reasons");
347  return {};
348  }
349 
350  Type elementType = convertType(type.getElementType());
351  if (!elementType)
352  return {};
353 
354  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
355  if (failed(addressSpace)) {
356  emitError(UnknownLoc::get(type.getContext()),
357  "conversion of memref memory space ")
358  << type.getMemorySpace()
359  << " to integer address space "
360  "failed. Consider adding memory space conversions.";
361  return {};
362  }
363  auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
364 
365  auto indexTy = getIndexType();
366 
367  SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
368  auto rank = type.getRank();
369  if (rank == 0)
370  return results;
371 
372  if (unpackAggregates)
373  results.insert(results.end(), 2 * rank, indexTy);
374  else
375  results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
376  return results;
377 }
378 
379 unsigned
381  const DataLayout &layout) const {
382  // Compute the descriptor size given that of its components indicated above.
383  unsigned space = *getMemRefAddressSpace(type);
384  return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
385  (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
386 }
387 
388 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
389 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
390 Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
391  // When converting a MemRefType to a struct with descriptor fields, do not
392  // unpack the `sizes` and `strides` arrays.
393  SmallVector<Type, 5> types =
394  getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
395  if (types.empty())
396  return {};
398 }
399 
400 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
401 /// that will form the unranked memref descriptor. In particular, the fields
402 /// for an unranked memref descriptor are:
403 /// 1. index-typed rank, the dynamic rank of this MemRef
404 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
405 /// stack allocated (alloca) copy of a MemRef descriptor that got casted to
406 /// be unranked.
408 LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const {
410 }
411 
413  UnrankedMemRefType type, const DataLayout &layout) const {
414  // Compute the descriptor size given that of its components indicated above.
415  unsigned space = *getMemRefAddressSpace(type);
416  return layout.getTypeSize(getIndexType()) +
418 }
419 
420 Type LLVMTypeConverter::convertUnrankedMemRefType(
421  UnrankedMemRefType type) const {
422  if (!convertType(type.getElementType()))
423  return {};
425  getUnrankedMemRefDescriptorFields());
426 }
427 
430  if (!type.getMemorySpace()) // Default memory space -> 0.
431  return 0;
432  std::optional<Attribute> converted =
433  convertTypeAttribute(type, type.getMemorySpace());
434  if (!converted)
435  return failure();
436  if (!(*converted)) // Conversion to default is 0.
437  return 0;
438  if (auto explicitSpace = llvm::dyn_cast_if_present<IntegerAttr>(*converted))
439  return explicitSpace.getInt();
440  return failure();
441 }
442 
443 // Check if a memref type can be converted to a bare pointer.
445  if (isa<UnrankedMemRefType>(type))
446  // Unranked memref is not supported in the bare pointer calling convention.
447  return false;
448 
449  // Check that the memref has static shape, strides and offset. Otherwise, it
450  // cannot be lowered to a bare pointer.
451  auto memrefTy = cast<MemRefType>(type);
452  if (!memrefTy.hasStaticShape())
453  return false;
454 
455  int64_t offset = 0;
456  SmallVector<int64_t, 4> strides;
457  if (failed(getStridesAndOffset(memrefTy, strides, offset)))
458  return false;
459 
460  for (int64_t stride : strides)
461  if (ShapedType::isDynamic(stride))
462  return false;
463 
464  return !ShapedType::isDynamic(offset);
465 }
466 
467 /// Convert a memref type to a bare pointer to the memref element type.
468 Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
469  if (!canConvertToBarePtr(type))
470  return {};
471  Type elementType = convertType(type.getElementType());
472  if (!elementType)
473  return {};
474  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
475  if (failed(addressSpace))
476  return {};
477  return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
478 }
479 
480 /// Convert an n-D vector type to an LLVM vector type:
481 /// * 0-D `vector<T>` are converted to vector<1xT>
482 /// * 1-D `vector<axT>` remains as is while,
483 /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
484 /// `!llvm.array<ax...array<jxvector<kxT>>>`.
485 /// Returns failure for n-D scalable vector types as LLVM does not support
486 /// arrays of scalable vectors.
487 FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
488  auto elementType = convertType(type.getElementType());
489  if (!elementType)
490  return {};
491  if (type.getShape().empty())
492  return VectorType::get({1}, elementType);
493  Type vectorType = VectorType::get(type.getShape().back(), elementType,
494  type.getScalableDims().back());
495  assert(LLVM::isCompatibleVectorType(vectorType) &&
496  "expected vector type compatible with the LLVM dialect");
497  // Only the trailing dimension can be scalable.
498  if (llvm::is_contained(type.getScalableDims().drop_back(), true))
499  return failure();
500  auto shape = type.getShape();
501  for (int i = shape.size() - 2; i >= 0; --i)
502  vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
503  return vectorType;
504 }
505 
506 /// Convert a type in the context of the default or bare pointer calling
507 /// convention. Calling convention sensitive types, such as MemRefType and
508 /// UnrankedMemRefType, are converted following the specific rules for the
509 /// calling convention. Calling convention independent types are converted
510 /// following the default LLVM type conversions.
512  Type type, bool useBarePtrCallConv) const {
513  if (useBarePtrCallConv)
514  if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
515  return convertMemRefToBarePtr(memrefTy);
516 
517  return convertType(type);
518 }
519 
520 /// Promote the bare pointers in 'values' that resulted from memrefs to
521 /// descriptors. 'stdTypes' holds they types of 'values' before the conversion
522 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
524  ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
525  SmallVectorImpl<Value> &values) const {
526  assert(stdTypes.size() == values.size() &&
527  "The number of types and values doesn't match");
528  for (unsigned i = 0, end = values.size(); i < end; ++i)
529  if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
530  values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
531  memrefTy, values[i]);
532 }
533 
534 /// Convert a non-empty list of types of values produced by an operation into an
535 /// LLVM-compatible type. In particular, if more than one value is
536 /// produced, create a literal structure with elements that correspond to each
537 /// of the types converted with `convertType`.
539  assert(!types.empty() && "expected non-empty list of type");
540  if (types.size() == 1)
541  return convertType(types[0]);
542 
543  SmallVector<Type> resultTypes;
544  resultTypes.reserve(types.size());
545  for (Type type : types) {
546  Type converted = convertType(type);
547  if (!converted || !LLVM::isCompatibleType(converted))
548  return {};
549  resultTypes.push_back(converted);
550  }
551 
552  return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
553 }
554 
555 /// Convert a non-empty list of types to be returned from a function into an
556 /// LLVM-compatible type. In particular, if more than one value is returned,
557 /// create an LLVM dialect structure type with elements that correspond to each
558 /// of the types converted with `convertCallingConventionType`.
560  bool useBarePtrCallConv) const {
561  assert(!types.empty() && "expected non-empty list of type");
562 
563  useBarePtrCallConv |= options.useBarePtrCallConv;
564  if (types.size() == 1)
565  return convertCallingConventionType(types.front(), useBarePtrCallConv);
566 
567  SmallVector<Type> resultTypes;
568  resultTypes.reserve(types.size());
569  for (auto t : types) {
570  auto converted = convertCallingConventionType(t, useBarePtrCallConv);
571  if (!converted || !LLVM::isCompatibleType(converted))
572  return {};
573  resultTypes.push_back(converted);
574  }
575 
576  return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
577 }
578 
580  OpBuilder &builder) const {
581  // Alloca with proper alignment. We do not expect optimizations of this
582  // alloca op and so we omit allocating at the entry block.
583  auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
584  Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
585  builder.getIndexAttr(1));
586  Value allocated =
587  builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one);
588  // Store into the alloca'ed descriptor.
589  builder.create<LLVM::StoreOp>(loc, operand, allocated);
590  return allocated;
591 }
592 
595  ValueRange operands, OpBuilder &builder,
596  bool useBarePtrCallConv) const {
597  SmallVector<Value, 4> promotedOperands;
598  promotedOperands.reserve(operands.size());
599  useBarePtrCallConv |= options.useBarePtrCallConv;
600  for (auto it : llvm::zip(opOperands, operands)) {
601  auto operand = std::get<0>(it);
602  auto llvmOperand = std::get<1>(it);
603 
604  if (useBarePtrCallConv) {
605  // For the bare-ptr calling convention, we only have to extract the
606  // aligned pointer of a memref.
607  if (dyn_cast<MemRefType>(operand.getType())) {
608  MemRefDescriptor desc(llvmOperand);
609  llvmOperand = desc.alignedPtr(builder, loc);
610  } else if (isa<UnrankedMemRefType>(operand.getType())) {
611  llvm_unreachable("Unranked memrefs are not supported");
612  }
613  } else {
614  if (isa<UnrankedMemRefType>(operand.getType())) {
615  UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
616  promotedOperands);
617  continue;
618  }
619  if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
620  MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
621  promotedOperands);
622  continue;
623  }
624  }
625 
626  promotedOperands.push_back(llvmOperand);
627  }
628  return promotedOperands;
629 }
630 
631 /// Callback to convert function argument types. It converts a MemRef function
632 /// argument to a list of non-aggregate types containing descriptor
633 /// information, and an UnrankedmemRef function argument to a list containing
634 /// the rank and a pointer to a descriptor struct.
637  SmallVectorImpl<Type> &result) {
638  if (auto memref = dyn_cast<MemRefType>(type)) {
639  // In signatures, Memref descriptors are expanded into lists of
640  // non-aggregate values.
641  auto converted =
642  converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
643  if (converted.empty())
644  return failure();
645  result.append(converted.begin(), converted.end());
646  return success();
647  }
648  if (isa<UnrankedMemRefType>(type)) {
649  auto converted = converter.getUnrankedMemRefDescriptorFields();
650  if (converted.empty())
651  return failure();
652  result.append(converted.begin(), converted.end());
653  return success();
654  }
655  auto converted = converter.convertType(type);
656  if (!converted)
657  return failure();
658  result.push_back(converted);
659  return success();
660 }
661 
662 /// Callback to convert function argument types. It converts MemRef function
663 /// arguments to bare pointers to the MemRef element type.
666  SmallVectorImpl<Type> &result) {
667  auto llvmTy = converter.convertCallingConventionType(
668  type, /*useBarePointerCallConv=*/true);
669  if (!llvmTy)
670  return failure();
671 
672  result.push_back(llvmTy);
673  return success();
674 }
static llvm::ManagedStatic< PassManagerOptions > options
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:138
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
IntegerType getI64Type()
Definition: Builders.cpp:85
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
unsigned getWidth()
Return the bitwidth of this float type.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
LLVM::LLVMDialect * llvmDialect
Pointer to the LLVM dialect.
llvm::sys::SmartRWMutex< true > callStackMutex
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Definition: TypeConverter.h:92
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, Location loc, ArrayRef< Type > stdTypes, SmallVectorImpl< Value > &values) const
Promote the bare pointers in 'values' that resulted from memrefs to descriptors.
DenseMap< uint64_t, std::unique_ptr< SmallVector< Type > > > conversionCallStack
SmallVector< Type > & getCurrentThreadRecursiveStack()
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
Type convertCallingConventionType(Type type, bool useBarePointerCallConv=false) const
Convert a type in the context of the default or bare pointer calling convention.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
SmallVector< Value, 4 > promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder, bool useBarePtrCallConv=false) const
Promote the LLVM representation of all operands including promoting MemRef descriptors to stack and u...
LLVMTypeConverter(MLIRContext *ctx, const DataLayoutAnalysis *analysis=nullptr)
Create an LLVMTypeConverter using the default LowerToLLVMOptions.
unsigned getPointerBitwidth(unsigned addressSpace=0) const
Gets the pointer bitwidth.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
Type getIndexType() const
Gets the LLVM representation of the index type.
friend LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Give structFuncArgTypeConverter access to memref-specific functions.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:109
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:453
static LLVMStructType getIdentified(MLIRContext *context, StringRef name)
Gets or creates an identified struct with the given name in the provided context.
Definition: LLVMTypes.cpp:425
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
llvm::DataLayout dataLayout
The data layout of the module to produce.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
void addArgumentMaterialization(FnT &&callback)
Register a materialization function, which must be convertible to the following form: std::optional<V...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isFloat8E4M3FN() const
Definition: Types.cpp:38
bool isFloat8E4M3FNUZ() const
Definition: Types.cpp:42
bool isFloat8E4M3B11FNUZ() const
Definition: Types.cpp:45
bool isFloat8E5M2() const
Definition: Types.cpp:37
bool isFloat8E5M2FNUZ() const
Definition: Types.cpp:39
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:876
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:858
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26