MLIR  14.0.0git
TypeUtilities.cpp
Go to the documentation of this file.
1 //===- TypeUtilities.cpp - Helper function for type queries ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines generic type utilities.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/IR/TypeUtilities.h"
14 
15 #include <numeric>
16 
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Types.h"
20 #include "mlir/IR/Value.h"
21 
22 using namespace mlir;
23 
25  if (auto st = type.dyn_cast<ShapedType>())
26  return st.getElementType();
27  return type;
28 }
29 
31  return getElementTypeOrSelf(val.getType());
32 }
33 
35  return getElementTypeOrSelf(attr.getType());
36 }
37 
39  SmallVector<Type, 10> fTypes;
40  t.getFlattenedTypes(fTypes);
41  return fTypes;
42 }
43 
44 /// Return true if the specified type is an opaque type with the specified
45 /// dialect and typeData.
46 bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
47  StringRef typeData) {
48  if (auto opaque = type.dyn_cast<mlir::OpaqueType>())
49  return opaque.getDialectNamespace() == dialect &&
50  opaque.getTypeData() == typeData;
51  return false;
52 }
53 
54 /// Returns success if the given two shapes are compatible. That is, they have
55 /// the same size and each pair of the elements are equal or one of them is
56 /// dynamic.
58  ArrayRef<int64_t> shape2) {
59  if (shape1.size() != shape2.size())
60  return failure();
61  for (auto dims : llvm::zip(shape1, shape2)) {
62  int64_t dim1 = std::get<0>(dims);
63  int64_t dim2 = std::get<1>(dims);
64  if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) &&
65  dim1 != dim2)
66  return failure();
67  }
68  return success();
69 }
70 
71 /// Returns success if the given two types have compatible shape. That is,
72 /// they are both scalars (not shaped), or they are both shaped types and at
73 /// least one is unranked or they have compatible dimensions. Dimensions are
74 /// compatible if at least one is dynamic or both are equal. The element type
75 /// does not matter.
77  auto sType1 = type1.dyn_cast<ShapedType>();
78  auto sType2 = type2.dyn_cast<ShapedType>();
79 
80  // Either both or neither type should be shaped.
81  if (!sType1)
82  return success(!sType2);
83  if (!sType2)
84  return failure();
85 
86  if (!sType1.hasRank() || !sType2.hasRank())
87  return success();
88 
89  return verifyCompatibleShape(sType1.getShape(), sType2.getShape());
90 }
91 
92 /// Returns success if the given two arrays have the same number of elements and
93 /// each pair wise entries have compatible shape.
95  if (types1.size() != types2.size())
96  return failure();
97  for (auto it : llvm::zip_first(types1, types2))
98  if (failed(verifyCompatibleShape(std::get<0>(it), std::get<1>(it))))
99  return failure();
100  return success();
101 }
102 
104  if (dims.empty())
105  return success();
106  auto staticDim = std::accumulate(
107  dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) {
108  return ShapedType::isDynamic(dim) ? fold : dim;
109  });
110  return success(llvm::all_of(dims, [&](auto dim) {
111  return ShapedType::isDynamic(dim) || dim == staticDim;
112  }));
113 }
114 
115 /// Returns success if all given types have compatible shapes. That is, they are
116 /// all scalars (not shaped), or they are all shaped types and any ranked shapes
117 /// have compatible dimensions. Dimensions are compatible if all non-dynamic
118 /// dims are equal. The element type does not matter.
120  auto shapedTypes = llvm::to_vector<8>(llvm::map_range(
121  types, [](auto type) { return type.template dyn_cast<ShapedType>(); }));
122  // Return failure if some, but not all are not shaped. Return early if none
123  // are shaped also.
124  if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
125  return success();
126  if (!llvm::all_of(shapedTypes, [](auto t) { return t; }))
127  return failure();
128 
129  // Remove all unranked shapes
130  auto shapes = llvm::to_vector<8>(llvm::make_filter_range(
131  shapedTypes, [](auto shapedType) { return shapedType.hasRank(); }));
132  if (shapes.empty())
133  return success();
134 
135  // All ranks should be equal
136  auto firstRank = shapes.front().getRank();
137  if (llvm::any_of(shapes,
138  [&](auto shape) { return firstRank != shape.getRank(); }))
139  return failure();
140 
141  for (unsigned i = 0; i < firstRank; ++i) {
142  // Retrieve all ranked dimensions
143  auto dims = llvm::to_vector<8>(llvm::map_range(
144  llvm::make_filter_range(
145  shapes, [&](auto shape) { return shape.getRank() >= i; }),
146  [&](auto shape) { return shape.getDimSize(i); }));
147  if (verifyCompatibleDims(dims).failed())
148  return failure();
149  }
150 
151  return success();
152 }
153 
155  return value.getType().cast<ShapedType>().getElementType();
156 }
157 
159  return value.getType().cast<ShapedType>().getElementType();
160 }
Include the generated interface declarations.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Get the types within a nested Tuple.
static constexpr const bool value
bool isOpaqueTypeWithName(Type type, StringRef dialect, StringRef typeData)
Return true if the specified type is an opaque type with the specified dialect and typeData...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
U dyn_cast() const
Definition: Types.h:244
Attributes are known-constant values of operations.
Definition: Attributes.h:24
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
Type mapElement(Value value) const
Map the element to the iterator result type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
Type getType() const
Return the type of this attribute.
Definition: Attributes.h:64
Type getType() const
Return the type of this value.
Definition: Value.h:117
Type mapElement(Value value) const
Map the element to the iterator result type.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
U cast() const
Definition: Types.h:250