MLIR  22.0.0git
TosaFolders.cpp
Go to the documentation of this file.
1 //===- TosaFolders.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Fold TOSA operations
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <functional>
14 #include <numeric>
15 
20 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::tosa;
28 
29 namespace {
30 
31 /// Apply the given transformation \p toApply to every element of the tensor to
32 /// be transformed \p toTransform.
33 ///
34 /// Elements of \p toTransform are extracted as \p SrcValueType.
35 ///
36 /// \returns A tensor with the same size as \p toTransform, containing
37 /// \p TargetValueType values of type \p TargetType.
38 template <class SrcValType, class TargetValType, class TargetType>
39 DenseElementsAttr applyElementWise(
40  const DenseElementsAttr &toTransform,
41  const std::function<TargetValType(const SrcValType &)> &toApply,
42  TargetType targetType) {
43  SmallVector<TargetValType> transformedValues;
44  // We already know the amount of values we will insert, reserve space for
45  // all of them to avoid dynamic resizing
46  transformedValues.reserve(toTransform.getNumElements());
47  for (auto val : toTransform.getValues<SrcValType>()) {
48  auto transformedVal = toApply(val);
49  transformedValues.push_back(transformedVal);
50  }
51 
52  // Make sure that the output tensor has the expected output type
53  auto inShape = toTransform.getType();
54  auto outTy = inShape.cloneWith({}, targetType);
55 
56  return DenseElementsAttr::get(outTy, transformedValues);
57 }
58 
59 template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
60  const DenseElementsAttr &toTransform,
61  const std::function<APFloat(const APFloat &)> &toApply,
62  FloatType targetType);
63 
64 /// Function that checks if the type contained in \p toCheck is float.
65 LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
66  PatternRewriter &rewriter) {
67  if (isa<FloatType>(toCheck.getType().getElementType())) {
68  return success();
69  }
70  return rewriter.notifyMatchFailure(location,
71  "Unexpected input tensor type: the "
72  "TOSA spec only allows floats");
73 }
74 
75 /// Function that checks if \p toCheck is a dense TOSA constant tensor.
76 LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
77  TosaOp location,
78  PatternRewriter &rewriter) {
79  // Check whether the tensor is constant and dense
80  // TODO We currently ensure the tensor is dense by using the correct type for
81  // the bind_value, however we do not actually need this value. It would be
82  // nicer to only have a check here.
84  if (!matchPattern(toCheck, m_Constant(&tmp))) {
85  return rewriter.notifyMatchFailure(location,
86  "Non-const or non-dense input tensor");
87  }
88 
89  // Make sure it actually is a TOSA constant (the match allows for other
90  // constants as well)
91  if (isa<ConstOp>(toCheck.getDefiningOp())) {
92  return success();
93  }
94 
95  return rewriter.notifyMatchFailure(location,
96  "The reciprocal can only be folded if "
97  "it operates on a TOSA constant");
98 }
99 
100 /// Function that checks if \p toCheck is a dense TOSA constant float tensor.
101 LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
102  TosaOp location,
103  PatternRewriter &rewriter) {
104  auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
105  if (failed(floatCheck)) {
106  return floatCheck;
107  }
108  return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
109 }
110 
111 /// Heuristic to decide when to replace a unary operation on a constant with the
112 /// folded value.
113 /// Folding operations on constants can lead to an increased memory usage
114 /// whenever the input cannot be replaced but a new constant is inserted. Hence,
115 /// this will currently only suggest folding when the memory impact is
116 /// negligible.
117 /// Takes the \p unaryOp and the constant input \p values.
118 /// \returns Whether folding should be applied.
119 bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
120  assert(unaryOp->getNumOperands() == 1);
121  auto inputOp = unaryOp->getOperand(0);
122 
123  // If the input is a splat, we don't care for the number of users
124  if (isa<SplatElementsAttr>(values)) {
125  return true;
126  }
127 
128  // If this is the only use of the tensor it should be replaced as no
129  // additional memory is required
130  return inputOp.hasOneUse();
131 }
132 
133 template <typename RangeType>
134 DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
135  ShapedType outputType,
136  llvm::ArrayRef<int64_t> permValues) {
137  using ElementType = std::decay_t<decltype(*std::begin(data))>;
138 
139  assert(inputType.getElementType() == outputType.getElementType());
140 
141  if (inputType.getNumElements() == 0)
143 
144  auto inputShape = inputType.getShape();
145 
146  // The inverted permutation map and strides of the output are used to compute
147  // the contribution of a given dimension to the destination linear index in
148  // an order-independent way.
149  auto outputStrides = computeStrides(outputType.getShape());
150  auto invertedPermValues = invertPermutationVector(permValues);
151 
152  auto initialValue = *std::begin(data);
153  SmallVector<ElementType> outputValues(inputType.getNumElements(),
154  initialValue);
155 
156  for (const auto &it : llvm::enumerate(data)) {
157  auto srcLinearIndex = it.index();
158 
159  uint64_t dstLinearIndex = 0;
160  for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
161  // Compute the index into the current dimension of the source vector.
162  auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
163  srcLinearIndex /= inputShape[dim];
164 
165  // Add the contribution of the current dimension to the output using the
166  // permutation map.
167  dstLinearIndex +=
168  outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
169  }
170 
171  outputValues[dstLinearIndex] = it.value();
172  }
173 
174  return DenseElementsAttr::get(outputType,
175  llvm::ArrayRef<ElementType>(outputValues));
176 }
177 
178 // Try to get the values of a DenseResourceElementsAttr construct
179 template <typename T>
180 std::optional<ArrayRef<T>> tryGetDenseResourceValues(ElementsAttr attr) {
181  if (auto denseResource = dyn_cast<DenseResourceElementsAttr>(attr)) {
182  // Check that the resource memory blob exists
183  AsmResourceBlob *blob = denseResource.getRawHandle().getBlob();
184  if (!blob)
185  return std::nullopt;
186 
187  // Check that the data are in a valid form
188  bool isSplat = false;
189  if (!DenseElementsAttr::isValidRawBuffer(attr.getShapedType(),
190  blob->getData(), isSplat)) {
191  return std::nullopt;
192  }
193 
194  return blob->template getDataAs<T>();
195  }
196 
197  return std::nullopt;
198 }
199 
200 // A type specialized transposition of an ElementsAttr.
201 // This implementation tries to operate on the underlying data in its raw
202 // representation when possible to avoid allocating a large number of Attribute
203 // objects.
204 DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
205  ShapedType outputType,
206  llvm::ArrayRef<int64_t> permValues) {
207  // Handle generic ElementsAttr
208  if (auto data = attr.tryGetValues<bool>())
209  return transposeType(*data, inputType, outputType, permValues);
210 
211  if (auto data = attr.tryGetValues<int8_t>())
212  return transposeType(*data, inputType, outputType, permValues);
213 
214  if (auto data = attr.tryGetValues<int16_t>())
215  return transposeType(*data, inputType, outputType, permValues);
216 
217  if (auto data = attr.tryGetValues<int32_t>())
218  return transposeType(*data, inputType, outputType, permValues);
219 
220  if (auto data = attr.tryGetValues<int64_t>())
221  return transposeType(*data, inputType, outputType, permValues);
222 
223  if (auto data = attr.tryGetValues<float>())
224  return transposeType(*data, inputType, outputType, permValues);
225 
226  if (auto data = attr.tryGetValues<APFloat>())
227  return transposeType(*data, inputType, outputType, permValues);
228 
229  // Handle DenseResourceElementsAttr
230  if (isa<DenseResourceElementsAttr>(attr)) {
231  auto elementTy = attr.getElementType();
232 
233  if (auto data = tryGetDenseResourceValues<bool>(attr);
234  data && elementTy.isInteger(1))
235  return transposeType(*data, inputType, outputType, permValues);
236 
237  if (auto data = tryGetDenseResourceValues<int8_t>(attr);
238  data && elementTy.isInteger(8))
239  return transposeType(*data, inputType, outputType, permValues);
240 
241  if (auto data = tryGetDenseResourceValues<int16_t>(attr);
242  data && elementTy.isInteger(16))
243  return transposeType(*data, inputType, outputType, permValues);
244 
245  if (auto data = tryGetDenseResourceValues<int32_t>(attr);
246  data && elementTy.isInteger(32))
247  return transposeType(*data, inputType, outputType, permValues);
248 
249  if (auto data = tryGetDenseResourceValues<int64_t>(attr);
250  data && elementTy.isInteger(64))
251  return transposeType(*data, inputType, outputType, permValues);
252 
253  if (auto data = tryGetDenseResourceValues<float>(attr);
254  data && elementTy.isF32())
255  return transposeType(*data, inputType, outputType, permValues);
256  }
257 
258  return nullptr;
259 }
260 
261 struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
263 
264  LogicalResult matchAndRewrite(tosa::TransposeOp op,
265  PatternRewriter &rewriter) const override {
266  auto outputType = cast<ShapedType>(op.getType());
267  // TOSA supports quantized types.
268  if (!outputType.getElementType().isIntOrIndexOrFloat())
269  return failure();
270 
271  ElementsAttr inputValues;
272  if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
273  return failure();
274  // Make sure the input is a constant that has a single user.
275  if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
276  return failure();
277 
278  auto permValues = llvm::map_to_vector(
279  op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });
280 
281  auto inputType = cast<ShapedType>(op.getInput1().getType());
282 
283  auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
284  if (!resultAttr) {
285  return rewriter.notifyMatchFailure(
286  op, "unsupported attribute or element type");
287  }
288 
289  rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
290  return success();
291  }
292 };
293 
294 struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
295 
297 
298  LogicalResult matchAndRewrite(ReciprocalOp recip,
299  PatternRewriter &rewriter) const override {
300  auto inputTensor = recip.getInput1();
301 
302  // Check that we can apply folding
303  auto preCondCheck =
304  notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
305  if (failed(preCondCheck)) {
306  return preCondCheck;
307  }
308 
309  // Extract the tensor values
310  DenseElementsAttr inputValues;
311  matchPattern(inputTensor, m_Constant(&inputValues));
312 
313  // Check whether this should be folded.
314  if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
315  return rewriter.notifyMatchFailure(
316  recip, "Currently, reciprocals will only be folded if the input "
317  "tensor has a single user");
318  }
319 
320  // Create a new tensor with the updated values
321  auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
322  inputValues, &ReciprocalOp::calcOneElement,
323  cast<FloatType>(inputValues.getElementType()));
324 
325  // Replace the use of the reciprocal with the transformed tensor
326  rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
327  return success();
328  }
329 };
330 
331 /// Getting the axes position of the element which is located
332 /// in the tensor at the counter index
333 
335 getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) {
336  int64_t remaining = index;
337  llvm::SmallVector<int64_t> position(tensorShape.size(), 0);
338  for (int64_t i = tensorShape.size() - 1; i >= 0; --i) {
339  position[i] = remaining % tensorShape[i];
340  remaining /= tensorShape[i];
341  }
342  return position;
343 }
344 
345 /// Getting the index of the element which is located at the
346 /// axes position in the tensor
347 
348 int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position,
349  llvm::ArrayRef<int64_t> tensorShape) {
350  int64_t index = 0;
351  int64_t multiplierTmp = 1;
352  for (int64_t i = position.size() - 1; i >= 0; --i) {
353  index += position[i] * multiplierTmp;
354  multiplierTmp *= tensorShape[i];
355  }
356  return index;
357 }
358 
359 template <typename OperationType>
360 llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
361  llvm::ArrayRef<int64_t> oldShape,
362  int64_t reductionAxis,
363  int64_t reductionIndex) {
364 
365  llvm::SmallVector<int64_t> newShape(oldShape);
366  newShape[reductionAxis] = 1;
367  /// Let's calculate the position of the index
368  llvm::SmallVector<int64_t> position =
369  getPositionFromIndex(reductionIndex, newShape);
370  auto oldTensor = oldTensorAttr.getValues<llvm::APInt>();
371  /// Starting from the first positon along the reduction axis
372  position[reductionAxis] = 0;
373  int64_t indexAtOldTensor = getIndexFromPosition(position, oldShape);
374  llvm::APInt reducedValue = oldTensor[indexAtOldTensor];
375 
376  for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
377  ++reductionAxisVal) {
378 
379  int64_t stride = llvm::product_of(oldShape.drop_front(reductionAxis + 1));
380  int64_t index = indexAtOldTensor + stride * reductionAxisVal;
381  reducedValue =
382  OperationType::calcOneElement(reducedValue, oldTensor[index]);
383  }
384  return reducedValue;
385 }
386 
387 template <typename OperationType>
388 struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
389 
390  ReduceConstantOptimization(MLIRContext *context,
391  bool aggressiveReduceConstant)
392  : OpRewritePattern<OperationType>(context),
393  aggressiveReduceConstant(aggressiveReduceConstant) {}
394 
396 
397  LogicalResult matchAndRewrite(OperationType op,
398  PatternRewriter &rewriter) const override {
399  Value inputOp = op.getInput();
400  auto constOp = inputOp.getDefiningOp<tosa::ConstOp>();
401 
402  if (!constOp)
403  return rewriter.notifyMatchFailure(
404  op, "reduce input must be const operation");
405 
406  if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
407  return rewriter.notifyMatchFailure(
408  op, "input operation has more than one user");
409 
410  auto resultType = cast<ShapedType>(op.getOutput().getType());
411 
412  if (!resultType.hasStaticShape())
413  return rewriter.notifyMatchFailure(op, "result type shape is not static");
414 
415  auto reductionAxis = op.getAxis();
416  const auto denseElementsAttr = constOp.getValues();
417  const auto shapedOldElementsValues =
418  cast<ShapedType>(denseElementsAttr.getType());
419 
420  if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
421  return rewriter.notifyMatchFailure(
422  op, "reduce input currently supported with integer type");
423 
424  auto oldShape = shapedOldElementsValues.getShape();
425  auto newShape = resultType.getShape();
426 
427  int64_t newNumOfElements = llvm::product_of(newShape);
428  llvm::SmallVector<APInt> newReducedTensor(newNumOfElements);
429 
430  for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
431  ++reductionIndex) {
432 
433  /// Let's reduce all the elements along this reduction axis
434  newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>(
435  denseElementsAttr, oldShape, reductionAxis, reductionIndex);
436  }
437 
438  auto rankedTensorType = cast<RankedTensorType>(resultType);
439  auto denseAttr =
440  mlir::DenseElementsAttr::get(rankedTensorType, newReducedTensor);
441  rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
442  return success();
443  }
444  const bool aggressiveReduceConstant;
445 };
446 
447 } // namespace
448 
451  bool aggressiveReduceConstant) {
452  patterns.add<ReduceConstantOptimization<ReduceAllOp>>(
453  ctx, aggressiveReduceConstant);
454  patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(
455  ctx, aggressiveReduceConstant);
456  patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(
457  ctx, aggressiveReduceConstant);
458  patterns.add<ReduceConstantOptimization<ReduceMinOp>>(
459  ctx, aggressiveReduceConstant);
460  patterns.add<ReduceConstantOptimization<ReduceProductOp>>(
461  ctx, aggressiveReduceConstant);
462  patterns.add<ReduceConstantOptimization<ReduceSumOp>>(
463  ctx, aggressiveReduceConstant);
464 }
465 
468  patterns.add<TosaFoldConstantTranspose>(ctx);
469 }
470 
473  patterns.add<TosaFoldConstantReciprocal>(ctx);
474 }
This class represents a processed binary blob of data.
Definition: AsmState.h:91
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition: AsmState.h:145
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
int64_t getNumElements() const
Returns the number of elements held by this attribute.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
void populateTosaConstantReduction(MLIRContext *ctx, RewritePatternSet &patterns, bool aggressiveReduceConstant)
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx, RewritePatternSet &patterns)
void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx, RewritePatternSet &patterns)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322