MLIR  21.0.0git
TosaFolders.cpp
Go to the documentation of this file.
1 //===- TosaFolders.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Fold TOSA operations
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <functional>
14 #include <numeric>
15 
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/Pass/Pass.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/FloatingPointMode.h"
25 #include "llvm/ADT/SmallVector.h"
26 
27 using namespace mlir;
28 using namespace mlir::tosa;
29 
30 namespace {
31 
32 /// Apply the given transformation \p toApply to every element of the tensor to
33 /// be transformed \p toTransform.
34 ///
35 /// Elements of \p toTransform are extracted as \p SrcValueType.
36 ///
37 /// \returns A tensor with the same size as \p toTransform, containing
38 /// \p TargetValueType values of type \p TargetType.
39 template <class SrcValType, class TargetValType, class TargetType>
40 DenseElementsAttr applyElementWise(
41  const DenseElementsAttr &toTransform,
42  const std::function<TargetValType(const SrcValType &)> &toApply,
43  TargetType targetType) {
44  SmallVector<TargetValType> transformedValues;
45  // We already know the amount of values we will insert, reserve space for
46  // all of them to avoid dynamic resizing
47  transformedValues.reserve(toTransform.getNumElements());
48  for (auto val : toTransform.getValues<SrcValType>()) {
49  auto transformedVal = toApply(val);
50  transformedValues.push_back(transformedVal);
51  }
52 
53  // Make sure that the output tensor has the expected output type
54  auto inShape = toTransform.getType();
55  auto outTy = inShape.cloneWith({}, targetType);
56 
57  return DenseElementsAttr::get(outTy, transformedValues);
58 }
59 
60 template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
61  const DenseElementsAttr &toTransform,
62  const std::function<APFloat(const APFloat &)> &toApply,
63  FloatType targetType);
64 
65 /// Function that checks if the type contained in \p toCheck is float.
66 LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
67  PatternRewriter &rewriter) {
68  if (isa<FloatType>(toCheck.getType().getElementType())) {
69  return success();
70  }
71  return rewriter.notifyMatchFailure(location,
72  "Unexpected input tensor type: the "
73  "TOSA spec only allows floats");
74 }
75 
76 /// Function that checks if \p toCheck is a dense TOSA constant tensor.
77 LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
78  TosaOp location,
79  PatternRewriter &rewriter) {
80  // Check whether the tensor is constant and dense
81  // TODO We currently ensure the tensor is dense by using the correct type for
82  // the bind_value, however we do not actually need this value. It would be
83  // nicer to only have a check here.
85  if (!matchPattern(toCheck, m_Constant(&tmp))) {
86  return rewriter.notifyMatchFailure(location,
87  "Non-const or non-dense input tensor");
88  }
89 
90  // Make sure it actually is a TOSA constant (the match allows for other
91  // constants as well)
92  if (isa<ConstOp>(toCheck.getDefiningOp())) {
93  return success();
94  }
95 
96  return rewriter.notifyMatchFailure(location,
97  "The reciprocal can only be folded if "
98  "it operates on a TOSA constant");
99 }
100 
101 /// Function that checks if \p toCheck is a dense TOSA constant float tensor.
102 LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
103  TosaOp location,
104  PatternRewriter &rewriter) {
105  auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
106  if (failed(floatCheck)) {
107  return floatCheck;
108  }
109  return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
110 }
111 
112 /// Heuristic to decide when to replace a unary operation on a constant with the
113 /// folded value.
114 /// Folding operations on constants can lead to an increased memory usage
115 /// whenever the input cannot be replaced but a new constant is inserted. Hence,
116 /// this will currently only suggest folding when the memory impact is
117 /// negligible.
118 /// Takes the \p unaryOp and the constant input \p values.
119 /// \returns Whether folding should be applied.
120 bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
121  assert(unaryOp->getNumOperands() == 1);
122  auto inputOp = unaryOp->getOperand(0);
123 
124  // If the input is a splat, we don't care for the number of users
125  if (isa<SplatElementsAttr>(values)) {
126  return true;
127  }
128 
129  // If this is the only use of the tensor it should be replaced as no
130  // additional memory is required
131  return inputOp.hasOneUse();
132 }
133 
134 template <typename RangeType>
135 DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
136  ShapedType outputType,
137  llvm::ArrayRef<int64_t> permValues) {
138  using ElementType = std::decay_t<decltype(*std::begin(data))>;
139 
140  assert(inputType.getElementType() == outputType.getElementType());
141 
142  if (inputType.getNumElements() == 0)
144 
145  auto inputShape = inputType.getShape();
146 
147  // The inverted permutation map and strides of the output are used to compute
148  // the contribution of a given dimension to the destination linear index in
149  // an order-independent way.
150  auto outputStrides = computeStrides(outputType.getShape());
151  auto invertedPermValues = invertPermutationVector(permValues);
152 
153  auto initialValue = *std::begin(data);
154  SmallVector<ElementType> outputValues(inputType.getNumElements(),
155  initialValue);
156 
157  for (const auto &it : llvm::enumerate(data)) {
158  auto srcLinearIndex = it.index();
159 
160  uint64_t dstLinearIndex = 0;
161  for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
162  // Compute the index into the current dimension of the source vector.
163  auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
164  srcLinearIndex /= inputShape[dim];
165 
166  // Add the contribution of the current dimension to the output using the
167  // permutation map.
168  dstLinearIndex +=
169  outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
170  }
171 
172  outputValues[dstLinearIndex] = it.value();
173  }
174 
175  return DenseElementsAttr::get(outputType,
176  llvm::ArrayRef<ElementType>(outputValues));
177 }
178 
179 // A type specialized transposition of an ElementsAttr.
180 // This implementation tries to operate on the underlying data in its raw
181 // representation when possible to avoid allocating a large number of Attribute
182 // objects.
183 DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
184  ShapedType outputType,
185  llvm::ArrayRef<int64_t> permValues) {
186  if (auto data = attr.tryGetValues<bool>())
187  return transposeType(*data, inputType, outputType, permValues);
188 
189  if (auto data = attr.tryGetValues<int8_t>())
190  return transposeType(*data, inputType, outputType, permValues);
191 
192  if (auto data = attr.tryGetValues<int16_t>())
193  return transposeType(*data, inputType, outputType, permValues);
194 
195  if (auto data = attr.tryGetValues<int32_t>())
196  return transposeType(*data, inputType, outputType, permValues);
197 
198  if (auto data = attr.tryGetValues<int64_t>())
199  return transposeType(*data, inputType, outputType, permValues);
200 
201  if (auto data = attr.tryGetValues<float>())
202  return transposeType(*data, inputType, outputType, permValues);
203 
204  if (auto data = attr.tryGetValues<APFloat>())
205  return transposeType(*data, inputType, outputType, permValues);
206 
207  return nullptr;
208 }
209 
210 struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
212 
213  LogicalResult matchAndRewrite(tosa::TransposeOp op,
214  PatternRewriter &rewriter) const override {
215  auto outputType = cast<ShapedType>(op.getType());
216  // TOSA supports quantized types.
217  if (!outputType.getElementType().isIntOrIndexOrFloat())
218  return failure();
219 
220  ElementsAttr inputValues;
221  if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
222  return failure();
223  // Make sure the input is a constant that has a single user.
224  if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
225  return failure();
226 
227  auto permValues = llvm::map_to_vector(
228  op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });
229 
230  auto inputType = cast<ShapedType>(op.getInput1().getType());
231 
232  auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
233  if (!resultAttr) {
234  return rewriter.notifyMatchFailure(
235  op, "unsupported attribute or element type");
236  }
237 
238  rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
239  return success();
240  }
241 };
242 
243 struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
244 
246 
247  LogicalResult matchAndRewrite(ReciprocalOp recip,
248  PatternRewriter &rewriter) const override {
249  auto inputTensor = recip.getInput1();
250 
251  // Check that we can apply folding
252  auto preCondCheck =
253  notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
254  if (failed(preCondCheck)) {
255  return preCondCheck;
256  }
257 
258  // Extract the tensor values
259  DenseElementsAttr inputValues;
260  matchPattern(inputTensor, m_Constant(&inputValues));
261 
262  // Check whether this should be folded.
263  if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
264  return rewriter.notifyMatchFailure(
265  recip, "Currently, reciprocals will only be folded if the input "
266  "tensor has a single user");
267  }
268 
269  // Create a new tensor with the updated values
270  auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
271  inputValues, &ReciprocalOp::calcOneElement,
272  cast<FloatType>(inputValues.getElementType()));
273 
274  // Replace the use of the reciprocal with the transformed tensor
275  rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
276  return success();
277  }
278 };
279 
280 /// Getting the axes position of the element which is located
281 /// in the tensor at the counter index
282 
284 getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) {
285  int64_t remaining = index;
286  llvm::SmallVector<int64_t> position(tensorShape.size(), 0);
287  for (int64_t i = tensorShape.size() - 1; i >= 0; --i) {
288  position[i] = remaining % tensorShape[i];
289  remaining /= tensorShape[i];
290  }
291  return position;
292 }
293 
294 /// Getting the index of the element which is located at the
295 /// axes position in the tensor
296 
297 int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position,
298  llvm::ArrayRef<int64_t> tensorShape) {
299  int64_t index = 0;
300  int64_t multiplierTmp = 1;
301  for (int64_t i = position.size() - 1; i >= 0; --i) {
302  index += position[i] * multiplierTmp;
303  multiplierTmp *= tensorShape[i];
304  }
305  return index;
306 }
307 
308 template <typename OperationType>
309 llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
310  llvm::ArrayRef<int64_t> oldShape,
311  int64_t reductionAxis,
312  int64_t reductionIndex) {
313 
314  llvm::SmallVector<int64_t> newShape(oldShape);
315  newShape[reductionAxis] = 1;
316  /// Let's calculate the position of the index
317  llvm::SmallVector<int64_t> position =
318  getPositionFromIndex(reductionIndex, newShape);
319  auto oldTensor = oldTensorAttr.getValues<llvm::APInt>();
320  /// Starting from the first positon along the reduction axis
321  position[reductionAxis] = 0;
322  int64_t indexAtOldTensor = getIndexFromPosition(position, oldShape);
323  llvm::APInt reducedValue = oldTensor[indexAtOldTensor];
324 
325  for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
326  ++reductionAxisVal) {
327 
328  int64_t stride = std::accumulate(oldShape.begin() + reductionAxis + 1,
329  oldShape.end(), 1, std::multiplies<int>());
330  int64_t index = indexAtOldTensor + stride * reductionAxisVal;
331  reducedValue =
332  OperationType::calcOneElement(reducedValue, oldTensor[index]);
333  }
334  return reducedValue;
335 }
336 
337 template <typename OperationType>
338 struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
339 
340  ReduceConstantOptimization(MLIRContext *context,
341  bool aggressiveReduceConstant)
342  : OpRewritePattern<OperationType>(context),
343  aggressiveReduceConstant(aggressiveReduceConstant) {}
344 
346 
347  LogicalResult matchAndRewrite(OperationType op,
348  PatternRewriter &rewriter) const override {
349  Value inputOp = op.getInput();
350  auto constOp = inputOp.getDefiningOp<tosa::ConstOp>();
351 
352  if (!constOp)
353  return rewriter.notifyMatchFailure(
354  op, "reduce input must be const operation");
355 
356  if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
357  return rewriter.notifyMatchFailure(
358  op, "input operation has more than one user");
359 
360  auto resultType = cast<ShapedType>(op.getOutput().getType());
361 
362  if (!resultType.hasStaticShape())
363  return rewriter.notifyMatchFailure(op, "result type shape is not static");
364 
365  auto reductionAxis = op.getAxis();
366  const auto denseElementsAttr = constOp.getValues();
367  const auto shapedOldElementsValues =
368  cast<ShapedType>(denseElementsAttr.getType());
369 
370  if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
371  return rewriter.notifyMatchFailure(
372  op, "reduce input currently supported with integer type");
373 
374  auto oldShape = shapedOldElementsValues.getShape();
375  auto newShape = resultType.getShape();
376 
377  auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1,
378  std::multiplies<int>());
379  llvm::SmallVector<APInt> newReducedTensor(newNumOfElements);
380 
381  for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
382  ++reductionIndex) {
383 
384  /// Let's reduce all the elements along this reduction axis
385  newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>(
386  denseElementsAttr, oldShape, reductionAxis, reductionIndex);
387  }
388 
389  auto rankedTensorType = cast<RankedTensorType>(resultType);
390  auto denseAttr =
391  mlir::DenseElementsAttr::get(rankedTensorType, newReducedTensor);
392  rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
393  return success();
394  }
395  const bool aggressiveReduceConstant;
396 };
397 
398 } // namespace
399 
402  bool aggressiveReduceConstant) {
403  patterns.add<ReduceConstantOptimization<ReduceAllOp>>(
404  ctx, aggressiveReduceConstant);
405  patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(
406  ctx, aggressiveReduceConstant);
407  patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(
408  ctx, aggressiveReduceConstant);
409  patterns.add<ReduceConstantOptimization<ReduceMinOp>>(
410  ctx, aggressiveReduceConstant);
411  patterns.add<ReduceConstantOptimization<ReduceProductOp>>(
412  ctx, aggressiveReduceConstant);
413  patterns.add<ReduceConstantOptimization<ReduceSumOp>>(
414  ctx, aggressiveReduceConstant);
415 }
416 
419  patterns.add<TosaFoldConstantTranspose>(ctx);
420 }
421 
424  patterns.add<TosaFoldConstantReciprocal>(ctx);
425 }
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
int64_t getNumElements() const
Returns the number of elements held by this attribute.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateTosaConstantReduction(MLIRContext *ctx, RewritePatternSet &patterns, bool aggressiveReduceConstant)
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx, RewritePatternSet &patterns)
void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx, RewritePatternSet &patterns)
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:368