MLIR 22.0.0git
ConversionUtils.h
Go to the documentation of this file.
1//===- ConversionUtils.h - Helper functions for tosa conversion -*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Utility functions for TOSA lowering
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
14#define DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
15
22#include <optional>
23
24namespace mlir {
25namespace tosa {
26
27// Creates a SmallVector of Stringrefs for N parallel loops
28SmallVector<utils::IteratorType>
29getNParallelLoopsAttrs(unsigned nParallelLoops);
30
31// Takes a vector of values and condenses them to a vector with no gaps.
32SmallVector<Value> condenseValues(const SmallVector<Value> &values);
33
34// Takes the parameters for a clamp and turns it into a series of ops for float
35// inputs.
36Value clampFloatHelper(Location loc, Value arg, Value min, Value max,
37 OpBuilder &rewriter);
38
39// Takes the parameters for a clamp and turns it into a series of ops for
40// integer inputs.
41Value clampIntHelper(Location loc, Value arg, Value min, Value max,
42 OpBuilder &rewriter, bool isUnsigned);
43
44// Determines whether the integer value falls witin the range of integer type.
45bool validIntegerRange(IntegerType ty, int64_t value);
46
47// Checks for a dynamic batch dim in any of the passed parameters of an op.
48// The batch dimention must be #0 and the rest of the dimensions must be static.
49template <typename Op>
50std::optional<SmallVector<Value>>
52 ArrayRef<Value> params) {
54 SmallVector<Value> dynamicDims;
55 for (const Value &param : params) {
56 auto paramTy = cast<ShapedType>(param.getType());
57 if (!paramTy.hasStaticShape())
58 dynTypes.push_back(paramTy);
59 }
60
61 if (dynTypes.empty())
62 return dynamicDims;
63
64 for (const ShapedType &dynTy : dynTypes) {
65 if (llvm::any_of(dynTy.getShape().drop_front(), ShapedType::isDynamic)) {
66 (void)rewriter.notifyMatchFailure(
67 op, "input can only be dynamic for batch size");
68 return std::nullopt;
69 }
70 }
71
72 dynamicDims.push_back(
73 tensor::DimOp::create(rewriter, op->getLoc(), params[0], 0));
74 return dynamicDims;
75}
76
77/// Common code to create the reshape op where necessary to make the rank of two
78/// values equal. input1 and input2 will be updated when the rank has
79/// changed. The caller is expected to use these to rewrite the original
80/// operator with the RESHAPE now in the graph.
81LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc,
82 Value &input1, Value &input2);
83
84LogicalResult EqualizeRanks(ImplicitLocOpBuilder &builder, Value &input1,
85 Value &input2);
86
87namespace {
88
89// Creates a TOSA operation and performs shape inference on the individual
90// op. This allows shape inference when lowering down to TOSA.
91template <typename TosaOp, typename... Args>
92TosaOp createOpAndInferShape(ImplicitLocOpBuilder &builder, Type resultTy,
93 Args &&...args) {
94 auto op = TosaOp::create(builder, resultTy, args...);
95
96 InferShapedTypeOpInterface shapeInterface =
97 dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
98 if (!shapeInterface)
99 return op;
100
102 if (shapeInterface
103 .inferReturnTypeComponents(op.getContext(), builder.getLoc(),
104 op->getOperands(), op->getAttrDictionary(),
105 op->getPropertiesStorage(),
106 op->getRegions(), returnedShapes)
107 .failed())
108 return op;
109
110 // We need to use the element type of the existing result type to generate
111 // the new result shaped type. This is because rescale can include a cast to
112 // different bit-width types and does not have a TypeAttr to define the
113 // target type.
114 auto result = op->getResult(0);
115 auto predictedShape = returnedShapes[0];
116 auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(resultTy);
117
118 // Compute the knowledge based on the inferred type.
119 auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
120 inferredKnowledge.dtype = mlir::cast<ShapedType>(resultTy).getElementType();
121 inferredKnowledge.hasRank = predictedShape.hasRank();
122 if (predictedShape.hasRank()) {
123 for (auto dim : predictedShape.getDims()) {
124 inferredKnowledge.sizes.push_back(dim);
125 }
126 }
127
128 // Compute the new type based on the joined version.
129 auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
130 Type newTy =
131 newKnowledge.hasRank
132 ? Type{mlir::RankedTensorType::get(llvm::ArrayRef(newKnowledge.sizes),
133 newKnowledge.dtype)}
134 : Type{mlir::UnrankedTensorType::get(newKnowledge.dtype)};
135 result.setType(newTy);
136 return op;
137}
138
139} // namespace
140
141// Creates a TOSA operation by:
142// - first equalize ranks for ops with SameOperandsAndResultRank trait
143// - create operator
144// - performs shape inference on this operator
145template <typename TosaOp, typename... Args>
147 Args &&...args) {
148 if (TosaOp::template hasTrait<::mlir::OpTrait::SameOperandsAndResultRank>()) {
149 // op requires same ranks for tensor operands
150 if constexpr (sizeof...(Args) == 2) {
151 auto argX = std::get<0>(std::tie(args...));
152 auto argY = std::get<1>(std::tie(args...));
153 using ArgX = decltype(argX);
154 using ArgY = decltype(argY);
155 if constexpr (std::is_same_v<ArgX, Value> &&
156 std::is_same_v<ArgY, Value>) {
157 Value x = std::get<0>(std::tie(args...));
158 Value y = std::get<1>(std::tie(args...));
159 if (EqualizeRanks(builder, x, y).failed()) {
160 // incompatible broadcast shapes, no reshape is inserted
161 // ResultsBroadcastableShape verify will handle this
162 }
163 return createOpAndInferShape<TosaOp>(builder, resultTy, x, y);
164 }
165 }
166 if constexpr (sizeof...(Args) == 3) {
167 auto argX = std::get<0>(std::tie(args...));
168 auto argY = std::get<1>(std::tie(args...));
169 auto argZ = std::get<2>(std::tie(args...));
170 using ArgX = decltype(argX);
171 using ArgY = decltype(argY);
172 using ArgZ = decltype(argZ);
173 if constexpr (std::is_same_v<ArgX, Value> &&
174 std::is_same_v<ArgY, Value> && std::is_same_v<ArgZ, bool>) {
175 // special case for ArithmeticRightShiftOp
176 Value x = std::get<0>(std::tie(args...));
177 Value y = std::get<1>(std::tie(args...));
178 bool round = std::get<2>(std::tie(args...));
179 if (EqualizeRanks(builder, x, y).failed()) {
180 // incompatible broadcast shapes, no reshape is inserted
181 // ResultsBroadcastableShape verify will handle this
182 }
183 return createOpAndInferShape<TosaOp>(builder, resultTy, x, y, round);
184 }
185 if constexpr (std::is_same_v<ArgX, Value> &&
186 std::is_same_v<ArgY, Value> &&
187 std::is_same_v<ArgZ, Value>) {
188 // special case for Select
189 Value x = std::get<0>(std::tie(args...));
190 Value y = std::get<1>(std::tie(args...));
191 Value z = std::get<2>(std::tie(args...));
192
193 if (EqualizeRanks(builder, x, y).failed() ||
194 EqualizeRanks(builder, x, z).failed() ||
195 EqualizeRanks(builder, y, z).failed()) {
196 // incompatible broadcast shapes, no reshape is inserted
197 // ResultsBroadcastableShape verify will handle this
198 }
199
200 return createOpAndInferShape<TosaOp>(builder, resultTy, x, y, z);
201 }
202 }
203 }
204
205 return createOpAndInferShape<TosaOp>(builder, resultTy, args...);
206}
207
208// Creates a TOSA operation by:
209// - first equalize ranks for ops with SameOperandsAndResultRank trait
210// - create operator
211// - performs shape inference on this operator
212template <typename TosaOp, typename... Args>
214 Type resultTy, Args &&...args) {
215 ImplicitLocOpBuilder builder(loc, rewriter);
216 return CreateOpAndInferShape<TosaOp>(builder, resultTy, args...);
217}
218
219// Apply an int32_t permutation to some input, that should be of the same
220// size as perms. Perms should contain some permutation of 0 - perms.size() - 1.
221template <typename T>
223 ArrayRef<int32_t> perms) {
224 SmallVector<T> permuted;
225 size_t N = input.size();
226 permuted.resize_for_overwrite(N);
227 for (size_t i = 0; i < N; i++)
228 permuted[i] = input[perms[i]];
229 return permuted;
230}
231
232// Computes shape value using tosa const_shape op.
237
239
241 llvm::SmallVector<int64_t> &result_shape);
242
243// returns a small vector of int64_t values that attr contains
245 const int rank);
246
247// returns true iff constant indices for scatter op contains unique indices
248// per batch
249bool hasUniqueConstantScatterIndices(ShapedType indicesType,
250 DenseIntElementsAttr indicesAttr);
251} // namespace tosa
252} // namespace mlir
253
254#endif // DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
An attribute that represents a reference to a dense vector or tensor object.
An attribute that represents a reference to a dense integer vector or tensor object.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:663
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
SmallVector< T > applyTOSAPermutation(ArrayRef< T > input, ArrayRef< int32_t > perms)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
bool hasUniqueConstantScatterIndices(ShapedType indicesType, DenseIntElementsAttr indicesAttr)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
TosaOp CreateOpAndInferShape(ImplicitLocOpBuilder &builder, Type resultTy, Args &&...args)
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
bool validIntegerRange(IntegerType ty, int64_t value)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition ShapeUtils.h:81
static ValueKnowledge getPessimisticValueState()
Definition ShapeUtils.h:61
static ValueKnowledge getKnowledgeFromType(Type type)
Definition ShapeUtils.h:45