MLIR  22.0.0git
TypeConverter.cpp
Go to the documentation of this file.
1 //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "llvm/ADT/ScopeExit.h"
14 #include "llvm/Support/Threading.h"
15 #include <memory>
16 #include <mutex>
17 #include <optional>
18 
19 using namespace mlir;
20 
22  {
23  // Most of the time, the entry already exists in the map.
24  std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
25  std::defer_lock);
26  if (getContext().isMultithreadingEnabled())
27  lock.lock();
28  auto recursiveStack = conversionCallStack.find(llvm::get_threadid());
29  if (recursiveStack != conversionCallStack.end())
30  return *recursiveStack->second;
31  }
32 
33  // First time this thread gets here, we have to get an exclusive access to
34  // inset in the map
35  std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
36  auto recursiveStackInserted = conversionCallStack.insert(std::make_pair(
37  llvm::get_threadid(), std::make_unique<SmallVector<Type>>()));
38  return *recursiveStackInserted.first->second;
39 }
40 
41 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
45 
46 /// Helper function that checks if the given value range is a bare pointer.
47 static bool isBarePointer(ValueRange values) {
48  return values.size() == 1 &&
49  isa<LLVM::LLVMPointerType>(values.front().getType());
50 }
51 
52 /// Pack SSA values into an unranked memref descriptor struct.
54  UnrankedMemRefType resultType,
55  ValueRange inputs, Location loc,
56  const LLVMTypeConverter &converter) {
57  // Note: Bare pointers are not supported for unranked memrefs because a
58  // memref descriptor cannot be built just from a bare pointer.
59  if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
60  return Value();
61  return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
62  inputs);
63 }
64 
65 /// Pack SSA values into a ranked memref descriptor struct.
66 static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType,
67  ValueRange inputs, Location loc,
68  const LLVMTypeConverter &converter) {
69  assert(resultType && "expected non-null result type");
70  if (isBarePointer(inputs))
71  return MemRefDescriptor::fromStaticShape(builder, loc, converter,
72  resultType, inputs[0]);
73  if (TypeRange(inputs) ==
74  converter.getMemRefDescriptorFields(resultType,
75  /*unpackAggregates=*/true))
76  return MemRefDescriptor::pack(builder, loc, converter, resultType, inputs);
77  // The inputs are neither a bare pointer nor an unpacked memref descriptor.
78  // This materialization function cannot be used.
79  return Value();
80 }
81 
82 /// MemRef descriptor elements -> UnrankedMemRefType
84  UnrankedMemRefType resultType,
85  ValueRange inputs, Location loc,
86  const LLVMTypeConverter &converter) {
87  // A source materialization must return a value of type
88  // `resultType`, so insert a cast from the memref descriptor type
89  // (!llvm.struct) to the original memref type.
90  Value packed =
91  packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter);
92  if (!packed)
93  return Value();
94  return UnrealizedConversionCastOp::create(builder, loc, resultType, packed)
95  .getResult(0);
96 }
97 
98 /// MemRef descriptor elements -> MemRefType
100  MemRefType resultType,
101  ValueRange inputs, Location loc,
102  const LLVMTypeConverter &converter) {
103  // A source materialization must return a value of type `resultType`,
104  // so insert a cast from the memref descriptor type (!llvm.struct) to the
105  // original memref type.
106  Value packed =
107  packRankedMemRefDesc(builder, resultType, inputs, loc, converter);
108  if (!packed)
109  return Value();
110  return UnrealizedConversionCastOp::create(builder, loc, resultType, packed)
111  .getResult(0);
112 }
113 
114 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
118  : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
119  dataLayoutAnalysis(analysis) {
120  assert(llvmDialect && "LLVM IR dialect is not registered");
121 
122  // Register conversions for the builtin types.
123  addConversion([&](ComplexType type) { return convertComplexType(type); });
124  addConversion([&](FloatType type) { return convertFloatType(type); });
125  addConversion([&](FunctionType type) { return convertFunctionType(type); });
126  addConversion([&](IndexType type) { return convertIndexType(type); });
127  addConversion([&](IntegerType type) { return convertIntegerType(type); });
128  addConversion([&](MemRefType type) { return convertMemRefType(type); });
130  [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
131  addConversion([&](VectorType type) -> std::optional<Type> {
132  FailureOr<Type> llvmType = convertVectorType(type);
133  if (failed(llvmType))
134  return std::nullopt;
135  return llvmType;
136  });
137 
138  // LLVM-compatible types are legal, so add a pass-through conversion. Do this
139  // before the conversions below since conversions are attempted in reverse
140  // order and those should take priority.
141  addConversion([](Type type) {
142  return LLVM::isCompatibleType(type) ? std::optional<Type>(type)
143  : std::nullopt;
144  });
145 
146  addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results)
147  -> std::optional<LogicalResult> {
148  // Fastpath for types that won't be converted by this callback anyway.
149  if (LLVM::isCompatibleType(type)) {
150  results.push_back(type);
151  return success();
152  }
153 
154  if (type.isIdentified()) {
155  auto convertedType = LLVM::LLVMStructType::getIdentified(
156  type.getContext(), ("_Converted." + type.getName()).str());
157 
159  if (llvm::count(recursiveStack, type)) {
160  results.push_back(convertedType);
161  return success();
162  }
163  recursiveStack.push_back(type);
164  auto popConversionCallStack = llvm::make_scope_exit(
165  [&recursiveStack]() { recursiveStack.pop_back(); });
166 
167  SmallVector<Type> convertedElemTypes;
168  convertedElemTypes.reserve(type.getBody().size());
169  if (failed(convertTypes(type.getBody(), convertedElemTypes)))
170  return std::nullopt;
171 
172  // If the converted type has not been initialized yet, just set its body
173  // to be the converted arguments and return.
174  if (!convertedType.isInitialized()) {
175  if (failed(
176  convertedType.setBody(convertedElemTypes, type.isPacked()))) {
177  return failure();
178  }
179  results.push_back(convertedType);
180  return success();
181  }
182 
183  // If it has been initialized, has the same body and packed bit, just use
184  // it. This ensures that recursive structs keep being recursive rather
185  // than including a non-updated name.
186  if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) &&
187  convertedType.isPacked() == type.isPacked()) {
188  results.push_back(convertedType);
189  return success();
190  }
191 
192  return failure();
193  }
194 
195  SmallVector<Type> convertedSubtypes;
196  convertedSubtypes.reserve(type.getBody().size());
197  if (failed(convertTypes(type.getBody(), convertedSubtypes)))
198  return std::nullopt;
199 
200  results.push_back(LLVM::LLVMStructType::getLiteral(
201  type.getContext(), convertedSubtypes, type.isPacked()));
202  return success();
203  });
204  addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> {
205  if (auto element = convertType(type.getElementType()))
206  return LLVM::LLVMArrayType::get(element, type.getNumElements());
207  return std::nullopt;
208  });
209  addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
210  Type convertedResType = convertType(type.getReturnType());
211  if (!convertedResType)
212  return std::nullopt;
213 
214  SmallVector<Type> convertedArgTypes;
215  convertedArgTypes.reserve(type.getNumParams());
216  if (failed(convertTypes(type.getParams(), convertedArgTypes)))
217  return std::nullopt;
218 
219  return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
220  type.isVarArg());
221  });
222 
223  // Add generic source and target materializations to handle cases where
224  // non-LLVM types persist after an LLVM conversion.
225  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
226  ValueRange inputs, Location loc) {
227  return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
228  .getResult(0);
229  });
230  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
231  ValueRange inputs, Location loc) {
232  return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs)
233  .getResult(0);
234  });
235 
236  // Source materializations convert from the new block argument types
237  // (multiple SSA values that make up a memref descriptor) back to the
238  // original block argument type.
239  addSourceMaterialization([&](OpBuilder &builder,
240  UnrankedMemRefType resultType, ValueRange inputs,
241  Location loc) {
242  return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
243  *this);
244  });
245  addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
246  ValueRange inputs, Location loc) {
247  return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
248  });
249 
250  // Bare pointer -> Packed MemRef descriptor
251  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
252  ValueRange inputs, Location loc,
253  Type originalType) -> Value {
254  // The original MemRef type is required to build a MemRef descriptor
255  // because the sizes/strides of the MemRef cannot be inferred from just the
256  // bare pointer.
257  if (!originalType)
258  return Value();
259  if (resultType != convertType(originalType))
260  return Value();
261  if (auto memrefType = dyn_cast<MemRefType>(originalType))
262  return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
263  if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
264  return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
265  *this);
266  return Value();
267  });
268 
269  // Integer memory spaces map to themselves.
271  [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
272 }
273 
274 /// Returns the MLIR context.
276  return *getDialect()->getContext();
277 }
278 
281 }
282 
283 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const {
284  return options.dataLayout.getPointerSizeInBits(addressSpace);
285 }
286 
287 Type LLVMTypeConverter::convertIndexType(IndexType type) const {
288  return getIndexType();
289 }
290 
291 Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
292  return IntegerType::get(&getContext(), type.getWidth());
293 }
294 
295 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
296  // Valid LLVM float types are used directly.
297  if (LLVM::isCompatibleType(type))
298  return type;
299 
300  // F4, F6, F8 types are converted to integer types with the same bit width.
301  if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
302  Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
303  Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
304  Float8E8M0FNUType>(type))
305  return IntegerType::get(&getContext(), type.getWidth());
306 
307  // Other floating-point types: A custom type conversion rule must be
308  // specified by the user.
309  return Type();
310 }
311 
312 // Convert a `ComplexType` to an LLVM type. The result is a complex number
313 // struct with entries for the
314 // 1. real part and for the
315 // 2. imaginary part.
316 Type LLVMTypeConverter::convertComplexType(ComplexType type) const {
317  auto elementType = convertType(type.getElementType());
318  return LLVM::LLVMStructType::getLiteral(&getContext(),
319  {elementType, elementType});
320 }
321 
322 // Except for signatures, MLIR function types are converted into LLVM
323 // pointer-to-function types.
324 Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
325  return LLVM::LLVMPointerType::get(type.getContext());
326 }
327 
328 /// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
329 /// function arguments. Returns an empty container if none of these attributes
330 /// are found in any of the arguments.
331 static void
332 filterByValRefArgAttrs(FunctionOpInterface funcOp,
333  SmallVectorImpl<std::optional<NamedAttribute>> &result) {
334  assert(result.empty() && "Unexpected non-empty output");
335  result.resize(funcOp.getNumArguments(), std::nullopt);
336  bool foundByValByRefAttrs = false;
337  for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
338  for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
339  if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
340  namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
341  foundByValByRefAttrs = true;
342  result[argIdx] = namedAttr;
343  break;
344  }
345  }
346  }
347 
348  if (!foundByValByRefAttrs)
349  result.clear();
350 }
351 
352 // Function types are converted to LLVM Function types by recursively converting
353 // argument and result types. If MLIR Function has zero results, the LLVM
354 // Function has one VoidType result. If MLIR Function has more than one result,
355 // they are into an LLVM StructType in their order of appearance.
356 // If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and
357 // `llvm.byref` function arguments which are not LLVM pointers are overridden
358 // with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already
359 // converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`.
360 Type LLVMTypeConverter::convertFunctionSignatureImpl(
361  FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
362  LLVMTypeConverter::SignatureConversion &result,
363  SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) const {
364  // Select the argument converter depending on the calling convention.
365  useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
366  auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
368 
369  // Convert argument types one by one and check for errors.
370  for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
371  SmallVector<Type, 8> converted;
372  if (failed(funcArgConverter(*this, type, converted)))
373  return {};
374 
375  // Rewrite converted type of `llvm.byval` or `llvm.byref` function
376  // argument that was not converted to an LLVM pointer types.
377  if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() &&
378  converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
379  // If the argument was already converted to an LLVM pointer type, we stop
380  // tracking it as it doesn't need more processing.
381  if (isa<LLVM::LLVMPointerType>(converted[0]))
382  (*byValRefNonPtrAttrs)[idx] = std::nullopt;
383  else
384  converted[0] = LLVM::LLVMPointerType::get(&getContext());
385  }
386 
387  result.addInputs(idx, converted);
388  }
389 
390  // If function does not return anything, create the void result type,
391  // if it returns on element, convert it, otherwise pack the result types into
392  // a struct.
393  Type resultType =
394  funcTy.getNumResults() == 0
396  : packFunctionResults(funcTy.getResults(), useBarePtrCallConv);
397  if (!resultType)
398  return {};
399  return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
400  isVariadic);
401 }
402 
404  FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
406  return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
407  result,
408  /*byValRefNonPtrAttrs=*/nullptr);
409 }
410 
412  FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv,
414  SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs) const {
415  // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those
416  // that were not converted to LLVM pointer types will be returned for further
417  // processing.
418  filterByValRefArgAttrs(funcOp, byValRefNonPtrAttrs);
419  auto funcTy = cast<FunctionType>(funcOp.getFunctionType());
420  return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
421  result, &byValRefNonPtrAttrs);
422 }
423 
424 /// Converts the function type to a C-compatible format, in particular using
425 /// pointers to memref descriptors for arguments.
426 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
428  SmallVector<Type, 4> inputs;
429 
430  Type resultType = type.getNumResults() == 0
432  : packFunctionResults(type.getResults());
433  if (!resultType)
434  return {};
435 
436  auto ptrType = LLVM::LLVMPointerType::get(type.getContext());
437  auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
438  if (structType) {
439  // Struct types cannot be safely returned via C interface. Make this a
440  // pointer argument, instead.
441  inputs.push_back(ptrType);
442  resultType = LLVM::LLVMVoidType::get(&getContext());
443  }
444 
445  for (Type t : type.getInputs()) {
446  auto converted = convertType(t);
447  if (!converted || !LLVM::isCompatibleType(converted))
448  return {};
449  if (isa<MemRefType, UnrankedMemRefType>(t))
450  converted = ptrType;
451  inputs.push_back(converted);
452  }
453 
454  return {LLVM::LLVMFunctionType::get(resultType, inputs), structType};
455 }
456 
457 /// Convert a memref type into a list of LLVM IR types that will form the
458 /// memref descriptor. The result contains the following types:
459 /// 1. The pointer to the allocated data buffer, followed by
460 /// 2. The pointer to the aligned data buffer, followed by
461 /// 3. A lowered `index`-type integer containing the distance between the
462 /// beginning of the buffer and the first element to be accessed through the
463 /// view, followed by
464 /// 4. An array containing as many `index`-type integers as the rank of the
465 /// MemRef: the array represents the size, in number of elements, of the memref
466 /// along the given dimension. For constant MemRef dimensions, the
467 /// corresponding size entry is a constant whose runtime value must match the
468 /// static value, followed by
469 /// 5. A second array containing as many `index`-type integers as the rank of
470 /// the MemRef: the second array represents the "stride" (in tensor abstraction
471 /// sense), i.e. the number of consecutive elements of the underlying buffer.
472 /// TODO: add assertions for the static cases.
473 ///
474 /// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
475 /// are expanded into individual index-type elements.
476 ///
477 /// template <typename Elem, typename Index, size_t Rank>
478 /// struct {
479 /// Elem *allocatedPtr;
480 /// Elem *alignedPtr;
481 /// Index offset;
482 /// Index sizes[Rank]; // omitted when rank == 0
483 /// Index strides[Rank]; // omitted when rank == 0
484 /// };
487  bool unpackAggregates) const {
488  if (!type.isStrided()) {
489  emitError(
490  UnknownLoc::get(type.getContext()),
491  "conversion to strided form failed either due to non-strided layout "
492  "maps (which should have been normalized away) or other reasons");
493  return {};
494  }
495 
496  Type elementType = convertType(type.getElementType());
497  if (!elementType)
498  return {};
499 
500  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
501  if (failed(addressSpace)) {
502  emitError(UnknownLoc::get(type.getContext()),
503  "conversion of memref memory space ")
504  << type.getMemorySpace()
505  << " to integer address space "
506  "failed. Consider adding memory space conversions.";
507  return {};
508  }
509  auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
510 
511  auto indexTy = getIndexType();
512 
513  SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
514  auto rank = type.getRank();
515  if (rank == 0)
516  return results;
517 
518  if (unpackAggregates)
519  results.insert(results.end(), 2 * rank, indexTy);
520  else
521  results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
522  return results;
523 }
524 
525 unsigned
527  const DataLayout &layout) const {
528  // Compute the descriptor size given that of its components indicated above.
529  unsigned space = *getMemRefAddressSpace(type);
530  return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
531  (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
532 }
533 
534 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
535 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
536 Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
537  // When converting a MemRefType to a struct with descriptor fields, do not
538  // unpack the `sizes` and `strides` arrays.
539  SmallVector<Type, 5> types =
540  getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
541  if (types.empty())
542  return {};
543  return LLVM::LLVMStructType::getLiteral(&getContext(), types);
544 }
545 
546 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
547 /// that will form the unranked memref descriptor. In particular, the fields
548 /// for an unranked memref descriptor are:
549 /// 1. index-typed rank, the dynamic rank of this MemRef
550 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
551 /// stack allocated (alloca) copy of a MemRef descriptor that got casted to
552 /// be unranked.
556 }
557 
559  UnrankedMemRefType type, const DataLayout &layout) const {
560  // Compute the descriptor size given that of its components indicated above.
561  unsigned space = *getMemRefAddressSpace(type);
562  return layout.getTypeSize(getIndexType()) +
564 }
565 
566 Type LLVMTypeConverter::convertUnrankedMemRefType(
567  UnrankedMemRefType type) const {
568  if (!convertType(type.getElementType()))
569  return {};
570  return LLVM::LLVMStructType::getLiteral(&getContext(),
572 }
573 
574 FailureOr<unsigned>
576  if (!type.getMemorySpace()) // Default memory space -> 0.
577  return 0;
578  std::optional<Attribute> converted =
579  convertTypeAttribute(type, type.getMemorySpace());
580  if (!converted)
581  return failure();
582  if (!(*converted)) // Conversion to default is 0.
583  return 0;
584  if (auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) {
585  if (explicitSpace.getType().isIndex() ||
586  explicitSpace.getType().isSignlessInteger())
587  return explicitSpace.getInt();
588  }
589  return failure();
590 }
591 
592 // Check if a memref type can be converted to a bare pointer.
594  if (isa<UnrankedMemRefType>(type))
595  // Unranked memref is not supported in the bare pointer calling convention.
596  return false;
597 
598  // Check that the memref has static shape, strides and offset. Otherwise, it
599  // cannot be lowered to a bare pointer.
600  auto memrefTy = cast<MemRefType>(type);
601  if (!memrefTy.hasStaticShape())
602  return false;
603 
604  int64_t offset = 0;
605  SmallVector<int64_t, 4> strides;
606  if (failed(memrefTy.getStridesAndOffset(strides, offset)))
607  return false;
608 
609  for (int64_t stride : strides)
610  if (ShapedType::isDynamic(stride))
611  return false;
612 
613  return ShapedType::isStatic(offset);
614 }
615 
616 /// Convert a memref type to a bare pointer to the memref element type.
617 Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
618  if (!canConvertToBarePtr(type))
619  return {};
620  Type elementType = convertType(type.getElementType());
621  if (!elementType)
622  return {};
623  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
624  if (failed(addressSpace))
625  return {};
626  return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
627 }
628 
629 /// Convert an n-D vector type to an LLVM vector type:
630 /// * 0-D `vector<T>` are converted to vector<1xT>
631 /// * 1-D `vector<axT>` remains as is while,
632 /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
633 /// `!llvm.array<ax...array<jxvector<kxT>>>`.
634 /// As LLVM supports arrays of scalable vectors, this method will also convert
635 /// n-D scalable vectors provided that only the trailing dim is scalable.
636 FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
637  auto elementType = convertType(type.getElementType());
638  if (!elementType)
639  return {};
640  if (type.getShape().empty())
641  return VectorType::get({1}, elementType);
642  Type vectorType = VectorType::get(type.getShape().back(), elementType,
643  type.getScalableDims().back());
644  assert(LLVM::isCompatibleVectorType(vectorType) &&
645  "expected vector type compatible with the LLVM dialect");
646  // For n-D vector types for which a _non-trailing_ dim is scalable,
647  // return a failure. Supporting such cases would require LLVM
648  // to support something akin "scalable arrays" of vectors.
649  if (llvm::is_contained(type.getScalableDims().drop_back(), true))
650  return failure();
651  auto shape = type.getShape();
652  for (int i = shape.size() - 2; i >= 0; --i)
653  vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
654  return vectorType;
655 }
656 
657 /// Convert a type in the context of the default or bare pointer calling
658 /// convention. Calling convention sensitive types, such as MemRefType and
659 /// UnrankedMemRefType, are converted following the specific rules for the
660 /// calling convention. Calling convention independent types are converted
661 /// following the default LLVM type conversions.
663  Type type, SmallVectorImpl<Type> &result, bool useBarePtrCallConv) const {
664  if (useBarePtrCallConv) {
665  if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) {
666  Type converted = convertMemRefToBarePtr(memrefTy);
667  if (!converted)
668  return failure();
669  result.push_back(converted);
670  return success();
671  }
672  }
673 
674  return convertType(type, result);
675 }
676 
677 /// Convert a non-empty list of types of values produced by an operation into an
678 /// LLVM-compatible type. In particular, if more than one value is
679 /// produced, create a literal structure with elements that correspond to each
680 /// of the types converted with `convertType`.
682  assert(!types.empty() && "expected non-empty list of type");
683  if (types.size() == 1)
684  return convertType(types[0]);
685 
686  SmallVector<Type> resultTypes;
687  resultTypes.reserve(types.size());
688  for (Type type : types) {
689  Type converted = convertType(type);
690  if (!converted || !LLVM::isCompatibleType(converted))
691  return {};
692  resultTypes.push_back(converted);
693  }
694 
695  return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
696 }
697 
698 /// Convert a non-empty list of types to be returned from a function into an
699 /// LLVM-compatible type. In particular, if more than one value is returned,
700 /// create an LLVM dialect structure type with elements that correspond to each
701 /// of the types converted with `convertCallingConventionType`.
703  TypeRange types, bool useBarePtrCallConv,
704  SmallVector<SmallVector<Type>> *groupedTypes,
705  int64_t *numConvertedTypes) const {
706  assert(!types.empty() && "expected non-empty list of type");
707  assert((!groupedTypes || groupedTypes->empty()) &&
708  "expected groupedTypes to be empty");
709 
710  useBarePtrCallConv |= options.useBarePtrCallConv;
711  SmallVector<Type> resultTypes;
712  resultTypes.reserve(types.size());
713  size_t sizeBefore = 0;
714  for (auto t : types) {
715  if (failed(
716  convertCallingConventionType(t, resultTypes, useBarePtrCallConv)))
717  return {};
718  if (groupedTypes) {
719  SmallVector<Type> &group = groupedTypes->emplace_back();
720  llvm::append_range(group, ArrayRef(resultTypes).drop_front(sizeBefore));
721  }
722  sizeBefore = resultTypes.size();
723  }
724 
725  if (numConvertedTypes)
726  *numConvertedTypes = resultTypes.size();
727  if (resultTypes.size() == 1)
728  return resultTypes.front();
729  if (resultTypes.empty())
730  return {};
731  return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
732 }
733 
735  OpBuilder &builder) const {
736  // Alloca with proper alignment. We do not expect optimizations of this
737  // alloca op and so we omit allocating at the entry block.
738  auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
739  Value one = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(),
740  builder.getIndexAttr(1));
741  Value allocated =
742  LLVM::AllocaOp::create(builder, loc, ptrType, operand.getType(), one);
743  // Store into the alloca'ed descriptor.
744  LLVM::StoreOp::create(builder, loc, operand, allocated);
745  return allocated;
746 }
747 
749  Location loc, ValueRange opOperands, ValueRange adaptorOperands,
750  OpBuilder &builder, bool useBarePtrCallConv) const {
752  for (size_t i = 0, e = adaptorOperands.size(); i < e; i++)
753  ranges.push_back(adaptorOperands.slice(i, 1));
754  return promoteOperands(loc, opOperands, ranges, builder, useBarePtrCallConv);
755 }
756 
758  Location loc, ValueRange opOperands, ArrayRef<ValueRange> adaptorOperands,
759  OpBuilder &builder, bool useBarePtrCallConv) const {
760  SmallVector<Value, 4> promotedOperands;
761  promotedOperands.reserve(adaptorOperands.size());
762  useBarePtrCallConv |= options.useBarePtrCallConv;
763  for (auto [operand, llvmOperand] :
764  llvm::zip_equal(opOperands, adaptorOperands)) {
765  if (useBarePtrCallConv) {
766  // For the bare-ptr calling convention, we only have to extract the
767  // aligned pointer of a memref.
768  if (isa<MemRefType>(operand.getType())) {
769  assert(llvmOperand.size() == 1 && "Expected a single operand");
770  MemRefDescriptor desc(llvmOperand.front());
771  promotedOperands.push_back(desc.alignedPtr(builder, loc));
772  continue;
773  } else if (isa<UnrankedMemRefType>(operand.getType())) {
774  llvm_unreachable("Unranked memrefs are not supported");
775  }
776  } else {
777  if (isa<UnrankedMemRefType>(operand.getType())) {
778  assert(llvmOperand.size() == 1 && "Expected a single operand");
779  UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand.front(),
780  promotedOperands);
781  continue;
782  }
783  if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
784  assert(llvmOperand.size() == 1 && "Expected a single operand");
785  MemRefDescriptor::unpack(builder, loc, llvmOperand.front(), memrefType,
786  promotedOperands);
787  continue;
788  }
789  }
790 
791  llvm::append_range(promotedOperands, llvmOperand);
792  }
793  return promotedOperands;
794 }
795 
796 /// Callback to convert function argument types. It converts a MemRef function
797 /// argument to a list of non-aggregate types containing descriptor
798 /// information, and an UnrankedmemRef function argument to a list containing
799 /// the rank and a pointer to a descriptor struct.
800 LogicalResult
802  SmallVectorImpl<Type> &result) {
803  if (auto memref = dyn_cast<MemRefType>(type)) {
804  // In signatures, Memref descriptors are expanded into lists of
805  // non-aggregate values.
806  auto converted =
807  converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
808  if (converted.empty())
809  return failure();
810  result.append(converted.begin(), converted.end());
811  return success();
812  }
813  if (isa<UnrankedMemRefType>(type)) {
814  auto converted = converter.getUnrankedMemRefDescriptorFields();
815  if (converted.empty())
816  return failure();
817  result.append(converted.begin(), converted.end());
818  return success();
819  }
820  return converter.convertType(type, result);
821 }
822 
823 /// Callback to convert function argument types. It converts MemRef function
824 /// arguments to bare pointers to the MemRef element type.
825 LogicalResult
827  SmallVectorImpl<Type> &result) {
828  return converter.convertCallingConventionType(
829  type, result,
830  /*useBarePointerCallConv=*/true);
831 }
static llvm::ManagedStatic< PassManagerOptions > options
static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
Pack SSA values into a ranked memref descriptor struct.
static Value unrankedMemRefMaterialization(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
MemRef descriptor elements -> UnrankedMemRefType.
static Value packUnrankedMemRefDesc(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
Pack SSA values into an unranked memref descriptor struct.
static Value rankedMemRefMaterialization(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
MemRef descriptor elements -> MemRefType.
static bool isBarePointer(ValueRange values)
Helper function that checks if the given value range is a bare pointer.
static void filterByValRefArgAttrs(FunctionOpInterface funcOp, SmallVectorImpl< std::optional< NamedAttribute >> &result)
Returns the llvm.byval or llvm.byref attributes that are present in the function arguments.
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
IntegerType getI64Type()
Definition: Builders.cpp:64
MLIRContext * getContext() const
Definition: Builders.h:56
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
LLVM::LLVMDialect * llvmDialect
Pointer to the LLVM dialect.
llvm::sys::SmartRWMutex< true > callStackMutex
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
SmallVector< Value, 4 > promoteOperands(Location loc, ValueRange opOperands, ArrayRef< ValueRange > adaptorOperands, OpBuilder &builder, bool useBarePtrCallConv=false) const
Promote the LLVM representation of all operands including promoting MemRef descriptors to stack and u...
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false, SmallVector< SmallVector< Type >> *groupedTypes=nullptr, int64_t *numConvertedTypes=nullptr) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
LogicalResult convertCallingConventionType(Type type, SmallVectorImpl< Type > &result, bool useBarePointerCallConv=false) const
Convert a type in the context of the default or bare pointer calling convention.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
SmallVector< Type, 5 > getMemRefDescriptorFields(MemRefType type, bool unpackAggregates) const
Convert a memref type into a list of LLVM IR types that will form the memref descriptor.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
DenseMap< uint64_t, std::unique_ptr< SmallVector< Type > > > conversionCallStack
SmallVector< Type > & getCurrentThreadRecursiveStack()
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LLVMTypeConverter(MLIRContext *ctx, const DataLayoutAnalysis *analysis=nullptr)
Create an LLVMTypeConverter using the default LowerToLLVMOptions.
unsigned getPointerBitwidth(unsigned addressSpace=0) const
Gets the pointer bitwidth.
SmallVector< Type, 2 > getUnrankedMemRefDescriptorFields() const
Convert an unranked memref type into a list of non-aggregate LLVM IR types that will form the unranke...
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
Type getIndexType() const
Gets the LLVM representation of the index type.
friend LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Give structFuncArgTypeConverter access to memref-specific functions.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
llvm::DataLayout dataLayout
The data layout of the module to produce.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
This class helps build Operations.
Definition: Builders.h:207
This class provides all of the information necessary to convert a type signature.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:839
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:809
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
detail::InFlightRemark analysis(Location loc, RemarkOpts opts)
Report an optimization analysis remark.
Definition: Remarks.h:497
Include the generated interface declarations.
LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...