MLIR  21.0.0git
TosaFolders.cpp
Go to the documentation of this file.
1 //===- TosaFolders.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Fold TOSA operations
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <functional>
14 #include <numeric>
15 
20 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/Pass/Pass.h"
24 #include "llvm/ADT/APFloat.h"
25 #include "llvm/ADT/FloatingPointMode.h"
26 #include "llvm/ADT/SmallVector.h"
27 
28 using namespace mlir;
29 using namespace mlir::tosa;
30 
31 namespace {
32 
33 /// Apply the given transformation \p toApply to every element of the tensor to
34 /// be transformed \p toTransform.
35 ///
36 /// Elements of \p toTransform are extracted as \p SrcValueType.
37 ///
38 /// \returns A tensor with the same size as \p toTransform, containing
39 /// \p TargetValueType values of type \p TargetType.
40 template <class SrcValType, class TargetValType, class TargetType>
41 DenseElementsAttr applyElementWise(
42  const DenseElementsAttr &toTransform,
43  const std::function<TargetValType(const SrcValType &)> &toApply,
44  TargetType targetType) {
45  SmallVector<TargetValType> transformedValues;
46  // We already know the amount of values we will insert, reserve space for
47  // all of them to avoid dynamic resizing
48  transformedValues.reserve(toTransform.getNumElements());
49  for (auto val : toTransform.getValues<SrcValType>()) {
50  auto transformedVal = toApply(val);
51  transformedValues.push_back(transformedVal);
52  }
53 
54  // Make sure that the output tensor has the expected output type
55  auto inShape = toTransform.getType();
56  auto outTy = inShape.cloneWith({}, targetType);
57 
58  return DenseElementsAttr::get(outTy, transformedValues);
59 }
60 
61 template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
62  const DenseElementsAttr &toTransform,
63  const std::function<APFloat(const APFloat &)> &toApply,
64  FloatType targetType);
65 
66 /// Function that checks if the type contained in \p toCheck is float.
67 LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
68  PatternRewriter &rewriter) {
69  if (isa<FloatType>(toCheck.getType().getElementType())) {
70  return success();
71  }
72  return rewriter.notifyMatchFailure(location,
73  "Unexpected input tensor type: the "
74  "TOSA spec only allows floats");
75 }
76 
77 /// Function that checks if \p toCheck is a dense TOSA constant tensor.
78 LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
79  TosaOp location,
80  PatternRewriter &rewriter) {
81  // Check whether the tensor is constant and dense
82  // TODO We currently ensure the tensor is dense by using the correct type for
83  // the bind_value, however we do not actually need this value. It would be
84  // nicer to only have a check here.
86  if (!matchPattern(toCheck, m_Constant(&tmp))) {
87  return rewriter.notifyMatchFailure(location,
88  "Non-const or non-dense input tensor");
89  }
90 
91  // Make sure it actually is a TOSA constant (the match allows for other
92  // constants as well)
93  if (isa<ConstOp>(toCheck.getDefiningOp())) {
94  return success();
95  }
96 
97  return rewriter.notifyMatchFailure(location,
98  "The reciprocal can only be folded if "
99  "it operates on a TOSA constant");
100 }
101 
102 /// Function that checks if \p toCheck is a dense TOSA constant float tensor.
103 LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
104  TosaOp location,
105  PatternRewriter &rewriter) {
106  auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
107  if (failed(floatCheck)) {
108  return floatCheck;
109  }
110  return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
111 }
112 
113 /// Heuristic to decide when to replace a unary operation on a constant with the
114 /// folded value.
115 /// Folding operations on constants can lead to an increased memory usage
116 /// whenever the input cannot be replaced but a new constant is inserted. Hence,
117 /// this will currently only suggest folding when the memory impact is
118 /// negligible.
119 /// Takes the \p unaryOp and the constant input \p values.
120 /// \returns Whether folding should be applied.
121 bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
122  assert(unaryOp->getNumOperands() == 1);
123  auto inputOp = unaryOp->getOperand(0);
124 
125  // If the input is a splat, we don't care for the number of users
126  if (isa<SplatElementsAttr>(values)) {
127  return true;
128  }
129 
130  // If this is the only use of the tensor it should be replaced as no
131  // additional memory is required
132  return inputOp.hasOneUse();
133 }
134 
135 template <typename RangeType>
136 DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
137  ShapedType outputType,
138  llvm::ArrayRef<int64_t> permValues) {
139  using ElementType = std::decay_t<decltype(*std::begin(data))>;
140 
141  assert(inputType.getElementType() == outputType.getElementType());
142 
143  if (inputType.getNumElements() == 0)
145 
146  auto inputShape = inputType.getShape();
147 
148  // The inverted permutation map and strides of the output are used to compute
149  // the contribution of a given dimension to the destination linear index in
150  // an order-independent way.
151  auto outputStrides = computeStrides(outputType.getShape());
152  auto invertedPermValues = invertPermutationVector(permValues);
153 
154  auto initialValue = *std::begin(data);
155  SmallVector<ElementType> outputValues(inputType.getNumElements(),
156  initialValue);
157 
158  for (const auto &it : llvm::enumerate(data)) {
159  auto srcLinearIndex = it.index();
160 
161  uint64_t dstLinearIndex = 0;
162  for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
163  // Compute the index into the current dimension of the source vector.
164  auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
165  srcLinearIndex /= inputShape[dim];
166 
167  // Add the contribution of the current dimension to the output using the
168  // permutation map.
169  dstLinearIndex +=
170  outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
171  }
172 
173  outputValues[dstLinearIndex] = it.value();
174  }
175 
176  return DenseElementsAttr::get(outputType,
177  llvm::ArrayRef<ElementType>(outputValues));
178 }
179 
180 // Try to get the values of a DenseResourceElementsAttr construct
181 template <typename T>
182 std::optional<ArrayRef<T>> tryGetDenseResourceValues(ElementsAttr attr) {
183  if (auto denseResource = dyn_cast<DenseResourceElementsAttr>(attr)) {
184  // Check that the resource memory blob exists
185  AsmResourceBlob *blob = denseResource.getRawHandle().getBlob();
186  if (!blob)
187  return std::nullopt;
188 
189  // Check that the data are in a valid form
190  bool isSplat = false;
191  if (!DenseElementsAttr::isValidRawBuffer(attr.getShapedType(),
192  blob->getData(), isSplat)) {
193  return std::nullopt;
194  }
195 
196  return blob->template getDataAs<T>();
197  }
198 
199  return std::nullopt;
200 }
201 
202 // A type specialized transposition of an ElementsAttr.
203 // This implementation tries to operate on the underlying data in its raw
204 // representation when possible to avoid allocating a large number of Attribute
205 // objects.
206 DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
207  ShapedType outputType,
208  llvm::ArrayRef<int64_t> permValues) {
209  // Handle generic ElementsAttr
210  if (auto data = attr.tryGetValues<bool>())
211  return transposeType(*data, inputType, outputType, permValues);
212 
213  if (auto data = attr.tryGetValues<int8_t>())
214  return transposeType(*data, inputType, outputType, permValues);
215 
216  if (auto data = attr.tryGetValues<int16_t>())
217  return transposeType(*data, inputType, outputType, permValues);
218 
219  if (auto data = attr.tryGetValues<int32_t>())
220  return transposeType(*data, inputType, outputType, permValues);
221 
222  if (auto data = attr.tryGetValues<int64_t>())
223  return transposeType(*data, inputType, outputType, permValues);
224 
225  if (auto data = attr.tryGetValues<float>())
226  return transposeType(*data, inputType, outputType, permValues);
227 
228  if (auto data = attr.tryGetValues<APFloat>())
229  return transposeType(*data, inputType, outputType, permValues);
230 
231  // Handle DenseResourceElementsAttr
232  if (isa<DenseResourceElementsAttr>(attr)) {
233  auto elementTy = attr.getElementType();
234 
235  if (auto data = tryGetDenseResourceValues<bool>(attr);
236  data && elementTy.isInteger(1))
237  return transposeType(*data, inputType, outputType, permValues);
238 
239  if (auto data = tryGetDenseResourceValues<int8_t>(attr);
240  data && elementTy.isInteger(8))
241  return transposeType(*data, inputType, outputType, permValues);
242 
243  if (auto data = tryGetDenseResourceValues<int16_t>(attr);
244  data && elementTy.isInteger(16))
245  return transposeType(*data, inputType, outputType, permValues);
246 
247  if (auto data = tryGetDenseResourceValues<int32_t>(attr);
248  data && elementTy.isInteger(32))
249  return transposeType(*data, inputType, outputType, permValues);
250 
251  if (auto data = tryGetDenseResourceValues<int64_t>(attr);
252  data && elementTy.isInteger(64))
253  return transposeType(*data, inputType, outputType, permValues);
254 
255  if (auto data = tryGetDenseResourceValues<float>(attr);
256  data && elementTy.isF32())
257  return transposeType(*data, inputType, outputType, permValues);
258  }
259 
260  return nullptr;
261 }
262 
263 struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
265 
266  LogicalResult matchAndRewrite(tosa::TransposeOp op,
267  PatternRewriter &rewriter) const override {
268  auto outputType = cast<ShapedType>(op.getType());
269  // TOSA supports quantized types.
270  if (!outputType.getElementType().isIntOrIndexOrFloat())
271  return failure();
272 
273  ElementsAttr inputValues;
274  if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
275  return failure();
276  // Make sure the input is a constant that has a single user.
277  if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
278  return failure();
279 
280  auto permValues = llvm::map_to_vector(
281  op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });
282 
283  auto inputType = cast<ShapedType>(op.getInput1().getType());
284 
285  auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
286  if (!resultAttr) {
287  return rewriter.notifyMatchFailure(
288  op, "unsupported attribute or element type");
289  }
290 
291  rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
292  return success();
293  }
294 };
295 
296 struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
297 
299 
300  LogicalResult matchAndRewrite(ReciprocalOp recip,
301  PatternRewriter &rewriter) const override {
302  auto inputTensor = recip.getInput1();
303 
304  // Check that we can apply folding
305  auto preCondCheck =
306  notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
307  if (failed(preCondCheck)) {
308  return preCondCheck;
309  }
310 
311  // Extract the tensor values
312  DenseElementsAttr inputValues;
313  matchPattern(inputTensor, m_Constant(&inputValues));
314 
315  // Check whether this should be folded.
316  if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
317  return rewriter.notifyMatchFailure(
318  recip, "Currently, reciprocals will only be folded if the input "
319  "tensor has a single user");
320  }
321 
322  // Create a new tensor with the updated values
323  auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
324  inputValues, &ReciprocalOp::calcOneElement,
325  cast<FloatType>(inputValues.getElementType()));
326 
327  // Replace the use of the reciprocal with the transformed tensor
328  rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
329  return success();
330  }
331 };
332 
333 /// Getting the axes position of the element which is located
334 /// in the tensor at the counter index
335 
337 getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) {
338  int64_t remaining = index;
339  llvm::SmallVector<int64_t> position(tensorShape.size(), 0);
340  for (int64_t i = tensorShape.size() - 1; i >= 0; --i) {
341  position[i] = remaining % tensorShape[i];
342  remaining /= tensorShape[i];
343  }
344  return position;
345 }
346 
347 /// Getting the index of the element which is located at the
348 /// axes position in the tensor
349 
350 int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position,
351  llvm::ArrayRef<int64_t> tensorShape) {
352  int64_t index = 0;
353  int64_t multiplierTmp = 1;
354  for (int64_t i = position.size() - 1; i >= 0; --i) {
355  index += position[i] * multiplierTmp;
356  multiplierTmp *= tensorShape[i];
357  }
358  return index;
359 }
360 
361 template <typename OperationType>
362 llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
363  llvm::ArrayRef<int64_t> oldShape,
364  int64_t reductionAxis,
365  int64_t reductionIndex) {
366 
367  llvm::SmallVector<int64_t> newShape(oldShape);
368  newShape[reductionAxis] = 1;
369  /// Let's calculate the position of the index
370  llvm::SmallVector<int64_t> position =
371  getPositionFromIndex(reductionIndex, newShape);
372  auto oldTensor = oldTensorAttr.getValues<llvm::APInt>();
373  /// Starting from the first positon along the reduction axis
374  position[reductionAxis] = 0;
375  int64_t indexAtOldTensor = getIndexFromPosition(position, oldShape);
376  llvm::APInt reducedValue = oldTensor[indexAtOldTensor];
377 
378  for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
379  ++reductionAxisVal) {
380 
381  int64_t stride = std::accumulate(oldShape.begin() + reductionAxis + 1,
382  oldShape.end(), 1, std::multiplies<int>());
383  int64_t index = indexAtOldTensor + stride * reductionAxisVal;
384  reducedValue =
385  OperationType::calcOneElement(reducedValue, oldTensor[index]);
386  }
387  return reducedValue;
388 }
389 
390 template <typename OperationType>
391 struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
392 
393  ReduceConstantOptimization(MLIRContext *context,
394  bool aggressiveReduceConstant)
395  : OpRewritePattern<OperationType>(context),
396  aggressiveReduceConstant(aggressiveReduceConstant) {}
397 
399 
400  LogicalResult matchAndRewrite(OperationType op,
401  PatternRewriter &rewriter) const override {
402  Value inputOp = op.getInput();
403  auto constOp = inputOp.getDefiningOp<tosa::ConstOp>();
404 
405  if (!constOp)
406  return rewriter.notifyMatchFailure(
407  op, "reduce input must be const operation");
408 
409  if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
410  return rewriter.notifyMatchFailure(
411  op, "input operation has more than one user");
412 
413  auto resultType = cast<ShapedType>(op.getOutput().getType());
414 
415  if (!resultType.hasStaticShape())
416  return rewriter.notifyMatchFailure(op, "result type shape is not static");
417 
418  auto reductionAxis = op.getAxis();
419  const auto denseElementsAttr = constOp.getValues();
420  const auto shapedOldElementsValues =
421  cast<ShapedType>(denseElementsAttr.getType());
422 
423  if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
424  return rewriter.notifyMatchFailure(
425  op, "reduce input currently supported with integer type");
426 
427  auto oldShape = shapedOldElementsValues.getShape();
428  auto newShape = resultType.getShape();
429 
430  auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1,
431  std::multiplies<int>());
432  llvm::SmallVector<APInt> newReducedTensor(newNumOfElements);
433 
434  for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
435  ++reductionIndex) {
436 
437  /// Let's reduce all the elements along this reduction axis
438  newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>(
439  denseElementsAttr, oldShape, reductionAxis, reductionIndex);
440  }
441 
442  auto rankedTensorType = cast<RankedTensorType>(resultType);
443  auto denseAttr =
444  mlir::DenseElementsAttr::get(rankedTensorType, newReducedTensor);
445  rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
446  return success();
447  }
448  const bool aggressiveReduceConstant;
449 };
450 
451 } // namespace
452 
455  bool aggressiveReduceConstant) {
456  patterns.add<ReduceConstantOptimization<ReduceAllOp>>(
457  ctx, aggressiveReduceConstant);
458  patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(
459  ctx, aggressiveReduceConstant);
460  patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(
461  ctx, aggressiveReduceConstant);
462  patterns.add<ReduceConstantOptimization<ReduceMinOp>>(
463  ctx, aggressiveReduceConstant);
464  patterns.add<ReduceConstantOptimization<ReduceProductOp>>(
465  ctx, aggressiveReduceConstant);
466  patterns.add<ReduceConstantOptimization<ReduceSumOp>>(
467  ctx, aggressiveReduceConstant);
468 }
469 
472  patterns.add<TosaFoldConstantTranspose>(ctx);
473 }
474 
477  patterns.add<TosaFoldConstantReciprocal>(ctx);
478 }
This class represents a processed binary blob of data.
Definition: AsmState.h:91
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition: AsmState.h:145
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
int64_t getNumElements() const
Returns the number of elements held by this attribute.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:191
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateTosaConstantReduction(MLIRContext *ctx, RewritePatternSet &patterns, bool aggressiveReduceConstant)
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx, RewritePatternSet &patterns)
void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx, RewritePatternSet &patterns)
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:474
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319