MLIR  20.0.0git
ConversionUtils.h
Go to the documentation of this file.
1 //===- ConversionUtils.h - Helper functions for tosa conversion -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utility functions for TOSA lowering
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
14 #define DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
15 
21 #include "mlir/IR/PatternMatch.h"
22 #include <optional>
23 
24 namespace mlir {
25 namespace tosa {
26 
27 // Creates a SmallVector of Stringrefs for N parallel loops
28 SmallVector<utils::IteratorType>
29 getNParallelLoopsAttrs(unsigned nParallelLoops);
30 
31 // Takes a vector of values and condenses them to a vector with no gaps.
32 SmallVector<Value> condenseValues(const SmallVector<Value> &values);
33 
34 // Takes the parameters for a clamp and turns it into a series of ops for float
35 // inputs.
36 Value clampFloatHelper(Location loc, Value arg, Value min, Value max,
37  OpBuilder &rewriter);
38 
39 // Takes the parameters for a clamp and turns it into a series of ops for
40 // integer inputs.
41 Value clampIntHelper(Location loc, Value arg, Value min, Value max,
42  OpBuilder &rewriter, bool isUnsigned);
43 
44 // Determines whether the integer value falls witin the range of integer type.
45 bool validIntegerRange(IntegerType ty, int64_t value);
46 
47 // Checks for a dynamic batch dim in any of the passed parameters of an op.
48 // The batch dimention must be #0 and the rest of the dimensions must be static.
49 template <typename Op>
50 std::optional<SmallVector<Value>>
52  ArrayRef<Value> params) {
53  SmallVector<ShapedType> dynTypes;
54  SmallVector<Value> dynamicDims;
55  for (const Value &param : params) {
56  auto paramTy = cast<ShapedType>(param.getType());
57  if (!paramTy.hasStaticShape())
58  dynTypes.push_back(paramTy);
59  }
60 
61  if (dynTypes.empty())
62  return dynamicDims;
63 
64  for (const ShapedType &dynTy : dynTypes) {
65  if (llvm::any_of(dynTy.getShape().drop_front(), ShapedType::isDynamic)) {
66  (void)rewriter.notifyMatchFailure(
67  op, "input can only be dynamic for batch size");
68  return std::nullopt;
69  }
70  }
71 
72  dynamicDims.push_back(
73  rewriter.create<tensor::DimOp>(op->getLoc(), params[0], 0));
74  return dynamicDims;
75 }
76 
77 /// Common code to create the reshape op where necessary to make the rank of two
78 /// values equal. input1 and input2 will be updated when the rank has
79 /// changed. The caller is expected to use these to rewrite the original
80 /// operator with the RESHAPE now in the graph.
81 LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc,
82  Value &input1, Value &input2);
83 
84 LogicalResult EqualizeRanks(ImplicitLocOpBuilder &builder, Value &input1,
85  Value &input2);
86 
87 namespace {
88 
89 // Creates a TOSA operation and performs shape inference on the individual
90 // op. This allows shape inference when lowering down to TOSA.
91 template <typename TosaOp, typename... Args>
92 TosaOp createOpAndInferShape(ImplicitLocOpBuilder &builder, Type resultTy,
93  Args &&...args) {
94  auto op = builder.create<TosaOp>(resultTy, args...);
95 
96  InferShapedTypeOpInterface shapeInterface =
97  dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
98  if (!shapeInterface)
99  return op;
100 
101  SmallVector<ShapedTypeComponents> returnedShapes;
102  if (shapeInterface
103  .inferReturnTypeComponents(op.getContext(), builder.getLoc(),
104  op->getOperands(), op->getAttrDictionary(),
105  op->getPropertiesStorage(),
106  op->getRegions(), returnedShapes)
107  .failed())
108  return op;
109 
110  // We need to use the element type of the existing result type to generate
111  // the new result shaped type. This is because rescale can include a cast to
112  // different bit-width types and does not have a TypeAttr to define the
113  // target type.
114  auto result = op->getResult(0);
115  auto predictedShape = returnedShapes[0];
116  auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(resultTy);
117 
118  // Compute the knowledge based on the inferred type.
119  auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
120  inferredKnowledge.dtype = mlir::cast<ShapedType>(resultTy).getElementType();
121  inferredKnowledge.hasRank = predictedShape.hasRank();
122  if (predictedShape.hasRank()) {
123  for (auto dim : predictedShape.getDims()) {
124  inferredKnowledge.sizes.push_back(dim);
125  }
126  }
127 
128  // Compute the new type based on the joined version.
129  auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
130  Type newTy =
131  newKnowledge.hasRank
132  ? Type{mlir::RankedTensorType::get(llvm::ArrayRef(newKnowledge.sizes),
133  newKnowledge.dtype)}
134  : Type{mlir::UnrankedTensorType::get(newKnowledge.dtype)};
135  result.setType(newTy);
136  return op;
137 }
138 
139 } // namespace
140 
141 // Creates a TOSA operation by:
142 // - first equalize ranks for ops with SameOperandsAndResultRank trait
143 // - create operator
144 // - performs shape inference on this operator
145 template <typename TosaOp, typename... Args>
147  Args &&...args) {
148  if (TosaOp::template hasTrait<OpTrait::SameOperandsAndResultRank>()) {
149  // op requires same ranks for tensor operands
150  if constexpr (sizeof...(Args) == 2) {
151  auto argX = std::get<0>(std::tie(args...));
152  auto argY = std::get<1>(std::tie(args...));
153  using ArgX = decltype(argX);
154  using ArgY = decltype(argY);
155  if constexpr (std::is_same_v<ArgX, Value> &&
156  std::is_same_v<ArgY, Value>) {
157  Value x = std::get<0>(std::tie(args...));
158  Value y = std::get<1>(std::tie(args...));
159  if (EqualizeRanks(builder, x, y).failed()) {
160  // incompatible broadcast shapes, no reshape is inserted
161  // ResultsBroadcastableShape verify will handle this
162  }
163  return createOpAndInferShape<TosaOp>(builder, resultTy, x, y);
164  }
165  }
166  if constexpr (sizeof...(Args) == 3) {
167  auto argX = std::get<0>(std::tie(args...));
168  auto argY = std::get<1>(std::tie(args...));
169  auto argZ = std::get<2>(std::tie(args...));
170  using ArgX = decltype(argX);
171  using ArgY = decltype(argY);
172  using ArgZ = decltype(argZ);
173  if constexpr (std::is_same_v<ArgX, Value> &&
174  std::is_same_v<ArgY, Value> && std::is_same_v<ArgZ, bool>) {
175  // special case for ArithmeticRightShiftOp
176  Value x = std::get<0>(std::tie(args...));
177  Value y = std::get<1>(std::tie(args...));
178  bool round = std::get<2>(std::tie(args...));
179  if (EqualizeRanks(builder, x, y).failed()) {
180  // incompatible broadcast shapes, no reshape is inserted
181  // ResultsBroadcastableShape verify will handle this
182  }
183  return createOpAndInferShape<TosaOp>(builder, resultTy, x, y, round);
184  }
185  if constexpr (std::is_same_v<ArgX, Value> &&
186  std::is_same_v<ArgY, Value> &&
187  std::is_same_v<ArgZ, Value>) {
188  // special case for Select
189  Value x = std::get<0>(std::tie(args...));
190  Value y = std::get<1>(std::tie(args...));
191  Value z = std::get<2>(std::tie(args...));
192 
193  if (EqualizeRanks(builder, x, y).failed() ||
194  EqualizeRanks(builder, x, z).failed() ||
195  EqualizeRanks(builder, y, z).failed()) {
196  // incompatible broadcast shapes, no reshape is inserted
197  // ResultsBroadcastableShape verify will handle this
198  }
199 
200  return createOpAndInferShape<TosaOp>(builder, resultTy, x, y, z);
201  }
202  }
203  }
204 
205  return createOpAndInferShape<TosaOp>(builder, resultTy, args...);
206 }
207 
208 // Creates a TOSA operation by:
209 // - first equalize ranks for ops with SameOperandsAndResultRank trait
210 // - create operator
211 // - performs shape inference on this operator
212 template <typename TosaOp, typename... Args>
214  Type resultTy, Args &&...args) {
215  ImplicitLocOpBuilder builder(loc, rewriter);
216  return CreateOpAndInferShape<TosaOp>(builder, resultTy, args...);
217 }
218 
219 // Apply an int32_t permutation to some input, that should be of the same
220 // size as perms. Perms should contain some permutation of 0 - perms.size() - 1.
221 template <typename T>
223  ArrayRef<int32_t> perms) {
224  SmallVector<T> permuted;
225  size_t N = input.size();
226  permuted.resize_for_overwrite(N);
227  for (size_t i = 0; i < N; i++)
228  permuted[i] = input[perms[i]];
229  return permuted;
230 }
231 
232 } // namespace tosa
233 } // namespace mlir
234 
235 #endif // DIALECT_TOSA_UTILS_COVERSION_UTILS_H_
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
@ Type
An inlay hint that for a type annotation.
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
SmallVector< T > applyTOSAPermutation(ArrayRef< T > input, ArrayRef< int32_t > perms)
TosaOp CreateOpAndInferShape(ImplicitLocOpBuilder &builder, Type resultTy, Args &&...args)
bool validIntegerRange(IntegerType ty, int64_t value)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:81
static ValueKnowledge getPessimisticValueState()
Definition: ShapeUtils.h:61
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45