MLIR  22.0.0git
TosaFolders.cpp
Go to the documentation of this file.
1 //===- TosaFolders.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Fold TOSA operations
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <functional>
14 #include <numeric>
15 
20 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "llvm/ADT/SmallVector.h"
24 
25 using namespace mlir;
26 using namespace mlir::tosa;
27 
28 namespace {
29 
30 /// Apply the given transformation \p toApply to every element of the tensor to
31 /// be transformed \p toTransform.
32 ///
33 /// Elements of \p toTransform are extracted as \p SrcValueType.
34 ///
35 /// \returns A tensor with the same size as \p toTransform, containing
36 /// \p TargetValueType values of type \p TargetType.
37 template <class SrcValType, class TargetValType, class TargetType>
38 DenseElementsAttr applyElementWise(
39  const DenseElementsAttr &toTransform,
40  const std::function<TargetValType(const SrcValType &)> &toApply,
41  TargetType targetType) {
42  SmallVector<TargetValType> transformedValues;
43  // We already know the amount of values we will insert, reserve space for
44  // all of them to avoid dynamic resizing
45  transformedValues.reserve(toTransform.getNumElements());
46  for (auto val : toTransform.getValues<SrcValType>()) {
47  auto transformedVal = toApply(val);
48  transformedValues.push_back(transformedVal);
49  }
50 
51  // Make sure that the output tensor has the expected output type
52  auto inShape = toTransform.getType();
53  auto outTy = inShape.cloneWith({}, targetType);
54 
55  return DenseElementsAttr::get(outTy, transformedValues);
56 }
57 
58 template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
59  const DenseElementsAttr &toTransform,
60  const std::function<APFloat(const APFloat &)> &toApply,
61  FloatType targetType);
62 
63 /// Function that checks if the type contained in \p toCheck is float.
64 LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
65  PatternRewriter &rewriter) {
66  if (isa<FloatType>(toCheck.getType().getElementType())) {
67  return success();
68  }
69  return rewriter.notifyMatchFailure(location,
70  "Unexpected input tensor type: the "
71  "TOSA spec only allows floats");
72 }
73 
74 /// Function that checks if \p toCheck is a dense TOSA constant tensor.
75 LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
76  TosaOp location,
77  PatternRewriter &rewriter) {
78  // Check whether the tensor is constant and dense
79  // TODO We currently ensure the tensor is dense by using the correct type for
80  // the bind_value, however we do not actually need this value. It would be
81  // nicer to only have a check here.
83  if (!matchPattern(toCheck, m_Constant(&tmp))) {
84  return rewriter.notifyMatchFailure(location,
85  "Non-const or non-dense input tensor");
86  }
87 
88  // Make sure it actually is a TOSA constant (the match allows for other
89  // constants as well)
90  if (isa<ConstOp>(toCheck.getDefiningOp())) {
91  return success();
92  }
93 
94  return rewriter.notifyMatchFailure(location,
95  "The reciprocal can only be folded if "
96  "it operates on a TOSA constant");
97 }
98 
99 /// Function that checks if \p toCheck is a dense TOSA constant float tensor.
100 LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
101  TosaOp location,
102  PatternRewriter &rewriter) {
103  auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
104  if (failed(floatCheck)) {
105  return floatCheck;
106  }
107  return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
108 }
109 
110 /// Heuristic to decide when to replace a unary operation on a constant with the
111 /// folded value.
112 /// Folding operations on constants can lead to an increased memory usage
113 /// whenever the input cannot be replaced but a new constant is inserted. Hence,
114 /// this will currently only suggest folding when the memory impact is
115 /// negligible.
116 /// Takes the \p unaryOp and the constant input \p values.
117 /// \returns Whether folding should be applied.
118 bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
119  assert(unaryOp->getNumOperands() == 1);
120  auto inputOp = unaryOp->getOperand(0);
121 
122  // If the input is a splat, we don't care for the number of users
123  if (isa<SplatElementsAttr>(values)) {
124  return true;
125  }
126 
127  // If this is the only use of the tensor it should be replaced as no
128  // additional memory is required
129  return inputOp.hasOneUse();
130 }
131 
132 template <typename RangeType>
133 DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
134  ShapedType outputType,
135  llvm::ArrayRef<int64_t> permValues) {
136  using ElementType = std::decay_t<decltype(*std::begin(data))>;
137 
138  assert(inputType.getElementType() == outputType.getElementType());
139 
140  if (inputType.getNumElements() == 0)
142 
143  auto inputShape = inputType.getShape();
144 
145  // The inverted permutation map and strides of the output are used to compute
146  // the contribution of a given dimension to the destination linear index in
147  // an order-independent way.
148  auto outputStrides = computeStrides(outputType.getShape());
149  auto invertedPermValues = invertPermutationVector(permValues);
150 
151  auto initialValue = *std::begin(data);
152  SmallVector<ElementType> outputValues(inputType.getNumElements(),
153  initialValue);
154 
155  for (const auto &it : llvm::enumerate(data)) {
156  auto srcLinearIndex = it.index();
157 
158  uint64_t dstLinearIndex = 0;
159  for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
160  // Compute the index into the current dimension of the source vector.
161  auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
162  srcLinearIndex /= inputShape[dim];
163 
164  // Add the contribution of the current dimension to the output using the
165  // permutation map.
166  dstLinearIndex +=
167  outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
168  }
169 
170  outputValues[dstLinearIndex] = it.value();
171  }
172 
173  return DenseElementsAttr::get(outputType,
174  llvm::ArrayRef<ElementType>(outputValues));
175 }
176 
177 // Try to get the values of a DenseResourceElementsAttr construct
178 template <typename T>
179 std::optional<ArrayRef<T>> tryGetDenseResourceValues(ElementsAttr attr) {
180  if (auto denseResource = dyn_cast<DenseResourceElementsAttr>(attr)) {
181  // Check that the resource memory blob exists
182  AsmResourceBlob *blob = denseResource.getRawHandle().getBlob();
183  if (!blob)
184  return std::nullopt;
185 
186  // Check that the data are in a valid form
187  bool isSplat = false;
188  if (!DenseElementsAttr::isValidRawBuffer(attr.getShapedType(),
189  blob->getData(), isSplat)) {
190  return std::nullopt;
191  }
192 
193  return blob->template getDataAs<T>();
194  }
195 
196  return std::nullopt;
197 }
198 
199 // A type specialized transposition of an ElementsAttr.
200 // This implementation tries to operate on the underlying data in its raw
201 // representation when possible to avoid allocating a large number of Attribute
202 // objects.
203 DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
204  ShapedType outputType,
205  llvm::ArrayRef<int64_t> permValues) {
206  // Handle generic ElementsAttr
207  if (auto data = attr.tryGetValues<bool>())
208  return transposeType(*data, inputType, outputType, permValues);
209 
210  if (auto data = attr.tryGetValues<int8_t>())
211  return transposeType(*data, inputType, outputType, permValues);
212 
213  if (auto data = attr.tryGetValues<int16_t>())
214  return transposeType(*data, inputType, outputType, permValues);
215 
216  if (auto data = attr.tryGetValues<int32_t>())
217  return transposeType(*data, inputType, outputType, permValues);
218 
219  if (auto data = attr.tryGetValues<int64_t>())
220  return transposeType(*data, inputType, outputType, permValues);
221 
222  if (auto data = attr.tryGetValues<float>())
223  return transposeType(*data, inputType, outputType, permValues);
224 
225  if (auto data = attr.tryGetValues<APFloat>())
226  return transposeType(*data, inputType, outputType, permValues);
227 
228  // Handle DenseResourceElementsAttr
229  if (isa<DenseResourceElementsAttr>(attr)) {
230  auto elementTy = attr.getElementType();
231 
232  if (auto data = tryGetDenseResourceValues<bool>(attr);
233  data && elementTy.isInteger(1))
234  return transposeType(*data, inputType, outputType, permValues);
235 
236  if (auto data = tryGetDenseResourceValues<int8_t>(attr);
237  data && elementTy.isInteger(8))
238  return transposeType(*data, inputType, outputType, permValues);
239 
240  if (auto data = tryGetDenseResourceValues<int16_t>(attr);
241  data && elementTy.isInteger(16))
242  return transposeType(*data, inputType, outputType, permValues);
243 
244  if (auto data = tryGetDenseResourceValues<int32_t>(attr);
245  data && elementTy.isInteger(32))
246  return transposeType(*data, inputType, outputType, permValues);
247 
248  if (auto data = tryGetDenseResourceValues<int64_t>(attr);
249  data && elementTy.isInteger(64))
250  return transposeType(*data, inputType, outputType, permValues);
251 
252  if (auto data = tryGetDenseResourceValues<float>(attr);
253  data && elementTy.isF32())
254  return transposeType(*data, inputType, outputType, permValues);
255  }
256 
257  return nullptr;
258 }
259 
260 struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
262 
263  LogicalResult matchAndRewrite(tosa::TransposeOp op,
264  PatternRewriter &rewriter) const override {
265  auto outputType = cast<ShapedType>(op.getType());
266  // TOSA supports quantized types.
267  if (!outputType.getElementType().isIntOrIndexOrFloat())
268  return failure();
269 
270  ElementsAttr inputValues;
271  if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
272  return failure();
273  // Make sure the input is a constant that has a single user.
274  if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
275  return failure();
276 
277  auto permValues = llvm::map_to_vector(
278  op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });
279 
280  auto inputType = cast<ShapedType>(op.getInput1().getType());
281 
282  auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
283  if (!resultAttr) {
284  return rewriter.notifyMatchFailure(
285  op, "unsupported attribute or element type");
286  }
287 
288  rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
289  return success();
290  }
291 };
292 
293 struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
294 
296 
297  LogicalResult matchAndRewrite(ReciprocalOp recip,
298  PatternRewriter &rewriter) const override {
299  auto inputTensor = recip.getInput1();
300 
301  // Check that we can apply folding
302  auto preCondCheck =
303  notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
304  if (failed(preCondCheck)) {
305  return preCondCheck;
306  }
307 
308  // Extract the tensor values
309  DenseElementsAttr inputValues;
310  matchPattern(inputTensor, m_Constant(&inputValues));
311 
312  // Check whether this should be folded.
313  if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
314  return rewriter.notifyMatchFailure(
315  recip, "Currently, reciprocals will only be folded if the input "
316  "tensor has a single user");
317  }
318 
319  // Create a new tensor with the updated values
320  auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
321  inputValues, &ReciprocalOp::calcOneElement,
322  cast<FloatType>(inputValues.getElementType()));
323 
324  // Replace the use of the reciprocal with the transformed tensor
325  rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
326  return success();
327  }
328 };
329 
330 /// Getting the axes position of the element which is located
331 /// in the tensor at the counter index
332 
334 getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) {
335  int64_t remaining = index;
336  llvm::SmallVector<int64_t> position(tensorShape.size(), 0);
337  for (int64_t i = tensorShape.size() - 1; i >= 0; --i) {
338  position[i] = remaining % tensorShape[i];
339  remaining /= tensorShape[i];
340  }
341  return position;
342 }
343 
344 /// Getting the index of the element which is located at the
345 /// axes position in the tensor
346 
347 int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position,
348  llvm::ArrayRef<int64_t> tensorShape) {
349  int64_t index = 0;
350  int64_t multiplierTmp = 1;
351  for (int64_t i = position.size() - 1; i >= 0; --i) {
352  index += position[i] * multiplierTmp;
353  multiplierTmp *= tensorShape[i];
354  }
355  return index;
356 }
357 
358 template <typename OperationType>
359 llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
360  llvm::ArrayRef<int64_t> oldShape,
361  int64_t reductionAxis,
362  int64_t reductionIndex) {
363 
364  llvm::SmallVector<int64_t> newShape(oldShape);
365  newShape[reductionAxis] = 1;
366  /// Let's calculate the position of the index
367  llvm::SmallVector<int64_t> position =
368  getPositionFromIndex(reductionIndex, newShape);
369  auto oldTensor = oldTensorAttr.getValues<llvm::APInt>();
370  /// Starting from the first positon along the reduction axis
371  position[reductionAxis] = 0;
372  int64_t indexAtOldTensor = getIndexFromPosition(position, oldShape);
373  llvm::APInt reducedValue = oldTensor[indexAtOldTensor];
374 
375  for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
376  ++reductionAxisVal) {
377 
378  int64_t stride = std::accumulate(oldShape.begin() + reductionAxis + 1,
379  oldShape.end(), 1, std::multiplies<int>());
380  int64_t index = indexAtOldTensor + stride * reductionAxisVal;
381  reducedValue =
382  OperationType::calcOneElement(reducedValue, oldTensor[index]);
383  }
384  return reducedValue;
385 }
386 
387 template <typename OperationType>
388 struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
389 
390  ReduceConstantOptimization(MLIRContext *context,
391  bool aggressiveReduceConstant)
392  : OpRewritePattern<OperationType>(context),
393  aggressiveReduceConstant(aggressiveReduceConstant) {}
394 
396 
397  LogicalResult matchAndRewrite(OperationType op,
398  PatternRewriter &rewriter) const override {
399  Value inputOp = op.getInput();
400  auto constOp = inputOp.getDefiningOp<tosa::ConstOp>();
401 
402  if (!constOp)
403  return rewriter.notifyMatchFailure(
404  op, "reduce input must be const operation");
405 
406  if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
407  return rewriter.notifyMatchFailure(
408  op, "input operation has more than one user");
409 
410  auto resultType = cast<ShapedType>(op.getOutput().getType());
411 
412  if (!resultType.hasStaticShape())
413  return rewriter.notifyMatchFailure(op, "result type shape is not static");
414 
415  auto reductionAxis = op.getAxis();
416  const auto denseElementsAttr = constOp.getValues();
417  const auto shapedOldElementsValues =
418  cast<ShapedType>(denseElementsAttr.getType());
419 
420  if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
421  return rewriter.notifyMatchFailure(
422  op, "reduce input currently supported with integer type");
423 
424  auto oldShape = shapedOldElementsValues.getShape();
425  auto newShape = resultType.getShape();
426 
427  auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1,
428  std::multiplies<int>());
429  llvm::SmallVector<APInt> newReducedTensor(newNumOfElements);
430 
431  for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
432  ++reductionIndex) {
433 
434  /// Let's reduce all the elements along this reduction axis
435  newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>(
436  denseElementsAttr, oldShape, reductionAxis, reductionIndex);
437  }
438 
439  auto rankedTensorType = cast<RankedTensorType>(resultType);
440  auto denseAttr =
441  mlir::DenseElementsAttr::get(rankedTensorType, newReducedTensor);
442  rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
443  return success();
444  }
445  const bool aggressiveReduceConstant;
446 };
447 
448 } // namespace
449 
452  bool aggressiveReduceConstant) {
453  patterns.add<ReduceConstantOptimization<ReduceAllOp>>(
454  ctx, aggressiveReduceConstant);
455  patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(
456  ctx, aggressiveReduceConstant);
457  patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(
458  ctx, aggressiveReduceConstant);
459  patterns.add<ReduceConstantOptimization<ReduceMinOp>>(
460  ctx, aggressiveReduceConstant);
461  patterns.add<ReduceConstantOptimization<ReduceProductOp>>(
462  ctx, aggressiveReduceConstant);
463  patterns.add<ReduceConstantOptimization<ReduceSumOp>>(
464  ctx, aggressiveReduceConstant);
465 }
466 
469  patterns.add<TosaFoldConstantTranspose>(ctx);
470 }
471 
474  patterns.add<TosaFoldConstantReciprocal>(ctx);
475 }
This class represents a processed binary blob of data.
Definition: AsmState.h:91
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition: AsmState.h:145
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
int64_t getNumElements() const
Returns the number of elements held by this attribute.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void populateTosaConstantReduction(MLIRContext *ctx, RewritePatternSet &patterns, bool aggressiveReduceConstant)
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx, RewritePatternSet &patterns)
void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx, RewritePatternSet &patterns)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319