MLIR  17.0.0git
TypeUtilities.cpp
Go to the documentation of this file.
1 //===- TypeUtilities.cpp - Helper function for type queries ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines generic type utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/IR/TypeUtilities.h"
14 #include "mlir/IR/Attributes.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Types.h"
17 #include "mlir/IR/Value.h"
18 #include "llvm/ADT/SmallVectorExtras.h"
19 #include <numeric>
20 
21 using namespace mlir;
22 
24  if (auto st = llvm::dyn_cast<ShapedType>(type))
25  return st.getElementType();
26  return type;
27 }
28 
30  return getElementTypeOrSelf(val.getType());
31 }
32 
34  if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr))
35  return getElementTypeOrSelf(typedAttr.getType());
36  return {};
37 }
38 
40  SmallVector<Type, 10> fTypes;
41  t.getFlattenedTypes(fTypes);
42  return fTypes;
43 }
44 
45 /// Return true if the specified type is an opaque type with the specified
46 /// dialect and typeData.
47 bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
48  StringRef typeData) {
49  if (auto opaque = llvm::dyn_cast<mlir::OpaqueType>(type))
50  return opaque.getDialectNamespace() == dialect &&
51  opaque.getTypeData() == typeData;
52  return false;
53 }
54 
55 /// Returns success if the given two shapes are compatible. That is, they have
56 /// the same size and each pair of the elements are equal or one of them is
57 /// dynamic.
59  ArrayRef<int64_t> shape2) {
60  if (shape1.size() != shape2.size())
61  return failure();
62  for (auto dims : llvm::zip(shape1, shape2)) {
63  int64_t dim1 = std::get<0>(dims);
64  int64_t dim2 = std::get<1>(dims);
65  if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) &&
66  dim1 != dim2)
67  return failure();
68  }
69  return success();
70 }
71 
72 /// Returns success if the given two types have compatible shape. That is,
73 /// they are both scalars (not shaped), or they are both shaped types and at
74 /// least one is unranked or they have compatible dimensions. Dimensions are
75 /// compatible if at least one is dynamic or both are equal. The element type
76 /// does not matter.
78  auto sType1 = llvm::dyn_cast<ShapedType>(type1);
79  auto sType2 = llvm::dyn_cast<ShapedType>(type2);
80 
81  // Either both or neither type should be shaped.
82  if (!sType1)
83  return success(!sType2);
84  if (!sType2)
85  return failure();
86 
87  if (!sType1.hasRank() || !sType2.hasRank())
88  return success();
89 
90  return verifyCompatibleShape(sType1.getShape(), sType2.getShape());
91 }
92 
93 /// Returns success if the given two arrays have the same number of elements and
94 /// each pair wise entries have compatible shape.
96  if (types1.size() != types2.size())
97  return failure();
98  for (auto it : llvm::zip_first(types1, types2))
99  if (failed(verifyCompatibleShape(std::get<0>(it), std::get<1>(it))))
100  return failure();
101  return success();
102 }
103 
105  if (dims.empty())
106  return success();
107  auto staticDim = std::accumulate(
108  dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) {
109  return ShapedType::isDynamic(dim) ? fold : dim;
110  });
111  return success(llvm::all_of(dims, [&](auto dim) {
112  return ShapedType::isDynamic(dim) || dim == staticDim;
113  }));
114 }
115 
116 /// Returns success if all given types have compatible shapes. That is, they are
117 /// all scalars (not shaped), or they are all shaped types and any ranked shapes
118 /// have compatible dimensions. Dimensions are compatible if all non-dynamic
119 /// dims are equal. The element type does not matter.
121  auto shapedTypes = llvm::map_to_vector<8>(
122  types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); });
123  // Return failure if some, but not all are not shaped. Return early if none
124  // are shaped also.
125  if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
126  return success();
127  if (!llvm::all_of(shapedTypes, [](auto t) { return t; }))
128  return failure();
129 
130  // Return failure if some, but not all, are scalable vectors.
131  bool hasScalableVecTypes = false;
132  bool hasNonScalableVecTypes = false;
133  for (Type t : types) {
134  auto vType = llvm::dyn_cast<VectorType>(t);
135  if (vType && vType.isScalable())
136  hasScalableVecTypes = true;
137  else
138  hasNonScalableVecTypes = true;
139  if (hasScalableVecTypes && hasNonScalableVecTypes)
140  return failure();
141  }
142 
143  // Remove all unranked shapes
144  auto shapes = llvm::to_vector<8>(llvm::make_filter_range(
145  shapedTypes, [](auto shapedType) { return shapedType.hasRank(); }));
146  if (shapes.empty())
147  return success();
148 
149  // All ranks should be equal
150  auto firstRank = shapes.front().getRank();
151  if (llvm::any_of(shapes,
152  [&](auto shape) { return firstRank != shape.getRank(); }))
153  return failure();
154 
155  for (unsigned i = 0; i < firstRank; ++i) {
156  // Retrieve all ranked dimensions
157  auto dims = llvm::map_to_vector<8>(
158  llvm::make_filter_range(
159  shapes, [&](auto shape) { return shape.getRank() >= i; }),
160  [&](auto shape) { return shape.getDimSize(i); });
161  if (verifyCompatibleDims(dims).failed())
162  return failure();
163  }
164 
165  return success();
166 }
167 
169  return llvm::cast<ShapedType>(value.getType()).getElementType();
170 }
171 
173  return llvm::cast<ShapedType>(value.getType()).getElementType();
174 }
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Type mapElement(Value value) const
Map the element to the iterator result type.
Type mapElement(Value value) const
Map the element to the iterator result type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
bool isOpaqueTypeWithName(Type type, StringRef dialect, StringRef typeData)
Return true if the specified type is an opaque type with the specified dialect and typeData.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26