MLIR  20.0.0git
TypeConverter.cpp
Go to the documentation of this file.
1 //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "MemRefDescriptor.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Threading.h"
16 #include <memory>
17 #include <mutex>
18 #include <optional>
19 
20 using namespace mlir;
21 
23  {
24  // Most of the time, the entry already exists in the map.
25  std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
26  std::defer_lock);
27  if (getContext().isMultithreadingEnabled())
28  lock.lock();
29  auto recursiveStack = conversionCallStack.find(llvm::get_threadid());
30  if (recursiveStack != conversionCallStack.end())
31  return *recursiveStack->second;
32  }
33 
34  // First time this thread gets here, we have to get an exclusive access to
35  // inset in the map
36  std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
37  auto recursiveStackInserted = conversionCallStack.insert(std::make_pair(
38  llvm::get_threadid(), std::make_unique<SmallVector<Type>>()));
39  return *recursiveStackInserted.first->second;
40 }
41 
42 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
44  const DataLayoutAnalysis *analysis)
45  : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
46 
47 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
50  const DataLayoutAnalysis *analysis)
51  : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
52  dataLayoutAnalysis(analysis) {
53  assert(llvmDialect && "LLVM IR dialect is not registered");
54 
55  // Register conversions for the builtin types.
56  addConversion([&](ComplexType type) { return convertComplexType(type); });
57  addConversion([&](FloatType type) { return convertFloatType(type); });
58  addConversion([&](FunctionType type) { return convertFunctionType(type); });
59  addConversion([&](IndexType type) { return convertIndexType(type); });
60  addConversion([&](IntegerType type) { return convertIntegerType(type); });
61  addConversion([&](MemRefType type) { return convertMemRefType(type); });
63  [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
64  addConversion([&](VectorType type) -> std::optional<Type> {
65  FailureOr<Type> llvmType = convertVectorType(type);
66  if (failed(llvmType))
67  return std::nullopt;
68  return llvmType;
69  });
70 
71  // LLVM-compatible types are legal, so add a pass-through conversion. Do this
72  // before the conversions below since conversions are attempted in reverse
73  // order and those should take priority.
74  addConversion([](Type type) {
75  return LLVM::isCompatibleType(type) ? std::optional<Type>(type)
76  : std::nullopt;
77  });
78 
80  -> std::optional<LogicalResult> {
81  // Fastpath for types that won't be converted by this callback anyway.
82  if (LLVM::isCompatibleType(type)) {
83  results.push_back(type);
84  return success();
85  }
86 
87  if (type.isIdentified()) {
88  auto convertedType = LLVM::LLVMStructType::getIdentified(
89  type.getContext(), ("_Converted." + type.getName()).str());
90 
92  if (llvm::count(recursiveStack, type)) {
93  results.push_back(convertedType);
94  return success();
95  }
96  recursiveStack.push_back(type);
97  auto popConversionCallStack = llvm::make_scope_exit(
98  [&recursiveStack]() { recursiveStack.pop_back(); });
99 
100  SmallVector<Type> convertedElemTypes;
101  convertedElemTypes.reserve(type.getBody().size());
102  if (failed(convertTypes(type.getBody(), convertedElemTypes)))
103  return std::nullopt;
104 
105  // If the converted type has not been initialized yet, just set its body
106  // to be the converted arguments and return.
107  if (!convertedType.isInitialized()) {
108  if (failed(
109  convertedType.setBody(convertedElemTypes, type.isPacked()))) {
110  return failure();
111  }
112  results.push_back(convertedType);
113  return success();
114  }
115 
116  // If it has been initialized, has the same body and packed bit, just use
117  // it. This ensures that recursive structs keep being recursive rather
118  // than including a non-updated name.
119  if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) &&
120  convertedType.isPacked() == type.isPacked()) {
121  results.push_back(convertedType);
122  return success();
123  }
124 
125  return failure();
126  }
127 
128  SmallVector<Type> convertedSubtypes;
129  convertedSubtypes.reserve(type.getBody().size());
130  if (failed(convertTypes(type.getBody(), convertedSubtypes)))
131  return std::nullopt;
132 
133  results.push_back(LLVM::LLVMStructType::getLiteral(
134  type.getContext(), convertedSubtypes, type.isPacked()));
135  return success();
136  });
137  addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> {
138  if (auto element = convertType(type.getElementType()))
139  return LLVM::LLVMArrayType::get(element, type.getNumElements());
140  return std::nullopt;
141  });
142  addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
143  Type convertedResType = convertType(type.getReturnType());
144  if (!convertedResType)
145  return std::nullopt;
146 
147  SmallVector<Type> convertedArgTypes;
148  convertedArgTypes.reserve(type.getNumParams());
149  if (failed(convertTypes(type.getParams(), convertedArgTypes)))
150  return std::nullopt;
151 
152  return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
153  type.isVarArg());
154  });
155 
156  // Argument materializations convert from the new block argument types
157  // (multiple SSA values that make up a memref descriptor) back to the
158  // original block argument type. The dialect conversion framework will then
159  // insert a target materialization from the original block argument type to
160  // a legal type.
162  [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
163  Location loc) -> std::optional<Value> {
164  if (inputs.size() == 1) {
165  // Bare pointers are not supported for unranked memrefs because a
166  // memref descriptor cannot be built just from a bare pointer.
167  return std::nullopt;
168  }
169  Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this,
170  resultType, inputs);
171  // An argument materialization must return a value of type
172  // `resultType`, so insert a cast from the memref descriptor type
173  // (!llvm.struct) to the original memref type.
174  return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
175  .getResult(0);
176  });
177  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
178  ValueRange inputs,
179  Location loc) -> std::optional<Value> {
180  Value desc;
181  if (inputs.size() == 1) {
182  // This is a bare pointer. We allow bare pointers only for function entry
183  // blocks.
184  BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
185  if (!barePtr)
186  return std::nullopt;
187  Block *block = barePtr.getOwner();
188  if (!block->isEntryBlock() ||
189  !isa<FunctionOpInterface>(block->getParentOp()))
190  return std::nullopt;
191  desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
192  inputs[0]);
193  } else {
194  desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
195  }
196  // An argument materialization must return a value of type `resultType`,
197  // so insert a cast from the memref descriptor type (!llvm.struct) to the
198  // original memref type.
199  return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
200  .getResult(0);
201  });
202  // Add generic source and target materializations to handle cases where
203  // non-LLVM types persist after an LLVM conversion.
204  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
205  ValueRange inputs,
206  Location loc) -> std::optional<Value> {
207  if (inputs.size() != 1)
208  return std::nullopt;
209 
210  return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
211  .getResult(0);
212  });
213  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
214  ValueRange inputs,
215  Location loc) -> std::optional<Value> {
216  if (inputs.size() != 1)
217  return std::nullopt;
218 
219  return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
220  .getResult(0);
221  });
222 
223  // Integer memory spaces map to themselves.
225  [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
226 }
227 
228 /// Returns the MLIR context.
230  return *getDialect()->getContext();
231 }
232 
235 }
236 
237 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const {
238  return options.dataLayout.getPointerSizeInBits(addressSpace);
239 }
240 
241 Type LLVMTypeConverter::convertIndexType(IndexType type) const {
242  return getIndexType();
243 }
244 
245 Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
246  return IntegerType::get(&getContext(), type.getWidth());
247 }
248 
249 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
250  if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
251  type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
252  type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
253  type.isFloat6E3M2FN())
254  return IntegerType::get(&getContext(), type.getWidth());
255  return type;
256 }
257 
258 // Convert a `ComplexType` to an LLVM type. The result is a complex number
259 // struct with entries for the
260 // 1. real part and for the
261 // 2. imaginary part.
262 Type LLVMTypeConverter::convertComplexType(ComplexType type) const {
263  auto elementType = convertType(type.getElementType());
265  {elementType, elementType});
266 }
267 
268 // Except for signatures, MLIR function types are converted into LLVM
269 // pointer-to-function types.
270 Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
271  return LLVM::LLVMPointerType::get(type.getContext());
272 }
273 
274 /// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
275 /// function arguments. Returns an empty container if none of these attributes
276 /// are found in any of the arguments.
277 static void
278 filterByValRefArgAttrs(FunctionOpInterface funcOp,
279  SmallVectorImpl<std::optional<NamedAttribute>> &result) {
280  assert(result.empty() && "Unexpected non-empty output");
281  result.resize(funcOp.getNumArguments(), std::nullopt);
282  bool foundByValByRefAttrs = false;
283  for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
284  for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
285  if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
286  namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
287  foundByValByRefAttrs = true;
288  result[argIdx] = namedAttr;
289  break;
290  }
291  }
292  }
293 
294  if (!foundByValByRefAttrs)
295  result.clear();
296 }
297 
298 // Function types are converted to LLVM Function types by recursively converting
299 // argument and result types. If MLIR Function has zero results, the LLVM
300 // Function has one VoidType result. If MLIR Function has more than one result,
301 // they are into an LLVM StructType in their order of appearance.
302 // If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and
303 // `llvm.byref` function arguments which are not LLVM pointers are overridden
304 // with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already
305 // converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`.
306 Type LLVMTypeConverter::convertFunctionSignatureImpl(
307  FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
308  LLVMTypeConverter::SignatureConversion &result,
309  SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) const {
310  // Select the argument converter depending on the calling convention.
311  useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
312  auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
314  // Convert argument types one by one and check for errors.
315  for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
316  SmallVector<Type, 8> converted;
317  if (failed(funcArgConverter(*this, type, converted)))
318  return {};
319 
320  // Rewrite converted type of `llvm.byval` or `llvm.byref` function
321  // argument that was not converted to an LLVM pointer types.
322  if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() &&
323  converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
324  // If the argument was already converted to an LLVM pointer type, we stop
325  // tracking it as it doesn't need more processing.
326  if (isa<LLVM::LLVMPointerType>(converted[0]))
327  (*byValRefNonPtrAttrs)[idx] = std::nullopt;
328  else
329  converted[0] = LLVM::LLVMPointerType::get(&getContext());
330  }
331 
332  result.addInputs(idx, converted);
333  }
334 
335  // If function does not return anything, create the void result type,
336  // if it returns on element, convert it, otherwise pack the result types into
337  // a struct.
338  Type resultType =
339  funcTy.getNumResults() == 0
341  : packFunctionResults(funcTy.getResults(), useBarePtrCallConv);
342  if (!resultType)
343  return {};
344  return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
345  isVariadic);
346 }
347 
349  FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
351  return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
352  result,
353  /*byValRefNonPtrAttrs=*/nullptr);
354 }
355 
357  FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv,
359  SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs) const {
360  // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those
361  // that were not converted to LLVM pointer types will be returned for further
362  // processing.
363  filterByValRefArgAttrs(funcOp, byValRefNonPtrAttrs);
364  auto funcTy = cast<FunctionType>(funcOp.getFunctionType());
365  return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
366  result, &byValRefNonPtrAttrs);
367 }
368 
369 /// Converts the function type to a C-compatible format, in particular using
370 /// pointers to memref descriptors for arguments.
371 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
373  SmallVector<Type, 4> inputs;
374 
375  Type resultType = type.getNumResults() == 0
377  : packFunctionResults(type.getResults());
378  if (!resultType)
379  return {};
380 
381  auto ptrType = LLVM::LLVMPointerType::get(type.getContext());
382  auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
383  if (structType) {
384  // Struct types cannot be safely returned via C interface. Make this a
385  // pointer argument, instead.
386  inputs.push_back(ptrType);
387  resultType = LLVM::LLVMVoidType::get(&getContext());
388  }
389 
390  for (Type t : type.getInputs()) {
391  auto converted = convertType(t);
392  if (!converted || !LLVM::isCompatibleType(converted))
393  return {};
394  if (isa<MemRefType, UnrankedMemRefType>(t))
395  converted = ptrType;
396  inputs.push_back(converted);
397  }
398 
399  return {LLVM::LLVMFunctionType::get(resultType, inputs), structType};
400 }
401 
402 /// Convert a memref type into a list of LLVM IR types that will form the
403 /// memref descriptor. The result contains the following types:
404 /// 1. The pointer to the allocated data buffer, followed by
405 /// 2. The pointer to the aligned data buffer, followed by
406 /// 3. A lowered `index`-type integer containing the distance between the
407 /// beginning of the buffer and the first element to be accessed through the
408 /// view, followed by
409 /// 4. An array containing as many `index`-type integers as the rank of the
410 /// MemRef: the array represents the size, in number of elements, of the memref
411 /// along the given dimension. For constant MemRef dimensions, the
412 /// corresponding size entry is a constant whose runtime value must match the
413 /// static value, followed by
414 /// 5. A second array containing as many `index`-type integers as the rank of
415 /// the MemRef: the second array represents the "stride" (in tensor abstraction
416 /// sense), i.e. the number of consecutive elements of the underlying buffer.
417 /// TODO: add assertions for the static cases.
418 ///
419 /// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
420 /// are expanded into individual index-type elements.
421 ///
422 /// template <typename Elem, typename Index, size_t Rank>
423 /// struct {
424 /// Elem *allocatedPtr;
425 /// Elem *alignedPtr;
426 /// Index offset;
427 /// Index sizes[Rank]; // omitted when rank == 0
428 /// Index strides[Rank]; // omitted when rank == 0
429 /// };
431 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
432  bool unpackAggregates) const {
433  if (!isStrided(type)) {
434  emitError(
435  UnknownLoc::get(type.getContext()),
436  "conversion to strided form failed either due to non-strided layout "
437  "maps (which should have been normalized away) or other reasons");
438  return {};
439  }
440 
441  Type elementType = convertType(type.getElementType());
442  if (!elementType)
443  return {};
444 
445  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
446  if (failed(addressSpace)) {
447  emitError(UnknownLoc::get(type.getContext()),
448  "conversion of memref memory space ")
449  << type.getMemorySpace()
450  << " to integer address space "
451  "failed. Consider adding memory space conversions.";
452  return {};
453  }
454  auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
455 
456  auto indexTy = getIndexType();
457 
458  SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
459  auto rank = type.getRank();
460  if (rank == 0)
461  return results;
462 
463  if (unpackAggregates)
464  results.insert(results.end(), 2 * rank, indexTy);
465  else
466  results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
467  return results;
468 }
469 
470 unsigned
472  const DataLayout &layout) const {
473  // Compute the descriptor size given that of its components indicated above.
474  unsigned space = *getMemRefAddressSpace(type);
475  return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
476  (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
477 }
478 
479 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
480 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
481 Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
482  // When converting a MemRefType to a struct with descriptor fields, do not
483  // unpack the `sizes` and `strides` arrays.
484  SmallVector<Type, 5> types =
485  getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
486  if (types.empty())
487  return {};
489 }
490 
491 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
492 /// that will form the unranked memref descriptor. In particular, the fields
493 /// for an unranked memref descriptor are:
494 /// 1. index-typed rank, the dynamic rank of this MemRef
495 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
496 /// stack allocated (alloca) copy of a MemRef descriptor that got casted to
497 /// be unranked.
499 LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const {
501 }
502 
504  UnrankedMemRefType type, const DataLayout &layout) const {
505  // Compute the descriptor size given that of its components indicated above.
506  unsigned space = *getMemRefAddressSpace(type);
507  return layout.getTypeSize(getIndexType()) +
509 }
510 
511 Type LLVMTypeConverter::convertUnrankedMemRefType(
512  UnrankedMemRefType type) const {
513  if (!convertType(type.getElementType()))
514  return {};
516  getUnrankedMemRefDescriptorFields());
517 }
518 
519 FailureOr<unsigned>
521  if (!type.getMemorySpace()) // Default memory space -> 0.
522  return 0;
523  std::optional<Attribute> converted =
524  convertTypeAttribute(type, type.getMemorySpace());
525  if (!converted)
526  return failure();
527  if (!(*converted)) // Conversion to default is 0.
528  return 0;
529  if (auto explicitSpace = llvm::dyn_cast_if_present<IntegerAttr>(*converted))
530  return explicitSpace.getInt();
531  return failure();
532 }
533 
534 // Check if a memref type can be converted to a bare pointer.
536  if (isa<UnrankedMemRefType>(type))
537  // Unranked memref is not supported in the bare pointer calling convention.
538  return false;
539 
540  // Check that the memref has static shape, strides and offset. Otherwise, it
541  // cannot be lowered to a bare pointer.
542  auto memrefTy = cast<MemRefType>(type);
543  if (!memrefTy.hasStaticShape())
544  return false;
545 
546  int64_t offset = 0;
547  SmallVector<int64_t, 4> strides;
548  if (failed(getStridesAndOffset(memrefTy, strides, offset)))
549  return false;
550 
551  for (int64_t stride : strides)
552  if (ShapedType::isDynamic(stride))
553  return false;
554 
555  return !ShapedType::isDynamic(offset);
556 }
557 
558 /// Convert a memref type to a bare pointer to the memref element type.
559 Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
560  if (!canConvertToBarePtr(type))
561  return {};
562  Type elementType = convertType(type.getElementType());
563  if (!elementType)
564  return {};
565  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
566  if (failed(addressSpace))
567  return {};
568  return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
569 }
570 
571 /// Convert an n-D vector type to an LLVM vector type:
572 /// * 0-D `vector<T>` are converted to vector<1xT>
573 /// * 1-D `vector<axT>` remains as is while,
574 /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
575 /// `!llvm.array<ax...array<jxvector<kxT>>>`.
576 /// As LLVM supports arrays of scalable vectors, this method will also convert
577 /// n-D scalable vectors provided that only the trailing dim is scalable.
578 FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
579  auto elementType = convertType(type.getElementType());
580  if (!elementType)
581  return {};
582  if (type.getShape().empty())
583  return VectorType::get({1}, elementType);
584  Type vectorType = VectorType::get(type.getShape().back(), elementType,
585  type.getScalableDims().back());
586  assert(LLVM::isCompatibleVectorType(vectorType) &&
587  "expected vector type compatible with the LLVM dialect");
588  // For n-D vector types for which a _non-trailing_ dim is scalable,
589  // return a failure. Supporting such cases would require LLVM
590  // to support something akin "scalable arrays" of vectors.
591  if (llvm::is_contained(type.getScalableDims().drop_back(), true))
592  return failure();
593  auto shape = type.getShape();
594  for (int i = shape.size() - 2; i >= 0; --i)
595  vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
596  return vectorType;
597 }
598 
599 /// Convert a type in the context of the default or bare pointer calling
600 /// convention. Calling convention sensitive types, such as MemRefType and
601 /// UnrankedMemRefType, are converted following the specific rules for the
602 /// calling convention. Calling convention independent types are converted
603 /// following the default LLVM type conversions.
605  Type type, bool useBarePtrCallConv) const {
606  if (useBarePtrCallConv)
607  if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
608  return convertMemRefToBarePtr(memrefTy);
609 
610  return convertType(type);
611 }
612 
613 /// Promote the bare pointers in 'values' that resulted from memrefs to
614 /// descriptors. 'stdTypes' holds they types of 'values' before the conversion
615 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
617  ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
618  SmallVectorImpl<Value> &values) const {
619  assert(stdTypes.size() == values.size() &&
620  "The number of types and values doesn't match");
621  for (unsigned i = 0, end = values.size(); i < end; ++i)
622  if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
623  values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
624  memrefTy, values[i]);
625 }
626 
627 /// Convert a non-empty list of types of values produced by an operation into an
628 /// LLVM-compatible type. In particular, if more than one value is
629 /// produced, create a literal structure with elements that correspond to each
630 /// of the types converted with `convertType`.
632  assert(!types.empty() && "expected non-empty list of type");
633  if (types.size() == 1)
634  return convertType(types[0]);
635 
636  SmallVector<Type> resultTypes;
637  resultTypes.reserve(types.size());
638  for (Type type : types) {
639  Type converted = convertType(type);
640  if (!converted || !LLVM::isCompatibleType(converted))
641  return {};
642  resultTypes.push_back(converted);
643  }
644 
645  return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
646 }
647 
648 /// Convert a non-empty list of types to be returned from a function into an
649 /// LLVM-compatible type. In particular, if more than one value is returned,
650 /// create an LLVM dialect structure type with elements that correspond to each
651 /// of the types converted with `convertCallingConventionType`.
653  bool useBarePtrCallConv) const {
654  assert(!types.empty() && "expected non-empty list of type");
655 
656  useBarePtrCallConv |= options.useBarePtrCallConv;
657  if (types.size() == 1)
658  return convertCallingConventionType(types.front(), useBarePtrCallConv);
659 
660  SmallVector<Type> resultTypes;
661  resultTypes.reserve(types.size());
662  for (auto t : types) {
663  auto converted = convertCallingConventionType(t, useBarePtrCallConv);
664  if (!converted || !LLVM::isCompatibleType(converted))
665  return {};
666  resultTypes.push_back(converted);
667  }
668 
669  return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
670 }
671 
673  OpBuilder &builder) const {
674  // Alloca with proper alignment. We do not expect optimizations of this
675  // alloca op and so we omit allocating at the entry block.
676  auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
677  Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
678  builder.getIndexAttr(1));
679  Value allocated =
680  builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one);
681  // Store into the alloca'ed descriptor.
682  builder.create<LLVM::StoreOp>(loc, operand, allocated);
683  return allocated;
684 }
685 
688  ValueRange operands, OpBuilder &builder,
689  bool useBarePtrCallConv) const {
690  SmallVector<Value, 4> promotedOperands;
691  promotedOperands.reserve(operands.size());
692  useBarePtrCallConv |= options.useBarePtrCallConv;
693  for (auto it : llvm::zip(opOperands, operands)) {
694  auto operand = std::get<0>(it);
695  auto llvmOperand = std::get<1>(it);
696 
697  if (useBarePtrCallConv) {
698  // For the bare-ptr calling convention, we only have to extract the
699  // aligned pointer of a memref.
700  if (dyn_cast<MemRefType>(operand.getType())) {
701  MemRefDescriptor desc(llvmOperand);
702  llvmOperand = desc.alignedPtr(builder, loc);
703  } else if (isa<UnrankedMemRefType>(operand.getType())) {
704  llvm_unreachable("Unranked memrefs are not supported");
705  }
706  } else {
707  if (isa<UnrankedMemRefType>(operand.getType())) {
708  UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
709  promotedOperands);
710  continue;
711  }
712  if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
713  MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
714  promotedOperands);
715  continue;
716  }
717  }
718 
719  promotedOperands.push_back(llvmOperand);
720  }
721  return promotedOperands;
722 }
723 
724 /// Callback to convert function argument types. It converts a MemRef function
725 /// argument to a list of non-aggregate types containing descriptor
726 /// information, and an UnrankedmemRef function argument to a list containing
727 /// the rank and a pointer to a descriptor struct.
728 LogicalResult
730  SmallVectorImpl<Type> &result) {
731  if (auto memref = dyn_cast<MemRefType>(type)) {
732  // In signatures, Memref descriptors are expanded into lists of
733  // non-aggregate values.
734  auto converted =
735  converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
736  if (converted.empty())
737  return failure();
738  result.append(converted.begin(), converted.end());
739  return success();
740  }
741  if (isa<UnrankedMemRefType>(type)) {
742  auto converted = converter.getUnrankedMemRefDescriptorFields();
743  if (converted.empty())
744  return failure();
745  result.append(converted.begin(), converted.end());
746  return success();
747  }
748  auto converted = converter.convertType(type);
749  if (!converted)
750  return failure();
751  result.push_back(converted);
752  return success();
753 }
754 
755 /// Callback to convert function argument types. It converts MemRef function
756 /// arguments to bare pointers to the MemRef element type.
757 LogicalResult
759  SmallVectorImpl<Type> &result) {
760  auto llvmTy = converter.convertCallingConventionType(
761  type, /*useBarePointerCallConv=*/true);
762  if (!llvmTy)
763  return failure();
764 
765  result.push_back(llvmTy);
766  return success();
767 }
static llvm::ManagedStatic< PassManagerOptions > options
static void filterByValRefArgAttrs(FunctionOpInterface funcOp, SmallVectorImpl< std::optional< NamedAttribute >> &result)
Returns the llvm.byval or llvm.byref attributes that are present in the function arguments.
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:146
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:136
IntegerType getI64Type()
Definition: Builders.cpp:97
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
unsigned getWidth()
Return the bitwidth of this float type.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
LLVM::LLVMDialect * llvmDialect
Pointer to the LLVM dialect.
llvm::sys::SmartRWMutex< true > callStackMutex
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, Location loc, ArrayRef< Type > stdTypes, SmallVectorImpl< Value > &values) const
Promote the bare pointers in 'values' that resulted from memrefs to descriptors.
DenseMap< uint64_t, std::unique_ptr< SmallVector< Type > > > conversionCallStack
SmallVector< Type > & getCurrentThreadRecursiveStack()
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
Type convertCallingConventionType(Type type, bool useBarePointerCallConv=false) const
Convert a type in the context of the default or bare pointer calling convention.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
SmallVector< Value, 4 > promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder, bool useBarePtrCallConv=false) const
Promote the LLVM representation of all operands including promoting MemRef descriptors to stack and u...
LLVMTypeConverter(MLIRContext *ctx, const DataLayoutAnalysis *analysis=nullptr)
Create an LLVMTypeConverter using the default LowerToLLVMOptions.
unsigned getPointerBitwidth(unsigned addressSpace=0) const
Gets the pointer bitwidth.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
Type getIndexType() const
Gets the LLVM representation of the index type.
friend LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Give structFuncArgTypeConverter access to memref-specific functions.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:108
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:452
static LLVMStructType getIdentified(MLIRContext *context, StringRef name)
Gets or creates an identified struct with the given name in the provided context.
Definition: LLVMTypes.cpp:424
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
llvm::DataLayout dataLayout
The data layout of the module to produce.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
This class helps build Operations.
Definition: Builders.h:212
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
This class provides all of the information necessary to convert a type signature.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
void addArgumentMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal replacement value...
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting an illegal (source) value...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isFloat8E4M3FN() const
Definition: Types.cpp:40
bool isFloat8E3M4() const
Definition: Types.cpp:50
bool isFloat8E4M3FNUZ() const
Definition: Types.cpp:44
bool isFloat8E4M3B11FNUZ() const
Definition: Types.cpp:47
bool isFloat6E3M2FN() const
Definition: Types.cpp:37
bool isFloat8E5M2() const
Definition: Types.cpp:38
bool isFloat8E4M3() const
Definition: Types.cpp:39
bool isFloat8E5M2FNUZ() const
Definition: Types.cpp:41
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:874
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:856
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.