MLIR  20.0.0git
TypeConverter.cpp
Go to the documentation of this file.
1 //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "MemRefDescriptor.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Threading.h"
16 #include <memory>
17 #include <mutex>
18 #include <optional>
19 
20 using namespace mlir;
21 
23  {
24  // Most of the time, the entry already exists in the map.
25  std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
26  std::defer_lock);
27  if (getContext().isMultithreadingEnabled())
28  lock.lock();
29  auto recursiveStack = conversionCallStack.find(llvm::get_threadid());
30  if (recursiveStack != conversionCallStack.end())
31  return *recursiveStack->second;
32  }
33 
34  // First time this thread gets here, we have to get an exclusive access to
35  // inset in the map
36  std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
37  auto recursiveStackInserted = conversionCallStack.insert(std::make_pair(
38  llvm::get_threadid(), std::make_unique<SmallVector<Type>>()));
39  return *recursiveStackInserted.first->second;
40 }
41 
42 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
44  const DataLayoutAnalysis *analysis)
45  : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
46 
47 /// Helper function that checks if the given value range is a bare pointer.
48 static bool isBarePointer(ValueRange values) {
49  return values.size() == 1 &&
50  isa<LLVM::LLVMPointerType>(values.front().getType());
51 }
52 
53 /// Pack SSA values into an unranked memref descriptor struct.
55  UnrankedMemRefType resultType,
56  ValueRange inputs, Location loc,
57  const LLVMTypeConverter &converter) {
58  // Note: Bare pointers are not supported for unranked memrefs because a
59  // memref descriptor cannot be built just from a bare pointer.
60  if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
61  return Value();
62  return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
63  inputs);
64 }
65 
66 /// Pack SSA values into a ranked memref descriptor struct.
67 static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType,
68  ValueRange inputs, Location loc,
69  const LLVMTypeConverter &converter) {
70  assert(resultType && "expected non-null result type");
71  if (isBarePointer(inputs))
72  return MemRefDescriptor::fromStaticShape(builder, loc, converter,
73  resultType, inputs[0]);
74  if (TypeRange(inputs) ==
75  converter.getMemRefDescriptorFields(resultType,
76  /*unpackAggregates=*/true))
77  return MemRefDescriptor::pack(builder, loc, converter, resultType, inputs);
78  // The inputs are neither a bare pointer nor an unpacked memref descriptor.
79  // This materialization function cannot be used.
80  return Value();
81 }
82 
83 /// MemRef descriptor elements -> UnrankedMemRefType
85  UnrankedMemRefType resultType,
86  ValueRange inputs, Location loc,
87  const LLVMTypeConverter &converter) {
88  // A source materialization must return a value of type
89  // `resultType`, so insert a cast from the memref descriptor type
90  // (!llvm.struct) to the original memref type.
91  Value packed =
92  packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter);
93  if (!packed)
94  return Value();
95  return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
96  .getResult(0);
97 }
98 
99 /// MemRef descriptor elements -> MemRefType
101  MemRefType resultType,
102  ValueRange inputs, Location loc,
103  const LLVMTypeConverter &converter) {
104  // A source materialization must return a value of type `resultType`,
105  // so insert a cast from the memref descriptor type (!llvm.struct) to the
106  // original memref type.
107  Value packed =
108  packRankedMemRefDesc(builder, resultType, inputs, loc, converter);
109  if (!packed)
110  return Value();
111  return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
112  .getResult(0);
113 }
114 
115 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
118  const DataLayoutAnalysis *analysis)
119  : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
120  dataLayoutAnalysis(analysis) {
121  assert(llvmDialect && "LLVM IR dialect is not registered");
122 
123  // Register conversions for the builtin types.
124  addConversion([&](ComplexType type) { return convertComplexType(type); });
125  addConversion([&](FloatType type) { return convertFloatType(type); });
126  addConversion([&](FunctionType type) { return convertFunctionType(type); });
127  addConversion([&](IndexType type) { return convertIndexType(type); });
128  addConversion([&](IntegerType type) { return convertIntegerType(type); });
129  addConversion([&](MemRefType type) { return convertMemRefType(type); });
131  [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
132  addConversion([&](VectorType type) -> std::optional<Type> {
133  FailureOr<Type> llvmType = convertVectorType(type);
134  if (failed(llvmType))
135  return std::nullopt;
136  return llvmType;
137  });
138 
139  // LLVM-compatible types are legal, so add a pass-through conversion. Do this
140  // before the conversions below since conversions are attempted in reverse
141  // order and those should take priority.
142  addConversion([](Type type) {
143  return LLVM::isCompatibleType(type) ? std::optional<Type>(type)
144  : std::nullopt;
145  });
146 
147  addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results)
148  -> std::optional<LogicalResult> {
149  // Fastpath for types that won't be converted by this callback anyway.
150  if (LLVM::isCompatibleType(type)) {
151  results.push_back(type);
152  return success();
153  }
154 
155  if (type.isIdentified()) {
156  auto convertedType = LLVM::LLVMStructType::getIdentified(
157  type.getContext(), ("_Converted." + type.getName()).str());
158 
160  if (llvm::count(recursiveStack, type)) {
161  results.push_back(convertedType);
162  return success();
163  }
164  recursiveStack.push_back(type);
165  auto popConversionCallStack = llvm::make_scope_exit(
166  [&recursiveStack]() { recursiveStack.pop_back(); });
167 
168  SmallVector<Type> convertedElemTypes;
169  convertedElemTypes.reserve(type.getBody().size());
170  if (failed(convertTypes(type.getBody(), convertedElemTypes)))
171  return std::nullopt;
172 
173  // If the converted type has not been initialized yet, just set its body
174  // to be the converted arguments and return.
175  if (!convertedType.isInitialized()) {
176  if (failed(
177  convertedType.setBody(convertedElemTypes, type.isPacked()))) {
178  return failure();
179  }
180  results.push_back(convertedType);
181  return success();
182  }
183 
184  // If it has been initialized, has the same body and packed bit, just use
185  // it. This ensures that recursive structs keep being recursive rather
186  // than including a non-updated name.
187  if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) &&
188  convertedType.isPacked() == type.isPacked()) {
189  results.push_back(convertedType);
190  return success();
191  }
192 
193  return failure();
194  }
195 
196  SmallVector<Type> convertedSubtypes;
197  convertedSubtypes.reserve(type.getBody().size());
198  if (failed(convertTypes(type.getBody(), convertedSubtypes)))
199  return std::nullopt;
200 
201  results.push_back(LLVM::LLVMStructType::getLiteral(
202  type.getContext(), convertedSubtypes, type.isPacked()));
203  return success();
204  });
205  addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> {
206  if (auto element = convertType(type.getElementType()))
207  return LLVM::LLVMArrayType::get(element, type.getNumElements());
208  return std::nullopt;
209  });
210  addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
211  Type convertedResType = convertType(type.getReturnType());
212  if (!convertedResType)
213  return std::nullopt;
214 
215  SmallVector<Type> convertedArgTypes;
216  convertedArgTypes.reserve(type.getNumParams());
217  if (failed(convertTypes(type.getParams(), convertedArgTypes)))
218  return std::nullopt;
219 
220  return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
221  type.isVarArg());
222  });
223 
224  // Add generic source and target materializations to handle cases where
225  // non-LLVM types persist after an LLVM conversion.
226  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
227  ValueRange inputs, Location loc) {
228  return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
229  .getResult(0);
230  });
231  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
232  ValueRange inputs, Location loc) {
233  return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
234  .getResult(0);
235  });
236 
237  // Source materializations convert from the new block argument types
238  // (multiple SSA values that make up a memref descriptor) back to the
239  // original block argument type.
240  addSourceMaterialization([&](OpBuilder &builder,
241  UnrankedMemRefType resultType, ValueRange inputs,
242  Location loc) {
243  return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
244  *this);
245  });
246  addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
247  ValueRange inputs, Location loc) {
248  return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
249  });
250 
251  // Bare pointer -> Packed MemRef descriptor
252  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
253  ValueRange inputs, Location loc,
254  Type originalType) -> Value {
255  // The original MemRef type is required to build a MemRef descriptor
256  // because the sizes/strides of the MemRef cannot be inferred from just the
257  // bare pointer.
258  if (!originalType)
259  return Value();
260  if (resultType != convertType(originalType))
261  return Value();
262  if (auto memrefType = dyn_cast<MemRefType>(originalType))
263  return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
264  if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
265  return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
266  *this);
267  return Value();
268  });
269 
270  // Integer memory spaces map to themselves.
272  [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
273 }
274 
275 /// Returns the MLIR context.
277  return *getDialect()->getContext();
278 }
279 
282 }
283 
284 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const {
285  return options.dataLayout.getPointerSizeInBits(addressSpace);
286 }
287 
288 Type LLVMTypeConverter::convertIndexType(IndexType type) const {
289  return getIndexType();
290 }
291 
292 Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
293  return IntegerType::get(&getContext(), type.getWidth());
294 }
295 
296 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
297  // Valid LLVM float types are used directly.
298  if (LLVM::isCompatibleType(type))
299  return type;
300 
301  // F4, F6, F8 types are converted to integer types with the same bit width.
302  if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
303  type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
304  type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
305  type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
306  type.isFloat8E8M0FNU())
307  return IntegerType::get(&getContext(), type.getWidth());
308 
309  // Other floating-point types: A custom type conversion rule must be
310  // specified by the user.
311  return Type();
312 }
313 
314 // Convert a `ComplexType` to an LLVM type. The result is a complex number
315 // struct with entries for the
316 // 1. real part and for the
317 // 2. imaginary part.
318 Type LLVMTypeConverter::convertComplexType(ComplexType type) const {
319  auto elementType = convertType(type.getElementType());
320  return LLVM::LLVMStructType::getLiteral(&getContext(),
321  {elementType, elementType});
322 }
323 
324 // Except for signatures, MLIR function types are converted into LLVM
325 // pointer-to-function types.
326 Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
327  return LLVM::LLVMPointerType::get(type.getContext());
328 }
329 
330 /// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
331 /// function arguments. Returns an empty container if none of these attributes
332 /// are found in any of the arguments.
333 static void
334 filterByValRefArgAttrs(FunctionOpInterface funcOp,
335  SmallVectorImpl<std::optional<NamedAttribute>> &result) {
336  assert(result.empty() && "Unexpected non-empty output");
337  result.resize(funcOp.getNumArguments(), std::nullopt);
338  bool foundByValByRefAttrs = false;
339  for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
340  for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
341  if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
342  namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
343  foundByValByRefAttrs = true;
344  result[argIdx] = namedAttr;
345  break;
346  }
347  }
348  }
349 
350  if (!foundByValByRefAttrs)
351  result.clear();
352 }
353 
354 // Function types are converted to LLVM Function types by recursively converting
355 // argument and result types. If MLIR Function has zero results, the LLVM
356 // Function has one VoidType result. If MLIR Function has more than one result,
357 // they are into an LLVM StructType in their order of appearance.
358 // If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and
359 // `llvm.byref` function arguments which are not LLVM pointers are overridden
360 // with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already
361 // converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`.
362 Type LLVMTypeConverter::convertFunctionSignatureImpl(
363  FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
364  LLVMTypeConverter::SignatureConversion &result,
365  SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) const {
366  // Select the argument converter depending on the calling convention.
367  useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
368  auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
370  // Convert argument types one by one and check for errors.
371  for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
372  SmallVector<Type, 8> converted;
373  if (failed(funcArgConverter(*this, type, converted)))
374  return {};
375 
376  // Rewrite converted type of `llvm.byval` or `llvm.byref` function
377  // argument that was not converted to an LLVM pointer types.
378  if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() &&
379  converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
380  // If the argument was already converted to an LLVM pointer type, we stop
381  // tracking it as it doesn't need more processing.
382  if (isa<LLVM::LLVMPointerType>(converted[0]))
383  (*byValRefNonPtrAttrs)[idx] = std::nullopt;
384  else
385  converted[0] = LLVM::LLVMPointerType::get(&getContext());
386  }
387 
388  result.addInputs(idx, converted);
389  }
390 
391  // If function does not return anything, create the void result type,
392  // if it returns on element, convert it, otherwise pack the result types into
393  // a struct.
394  Type resultType =
395  funcTy.getNumResults() == 0
397  : packFunctionResults(funcTy.getResults(), useBarePtrCallConv);
398  if (!resultType)
399  return {};
400  return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
401  isVariadic);
402 }
403 
405  FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
407  return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
408  result,
409  /*byValRefNonPtrAttrs=*/nullptr);
410 }
411 
413  FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv,
415  SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs) const {
416  // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those
417  // that were not converted to LLVM pointer types will be returned for further
418  // processing.
419  filterByValRefArgAttrs(funcOp, byValRefNonPtrAttrs);
420  auto funcTy = cast<FunctionType>(funcOp.getFunctionType());
421  return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
422  result, &byValRefNonPtrAttrs);
423 }
424 
425 /// Converts the function type to a C-compatible format, in particular using
426 /// pointers to memref descriptors for arguments.
427 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
429  SmallVector<Type, 4> inputs;
430 
431  Type resultType = type.getNumResults() == 0
433  : packFunctionResults(type.getResults());
434  if (!resultType)
435  return {};
436 
437  auto ptrType = LLVM::LLVMPointerType::get(type.getContext());
438  auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
439  if (structType) {
440  // Struct types cannot be safely returned via C interface. Make this a
441  // pointer argument, instead.
442  inputs.push_back(ptrType);
443  resultType = LLVM::LLVMVoidType::get(&getContext());
444  }
445 
446  for (Type t : type.getInputs()) {
447  auto converted = convertType(t);
448  if (!converted || !LLVM::isCompatibleType(converted))
449  return {};
450  if (isa<MemRefType, UnrankedMemRefType>(t))
451  converted = ptrType;
452  inputs.push_back(converted);
453  }
454 
455  return {LLVM::LLVMFunctionType::get(resultType, inputs), structType};
456 }
457 
458 /// Convert a memref type into a list of LLVM IR types that will form the
459 /// memref descriptor. The result contains the following types:
460 /// 1. The pointer to the allocated data buffer, followed by
461 /// 2. The pointer to the aligned data buffer, followed by
462 /// 3. A lowered `index`-type integer containing the distance between the
463 /// beginning of the buffer and the first element to be accessed through the
464 /// view, followed by
465 /// 4. An array containing as many `index`-type integers as the rank of the
466 /// MemRef: the array represents the size, in number of elements, of the memref
467 /// along the given dimension. For constant MemRef dimensions, the
468 /// corresponding size entry is a constant whose runtime value must match the
469 /// static value, followed by
470 /// 5. A second array containing as many `index`-type integers as the rank of
471 /// the MemRef: the second array represents the "stride" (in tensor abstraction
472 /// sense), i.e. the number of consecutive elements of the underlying buffer.
473 /// TODO: add assertions for the static cases.
474 ///
475 /// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
476 /// are expanded into individual index-type elements.
477 ///
478 /// template <typename Elem, typename Index, size_t Rank>
479 /// struct {
480 /// Elem *allocatedPtr;
481 /// Elem *alignedPtr;
482 /// Index offset;
483 /// Index sizes[Rank]; // omitted when rank == 0
484 /// Index strides[Rank]; // omitted when rank == 0
485 /// };
488  bool unpackAggregates) const {
489  if (!isStrided(type)) {
490  emitError(
491  UnknownLoc::get(type.getContext()),
492  "conversion to strided form failed either due to non-strided layout "
493  "maps (which should have been normalized away) or other reasons");
494  return {};
495  }
496 
497  Type elementType = convertType(type.getElementType());
498  if (!elementType)
499  return {};
500 
501  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
502  if (failed(addressSpace)) {
503  emitError(UnknownLoc::get(type.getContext()),
504  "conversion of memref memory space ")
505  << type.getMemorySpace()
506  << " to integer address space "
507  "failed. Consider adding memory space conversions.";
508  return {};
509  }
510  auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
511 
512  auto indexTy = getIndexType();
513 
514  SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
515  auto rank = type.getRank();
516  if (rank == 0)
517  return results;
518 
519  if (unpackAggregates)
520  results.insert(results.end(), 2 * rank, indexTy);
521  else
522  results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
523  return results;
524 }
525 
526 unsigned
528  const DataLayout &layout) const {
529  // Compute the descriptor size given that of its components indicated above.
530  unsigned space = *getMemRefAddressSpace(type);
531  return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
532  (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
533 }
534 
535 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
536 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
537 Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
538  // When converting a MemRefType to a struct with descriptor fields, do not
539  // unpack the `sizes` and `strides` arrays.
540  SmallVector<Type, 5> types =
541  getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
542  if (types.empty())
543  return {};
544  return LLVM::LLVMStructType::getLiteral(&getContext(), types);
545 }
546 
547 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
548 /// that will form the unranked memref descriptor. In particular, the fields
549 /// for an unranked memref descriptor are:
550 /// 1. index-typed rank, the dynamic rank of this MemRef
551 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
552 /// stack allocated (alloca) copy of a MemRef descriptor that got casted to
553 /// be unranked.
557 }
558 
560  UnrankedMemRefType type, const DataLayout &layout) const {
561  // Compute the descriptor size given that of its components indicated above.
562  unsigned space = *getMemRefAddressSpace(type);
563  return layout.getTypeSize(getIndexType()) +
565 }
566 
567 Type LLVMTypeConverter::convertUnrankedMemRefType(
568  UnrankedMemRefType type) const {
569  if (!convertType(type.getElementType()))
570  return {};
571  return LLVM::LLVMStructType::getLiteral(&getContext(),
573 }
574 
575 FailureOr<unsigned>
577  if (!type.getMemorySpace()) // Default memory space -> 0.
578  return 0;
579  std::optional<Attribute> converted =
580  convertTypeAttribute(type, type.getMemorySpace());
581  if (!converted)
582  return failure();
583  if (!(*converted)) // Conversion to default is 0.
584  return 0;
585  if (auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) {
586  if (explicitSpace.getType().isIndex() ||
587  explicitSpace.getType().isSignlessInteger())
588  return explicitSpace.getInt();
589  }
590  return failure();
591 }
592 
593 // Check if a memref type can be converted to a bare pointer.
595  if (isa<UnrankedMemRefType>(type))
596  // Unranked memref is not supported in the bare pointer calling convention.
597  return false;
598 
599  // Check that the memref has static shape, strides and offset. Otherwise, it
600  // cannot be lowered to a bare pointer.
601  auto memrefTy = cast<MemRefType>(type);
602  if (!memrefTy.hasStaticShape())
603  return false;
604 
605  int64_t offset = 0;
606  SmallVector<int64_t, 4> strides;
607  if (failed(getStridesAndOffset(memrefTy, strides, offset)))
608  return false;
609 
610  for (int64_t stride : strides)
611  if (ShapedType::isDynamic(stride))
612  return false;
613 
614  return !ShapedType::isDynamic(offset);
615 }
616 
617 /// Convert a memref type to a bare pointer to the memref element type.
618 Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
619  if (!canConvertToBarePtr(type))
620  return {};
621  Type elementType = convertType(type.getElementType());
622  if (!elementType)
623  return {};
624  FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
625  if (failed(addressSpace))
626  return {};
627  return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
628 }
629 
630 /// Convert an n-D vector type to an LLVM vector type:
631 /// * 0-D `vector<T>` are converted to vector<1xT>
632 /// * 1-D `vector<axT>` remains as is while,
633 /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
634 /// `!llvm.array<ax...array<jxvector<kxT>>>`.
635 /// As LLVM supports arrays of scalable vectors, this method will also convert
636 /// n-D scalable vectors provided that only the trailing dim is scalable.
637 FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
638  auto elementType = convertType(type.getElementType());
639  if (!elementType)
640  return {};
641  if (type.getShape().empty())
642  return VectorType::get({1}, elementType);
643  Type vectorType = VectorType::get(type.getShape().back(), elementType,
644  type.getScalableDims().back());
645  assert(LLVM::isCompatibleVectorType(vectorType) &&
646  "expected vector type compatible with the LLVM dialect");
647  // For n-D vector types for which a _non-trailing_ dim is scalable,
648  // return a failure. Supporting such cases would require LLVM
649  // to support something akin "scalable arrays" of vectors.
650  if (llvm::is_contained(type.getScalableDims().drop_back(), true))
651  return failure();
652  auto shape = type.getShape();
653  for (int i = shape.size() - 2; i >= 0; --i)
654  vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
655  return vectorType;
656 }
657 
658 /// Convert a type in the context of the default or bare pointer calling
659 /// convention. Calling convention sensitive types, such as MemRefType and
660 /// UnrankedMemRefType, are converted following the specific rules for the
661 /// calling convention. Calling convention independent types are converted
662 /// following the default LLVM type conversions.
664  Type type, bool useBarePtrCallConv) const {
665  if (useBarePtrCallConv)
666  if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
667  return convertMemRefToBarePtr(memrefTy);
668 
669  return convertType(type);
670 }
671 
672 /// Promote the bare pointers in 'values' that resulted from memrefs to
673 /// descriptors. 'stdTypes' holds they types of 'values' before the conversion
674 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
676  ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
677  SmallVectorImpl<Value> &values) const {
678  assert(stdTypes.size() == values.size() &&
679  "The number of types and values doesn't match");
680  for (unsigned i = 0, end = values.size(); i < end; ++i)
681  if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
682  values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
683  memrefTy, values[i]);
684 }
685 
686 /// Convert a non-empty list of types of values produced by an operation into an
687 /// LLVM-compatible type. In particular, if more than one value is
688 /// produced, create a literal structure with elements that correspond to each
689 /// of the types converted with `convertType`.
691  assert(!types.empty() && "expected non-empty list of type");
692  if (types.size() == 1)
693  return convertType(types[0]);
694 
695  SmallVector<Type> resultTypes;
696  resultTypes.reserve(types.size());
697  for (Type type : types) {
698  Type converted = convertType(type);
699  if (!converted || !LLVM::isCompatibleType(converted))
700  return {};
701  resultTypes.push_back(converted);
702  }
703 
704  return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
705 }
706 
707 /// Convert a non-empty list of types to be returned from a function into an
708 /// LLVM-compatible type. In particular, if more than one value is returned,
709 /// create an LLVM dialect structure type with elements that correspond to each
710 /// of the types converted with `convertCallingConventionType`.
712  bool useBarePtrCallConv) const {
713  assert(!types.empty() && "expected non-empty list of type");
714 
715  useBarePtrCallConv |= options.useBarePtrCallConv;
716  if (types.size() == 1)
717  return convertCallingConventionType(types.front(), useBarePtrCallConv);
718 
719  SmallVector<Type> resultTypes;
720  resultTypes.reserve(types.size());
721  for (auto t : types) {
722  auto converted = convertCallingConventionType(t, useBarePtrCallConv);
723  if (!converted || !LLVM::isCompatibleType(converted))
724  return {};
725  resultTypes.push_back(converted);
726  }
727 
728  return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
729 }
730 
732  OpBuilder &builder) const {
733  // Alloca with proper alignment. We do not expect optimizations of this
734  // alloca op and so we omit allocating at the entry block.
735  auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
736  Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
737  builder.getIndexAttr(1));
738  Value allocated =
739  builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one);
740  // Store into the alloca'ed descriptor.
741  builder.create<LLVM::StoreOp>(loc, operand, allocated);
742  return allocated;
743 }
744 
747  ValueRange operands, OpBuilder &builder,
748  bool useBarePtrCallConv) const {
749  SmallVector<Value, 4> promotedOperands;
750  promotedOperands.reserve(operands.size());
751  useBarePtrCallConv |= options.useBarePtrCallConv;
752  for (auto it : llvm::zip(opOperands, operands)) {
753  auto operand = std::get<0>(it);
754  auto llvmOperand = std::get<1>(it);
755 
756  if (useBarePtrCallConv) {
757  // For the bare-ptr calling convention, we only have to extract the
758  // aligned pointer of a memref.
759  if (dyn_cast<MemRefType>(operand.getType())) {
760  MemRefDescriptor desc(llvmOperand);
761  llvmOperand = desc.alignedPtr(builder, loc);
762  } else if (isa<UnrankedMemRefType>(operand.getType())) {
763  llvm_unreachable("Unranked memrefs are not supported");
764  }
765  } else {
766  if (isa<UnrankedMemRefType>(operand.getType())) {
767  UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
768  promotedOperands);
769  continue;
770  }
771  if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
772  MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
773  promotedOperands);
774  continue;
775  }
776  }
777 
778  promotedOperands.push_back(llvmOperand);
779  }
780  return promotedOperands;
781 }
782 
783 /// Callback to convert function argument types. It converts a MemRef function
784 /// argument to a list of non-aggregate types containing descriptor
785 /// information, and an UnrankedmemRef function argument to a list containing
786 /// the rank and a pointer to a descriptor struct.
787 LogicalResult
789  SmallVectorImpl<Type> &result) {
790  if (auto memref = dyn_cast<MemRefType>(type)) {
791  // In signatures, Memref descriptors are expanded into lists of
792  // non-aggregate values.
793  auto converted =
794  converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
795  if (converted.empty())
796  return failure();
797  result.append(converted.begin(), converted.end());
798  return success();
799  }
800  if (isa<UnrankedMemRefType>(type)) {
801  auto converted = converter.getUnrankedMemRefDescriptorFields();
802  if (converted.empty())
803  return failure();
804  result.append(converted.begin(), converted.end());
805  return success();
806  }
807  auto converted = converter.convertType(type);
808  if (!converted)
809  return failure();
810  result.push_back(converted);
811  return success();
812 }
813 
814 /// Callback to convert function argument types. It converts MemRef function
815 /// arguments to bare pointers to the MemRef element type.
816 LogicalResult
818  SmallVectorImpl<Type> &result) {
819  auto llvmTy = converter.convertCallingConventionType(
820  type, /*useBarePointerCallConv=*/true);
821  if (!llvmTy)
822  return failure();
823 
824  result.push_back(llvmTy);
825  return success();
826 }
static llvm::ManagedStatic< PassManagerOptions > options
static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
Pack SSA values into a ranked memref descriptor struct.
static Value unrankedMemRefMaterialization(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
MemRef descriptor elements -> UnrankedMemRefType.
static Value packUnrankedMemRefDesc(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
Pack SSA values into an unranked memref descriptor struct.
static Value rankedMemRefMaterialization(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter)
MemRef descriptor elements -> MemRefType.
static bool isBarePointer(ValueRange values)
Helper function that checks if the given value range is a bare pointer.
static void filterByValRefArgAttrs(FunctionOpInterface funcOp, SmallVectorImpl< std::optional< NamedAttribute >> &result)
Returns the llvm.byval or llvm.byref attributes that are present in the function arguments.
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:102
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
Type getElementType() const
Returns the element type of this memref type.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerType getI64Type()
Definition: Builders.cpp:65
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
LLVM::LLVMDialect * llvmDialect
Pointer to the LLVM dialect.
llvm::sys::SmartRWMutex< true > callStackMutex
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Type packFunctionResults(TypeRange types, bool useBarePointerCallConv=false) const
Convert a non-empty list of types to be returned from a function into an LLVM-compatible type.
unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, const DataLayout &layout) const
Returns the size of the unranked memref descriptor object in bytes.
SmallVector< Type, 5 > getMemRefDescriptorFields(MemRefType type, bool unpackAggregates) const
Convert a memref type into a list of LLVM IR types that will form the memref descriptor.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, SignatureConversion &result) const
Convert a function type.
void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, Location loc, ArrayRef< Type > stdTypes, SmallVectorImpl< Value > &values) const
Promote the bare pointers in 'values' that resulted from memrefs to descriptors.
DenseMap< uint64_t, std::unique_ptr< SmallVector< Type > > > conversionCallStack
SmallVector< Type > & getCurrentThreadRecursiveStack()
Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) const
Promote the LLVM struct representation of one MemRef descriptor to stack and use pointer to struct to...
Type convertCallingConventionType(Type type, bool useBarePointerCallConv=false) const
Convert a type in the context of the default or bare pointer calling convention.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
SmallVector< Value, 4 > promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder, bool useBarePtrCallConv=false) const
Promote the LLVM representation of all operands including promoting MemRef descriptors to stack and u...
LLVMTypeConverter(MLIRContext *ctx, const DataLayoutAnalysis *analysis=nullptr)
Create an LLVMTypeConverter using the default LowerToLLVMOptions.
unsigned getPointerBitwidth(unsigned addressSpace=0) const
Gets the pointer bitwidth.
SmallVector< Type, 2 > getUnrankedMemRefDescriptorFields() const
Convert an unranked memref type into a list of non-aggregate LLVM IR types that will form the unranke...
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
static bool canConvertToBarePtr(BaseMemRefType type)
Check if a memref type can be converted to a bare pointer.
MLIRContext & getContext() const
Returns the MLIR context.
unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout) const
Returns the size of the memref descriptor object in bytes.
std::pair< LLVM::LLVMFunctionType, LLVM::LLVMStructType > convertFunctionTypeCWrapper(FunctionType type) const
Converts the function type to a C-compatible format, in particular using pointers to memref descripto...
unsigned getIndexTypeBitwidth() const
Gets the bitwidth of the index type when converted to LLVM.
Type getIndexType() const
Gets the LLVM representation of the index type.
friend LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Give structFuncArgTypeConverter access to memref-specific functions.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
llvm::DataLayout dataLayout
The data layout of the module to produce.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, ValueRange values)
Builds IR populating a MemRef descriptor structure from a list of individual values composing that de...
Value alignedPtr(OpBuilder &builder, Location loc)
Builds IR extracting the aligned pointer from the descriptor.
static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements of a MemRef descriptor structure and returning them as resul...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class provides all of the information necessary to convert a type signature.
std::optional< Attribute > convertTypeAttribute(Type type, Attribute attr) const
Convert an attribute present attr from within the type type using the registered conversion functions...
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a replacement value back ...
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
static Value pack(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values)
Builds IR populating an unranked MemRef descriptor structure from a list of individual constituent va...
static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl< Value > &results)
Builds IR extracting individual elements that compose an unranked memref descriptor and returns them ...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:878
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:860
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl< Type > &result)
Callback to convert function argument types.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.