MLIR  16.0.0git
TypeUtilities.cpp
Go to the documentation of this file.
1 //===- TypeUtilities.cpp - Helper function for type queries ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines generic type utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/IR/TypeUtilities.h"
14 
15 #include <numeric>
16 
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Types.h"
20 #include "mlir/IR/Value.h"
21 
22 using namespace mlir;
23 
25  if (auto st = type.dyn_cast<ShapedType>())
26  return st.getElementType();
27  return type;
28 }
29 
31  return getElementTypeOrSelf(val.getType());
32 }
33 
35  if (auto typedAttr = attr.dyn_cast<TypedAttr>())
36  return getElementTypeOrSelf(typedAttr.getType());
37  return {};
38 }
39 
41  SmallVector<Type, 10> fTypes;
42  t.getFlattenedTypes(fTypes);
43  return fTypes;
44 }
45 
46 /// Return true if the specified type is an opaque type with the specified
47 /// dialect and typeData.
48 bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
49  StringRef typeData) {
50  if (auto opaque = type.dyn_cast<mlir::OpaqueType>())
51  return opaque.getDialectNamespace() == dialect &&
52  opaque.getTypeData() == typeData;
53  return false;
54 }
55 
56 /// Returns success if the given two shapes are compatible. That is, they have
57 /// the same size and each pair of the elements are equal or one of them is
58 /// dynamic.
60  ArrayRef<int64_t> shape2) {
61  if (shape1.size() != shape2.size())
62  return failure();
63  for (auto dims : llvm::zip(shape1, shape2)) {
64  int64_t dim1 = std::get<0>(dims);
65  int64_t dim2 = std::get<1>(dims);
66  if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) &&
67  dim1 != dim2)
68  return failure();
69  }
70  return success();
71 }
72 
73 /// Returns success if the given two types have compatible shape. That is,
74 /// they are both scalars (not shaped), or they are both shaped types and at
75 /// least one is unranked or they have compatible dimensions. Dimensions are
76 /// compatible if at least one is dynamic or both are equal. The element type
77 /// does not matter.
79  auto sType1 = type1.dyn_cast<ShapedType>();
80  auto sType2 = type2.dyn_cast<ShapedType>();
81 
82  // Either both or neither type should be shaped.
83  if (!sType1)
84  return success(!sType2);
85  if (!sType2)
86  return failure();
87 
88  if (!sType1.hasRank() || !sType2.hasRank())
89  return success();
90 
91  return verifyCompatibleShape(sType1.getShape(), sType2.getShape());
92 }
93 
94 /// Returns success if the given two arrays have the same number of elements and
95 /// each pair wise entries have compatible shape.
97  if (types1.size() != types2.size())
98  return failure();
99  for (auto it : llvm::zip_first(types1, types2))
100  if (failed(verifyCompatibleShape(std::get<0>(it), std::get<1>(it))))
101  return failure();
102  return success();
103 }
104 
106  if (dims.empty())
107  return success();
108  auto staticDim = std::accumulate(
109  dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) {
110  return ShapedType::isDynamic(dim) ? fold : dim;
111  });
112  return success(llvm::all_of(dims, [&](auto dim) {
113  return ShapedType::isDynamic(dim) || dim == staticDim;
114  }));
115 }
116 
117 /// Returns success if all given types have compatible shapes. That is, they are
118 /// all scalars (not shaped), or they are all shaped types and any ranked shapes
119 /// have compatible dimensions. Dimensions are compatible if all non-dynamic
120 /// dims are equal. The element type does not matter.
122  auto shapedTypes = llvm::to_vector<8>(llvm::map_range(
123  types, [](auto type) { return type.template dyn_cast<ShapedType>(); }));
124  // Return failure if some, but not all are not shaped. Return early if none
125  // are shaped also.
126  if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
127  return success();
128  if (!llvm::all_of(shapedTypes, [](auto t) { return t; }))
129  return failure();
130 
131  // Return failure if some, but not all, are scalable vectors.
132  bool hasScalableVecTypes = false;
133  bool hasNonScalableVecTypes = false;
134  for (Type t : types) {
135  auto vType = t.dyn_cast<VectorType>();
136  if (vType && vType.isScalable())
137  hasScalableVecTypes = true;
138  else
139  hasNonScalableVecTypes = true;
140  if (hasScalableVecTypes && hasNonScalableVecTypes)
141  return failure();
142  }
143 
144  // Remove all unranked shapes
145  auto shapes = llvm::to_vector<8>(llvm::make_filter_range(
146  shapedTypes, [](auto shapedType) { return shapedType.hasRank(); }));
147  if (shapes.empty())
148  return success();
149 
150  // All ranks should be equal
151  auto firstRank = shapes.front().getRank();
152  if (llvm::any_of(shapes,
153  [&](auto shape) { return firstRank != shape.getRank(); }))
154  return failure();
155 
156  for (unsigned i = 0; i < firstRank; ++i) {
157  // Retrieve all ranked dimensions
158  auto dims = llvm::to_vector<8>(llvm::map_range(
159  llvm::make_filter_range(
160  shapes, [&](auto shape) { return shape.getRank() >= i; }),
161  [&](auto shape) { return shape.getDimSize(i); }));
162  if (verifyCompatibleDims(dims).failed())
163  return failure();
164  }
165 
166  return success();
167 }
168 
170  return value.getType().cast<ShapedType>().getElementType();
171 }
172 
174  return value.getType().cast<ShapedType>().getElementType();
175 }
Include the generated interface declarations.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:685
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
static constexpr const bool value
bool isOpaqueTypeWithName(Type type, StringRef dialect, StringRef typeData)
Return true if the specified type is an opaque type with the specified dialect and typeData...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
U dyn_cast() const
Definition: Types.h:270
Attributes are known-constant values of operations.
Definition: Attributes.h:24
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
Type mapElement(Value value) const
Map the element to the iterator result type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:118
Type mapElement(Value value) const
Map the element to the iterator result type.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
U dyn_cast() const
Definition: Attributes.h:127
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
U cast() const
Definition: Types.h:278