MLIR  20.0.0git
TypeConverter.cpp
Go to the documentation of this file.
1 //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "MemRefDescriptor.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Threading.h"
16 #include <memory>
17 #include <mutex>
18 #include <optional>
19 
20 using namespace mlir;
21 
23  {
24  // Most of the time, the entry already exists in the map.
25  std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
26  std::defer_lock);
27  if (getContext().isMultithreadingEnabled())
28  lock.lock();
29  auto recursiveStack = conversionCallStack.find(llvm::get_threadid());
30  if (recursiveStack != conversionCallStack.end())
31  return *recursiveStack->second;
32  }
33 
34  // First time this thread gets here, we have to get an exclusive access to
35  // inset in the map
36  std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
37  auto recursiveStackInserted = conversionCallStack.insert(std::make_pair(
38  llvm::get_threadid(), std::make_unique<SmallVector<Type>>()));
39  return *recursiveStackInserted.first->second;
40 }
41 
42 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
44  const DataLayoutAnalysis *analysis)
45  : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
46 
47 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
50  const DataLayoutAnalysis *analysis)
51  : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
52  dataLayoutAnalysis(analysis) {
53  assert(llvmDialect && "LLVM IR dialect is not registered");
54 
55  // Register conversions for the builtin types.
56  addConversion([&](ComplexType type) { return convertComplexType(type); });
57  addConversion([&](FloatType type) { return convertFloatType(type); });
58  addConversion([&](FunctionType type) { return convertFunctionType(type); });
59  addConversion([&](IndexType type) { return convertIndexType(type); });
60  addConversion([&](IntegerType type) { return convertIntegerType(type); });
61  addConversion([&](MemRefType type) { return convertMemRefType(type); });
63  [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
64  addConversion([&](VectorType type) -> std::optional<Type> {
65  FailureOr<Type> llvmType = convertVectorType(type);
66  if (failed(llvmType))
67  return std::nullopt;
68  return llvmType;
69  });
70 
71  // LLVM-compatible types are legal, so add a pass-through conversion. Do this
72  // before the conversions below since conversions are attempted in reverse
73  // order and those should take priority.
74  addConversion([](Type type) {
75  return LLVM::isCompatibleType(type) ? std::optional<Type>(type)
76  : std::nullopt;
77  });
78 
80  -> std::optional<LogicalResult> {
81  // Fastpath for types that won't be converted by this callback anyway.
82  if (LLVM::isCompatibleType(type)) {
83  results.push_back(type);
84  return success();
85  }
86 
87  if (type.isIdentified()) {
88  auto convertedType = LLVM::LLVMStructType::getIdentified(
89  type.getContext(), ("_Converted." + type.getName()).str());
90 
92  if (llvm::count(recursiveStack, type)) {
93  results.push_back(convertedType);
94  return success();
95  }
96  recursiveStack.push_back(type);
97  auto popConversionCallStack = llvm::make_scope_exit(
98  [&recursiveStack]() { recursiveStack.pop_back(); });
99 
100  SmallVector<Type> convertedElemTypes;
101  convertedElemTypes.reserve(type.getBody().size());
102  if (failed(convertTypes(type.getBody(), convertedElemTypes)))
103  return std::nullopt;
104 
105  // If the converted type has not been initialized yet, just set its body
106  // to be the converted arguments and return.
107  if (!convertedType.isInitialized()) {
108  if (failed(
109  convertedType.setBody(convertedElemTypes, type.isPacked()))) {
110  return failure();
111  }
112  results.push_back(convertedType);
113  return success();
114  }
115 
116  // If it has been initialized, has the same body and packed bit, just use
117  // it. This ensures that recursive structs keep being recursive rather
118  // than including a non-updated name.
119  if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) &&
120  convertedType.isPacked() == type.isPacked()) {
121  results.push_back(convertedType);
122  return success();
123  }
124 
125  return failure();
126  }
127 
128  SmallVector<Type> convertedSubtypes;
129  convertedSubtypes.reserve(type.getBody().size());
130  if (failed(convertTypes(type.getBody(), convertedSubtypes)))
131  return std::nullopt;
132 
133  results.push_back(LLVM::LLVMStructType::getLiteral(
134  type.getContext(), convertedSubtypes, type.isPacked()));
135  return success();
136  });
137  addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> {
138  if (auto element = convertType(type.getElementType()))
139  return LLVM::LLVMArrayType::get(element, type.getNumElements());
140  return std::nullopt;
141  });
142  addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
143  Type convertedResType = convertType(type.getReturnType());
144  if (!convertedResType)
145  return std::nullopt;
146 
147  SmallVector<Type> convertedArgTypes;
148  convertedArgTypes.reserve(type.getNumParams());
149  if (failed(convertTypes(type.getParams(), convertedArgTypes)))
150  return std::nullopt;
151 
152  return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
153  type.isVarArg());
154  });
155 
156  // Argument materializations convert from the new block argument types
157  // (multiple SSA values that make up a memref descriptor) back to the
158  // original block argument type. The dialect conversion framework will then
159  // insert a target materialization from the original block argument type to
160  // a legal type.
162  UnrankedMemRefType resultType,
163  ValueRange inputs, Location loc) {
164  if (inputs.size() == 1) {
165  // Bare pointers are not supported for unranked memrefs because a
166  // memref descriptor cannot be built just from a bare pointer.
167  return Value();
168  }
169  Value desc =
170  UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
171  // An argument materialization must return a value of type
172  // `resultType`, so insert a cast from the memref descriptor type
173  // (!llvm.struct) to the original memref type.
174  return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
175  .getResult(0);
176  });
177  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
178  ValueRange inputs, Location loc) {
179  Value desc;
180  if (inputs.size() == 1) {
181  // This is a bare pointer. We allow bare pointers only for function entry
182  // blocks.
183  BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
184  if (!barePtr)
185  return Value();
186  Block *block = barePtr.getOwner();
187  if (!block->isEntryBlock() ||
188  !isa<FunctionOpInterface>(block->getParentOp()))
189  return Value();
190  desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
191  inputs[0]);
192  } else {
193  desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
194  }
195  // An argument materialization must return a value of type `resultType`,
196  // so insert a cast from the memref descriptor type (!llvm.struct) to the
197  // original memref type.
198  return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
199  .getResult(0);
200  });
201  // Add generic source and target materializations to handle cases where
202  // non-LLVM types persist after an LLVM conversion.
203  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
204  ValueRange inputs, Location loc) {
205  if (inputs.size() != 1)
206  return Value();
207 
208  return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
209  .getResult(0);
210  });
211  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
212  ValueRange inputs, Location loc) {
213  if (inputs.size() != 1)
214  return Value();
215 
216  return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
217  .getResult(0);
218  });
219 
220  // Integer memory spaces map to themselves.
222  [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
223 }
224 
225 /// Returns the MLIR context.
227  return *getDialect()->getContext();
228 }
229 
232 }
233 
234 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const {
235  return options.dataLayout.getPointerSizeInBits(addressSpace);
236 }
237 
238 Type LLVMTypeConverter::convertIndexType(IndexType type) const {
239  return getIndexType();
240 }
241 
242 Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
243  return IntegerType::get(&getContext(), type.getWidth());
244 }
245 
246 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
247  if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
248  type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
249  type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
250  type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
251  type.isFloat8E8M0FNU())
252  return IntegerType::get(&getContext(), type.getWidth());
253  return type;
254 }
255 
256 // Convert a `ComplexType` to an LLVM type. The result is a complex number
257 // struct with entries for the
258 // 1. real part and for the
259 // 2. imaginary part.
260 Type LLVMTypeConverter::convertComplexType(ComplexType type) const {
261  auto elementType = convertType(type.getElementType());
263  {elementType, elementType});
264 }
265 
266 // Except for signatures, MLIR function types are converted into LLVM
267 // pointer-to-function types.
268 Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
269  return LLVM::LLVMPointerType::get(type.getContext());
270 }
271 
272 /// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
273 /// function arguments. Returns an empty container if none of these attributes
274 /// are found in any of the arguments.
275 static void
276 filterByValRefArgAttrs(FunctionOpInterface funcOp,
277  SmallVectorImpl<std::optional<NamedAttribute>> &result) {
278  assert(result.empty() && "Unexpected non-empty output");
279  result.resize(funcOp.getNumArguments(), std::nullopt);
280  bool foundByValByRefAttrs = false;
281  for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
282  for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
283  if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
284  namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
285  foundByValByRefAttrs = true;
286  result[argIdx] = namedAttr;
287  break;
288  }
289  }
290  }
291 
292  if (!foundByValByRefAttrs)
293  result.clear();
294 }
295 
296 // Function types are converted to LLVM Function types by recursively converting
297 // argument and result types. If MLIR Function has zero results, the LLVM
298 // Function has one VoidType result. If MLIR Function has more than one result,
299 // they are into an LLVM StructType in their order of appearance.
300 // If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and
301 // `llvm.byref` function arguments which are not LLVM pointers are overridden
302 // with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already
303 // converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`.
304 Type LLVMTypeConverter::convertFunctionSignatureImpl(
305  FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
306  LLVMTypeConverter::SignatureConversion &result,
307  SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) const {
308  // Select the argument converter depending on the calling convention.
309  useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
310  auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
312  // Convert argument types one by one and check for errors.
313  for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
314  SmallVector<Type, 8> converted;
315  if (failed(funcArgConverter(*this, type, converted)))
316  return {};
317 
318  // Rewrite converted type of `llvm.byval` or `llvm.byref` function
319  // argument that was not converted to an LLVM pointer types.
320  if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() &&
321  converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
322  // If the argument was already converted to an LLVM pointer type, we stop
323  // tracking it as it doesn't need more processing.
324  if (isa<LLVM::LLVMPointerType>(converted[0]))
325  (*byValRefNonPtrAttrs)[idx] = std::nullopt;
326  else
327  converted[0] = LLVM::LLVMPointerType::get(&getContext());
328  }
329 
330  result.addInputs(idx, converted);
331  }
332 
333  // If function does not return anything, create the void result type,
334  // if it returns on element, convert it, otherwise pack the result types into
335  // a struct.
336  Type resultType =
337  funcTy.getNumResults() == 0
339  : packFunctionResults(funcTy.getResults(), useBarePtrCallConv);
340  if (!resultType)
341  return {};
342  return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
343  isVariadic);
344 }
345 
347  FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
349  return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
350  result,
351  /*byValRefNonPtrAttrs=*/nullptr);
352 }
353 
355  FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv,
357  SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs) const {
358  // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those
359  // that were not converted to LLVM pointer types will be returned for further
360  // processing.
361  filterByValRefArgAttrs(funcOp, byValRefNonPtrAttrs);
362  auto funcTy = cast<FunctionType>(funcOp.getFunctionType());
363  return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
364  result, &byValRefNonPtrAttrs);
365 }
366 
367 /// Converts the function type to a C-compatible format, in particular using
368 /// pointers to memref descriptors for arguments.
369 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
371  SmallVector<Type, 4> inputs;
372 
373  Type resultType = type.getNumResults() == 0
375  : packFunctionResults(type.getResults());
376  if (!resultType)
377  return {};
378 
379  auto ptrType = LLVM::LLVMPointerType::get(type.getContext());
380  auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
381  if (structType) {
382  // Struct types cannot be safely returned via C interface. Make this a
383  // pointer argument, instead.
384  inputs.push_back(ptrType);
385  resultType = LLVM::LLVMVoidType::get(&getContext());
386  }
387 
388  for (Type t : type.getInputs()) {
389  auto converted = convertType(t);
390  if (!converted || !LLVM::isCompatibleType(converted))
391  return {};
392  if (isa<MemRefType, UnrankedMemRefType>(t))
393  converted = ptrType;
394  inputs.push_back(converted);
395  }
396 
397  return {LLVM::LLVMFunctionType::get(resultType, inputs), structType};
398 }
399 
400 /// Convert a memref type into a list of LLVM IR types that will form the
401 /// memref descriptor. The result contains the following types:
402 /// 1. The pointer to the allocated data buffer, followed by
403 /// 2. The pointer to the aligned data buffer, followed by
404 /// 3. A lowered `index`-type integer containing the distance between the
405 /// beginning of the buffer and the first element to be accessed through the
406 /// view, followed by
407 /// 4. An array containing as many `index`-type integers as the rank of the
408 /// MemRef: the array represents the size, in number of elements, of the memref
409 /// along the given dimension. For constant MemRef dimensions, the
410 /// corresponding size entry is a constant whose runtime value must match the
411 /// static value, followed by
412 /// 5. A second array containing as many `index`-type integers as the rank of
413 /// the MemRef: the second array represents the "stride" (in tensor abstraction
414 /// sense), i.e. the number of consecutive elements of the underlying buffer.
415 /// TODO: add assertions for the static cases.
416 ///
417 /// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
418 /// are expanded into individual index-type elements.
419 ///
420 /// template <typename Elem, typename Index, size_t Rank>
421 /// struct {
422 /// Elem *allocatedPtr;
423 /// Elem *alignedPtr;
424 /// Index offset;
425 /// Index sizes[Rank]; // omitted when rank == 0
426 /// Index strides[Rank]; // omitted when rank == 0
427 /// };
429 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
430  bool unpackAggregates) const {
431  if (!isStrided(type)) {
432  emitError(
433  UnknownLoc::get(type.getContext()),
434  "conversion to strided form failed either due to non-strided layout "
435  "maps (which should have been normalized away) or other reasons");
436  return {};
437  }
438 
439  Type elementType = convertType(type.getElementType());
440  if (!elementType)
441  return {};
442 
443  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
444  if (failed(addressSpace)) {
445  emitError(UnknownLoc::get(type.getContext()),
446  "conversion of memref memory space ")
447  << type.getMemorySpace()
448  << " to integer address space "
449  "failed. Consider adding memory space conversions.";
450  return {};
451  }
452  auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
453 
454  auto indexTy = getIndexType();
455 
456  SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
457  auto rank = type.getRank();
458  if (rank == 0)
459  return results;
460 
461  if (unpackAggregates)
462  results.insert(results.end(), 2 * rank, indexTy);
463  else
464  results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
465  return results;
466 }
467 
468 unsigned
470  const DataLayout &layout) const {
471  // Compute the descriptor size given that of its components indicated above.
472  unsigned space = *getMemRefAddressSpace(type);
473  return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
474  (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
475 }
476 
477 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
478 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
479 Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
480  // When converting a MemRefType to a struct with descriptor fields, do not
481  // unpack the `sizes` and `strides` arrays.
482  SmallVector<Type, 5> types =
483  getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
484  if (types.empty())
485  return {};
487 }
488 
489 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
490 /// that will form the unranked memref descriptor. In particular, the fields
491 /// for an unranked memref descriptor are:
492 /// 1. index-typed rank, the dynamic rank of this MemRef
493 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
494 /// stack allocated (alloca) copy of a MemRef descriptor that got casted to
495 /// be unranked.
497 LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const {
499 }
500 
502  UnrankedMemRefType type, const DataLayout &layout) const {
503  // Compute the descriptor size given that of its components indicated above.
504  unsigned space = *getMemRefAddressSpace(type);
505  return layout.getTypeSize(getIndexType()) +
507 }
508 
509 Type LLVMTypeConverter::convertUnrankedMemRefType(
510  UnrankedMemRefType type) const {
511  if (!convertType(type.getElementType()))
512  return {};
514  getUnrankedMemRefDescriptorFields());
515 }
516 
517 FailureOr<unsigned>
519  if (!type.getMemorySpace()) // Default memory space -> 0.
520  return 0;
521  std::optional<Attribute> converted =
522  convertTypeAttribute(type, type.getMemorySpace());
523  if (!converted)
524  return failure();
525  if (!(*converted)) // Conversion to default is 0.
526  return 0;
527  if (auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) {
528  if (explicitSpace.getType().isIndex() ||
529  explicitSpace.getType().isSignlessInteger())
530  return explicitSpace.getInt();
531  }
532  return failure();
533 }
534 
535 // Check if a memref type can be converted to a bare pointer.
537  if (isa<UnrankedMemRefType>(type))
538  // Unranked memref is not supported in the bare pointer calling convention.
539  return false;
540 
541  // Check that the memref has static shape, strides and offset. Otherwise, it
542  // cannot be lowered to a bare pointer.
543  auto memrefTy = cast<MemRefType>(type);
544  if (!memrefTy.hasStaticShape())
545  return false;
546 
547  int64_t offset = 0;
548  SmallVector<int64_t, 4> strides;
549  if (failed(getStridesAndOffset(memrefTy, strides, offset)))
550  return false;
551 
552  for (int64_t stride : strides)
553  if (ShapedType::isDynamic(stride))
554  return false;
555 
556  return !ShapedType::isDynamic(offset);
557 }
558 
559 /// Convert a memref type to a bare pointer to the memref element type.
560 Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
561  if (!canConvertToBarePtr(type))
562  return {};
563  Type elementType = convertType(type.getElementType());
564  if (!elementType)
565  return {};
566  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
567  if (failed(addressSpace))
568  return {};
569  return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
570 }
571 
572 /// Convert an n-D vector type to an LLVM vector type:
573 /// * 0-D `vector<T>` are converted to vector<1xT>
574 /// * 1-D `vector<axT>` remains as is while,
575 /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
576 /// `!llvm.array<ax...array<jxvector<kxT>>>`.
577 /// As LLVM supports arrays of scalable vectors, this method will also convert
578 /// n-D scalable vectors provided that only the trailing dim is scalable.
579 FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
580  auto elementType = convertType(type.getElementType());
581  if (!elementType)
582  return {};
583  if (type.getShape().empty())
584  return VectorType::get({1}, elementType);
585  Type vectorType = VectorType::get(type.getShape().back(), elementType,
586  type.getScalableDims().back());
587  assert(LLVM::isCompatibleVectorType(vectorType) &&
588  "expected vector type compatible with the LLVM dialect");
589  // For n-D vector types for which a _non-trailing_ dim is scalable,
590  // return a failure. Supporting such cases would require LLVM
591  // to support something akin "scalable arrays" of vectors.
592  if (llvm::is_contained(type.getScalableDims().drop_back(), true))
593  return failure();
594  auto shape = type.getShape();
595  for (int i = shape.size() - 2; i >= 0; --i)
596  vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
597  return vectorType;
598 }
599 
600 /// Convert a type in the context of the default or bare pointer calling
601 /// convention. Calling convention sensitive types, such as MemRefType and
602 /// UnrankedMemRefType, are converted following the specific rules for the
603 /// calling convention. Calling convention independent types are converted
604 /// following the default LLVM type conversions.
606  Type type, bool useBarePtrCallConv) const {
607  if (useBarePtrCallConv)
608  if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
609  return convertMemRefToBarePtr(memrefTy);
610 
611  return convertType(type);
612 }
613 
614 /// Promote the bare pointers in 'values' that resulted from memrefs to
615 /// descriptors. 'stdTypes' holds they types of 'values' before the conversion
616 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
618  ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
619  SmallVectorImpl<Value> &values) const {
620  assert(stdTypes.size() == values.size() &&
621  "The number of types and values doesn't match");
622  for (unsigned i = 0, end = values.size(); i < end; ++i)
623  if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
624  values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
625  memrefTy, values[i]);
626 }
627 
628 /// Convert a non-empty list of types of values produced by an operation into an
629 /// LLVM-compatible type. In particular, if more than one value is
630 /// produced, create a literal structure with elements that correspond to each
631 /// of the types converted with `convertType`.
633  assert(!types.empty() && "expected non-empty list of type");
634  if (types.size() == 1)
635  return convertType(types[0]);
636 
637  SmallVector<Type> resultTypes;
638  resultTypes.reserve(types.size());
639  for (Type type : types) {
640  Type converted = convertType(type);
641  if (!converted || !LLVM::isCompatibleType(converted))
642  return {};
643  resultTypes.push_back(converted);
644  }
645 
646  return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
647 }
648 
649 /// Convert a non-empty list of types to be returned from a function into an
650 /// LLVM-compatible type. In particular, if more than one value is returned,
651 /// create an LLVM dialect structure type with elements that correspond to each
652 /// of the types converted with `convertCallingConventionType`.
654  bool useBarePtrCallConv) const {
655  assert(!types.empty() && "expected non-empty list of type");
656 
657  useBarePtrCallConv |= options.useBarePtrCallConv;
658  if (types.size() == 1)
659  return convertCallingConventionType(types.front(), useBarePtrCallConv);
660 
661  SmallVector<Type> resultTypes;
662  resultTypes.reserve(types.size());
663  for (auto t : types) {
664  auto converted = convertCallingConventionType(t, useBarePtrCallConv);
665  if (!converted || !LLVM::isCompatibleType(converted))
666  return {};
667  resultTypes.push_back(converted);
668  }
669 
670  return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
671 }
672 
674  OpBuilder &builder) const {
675  // Alloca with proper alignment. We do not expect optimizations of this
676  // alloca op and so we omit allocating at the entry block.
677  auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
678  Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
679  builder.getIndexAttr(1));
680  Value allocated =
681  builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one);
682  // Store into the alloca'ed descriptor.
683  builder.create<LLVM::StoreOp>(loc, operand, allocated);
684  return allocated;
685 }
686 
689  ValueRange operands, OpBuilder &builder,
690  bool useBarePtrCallConv) const {
691  SmallVector<Value, 4> promotedOperands;
692  promotedOperands.reserve(operands.size());
693  useBarePtrCallConv |= options.useBarePtrCallConv;
694  for (auto it : llvm::zip(opOperands, operands)) {
695  auto operand = std::get<0>(it);
696  auto llvmOperand = std::get<1>(it);
697 
698  if (useBarePtrCallConv) {
699  // For the bare-ptr calling convention, we only have to extract the
700  // aligned pointer of a memref.
701  if (dyn_cast<MemRefType>(operand.getType())) {
702  MemRefDescriptor desc(llvmOperand);
703  llvmOperand = desc.alignedPtr(builder, loc);
704  } else if (isa<UnrankedMemRefType>(operand.getType())) {
705  llvm_unreachable("Unranked memrefs are not supported");
706  }
707  } else {
708  if (isa<UnrankedMemRefType>(operand.getType())) {
709  UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
710  promotedOperands);
711  continue;
712  }
713  if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
714  MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
715  promotedOperands);
716  continue;
717  }
718  }
719 
720  promotedOperands.push_back(llvmOperand);
721  }
722  return promotedOperands;
723 }
724 
725 /// Callback to convert function argument types. It converts a MemRef function
726 /// argument to a list of non-aggregate types containing descriptor
727 /// information, and an UnrankedmemRef function argument to a list containing
728 /// the rank and a pointer to a descriptor struct.
729 LogicalResult
731  SmallVectorImpl<Type> &result) {
732  if (auto memref = dyn_cast<MemRefType>(type)) {
733  // In signatures, Memref descriptors are expanded into lists of
734  // non-aggregate values.
735  auto converted =
736  converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
737  if (converted.empty())
738  return failure();
739  result.append(converted.begin(), converted.end());
740  return success();
741  }
742  if (isa<UnrankedMemRefType>(type)) {
743  auto converted = converter.getUnrankedMemRefDescriptorFields();
744  if (converted.empty())
745  return failure();
746  result.append(converted.begin(), converted.end());
747  return success();
748  }
749  auto converted = converter.convertType(type);
750  if (!converted)
751  return failure();
752  result.push_back(converted);
753  return success();
754 }
755 
756 /// Callback to convert function argument types. It converts MemRef function
757 /// arguments to bare pointers to the MemRef element type.
758 LogicalResult
760  SmallVectorImpl<Type> &result) {
761  auto llvmTy = converter.convertCallingConventionType(
762  type, /*useBarePointerCallConv=*/true);
763  if (!llvmTy)
764  return failure();
765 
766  result.push_back(llvmTy);
767  return success();
768 }
static llvm::ManagedStatic< PassManagerOptions > options
static void filterByValRefArgAttrs(FunctionOpInterface funcOp, SmallVectorImpl< std::optional< NamedAttribute >> &result)
Returns the llvm.byval or llvm.byref attributes that are present in the function arguments.
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:149
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
IntegerType getI64Type()
Definition: Builders.cpp:109
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
unsigned getWidth()
Return the bitwidth of this float type.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
LLVM::LLVMDialect * llvmDialect
Pointer to the LLVM dialect.
llvm::sys::SmartRWMutex< true > callStackMutex
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, Location loc, ArrayRef< Type > stdTypes, SmallVectorImpl< Value > &values) const
Promote the bare pointers in 'values' that resulted from memrefs to descriptors.
DenseMap< uint64_t, std::unique_ptr< SmallVector< Type > > > conversionCallStack
SmallVector< Type > & getCurrentThreadRecursiveStack()
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
Type convertCallingConventionType(Type type, bool useBarePointerCallConv=false) const
Convert a type in the context of the default or bare pointer calling convention.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
SmallVector< Value, 4 > promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder, bool useBarePtrCallConv=false) const
Promote the LLVM representation of all operands including promoting MemRef descriptors to stack and u...
LLVMTypeConverter(MLIRContext *ctx, const DataLayoutAnalysis *analysis=nullptr)
Create an LLVMTypeConverter using the default LowerToLLVMOptions.
unsigned getPointerBitwidth(unsigned addressSpace=0) const
Gets the pointer bitwidth.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
Type getIndexType() const
Gets the LLVM representation of the index type.
friend LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Give structFuncArgTypeConverter access to memref-specific functions.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:108
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:452
static LLVMStructType getIdentified(MLIRContext *context, StringRef name)
Gets or creates an identified struct with the given name in the provided context.
Definition: LLVMTypes.cpp:424
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
llvm::DataLayout dataLayout
The data layout of the module to produce.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
This class helps build Operations.
Definition: Builders.h:215
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
This class provides all of the information necessary to convert a type signature.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
void addArgumentMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal replacement value...
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting an illegal (source) value...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isFloat8E4M3FN() const
Definition: Types.cpp:42
bool isFloat8E3M4() const
Definition: Types.cpp:55
bool isFloat8E4M3FNUZ() const
Definition: Types.cpp:46
bool isFloat8E4M3B11FNUZ() const
Definition: Types.cpp:49
bool isFloat6E3M2FN() const
Definition: Types.cpp:39
bool isFloat8E5M2() const
Definition: Types.cpp:40
bool isFloat8E8M0FNU() const
Definition: Types.cpp:52
bool isFloat4E2M1FN() const
Definition: Types.cpp:37
bool isFloat6E2M3FN() const
Definition: Types.cpp:38
bool isFloat8E4M3() const
Definition: Types.cpp:41
bool isFloat8E5M2FNUZ() const
Definition: Types.cpp:43
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:876
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:858
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.