MLIR 23.0.0git
ConversionUtils.h
Go to the documentation of this file.
1//===- ConversionUtils.h - Helper functions for tosa conversion -*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Utility functions for TOSA lowering
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
14#define DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
15
23#include <optional>
24
25namespace mlir {
26namespace tosa {
27
28// Creates a SmallVector of Stringrefs for N parallel loops
29SmallVector<utils::IteratorType>
30getNParallelLoopsAttrs(unsigned nParallelLoops);
31
32// Takes a vector of values and condenses them to a vector with no gaps.
33SmallVector<Value> condenseValues(const SmallVector<Value> &values);
34
35// Takes the parameters for a clamp and turns it into a series of ops for float
36// inputs.
37Value clampFloatHelper(Location loc, Value arg, Value min, Value max,
38 OpBuilder &rewriter);
39
40// Takes the parameters for a clamp and turns it into a series of ops for
41// integer inputs.
42Value clampIntHelper(Location loc, Value arg, Value min, Value max,
43 OpBuilder &rewriter, bool isUnsigned);
44
45// Determines whether the integer value falls witin the range of integer type.
46bool validIntegerRange(IntegerType ty, int64_t value);
47
48// Checks for a dynamic batch dim in any of the passed parameters of an op.
49// The batch dimention must be #0 and the rest of the dimensions must be static.
50template <typename Op>
51std::optional<SmallVector<Value>>
53 ArrayRef<Value> params) {
55 SmallVector<Value> dynamicDims;
56 for (const Value &param : params) {
57 auto paramTy = cast<ShapedType>(param.getType());
58 if (!paramTy.hasStaticShape())
59 dynTypes.push_back(paramTy);
60 }
61
62 if (dynTypes.empty())
63 return dynamicDims;
64
65 for (const ShapedType &dynTy : dynTypes) {
66 if (llvm::any_of(dynTy.getShape().drop_front(), ShapedType::isDynamic)) {
67 (void)rewriter.notifyMatchFailure(
68 op, "input can only be dynamic for batch size");
69 return std::nullopt;
70 }
71 }
72
73 dynamicDims.push_back(
74 tensor::DimOp::create(rewriter, op->getLoc(), params[0], 0));
75 return dynamicDims;
76}
77
78/// Common code to create the reshape op where necessary to make the rank of two
79/// values equal. input1 and input2 will be updated when the rank has
80/// changed. The caller is expected to use these to rewrite the original
81/// operator with the RESHAPE now in the graph.
82LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc,
83 Value &input1, Value &input2);
84
85LogicalResult EqualizeRanks(ImplicitLocOpBuilder &builder, Value &input1,
86 Value &input2);
87
88namespace {
89
90// Creates a TOSA operation and performs shape inference on the individual
91// op. This allows shape inference when lowering down to TOSA.
92template <typename TosaOp, typename... Args>
93TosaOp createOpAndInferShape(ImplicitLocOpBuilder &builder, Type resultTy,
94 Args &&...args) {
95 auto op = TosaOp::create(builder, resultTy, args...);
96
97 InferShapedTypeOpInterface shapeInterface =
98 dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
99 if (!shapeInterface)
100 return op;
101
103 if (shapeInterface
104 .inferReturnTypeComponents(op.getContext(), builder.getLoc(),
105 op->getOperands(), op->getAttrDictionary(),
106 op->getPropertiesStorage(),
107 op->getRegions(), returnedShapes)
108 .failed())
109 return op;
110
111 // We need to use the element type of the existing result type to generate
112 // the new result shaped type. This is because rescale can include a cast to
113 // different bit-width types and does not have a TypeAttr to define the
114 // target type.
115 auto result = op->getResult(0);
116 auto predictedShape = returnedShapes[0];
117 auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(resultTy);
118
119 // Compute the knowledge based on the inferred type.
120 auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
121 inferredKnowledge.dtype = mlir::cast<ShapedType>(resultTy).getElementType();
122 inferredKnowledge.hasRank = predictedShape.hasRank();
123 if (predictedShape.hasRank()) {
124 for (auto dim : predictedShape.getDims()) {
125 inferredKnowledge.sizes.push_back(dim);
126 }
127 }
128
129 // Compute the new type based on the joined version.
130 auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
131 Type newTy =
132 newKnowledge.hasRank
133 ? Type{mlir::RankedTensorType::get(llvm::ArrayRef(newKnowledge.sizes),
134 newKnowledge.dtype)}
135 : Type{mlir::UnrankedTensorType::get(newKnowledge.dtype)};
136 result.setType(newTy);
137 return op;
138}
139
140} // namespace
141
142// Creates a TOSA operation by:
143// - first equalize ranks for ops with SameOperandsAndResultRank trait
144// - create operator
145// - performs shape inference on this operator
146template <typename TosaOp, typename... Args>
148 Args &&...args) {
149 if (TosaOp::template hasTrait<::mlir::OpTrait::SameOperandsAndResultRank>()) {
150 // op requires same ranks for tensor operands
151 if constexpr (sizeof...(Args) == 2) {
152 auto argX = std::get<0>(std::tie(args...));
153 auto argY = std::get<1>(std::tie(args...));
154 using ArgX = decltype(argX);
155 using ArgY = decltype(argY);
156 if constexpr (std::is_same_v<ArgX, Value> &&
157 std::is_same_v<ArgY, Value>) {
158 Value x = std::get<0>(std::tie(args...));
159 Value y = std::get<1>(std::tie(args...));
160 if (EqualizeRanks(builder, x, y).failed()) {
161 // incompatible broadcast shapes, no reshape is inserted
162 // ResultsBroadcastableShape verify will handle this
163 }
164 return createOpAndInferShape<TosaOp>(builder, resultTy, x, y);
165 }
166 }
167 if constexpr (sizeof...(Args) == 3) {
168 auto argX = std::get<0>(std::tie(args...));
169 auto argY = std::get<1>(std::tie(args...));
170 auto argZ = std::get<2>(std::tie(args...));
171 using ArgX = decltype(argX);
172 using ArgY = decltype(argY);
173 using ArgZ = decltype(argZ);
174 if constexpr (std::is_same_v<ArgX, Value> &&
175 std::is_same_v<ArgY, Value> && std::is_same_v<ArgZ, bool>) {
176 // special case for ArithmeticRightShiftOp
177 Value x = std::get<0>(std::tie(args...));
178 Value y = std::get<1>(std::tie(args...));
179 bool round = std::get<2>(std::tie(args...));
180 if (EqualizeRanks(builder, x, y).failed()) {
181 // incompatible broadcast shapes, no reshape is inserted
182 // ResultsBroadcastableShape verify will handle this
183 }
184 return createOpAndInferShape<TosaOp>(builder, resultTy, x, y, round);
185 }
186 if constexpr (std::is_same_v<ArgX, Value> &&
187 std::is_same_v<ArgY, Value> &&
188 std::is_same_v<ArgZ, Value>) {
189 // special case for Select
190 Value x = std::get<0>(std::tie(args...));
191 Value y = std::get<1>(std::tie(args...));
192 Value z = std::get<2>(std::tie(args...));
193
194 if (EqualizeRanks(builder, x, y).failed() ||
195 EqualizeRanks(builder, x, z).failed() ||
196 EqualizeRanks(builder, y, z).failed()) {
197 // incompatible broadcast shapes, no reshape is inserted
198 // ResultsBroadcastableShape verify will handle this
199 }
200
201 return createOpAndInferShape<TosaOp>(builder, resultTy, x, y, z);
202 }
203 }
204 }
205
206 return createOpAndInferShape<TosaOp>(builder, resultTy, args...);
207}
208
209// Creates a TOSA operation by:
210// - first equalize ranks for ops with SameOperandsAndResultRank trait
211// - create operator
212// - performs shape inference on this operator
213template <typename TosaOp, typename... Args>
215 Type resultTy, Args &&...args) {
216 ImplicitLocOpBuilder builder(loc, rewriter);
217 return CreateOpAndInferShape<TosaOp>(builder, resultTy, args...);
218}
219
220// Apply an int32_t permutation to some input, that should be of the same
221// size as perms. Perms should contain some permutation of 0 - perms.size() - 1.
222template <typename T>
224 ArrayRef<int32_t> perms) {
225 SmallVector<T> permuted;
226 size_t N = input.size();
227 permuted.resize_for_overwrite(N);
228 for (size_t i = 0; i < N; i++)
229 permuted[i] = input[perms[i]];
230 return permuted;
231}
232
233// Computes shape value using tosa const_shape op.
238
240
242 llvm::SmallVector<int64_t> &result_shape);
243
244// returns a small vector of int64_t values that attr contains
246 const int rank);
247
248// returns true iff constant indices for scatter op contains unique indices
249// per batch
250bool hasUniqueConstantScatterIndices(ShapedType indicesType,
251 DenseIntElementsAttr indicesAttr);
252
253// Try to get the values of a DenseResourceElementsAttr construct
254template <typename T>
255std::optional<ArrayRef<T>> tryGetDenseResourceValues(ElementsAttr attr) {
256 if (auto denseResource = dyn_cast<DenseResourceElementsAttr>(attr)) {
257 // Check that the resource memory blob exists
258 AsmResourceBlob *blob = denseResource.getRawHandle().getBlob();
259 if (!blob)
260 return std::nullopt;
261
262 // Check that the data are in a valid form
263 if (!DenseElementsAttr::isValidRawBuffer(attr.getShapedType(),
264 blob->getData())) {
265 return std::nullopt;
266 }
267
268 return blob->template getDataAs<T>();
269 }
270
271 return std::nullopt;
272}
273
274} // namespace tosa
275} // namespace mlir
276
277#endif // DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class represents a processed binary blob of data.
Definition AsmState.h:91
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition AsmState.h:145
An attribute that represents a reference to a dense vector or tensor object.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Returns true if the given buffer is a valid raw buffer for the given type.
An attribute that represents a reference to a dense integer vector or tensor object.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:665
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
SmallVector< T > applyTOSAPermutation(ArrayRef< T > input, ArrayRef< int32_t > perms)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
TosaOp CreateOpAndInferShape(ImplicitLocOpBuilder &builder, Type resultTy, Args &&...args)
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
std::optional< ArrayRef< T > > tryGetDenseResourceValues(ElementsAttr attr)
bool validIntegerRange(IntegerType ty, int64_t value)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition ShapeUtils.h:81
static ValueKnowledge getPessimisticValueState()
Definition ShapeUtils.h:61
static ValueKnowledge getKnowledgeFromType(Type type)
Definition ShapeUtils.h:45