MLIR  21.0.0git
TypeConverter.cpp
Go to the documentation of this file.
1 //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "MemRefDescriptor.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Threading.h"
16 #include <memory>
17 #include <mutex>
18 #include <optional>
19 
20 using namespace mlir;
21 
23  {
24  // Most of the time, the entry already exists in the map.
25  std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
26  std::defer_lock);
27  if (getContext().isMultithreadingEnabled())
28  lock.lock();
29  auto recursiveStack = conversionCallStack.find(llvm::get_threadid());
30  if (recursiveStack != conversionCallStack.end())
31  return *recursiveStack->second;
32  }
33 
34  // First time this thread gets here, we have to get an exclusive access to
35  // inset in the map
36  std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
37  auto recursiveStackInserted = conversionCallStack.insert(std::make_pair(
38  llvm::get_threadid(), std::make_unique<SmallVector<Type>>()));
39  return *recursiveStackInserted.first->second;
40 }
41 
42 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
44  const DataLayoutAnalysis *analysis)
45  : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
46 
47 /// Helper function that checks if the given value range is a bare pointer.
48 static bool isBarePointer(ValueRange values) {
49  return values.size() == 1 &&
50  isa<LLVM::LLVMPointerType>(values.front().getType());
51 }
52 
53 /// Pack SSA values into an unranked memref descriptor struct.
55  UnrankedMemRefType resultType,
56  ValueRange inputs, Location loc,
57  const LLVMTypeConverter &converter) {
58  // Note: Bare pointers are not supported for unranked memrefs because a
59  // memref descriptor cannot be built just from a bare pointer.
60  if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
61  return Value();
62  return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
63  inputs);
64 }
65 
66 /// Pack SSA values into a ranked memref descriptor struct.
67 static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType,
68  ValueRange inputs, Location loc,
69  const LLVMTypeConverter &converter) {
70  assert(resultType && "expected non-null result type");
71  if (isBarePointer(inputs))
72  return MemRefDescriptor::fromStaticShape(builder, loc, converter,
73  resultType, inputs[0]);
74  if (TypeRange(inputs) ==
75  converter.getMemRefDescriptorFields(resultType,
76  /*unpackAggregates=*/true))
77  return MemRefDescriptor::pack(builder, loc, converter, resultType, inputs);
78  // The inputs are neither a bare pointer nor an unpacked memref descriptor.
79  // This materialization function cannot be used.
80  return Value();
81 }
82 
83 /// MemRef descriptor elements -> UnrankedMemRefType
85  UnrankedMemRefType resultType,
86  ValueRange inputs, Location loc,
87  const LLVMTypeConverter &converter) {
88  // A source materialization must return a value of type
89  // `resultType`, so insert a cast from the memref descriptor type
90  // (!llvm.struct) to the original memref type.
91  Value packed =
92  packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter);
93  if (!packed)
94  return Value();
95  return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
96  .getResult(0);
97 }
98 
99 /// MemRef descriptor elements -> MemRefType
101  MemRefType resultType,
102  ValueRange inputs, Location loc,
103  const LLVMTypeConverter &converter) {
104  // A source materialization must return a value of type `resultType`,
105  // so insert a cast from the memref descriptor type (!llvm.struct) to the
106  // original memref type.
107  Value packed =
108  packRankedMemRefDesc(builder, resultType, inputs, loc, converter);
109  if (!packed)
110  return Value();
111  return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
112  .getResult(0);
113 }
114 
115 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
118  const DataLayoutAnalysis *analysis)
119  : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
120  dataLayoutAnalysis(analysis) {
121  assert(llvmDialect && "LLVM IR dialect is not registered");
122 
123  // Register conversions for the builtin types.
124  addConversion([&](ComplexType type) { return convertComplexType(type); });
125  addConversion([&](FloatType type) { return convertFloatType(type); });
126  addConversion([&](FunctionType type) { return convertFunctionType(type); });
127  addConversion([&](IndexType type) { return convertIndexType(type); });
128  addConversion([&](IntegerType type) { return convertIntegerType(type); });
129  addConversion([&](MemRefType type) { return convertMemRefType(type); });
131  [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
132  addConversion([&](VectorType type) -> std::optional<Type> {
133  FailureOr<Type> llvmType = convertVectorType(type);
134  if (failed(llvmType))
135  return std::nullopt;
136  return llvmType;
137  });
138 
139  // LLVM-compatible types are legal, so add a pass-through conversion. Do this
140  // before the conversions below since conversions are attempted in reverse
141  // order and those should take priority.
142  addConversion([](Type type) {
143  return LLVM::isCompatibleType(type) ? std::optional<Type>(type)
144  : std::nullopt;
145  });
146 
147  addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results)
148  -> std::optional<LogicalResult> {
149  // Fastpath for types that won't be converted by this callback anyway.
150  if (LLVM::isCompatibleType(type)) {
151  results.push_back(type);
152  return success();
153  }
154 
155  if (type.isIdentified()) {
156  auto convertedType = LLVM::LLVMStructType::getIdentified(
157  type.getContext(), ("_Converted." + type.getName()).str());
158 
160  if (llvm::count(recursiveStack, type)) {
161  results.push_back(convertedType);
162  return success();
163  }
164  recursiveStack.push_back(type);
165  auto popConversionCallStack = llvm::make_scope_exit(
166  [&recursiveStack]() { recursiveStack.pop_back(); });
167 
168  SmallVector<Type> convertedElemTypes;
169  convertedElemTypes.reserve(type.getBody().size());
170  if (failed(convertTypes(type.getBody(), convertedElemTypes)))
171  return std::nullopt;
172 
173  // If the converted type has not been initialized yet, just set its body
174  // to be the converted arguments and return.
175  if (!convertedType.isInitialized()) {
176  if (failed(
177  convertedType.setBody(convertedElemTypes, type.isPacked()))) {
178  return failure();
179  }
180  results.push_back(convertedType);
181  return success();
182  }
183 
184  // If it has been initialized, has the same body and packed bit, just use
185  // it. This ensures that recursive structs keep being recursive rather
186  // than including a non-updated name.
187  if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) &&
188  convertedType.isPacked() == type.isPacked()) {
189  results.push_back(convertedType);
190  return success();
191  }
192 
193  return failure();
194  }
195 
196  SmallVector<Type> convertedSubtypes;
197  convertedSubtypes.reserve(type.getBody().size());
198  if (failed(convertTypes(type.getBody(), convertedSubtypes)))
199  return std::nullopt;
200 
201  results.push_back(LLVM::LLVMStructType::getLiteral(
202  type.getContext(), convertedSubtypes, type.isPacked()));
203  return success();
204  });
205  addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> {
206  if (auto element = convertType(type.getElementType()))
207  return LLVM::LLVMArrayType::get(element, type.getNumElements());
208  return std::nullopt;
209  });
210  addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
211  Type convertedResType = convertType(type.getReturnType());
212  if (!convertedResType)
213  return std::nullopt;
214 
215  SmallVector<Type> convertedArgTypes;
216  convertedArgTypes.reserve(type.getNumParams());
217  if (failed(convertTypes(type.getParams(), convertedArgTypes)))
218  return std::nullopt;
219 
220  return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
221  type.isVarArg());
222  });
223 
224  // Add generic source and target materializations to handle cases where
225  // non-LLVM types persist after an LLVM conversion.
226  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
227  ValueRange inputs, Location loc) {
228  return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
229  .getResult(0);
230  });
231  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
232  ValueRange inputs, Location loc) {
233  return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
234  .getResult(0);
235  });
236 
237  // Source materializations convert from the new block argument types
238  // (multiple SSA values that make up a memref descriptor) back to the
239  // original block argument type.
240  addSourceMaterialization([&](OpBuilder &builder,
241  UnrankedMemRefType resultType, ValueRange inputs,
242  Location loc) {
243  return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
244  *this);
245  });
246  addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
247  ValueRange inputs, Location loc) {
248  return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
249  });
250 
251  // Bare pointer -> Packed MemRef descriptor
252  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
253  ValueRange inputs, Location loc,
254  Type originalType) -> Value {
255  // The original MemRef type is required to build a MemRef descriptor
256  // because the sizes/strides of the MemRef cannot be inferred from just the
257  // bare pointer.
258  if (!originalType)
259  return Value();
260  if (resultType != convertType(originalType))
261  return Value();
262  if (auto memrefType = dyn_cast<MemRefType>(originalType))
263  return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
264  if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
265  return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
266  *this);
267  return Value();
268  });
269 
270  // Integer memory spaces map to themselves.
272  [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
273 }
274 
275 /// Returns the MLIR context.
277  return *getDialect()->getContext();
278 }
279 
282 }
283 
284 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const {
285  return options.dataLayout.getPointerSizeInBits(addressSpace);
286 }
287 
288 Type LLVMTypeConverter::convertIndexType(IndexType type) const {
289  return getIndexType();
290 }
291 
292 Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
293  return IntegerType::get(&getContext(), type.getWidth());
294 }
295 
296 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
297  // Valid LLVM float types are used directly.
298  if (LLVM::isCompatibleType(type))
299  return type;
300 
301  // F4, F6, F8 types are converted to integer types with the same bit width.
302  if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
303  Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
304  Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
305  Float8E8M0FNUType>(type))
306  return IntegerType::get(&getContext(), type.getWidth());
307 
308  // Other floating-point types: A custom type conversion rule must be
309  // specified by the user.
310  return Type();
311 }
312 
313 // Convert a `ComplexType` to an LLVM type. The result is a complex number
314 // struct with entries for the
315 // 1. real part and for the
316 // 2. imaginary part.
317 Type LLVMTypeConverter::convertComplexType(ComplexType type) const {
318  auto elementType = convertType(type.getElementType());
319  return LLVM::LLVMStructType::getLiteral(&getContext(),
320  {elementType, elementType});
321 }
322 
323 // Except for signatures, MLIR function types are converted into LLVM
324 // pointer-to-function types.
325 Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
326  return LLVM::LLVMPointerType::get(type.getContext());
327 }
328 
329 /// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
330 /// function arguments. Returns an empty container if none of these attributes
331 /// are found in any of the arguments.
332 static void
333 filterByValRefArgAttrs(FunctionOpInterface funcOp,
334  SmallVectorImpl<std::optional<NamedAttribute>> &result) {
335  assert(result.empty() && "Unexpected non-empty output");
336  result.resize(funcOp.getNumArguments(), std::nullopt);
337  bool foundByValByRefAttrs = false;
338  for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
339  for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
340  if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
341  namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
342  foundByValByRefAttrs = true;
343  result[argIdx] = namedAttr;
344  break;
345  }
346  }
347  }
348 
349  if (!foundByValByRefAttrs)
350  result.clear();
351 }
352 
353 // Function types are converted to LLVM Function types by recursively converting
354 // argument and result types. If MLIR Function has zero results, the LLVM
355 // Function has one VoidType result. If MLIR Function has more than one result,
356 // they are into an LLVM StructType in their order of appearance.
357 // If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and
358 // `llvm.byref` function arguments which are not LLVM pointers are overridden
359 // with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already
360 // converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`.
361 Type LLVMTypeConverter::convertFunctionSignatureImpl(
362  FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
363  LLVMTypeConverter::SignatureConversion &result,
364  SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) const {
365  // Select the argument converter depending on the calling convention.
366  useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
367  auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
369  // Convert argument types one by one and check for errors.
370  for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
371  SmallVector<Type, 8> converted;
372  if (failed(funcArgConverter(*this, type, converted)))
373  return {};
374 
375  // Rewrite converted type of `llvm.byval` or `llvm.byref` function
376  // argument that was not converted to an LLVM pointer types.
377  if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() &&
378  converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
379  // If the argument was already converted to an LLVM pointer type, we stop
380  // tracking it as it doesn't need more processing.
381  if (isa<LLVM::LLVMPointerType>(converted[0]))
382  (*byValRefNonPtrAttrs)[idx] = std::nullopt;
383  else
384  converted[0] = LLVM::LLVMPointerType::get(&getContext());
385  }
386 
387  result.addInputs(idx, converted);
388  }
389 
390  // If function does not return anything, create the void result type,
391  // if it returns on element, convert it, otherwise pack the result types into
392  // a struct.
393  Type resultType =
394  funcTy.getNumResults() == 0
396  : packFunctionResults(funcTy.getResults(), useBarePtrCallConv);
397  if (!resultType)
398  return {};
399  return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
400  isVariadic);
401 }
402 
404  FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
406  return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
407  result,
408  /*byValRefNonPtrAttrs=*/nullptr);
409 }
410 
412  FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv,
414  SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs) const {
415  // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those
416  // that were not converted to LLVM pointer types will be returned for further
417  // processing.
418  filterByValRefArgAttrs(funcOp, byValRefNonPtrAttrs);
419  auto funcTy = cast<FunctionType>(funcOp.getFunctionType());
420  return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
421  result, &byValRefNonPtrAttrs);
422 }
423 
424 /// Converts the function type to a C-compatible format, in particular using
425 /// pointers to memref descriptors for arguments.
426 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
428  SmallVector<Type, 4> inputs;
429 
430  Type resultType = type.getNumResults() == 0
432  : packFunctionResults(type.getResults());
433  if (!resultType)
434  return {};
435 
436  auto ptrType = LLVM::LLVMPointerType::get(type.getContext());
437  auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
438  if (structType) {
439  // Struct types cannot be safely returned via C interface. Make this a
440  // pointer argument, instead.
441  inputs.push_back(ptrType);
442  resultType = LLVM::LLVMVoidType::get(&getContext());
443  }
444 
445  for (Type t : type.getInputs()) {
446  auto converted = convertType(t);
447  if (!converted || !LLVM::isCompatibleType(converted))
448  return {};
449  if (isa<MemRefType, UnrankedMemRefType>(t))
450  converted = ptrType;
451  inputs.push_back(converted);
452  }
453 
454  return {LLVM::LLVMFunctionType::get(resultType, inputs), structType};
455 }
456 
457 /// Convert a memref type into a list of LLVM IR types that will form the
458 /// memref descriptor. The result contains the following types:
459 /// 1. The pointer to the allocated data buffer, followed by
460 /// 2. The pointer to the aligned data buffer, followed by
461 /// 3. A lowered `index`-type integer containing the distance between the
462 /// beginning of the buffer and the first element to be accessed through the
463 /// view, followed by
464 /// 4. An array containing as many `index`-type integers as the rank of the
465 /// MemRef: the array represents the size, in number of elements, of the memref
466 /// along the given dimension. For constant MemRef dimensions, the
467 /// corresponding size entry is a constant whose runtime value must match the
468 /// static value, followed by
469 /// 5. A second array containing as many `index`-type integers as the rank of
470 /// the MemRef: the second array represents the "stride" (in tensor abstraction
471 /// sense), i.e. the number of consecutive elements of the underlying buffer.
472 /// TODO: add assertions for the static cases.
473 ///
474 /// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
475 /// are expanded into individual index-type elements.
476 ///
477 /// template <typename Elem, typename Index, size_t Rank>
478 /// struct {
479 /// Elem *allocatedPtr;
480 /// Elem *alignedPtr;
481 /// Index offset;
482 /// Index sizes[Rank]; // omitted when rank == 0
483 /// Index strides[Rank]; // omitted when rank == 0
484 /// };
487  bool unpackAggregates) const {
488  if (!type.isStrided()) {
489  emitError(
490  UnknownLoc::get(type.getContext()),
491  "conversion to strided form failed either due to non-strided layout "
492  "maps (which should have been normalized away) or other reasons");
493  return {};
494  }
495 
496  Type elementType = convertType(type.getElementType());
497  if (!elementType)
498  return {};
499 
500  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
501  if (failed(addressSpace)) {
502  emitError(UnknownLoc::get(type.getContext()),
503  "conversion of memref memory space ")
504  << type.getMemorySpace()
505  << " to integer address space "
506  "failed. Consider adding memory space conversions.";
507  return {};
508  }
509  auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
510 
511  auto indexTy = getIndexType();
512 
513  SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
514  auto rank = type.getRank();
515  if (rank == 0)
516  return results;
517 
518  if (unpackAggregates)
519  results.insert(results.end(), 2 * rank, indexTy);
520  else
521  results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
522  return results;
523 }
524 
525 unsigned
527  const DataLayout &layout) const {
528  // Compute the descriptor size given that of its components indicated above.
529  unsigned space = *getMemRefAddressSpace(type);
530  return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
531  (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
532 }
533 
534 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
535 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
536 Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
537  // When converting a MemRefType to a struct with descriptor fields, do not
538  // unpack the `sizes` and `strides` arrays.
539  SmallVector<Type, 5> types =
540  getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
541  if (types.empty())
542  return {};
543  return LLVM::LLVMStructType::getLiteral(&getContext(), types);
544 }
545 
546 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
547 /// that will form the unranked memref descriptor. In particular, the fields
548 /// for an unranked memref descriptor are:
549 /// 1. index-typed rank, the dynamic rank of this MemRef
550 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
551 /// stack allocated (alloca) copy of a MemRef descriptor that got casted to
552 /// be unranked.
556 }
557 
559  UnrankedMemRefType type, const DataLayout &layout) const {
560  // Compute the descriptor size given that of its components indicated above.
561  unsigned space = *getMemRefAddressSpace(type);
562  return layout.getTypeSize(getIndexType()) +
564 }
565 
566 Type LLVMTypeConverter::convertUnrankedMemRefType(
567  UnrankedMemRefType type) const {
568  if (!convertType(type.getElementType()))
569  return {};
570  return LLVM::LLVMStructType::getLiteral(&getContext(),
572 }
573 
574 FailureOr<unsigned>
576  if (!type.getMemorySpace()) // Default memory space -> 0.
577  return 0;
578  std::optional<Attribute> converted =
579  convertTypeAttribute(type, type.getMemorySpace());
580  if (!converted)
581  return failure();
582  if (!(*converted)) // Conversion to default is 0.
583  return 0;
584  if (auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) {
585  if (explicitSpace.getType().isIndex() ||
586  explicitSpace.getType().isSignlessInteger())
587  return explicitSpace.getInt();
588  }
589  return failure();
590 }
591 
592 // Check if a memref type can be converted to a bare pointer.
594  if (isa<UnrankedMemRefType>(type))
595  // Unranked memref is not supported in the bare pointer calling convention.
596  return false;
597 
598  // Check that the memref has static shape, strides and offset. Otherwise, it
599  // cannot be lowered to a bare pointer.
600  auto memrefTy = cast<MemRefType>(type);
601  if (!memrefTy.hasStaticShape())
602  return false;
603 
604  int64_t offset = 0;
605  SmallVector<int64_t, 4> strides;
606  if (failed(memrefTy.getStridesAndOffset(strides, offset)))
607  return false;
608 
609  for (int64_t stride : strides)
610  if (ShapedType::isDynamic(stride))
611  return false;
612 
613  return !ShapedType::isDynamic(offset);
614 }
615 
616 /// Convert a memref type to a bare pointer to the memref element type.
617 Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
618  if (!canConvertToBarePtr(type))
619  return {};
620  Type elementType = convertType(type.getElementType());
621  if (!elementType)
622  return {};
623  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
624  if (failed(addressSpace))
625  return {};
626  return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
627 }
628 
629 /// Convert an n-D vector type to an LLVM vector type:
630 /// * 0-D `vector<T>` are converted to vector<1xT>
631 /// * 1-D `vector<axT>` remains as is while,
632 /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
633 /// `!llvm.array<ax...array<jxvector<kxT>>>`.
634 /// As LLVM supports arrays of scalable vectors, this method will also convert
635 /// n-D scalable vectors provided that only the trailing dim is scalable.
636 FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
637  auto elementType = convertType(type.getElementType());
638  if (!elementType)
639  return {};
640  if (type.getShape().empty())
641  return VectorType::get({1}, elementType);
642  Type vectorType = VectorType::get(type.getShape().back(), elementType,
643  type.getScalableDims().back());
644  assert(LLVM::isCompatibleVectorType(vectorType) &&
645  "expected vector type compatible with the LLVM dialect");
646  // For n-D vector types for which a _non-trailing_ dim is scalable,
647  // return a failure. Supporting such cases would require LLVM
648  // to support something akin "scalable arrays" of vectors.
649  if (llvm::is_contained(type.getScalableDims().drop_back(), true))
650  return failure();
651  auto shape = type.getShape();
652  for (int i = shape.size() - 2; i >= 0; --i)
653  vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
654  return vectorType;
655 }
656 
657 /// Convert a type in the context of the default or bare pointer calling
658 /// convention. Calling convention sensitive types, such as MemRefType and
659 /// UnrankedMemRefType, are converted following the specific rules for the
660 /// calling convention. Calling convention independent types are converted
661 /// following the default LLVM type conversions.
663  Type type, bool useBarePtrCallConv) const {
664  if (useBarePtrCallConv)
665  if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
666  return convertMemRefToBarePtr(memrefTy);
667 
668  return convertType(type);
669 }
670 
671 /// Promote the bare pointers in 'values' that resulted from memrefs to
672 /// descriptors. 'stdTypes' holds they types of 'values' before the conversion
673 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
675  ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
676  SmallVectorImpl<Value> &values) const {
677  assert(stdTypes.size() == values.size() &&
678  "The number of types and values doesn't match");
679  for (unsigned i = 0, end = values.size(); i < end; ++i)
680  if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
681  values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
682  memrefTy, values[i]);
683 }
684 
685 /// Convert a non-empty list of types of values produced by an operation into an
686 /// LLVM-compatible type. In particular, if more than one value is
687 /// produced, create a literal structure with elements that correspond to each
688 /// of the types converted with `convertType`.
690  assert(!types.empty() && "expected non-empty list of type");
691  if (types.size() == 1)
692  return convertType(types[0]);
693 
694  SmallVector<Type> resultTypes;
695  resultTypes.reserve(types.size());
696  for (Type type : types) {
697  Type converted = convertType(type);
698  if (!converted || !LLVM::isCompatibleType(converted))
699  return {};
700  resultTypes.push_back(converted);
701  }
702 
703  return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
704 }
705 
706 /// Convert a non-empty list of types to be returned from a function into an
707 /// LLVM-compatible type. In particular, if more than one value is returned,
708 /// create an LLVM dialect structure type with elements that correspond to each
709 /// of the types converted with `convertCallingConventionType`.
711  bool useBarePtrCallConv) const {
712  assert(!types.empty() && "expected non-empty list of type");
713 
714  useBarePtrCallConv |= options.useBarePtrCallConv;
715  if (types.size() == 1)
716  return convertCallingConventionType(types.front(), useBarePtrCallConv);
717 
718  SmallVector<Type> resultTypes;
719  resultTypes.reserve(types.size());
720  for (auto t : types) {
721  auto converted = convertCallingConventionType(t, useBarePtrCallConv);
722  if (!converted || !LLVM::isCompatibleType(converted))
723  return {};
724  resultTypes.push_back(converted);
725  }
726 
727  return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
728 }
729 
731  OpBuilder &builder) const {
732  // Alloca with proper alignment. We do not expect optimizations of this
733  // alloca op and so we omit allocating at the entry block.
734  auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
735  Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
736  builder.getIndexAttr(1));
737  Value allocated =
738  builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one);
739  // Store into the alloca'ed descriptor.
740  builder.create<LLVM::StoreOp>(loc, operand, allocated);
741  return allocated;
742 }
743 
746  ValueRange operands, OpBuilder &builder,
747  bool useBarePtrCallConv) const {
748  SmallVector<Value, 4> promotedOperands;
749  promotedOperands.reserve(operands.size());
750  useBarePtrCallConv |= options.useBarePtrCallConv;
751  for (auto it : llvm::zip(opOperands, operands)) {
752  auto operand = std::get<0>(it);
753  auto llvmOperand = std::get<1>(it);
754 
755  if (useBarePtrCallConv) {
756  // For the bare-ptr calling convention, we only have to extract the
757  // aligned pointer of a memref.
758  if (isa<MemRefType>(operand.getType())) {
759  MemRefDescriptor desc(llvmOperand);
760  llvmOperand = desc.alignedPtr(builder, loc);
761  } else if (isa<UnrankedMemRefType>(operand.getType())) {
762  llvm_unreachable("Unranked memrefs are not supported");
763  }
764  } else {
765  if (isa<UnrankedMemRefType>(operand.getType())) {
766  UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
767  promotedOperands);
768  continue;
769  }
770  if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
771  MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
772  promotedOperands);
773  continue;
774  }
775  }
776 
777  promotedOperands.push_back(llvmOperand);
778  }
779  return promotedOperands;
780 }
781 
782 /// Callback to convert function argument types. It converts a MemRef function
783 /// argument to a list of non-aggregate types containing descriptor
784 /// information, and an UnrankedmemRef function argument to a list containing
785 /// the rank and a pointer to a descriptor struct.
786 LogicalResult
788  SmallVectorImpl<Type> &result) {
789  if (auto memref = dyn_cast<MemRefType>(type)) {
790  // In signatures, Memref descriptors are expanded into lists of
791  // non-aggregate values.
792  auto converted =
793  converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
794  if (converted.empty())
795  return failure();
796  result.append(converted.begin(), converted.end());
797  return success();
798  }
799  if (isa<UnrankedMemRefType>(type)) {
800  auto converted = converter.getUnrankedMemRefDescriptorFields();
801  if (converted.empty())
802  return failure();
803  result.append(converted.begin(), converted.end());
804  return success();
805  }
806  auto converted = converter.convertType(type);
807  if (!converted)
808  return failure();
809  result.push_back(converted);
810  return success();
811 }
812 
813 /// Callback to convert function argument types. It converts MemRef function
814 /// arguments to bare pointers to the MemRef element type.
815 LogicalResult
817  SmallVectorImpl<Type> &result) {
818  auto llvmTy = converter.convertCallingConventionType(
819  type, /*useBarePointerCallConv=*/true);
820  if (!llvmTy)
821  return failure();
822 
823  result.push_back(llvmTy);
824  return success();
825 }
static llvm::ManagedStatic< PassManagerOptions > options
static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
Pack SSA values into a ranked memref descriptor struct.
static Value unrankedMemRefMaterialization(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
MemRef descriptor elements -> UnrankedMemRefType.
static Value packUnrankedMemRefDesc(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
Pack SSA values into an unranked memref descriptor struct.
static Value rankedMemRefMaterialization(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
MemRef descriptor elements -> MemRefType.
static bool isBarePointer(ValueRange values)
Helper function that checks if the given value range is a bare pointer.
static void filterByValRefArgAttrs(FunctionOpInterface funcOp, SmallVectorImpl< std::optional< NamedAttribute >> &result)
Returns the llvm.byval or llvm.byref attributes that are present in the function arguments.
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:102
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerType getI64Type()
Definition: Builders.cpp:65
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
LLVM::LLVMDialect * llvmDialect
Pointer to the LLVM dialect.
llvm::sys::SmartRWMutex< true > callStackMutex
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
SmallVector< Type, 5 > getMemRefDescriptorFields(MemRefType type, bool unpackAggregates) const
Convert a memref type into a list of LLVM IR types that will form the memref descriptor.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, Location loc, ArrayRef< Type > stdTypes, SmallVectorImpl< Value > &values) const
Promote the bare pointers in 'values' that resulted from memrefs to descriptors.
DenseMap< uint64_t, std::unique_ptr< SmallVector< Type > > > conversionCallStack
SmallVector< Type > & getCurrentThreadRecursiveStack()
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
Type convertCallingConventionType(Type type, bool useBarePointerCallConv=false) const
Convert a type in the context of the default or bare pointer calling convention.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
SmallVector< Value, 4 > promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder, bool useBarePtrCallConv=false) const
Promote the LLVM representation of all operands including promoting MemRef descriptors to stack and u...
LLVMTypeConverter(MLIRContext *ctx, const DataLayoutAnalysis *analysis=nullptr)
Create an LLVMTypeConverter using the default LowerToLLVMOptions.
unsigned getPointerBitwidth(unsigned addressSpace=0) const
Gets the pointer bitwidth.
SmallVector< Type, 2 > getUnrankedMemRefDescriptorFields() const
Convert an unranked memref type into a list of non-aggregate LLVM IR types that will form the unranke...
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
Type getIndexType() const
Gets the LLVM representation of the index type.
friend LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Give structFuncArgTypeConverter access to memref-specific functions.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
llvm::DataLayout dataLayout
The data layout of the module to produce.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
This class helps build Operations.
Definition: Builders.h:204
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class provides all of the information necessary to convert a type signature.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:814
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:796
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...