MLIR  20.0.0git
TypeConverter.cpp
Go to the documentation of this file.
1 //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "MemRefDescriptor.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Threading.h"
16 #include <memory>
17 #include <mutex>
18 #include <optional>
19 
20 using namespace mlir;
21 
23  {
24  // Most of the time, the entry already exists in the map.
25  std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
26  std::defer_lock);
27  if (getContext().isMultithreadingEnabled())
28  lock.lock();
29  auto recursiveStack = conversionCallStack.find(llvm::get_threadid());
30  if (recursiveStack != conversionCallStack.end())
31  return *recursiveStack->second;
32  }
33 
34  // First time this thread gets here, we have to get an exclusive access to
35  // inset in the map
36  std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
37  auto recursiveStackInserted = conversionCallStack.insert(std::make_pair(
38  llvm::get_threadid(), std::make_unique<SmallVector<Type>>()));
39  return *recursiveStackInserted.first->second;
40 }
41 
42 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
44  const DataLayoutAnalysis *analysis)
45  : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
46 
47 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
50  const DataLayoutAnalysis *analysis)
51  : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
52  dataLayoutAnalysis(analysis) {
53  assert(llvmDialect && "LLVM IR dialect is not registered");
54 
55  // Register conversions for the builtin types.
56  addConversion([&](ComplexType type) { return convertComplexType(type); });
57  addConversion([&](FloatType type) { return convertFloatType(type); });
58  addConversion([&](FunctionType type) { return convertFunctionType(type); });
59  addConversion([&](IndexType type) { return convertIndexType(type); });
60  addConversion([&](IntegerType type) { return convertIntegerType(type); });
61  addConversion([&](MemRefType type) { return convertMemRefType(type); });
63  [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
64  addConversion([&](VectorType type) -> std::optional<Type> {
65  FailureOr<Type> llvmType = convertVectorType(type);
66  if (failed(llvmType))
67  return std::nullopt;
68  return llvmType;
69  });
70 
71  // LLVM-compatible types are legal, so add a pass-through conversion. Do this
72  // before the conversions below since conversions are attempted in reverse
73  // order and those should take priority.
74  addConversion([](Type type) {
75  return LLVM::isCompatibleType(type) ? std::optional<Type>(type)
76  : std::nullopt;
77  });
78 
80  -> std::optional<LogicalResult> {
81  // Fastpath for types that won't be converted by this callback anyway.
82  if (LLVM::isCompatibleType(type)) {
83  results.push_back(type);
84  return success();
85  }
86 
87  if (type.isIdentified()) {
88  auto convertedType = LLVM::LLVMStructType::getIdentified(
89  type.getContext(), ("_Converted." + type.getName()).str());
90 
92  if (llvm::count(recursiveStack, type)) {
93  results.push_back(convertedType);
94  return success();
95  }
96  recursiveStack.push_back(type);
97  auto popConversionCallStack = llvm::make_scope_exit(
98  [&recursiveStack]() { recursiveStack.pop_back(); });
99 
100  SmallVector<Type> convertedElemTypes;
101  convertedElemTypes.reserve(type.getBody().size());
102  if (failed(convertTypes(type.getBody(), convertedElemTypes)))
103  return std::nullopt;
104 
105  // If the converted type has not been initialized yet, just set its body
106  // to be the converted arguments and return.
107  if (!convertedType.isInitialized()) {
108  if (failed(
109  convertedType.setBody(convertedElemTypes, type.isPacked()))) {
110  return failure();
111  }
112  results.push_back(convertedType);
113  return success();
114  }
115 
116  // If it has been initialized, has the same body and packed bit, just use
117  // it. This ensures that recursive structs keep being recursive rather
118  // than including a non-updated name.
119  if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) &&
120  convertedType.isPacked() == type.isPacked()) {
121  results.push_back(convertedType);
122  return success();
123  }
124 
125  return failure();
126  }
127 
128  SmallVector<Type> convertedSubtypes;
129  convertedSubtypes.reserve(type.getBody().size());
130  if (failed(convertTypes(type.getBody(), convertedSubtypes)))
131  return std::nullopt;
132 
133  results.push_back(LLVM::LLVMStructType::getLiteral(
134  type.getContext(), convertedSubtypes, type.isPacked()));
135  return success();
136  });
137  addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> {
138  if (auto element = convertType(type.getElementType()))
139  return LLVM::LLVMArrayType::get(element, type.getNumElements());
140  return std::nullopt;
141  });
142  addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
143  Type convertedResType = convertType(type.getReturnType());
144  if (!convertedResType)
145  return std::nullopt;
146 
147  SmallVector<Type> convertedArgTypes;
148  convertedArgTypes.reserve(type.getNumParams());
149  if (failed(convertTypes(type.getParams(), convertedArgTypes)))
150  return std::nullopt;
151 
152  return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
153  type.isVarArg());
154  });
155 
156  // Argument materializations convert from the new block argument types
157  // (multiple SSA values that make up a memref descriptor) back to the
158  // original block argument type. The dialect conversion framework will then
159  // insert a target materialization from the original block argument type to
160  // a legal type.
162  [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
163  Location loc) -> std::optional<Value> {
164  if (inputs.size() == 1) {
165  // Bare pointers are not supported for unranked memrefs because a
166  // memref descriptor cannot be built just from a bare pointer.
167  return std::nullopt;
168  }
169  Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this,
170  resultType, inputs);
171  // An argument materialization must return a value of type
172  // `resultType`, so insert a cast from the memref descriptor type
173  // (!llvm.struct) to the original memref type.
174  return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
175  .getResult(0);
176  });
177  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
178  ValueRange inputs,
179  Location loc) -> std::optional<Value> {
180  Value desc;
181  if (inputs.size() == 1) {
182  // This is a bare pointer. We allow bare pointers only for function entry
183  // blocks.
184  BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
185  if (!barePtr)
186  return std::nullopt;
187  Block *block = barePtr.getOwner();
188  if (!block->isEntryBlock() ||
189  !isa<FunctionOpInterface>(block->getParentOp()))
190  return std::nullopt;
191  desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
192  inputs[0]);
193  } else {
194  desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
195  }
196  // An argument materialization must return a value of type `resultType`,
197  // so insert a cast from the memref descriptor type (!llvm.struct) to the
198  // original memref type.
199  return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
200  .getResult(0);
201  });
202  // Add generic source and target materializations to handle cases where
203  // non-LLVM types persist after an LLVM conversion.
204  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
205  ValueRange inputs,
206  Location loc) -> std::optional<Value> {
207  if (inputs.size() != 1)
208  return std::nullopt;
209 
210  return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
211  .getResult(0);
212  });
213  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
214  ValueRange inputs,
215  Location loc) -> std::optional<Value> {
216  if (inputs.size() != 1)
217  return std::nullopt;
218 
219  return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
220  .getResult(0);
221  });
222 
223  // Integer memory spaces map to themselves.
225  [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
226 }
227 
228 /// Returns the MLIR context.
230  return *getDialect()->getContext();
231 }
232 
235 }
236 
237 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const {
238  return options.dataLayout.getPointerSizeInBits(addressSpace);
239 }
240 
241 Type LLVMTypeConverter::convertIndexType(IndexType type) const {
242  return getIndexType();
243 }
244 
245 Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
246  return IntegerType::get(&getContext(), type.getWidth());
247 }
248 
249 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
250  if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
251  type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
252  type.isFloat8E4M3B11FNUZ())
253  return IntegerType::get(&getContext(), type.getWidth());
254  return type;
255 }
256 
257 // Convert a `ComplexType` to an LLVM type. The result is a complex number
258 // struct with entries for the
259 // 1. real part and for the
260 // 2. imaginary part.
261 Type LLVMTypeConverter::convertComplexType(ComplexType type) const {
262  auto elementType = convertType(type.getElementType());
264  {elementType, elementType});
265 }
266 
267 // Except for signatures, MLIR function types are converted into LLVM
268 // pointer-to-function types.
269 Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
270  return LLVM::LLVMPointerType::get(type.getContext());
271 }
272 
273 // Function types are converted to LLVM Function types by recursively converting
274 // argument and result types. If MLIR Function has zero results, the LLVM
275 // Function has one VoidType result. If MLIR Function has more than one result,
276 // they are into an LLVM StructType in their order of appearance.
278  FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
280  // Select the argument converter depending on the calling convention.
281  useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
282  auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
284  // Convert argument types one by one and check for errors.
285  for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
286  SmallVector<Type, 8> converted;
287  if (failed(funcArgConverter(*this, type, converted)))
288  return {};
289  result.addInputs(idx, converted);
290  }
291 
292  // If function does not return anything, create the void result type,
293  // if it returns on element, convert it, otherwise pack the result types into
294  // a struct.
295  Type resultType =
296  funcTy.getNumResults() == 0
298  : packFunctionResults(funcTy.getResults(), useBarePtrCallConv);
299  if (!resultType)
300  return {};
301  return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
302  isVariadic);
303 }
304 
305 /// Converts the function type to a C-compatible format, in particular using
306 /// pointers to memref descriptors for arguments.
307 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
309  SmallVector<Type, 4> inputs;
310 
311  Type resultType = type.getNumResults() == 0
313  : packFunctionResults(type.getResults());
314  if (!resultType)
315  return {};
316 
317  auto ptrType = LLVM::LLVMPointerType::get(type.getContext());
318  auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
319  if (structType) {
320  // Struct types cannot be safely returned via C interface. Make this a
321  // pointer argument, instead.
322  inputs.push_back(ptrType);
323  resultType = LLVM::LLVMVoidType::get(&getContext());
324  }
325 
326  for (Type t : type.getInputs()) {
327  auto converted = convertType(t);
328  if (!converted || !LLVM::isCompatibleType(converted))
329  return {};
330  if (isa<MemRefType, UnrankedMemRefType>(t))
331  converted = ptrType;
332  inputs.push_back(converted);
333  }
334 
335  return {LLVM::LLVMFunctionType::get(resultType, inputs), structType};
336 }
337 
338 /// Convert a memref type into a list of LLVM IR types that will form the
339 /// memref descriptor. The result contains the following types:
340 /// 1. The pointer to the allocated data buffer, followed by
341 /// 2. The pointer to the aligned data buffer, followed by
342 /// 3. A lowered `index`-type integer containing the distance between the
343 /// beginning of the buffer and the first element to be accessed through the
344 /// view, followed by
345 /// 4. An array containing as many `index`-type integers as the rank of the
346 /// MemRef: the array represents the size, in number of elements, of the memref
347 /// along the given dimension. For constant MemRef dimensions, the
348 /// corresponding size entry is a constant whose runtime value must match the
349 /// static value, followed by
350 /// 5. A second array containing as many `index`-type integers as the rank of
351 /// the MemRef: the second array represents the "stride" (in tensor abstraction
352 /// sense), i.e. the number of consecutive elements of the underlying buffer.
353 /// TODO: add assertions for the static cases.
354 ///
355 /// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
356 /// are expanded into individual index-type elements.
357 ///
358 /// template <typename Elem, typename Index, size_t Rank>
359 /// struct {
360 /// Elem *allocatedPtr;
361 /// Elem *alignedPtr;
362 /// Index offset;
363 /// Index sizes[Rank]; // omitted when rank == 0
364 /// Index strides[Rank]; // omitted when rank == 0
365 /// };
367 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
368  bool unpackAggregates) const {
369  if (!isStrided(type)) {
370  emitError(
371  UnknownLoc::get(type.getContext()),
372  "conversion to strided form failed either due to non-strided layout "
373  "maps (which should have been normalized away) or other reasons");
374  return {};
375  }
376 
377  Type elementType = convertType(type.getElementType());
378  if (!elementType)
379  return {};
380 
381  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
382  if (failed(addressSpace)) {
383  emitError(UnknownLoc::get(type.getContext()),
384  "conversion of memref memory space ")
385  << type.getMemorySpace()
386  << " to integer address space "
387  "failed. Consider adding memory space conversions.";
388  return {};
389  }
390  auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
391 
392  auto indexTy = getIndexType();
393 
394  SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
395  auto rank = type.getRank();
396  if (rank == 0)
397  return results;
398 
399  if (unpackAggregates)
400  results.insert(results.end(), 2 * rank, indexTy);
401  else
402  results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
403  return results;
404 }
405 
406 unsigned
408  const DataLayout &layout) const {
409  // Compute the descriptor size given that of its components indicated above.
410  unsigned space = *getMemRefAddressSpace(type);
411  return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
412  (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
413 }
414 
415 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
416 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
417 Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
418  // When converting a MemRefType to a struct with descriptor fields, do not
419  // unpack the `sizes` and `strides` arrays.
420  SmallVector<Type, 5> types =
421  getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
422  if (types.empty())
423  return {};
425 }
426 
427 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
428 /// that will form the unranked memref descriptor. In particular, the fields
429 /// for an unranked memref descriptor are:
430 /// 1. index-typed rank, the dynamic rank of this MemRef
431 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
432 /// stack allocated (alloca) copy of a MemRef descriptor that got casted to
433 /// be unranked.
435 LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const {
437 }
438 
440  UnrankedMemRefType type, const DataLayout &layout) const {
441  // Compute the descriptor size given that of its components indicated above.
442  unsigned space = *getMemRefAddressSpace(type);
443  return layout.getTypeSize(getIndexType()) +
445 }
446 
447 Type LLVMTypeConverter::convertUnrankedMemRefType(
448  UnrankedMemRefType type) const {
449  if (!convertType(type.getElementType()))
450  return {};
452  getUnrankedMemRefDescriptorFields());
453 }
454 
455 FailureOr<unsigned>
457  if (!type.getMemorySpace()) // Default memory space -> 0.
458  return 0;
459  std::optional<Attribute> converted =
460  convertTypeAttribute(type, type.getMemorySpace());
461  if (!converted)
462  return failure();
463  if (!(*converted)) // Conversion to default is 0.
464  return 0;
465  if (auto explicitSpace = llvm::dyn_cast_if_present<IntegerAttr>(*converted))
466  return explicitSpace.getInt();
467  return failure();
468 }
469 
470 // Check if a memref type can be converted to a bare pointer.
472  if (isa<UnrankedMemRefType>(type))
473  // Unranked memref is not supported in the bare pointer calling convention.
474  return false;
475 
476  // Check that the memref has static shape, strides and offset. Otherwise, it
477  // cannot be lowered to a bare pointer.
478  auto memrefTy = cast<MemRefType>(type);
479  if (!memrefTy.hasStaticShape())
480  return false;
481 
482  int64_t offset = 0;
483  SmallVector<int64_t, 4> strides;
484  if (failed(getStridesAndOffset(memrefTy, strides, offset)))
485  return false;
486 
487  for (int64_t stride : strides)
488  if (ShapedType::isDynamic(stride))
489  return false;
490 
491  return !ShapedType::isDynamic(offset);
492 }
493 
494 /// Convert a memref type to a bare pointer to the memref element type.
495 Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
496  if (!canConvertToBarePtr(type))
497  return {};
498  Type elementType = convertType(type.getElementType());
499  if (!elementType)
500  return {};
501  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
502  if (failed(addressSpace))
503  return {};
504  return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
505 }
506 
507 /// Convert an n-D vector type to an LLVM vector type:
508 /// * 0-D `vector<T>` are converted to vector<1xT>
509 /// * 1-D `vector<axT>` remains as is while,
510 /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
511 /// `!llvm.array<ax...array<jxvector<kxT>>>`.
512 /// Returns failure for n-D scalable vector types as LLVM does not support
513 /// arrays of scalable vectors.
514 FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
515  auto elementType = convertType(type.getElementType());
516  if (!elementType)
517  return {};
518  if (type.getShape().empty())
519  return VectorType::get({1}, elementType);
520  Type vectorType = VectorType::get(type.getShape().back(), elementType,
521  type.getScalableDims().back());
522  assert(LLVM::isCompatibleVectorType(vectorType) &&
523  "expected vector type compatible with the LLVM dialect");
524  // Only the trailing dimension can be scalable.
525  if (llvm::is_contained(type.getScalableDims().drop_back(), true))
526  return failure();
527  auto shape = type.getShape();
528  for (int i = shape.size() - 2; i >= 0; --i)
529  vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
530  return vectorType;
531 }
532 
533 /// Convert a type in the context of the default or bare pointer calling
534 /// convention. Calling convention sensitive types, such as MemRefType and
535 /// UnrankedMemRefType, are converted following the specific rules for the
536 /// calling convention. Calling convention independent types are converted
537 /// following the default LLVM type conversions.
539  Type type, bool useBarePtrCallConv) const {
540  if (useBarePtrCallConv)
541  if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
542  return convertMemRefToBarePtr(memrefTy);
543 
544  return convertType(type);
545 }
546 
547 /// Promote the bare pointers in 'values' that resulted from memrefs to
548 /// descriptors. 'stdTypes' holds they types of 'values' before the conversion
549 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
551  ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
552  SmallVectorImpl<Value> &values) const {
553  assert(stdTypes.size() == values.size() &&
554  "The number of types and values doesn't match");
555  for (unsigned i = 0, end = values.size(); i < end; ++i)
556  if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
557  values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
558  memrefTy, values[i]);
559 }
560 
561 /// Convert a non-empty list of types of values produced by an operation into an
562 /// LLVM-compatible type. In particular, if more than one value is
563 /// produced, create a literal structure with elements that correspond to each
564 /// of the types converted with `convertType`.
566  assert(!types.empty() && "expected non-empty list of type");
567  if (types.size() == 1)
568  return convertType(types[0]);
569 
570  SmallVector<Type> resultTypes;
571  resultTypes.reserve(types.size());
572  for (Type type : types) {
573  Type converted = convertType(type);
574  if (!converted || !LLVM::isCompatibleType(converted))
575  return {};
576  resultTypes.push_back(converted);
577  }
578 
579  return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
580 }
581 
582 /// Convert a non-empty list of types to be returned from a function into an
583 /// LLVM-compatible type. In particular, if more than one value is returned,
584 /// create an LLVM dialect structure type with elements that correspond to each
585 /// of the types converted with `convertCallingConventionType`.
587  bool useBarePtrCallConv) const {
588  assert(!types.empty() && "expected non-empty list of type");
589 
590  useBarePtrCallConv |= options.useBarePtrCallConv;
591  if (types.size() == 1)
592  return convertCallingConventionType(types.front(), useBarePtrCallConv);
593 
594  SmallVector<Type> resultTypes;
595  resultTypes.reserve(types.size());
596  for (auto t : types) {
597  auto converted = convertCallingConventionType(t, useBarePtrCallConv);
598  if (!converted || !LLVM::isCompatibleType(converted))
599  return {};
600  resultTypes.push_back(converted);
601  }
602 
603  return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
604 }
605 
607  OpBuilder &builder) const {
608  // Alloca with proper alignment. We do not expect optimizations of this
609  // alloca op and so we omit allocating at the entry block.
610  auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
611  Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
612  builder.getIndexAttr(1));
613  Value allocated =
614  builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one);
615  // Store into the alloca'ed descriptor.
616  builder.create<LLVM::StoreOp>(loc, operand, allocated);
617  return allocated;
618 }
619 
622  ValueRange operands, OpBuilder &builder,
623  bool useBarePtrCallConv) const {
624  SmallVector<Value, 4> promotedOperands;
625  promotedOperands.reserve(operands.size());
626  useBarePtrCallConv |= options.useBarePtrCallConv;
627  for (auto it : llvm::zip(opOperands, operands)) {
628  auto operand = std::get<0>(it);
629  auto llvmOperand = std::get<1>(it);
630 
631  if (useBarePtrCallConv) {
632  // For the bare-ptr calling convention, we only have to extract the
633  // aligned pointer of a memref.
634  if (dyn_cast<MemRefType>(operand.getType())) {
635  MemRefDescriptor desc(llvmOperand);
636  llvmOperand = desc.alignedPtr(builder, loc);
637  } else if (isa<UnrankedMemRefType>(operand.getType())) {
638  llvm_unreachable("Unranked memrefs are not supported");
639  }
640  } else {
641  if (isa<UnrankedMemRefType>(operand.getType())) {
642  UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
643  promotedOperands);
644  continue;
645  }
646  if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
647  MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
648  promotedOperands);
649  continue;
650  }
651  }
652 
653  promotedOperands.push_back(llvmOperand);
654  }
655  return promotedOperands;
656 }
657 
658 /// Callback to convert function argument types. It converts a MemRef function
659 /// argument to a list of non-aggregate types containing descriptor
660 /// information, and an UnrankedmemRef function argument to a list containing
661 /// the rank and a pointer to a descriptor struct.
662 LogicalResult
664  SmallVectorImpl<Type> &result) {
665  if (auto memref = dyn_cast<MemRefType>(type)) {
666  // In signatures, Memref descriptors are expanded into lists of
667  // non-aggregate values.
668  auto converted =
669  converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
670  if (converted.empty())
671  return failure();
672  result.append(converted.begin(), converted.end());
673  return success();
674  }
675  if (isa<UnrankedMemRefType>(type)) {
676  auto converted = converter.getUnrankedMemRefDescriptorFields();
677  if (converted.empty())
678  return failure();
679  result.append(converted.begin(), converted.end());
680  return success();
681  }
682  auto converted = converter.convertType(type);
683  if (!converted)
684  return failure();
685  result.push_back(converted);
686  return success();
687 }
688 
689 /// Callback to convert function argument types. It converts MemRef function
690 /// arguments to bare pointers to the MemRef element type.
691 LogicalResult
693  SmallVectorImpl<Type> &result) {
694  auto llvmTy = converter.convertCallingConventionType(
695  type, /*useBarePointerCallConv=*/true);
696  if (!llvmTy)
697  return failure();
698 
699  result.push_back(llvmTy);
700  return success();
701 }
static llvm::ManagedStatic< PassManagerOptions > options
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:144
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:128
IntegerType getI64Type()
Definition: Builders.cpp:89
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
unsigned getWidth()
Return the bitwidth of this float type.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
LLVM::LLVMDialect * llvmDialect
Pointer to the LLVM dialect.
llvm::sys::SmartRWMutex< true > callStackMutex
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Definition: TypeConverter.h:92
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, Location loc, ArrayRef< Type > stdTypes, SmallVectorImpl< Value > &values) const
Promote the bare pointers in 'values' that resulted from memrefs to descriptors.
DenseMap< uint64_t, std::unique_ptr< SmallVector< Type > > > conversionCallStack
SmallVector< Type > & getCurrentThreadRecursiveStack()
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
Type convertCallingConventionType(Type type, bool useBarePointerCallConv=false) const
Convert a type in the context of the default or bare pointer calling convention.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
SmallVector< Value, 4 > promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder, bool useBarePtrCallConv=false) const
Promote the LLVM representation of all operands including promoting MemRef descriptors to stack and u...
LLVMTypeConverter(MLIRContext *ctx, const DataLayoutAnalysis *analysis=nullptr)
Create an LLVMTypeConverter using the default LowerToLLVMOptions.
unsigned getPointerBitwidth(unsigned addressSpace=0) const
Gets the pointer bitwidth.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
Type getIndexType() const
Gets the LLVM representation of the index type.
friend LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Give structFuncArgTypeConverter access to memref-specific functions.
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:108
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:453
static LLVMStructType getIdentified(MLIRContext *context, StringRef name)
Gets or creates an identified struct with the given name in the provided context.
Definition: LLVMTypes.cpp:425
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
llvm::DataLayout dataLayout
The data layout of the module to produce.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
This class helps build Operations.
Definition: Builders.h:210
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
void addArgumentMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal replacement value...
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting an illegal (source) value...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isFloat8E4M3FN() const
Definition: Types.cpp:39
bool isFloat8E4M3FNUZ() const
Definition: Types.cpp:43
bool isFloat8E4M3B11FNUZ() const
Definition: Types.cpp:46
bool isFloat8E5M2() const
Definition: Types.cpp:37
bool isFloat8E4M3() const
Definition: Types.cpp:38
bool isFloat8E5M2FNUZ() const
Definition: Types.cpp:40
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:874
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:856
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.