MLIR 22.0.0git
TosaFolders.cpp
Go to the documentation of this file.
1//===- TosaFolders.cpp ----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Fold TOSA operations
10//
11//===----------------------------------------------------------------------===//
12
13#include <functional>
14#include <numeric>
15
22#include "mlir/IR/Matchers.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallVector.h"
25
26using namespace mlir;
27using namespace mlir::tosa;
28
29namespace {
30
31/// Apply the given transformation \p toApply to every element of the tensor to
32/// be transformed \p toTransform.
33///
34/// Elements of \p toTransform are extracted as \p SrcValueType.
35///
36/// \returns A tensor with the same size as \p toTransform, containing
37/// \p TargetValueType values of type \p TargetType.
38template <class SrcValType, class TargetValType, class TargetType>
39DenseElementsAttr applyElementWise(
40 const DenseElementsAttr &toTransform,
41 const std::function<TargetValType(const SrcValType &)> &toApply,
42 TargetType targetType) {
43 SmallVector<TargetValType> transformedValues;
44 // We already know the amount of values we will insert, reserve space for
45 // all of them to avoid dynamic resizing
46 transformedValues.reserve(toTransform.getNumElements());
47 for (auto val : toTransform.getValues<SrcValType>()) {
48 auto transformedVal = toApply(val);
49 transformedValues.push_back(transformedVal);
50 }
51
52 // Make sure that the output tensor has the expected output type
53 auto inShape = toTransform.getType();
54 auto outTy = inShape.cloneWith({}, targetType);
55
56 return DenseElementsAttr::get(outTy, transformedValues);
57}
58
59template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
60 const DenseElementsAttr &toTransform,
61 const std::function<APFloat(const APFloat &)> &toApply,
62 FloatType targetType);
63
64/// Function that checks if the type contained in \p toCheck is float.
65LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
66 PatternRewriter &rewriter) {
67 if (isa<FloatType>(toCheck.getType().getElementType())) {
68 return success();
69 }
70 return rewriter.notifyMatchFailure(location,
71 "Unexpected input tensor type: the "
72 "TOSA spec only allows floats");
73}
74
75/// Function that checks if \p toCheck is a dense TOSA constant tensor.
76LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
77 TosaOp location,
78 PatternRewriter &rewriter) {
79 // Check whether the tensor is constant and dense
80 // TODO We currently ensure the tensor is dense by using the correct type for
81 // the bind_value, however we do not actually need this value. It would be
82 // nicer to only have a check here.
84 if (!matchPattern(toCheck, m_Constant(&tmp))) {
85 return rewriter.notifyMatchFailure(location,
86 "Non-const or non-dense input tensor");
87 }
88
89 // Make sure it actually is a TOSA constant (the match allows for other
90 // constants as well)
91 if (isa<ConstOp>(toCheck.getDefiningOp())) {
92 return success();
93 }
94
95 return rewriter.notifyMatchFailure(location,
96 "The reciprocal can only be folded if "
97 "it operates on a TOSA constant");
98}
99
100/// Function that checks if \p toCheck is a dense TOSA constant float tensor.
101LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
102 TosaOp location,
103 PatternRewriter &rewriter) {
104 auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
105 if (failed(floatCheck)) {
106 return floatCheck;
107 }
108 return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
109}
110
111/// Heuristic to decide when to replace a unary operation on a constant with the
112/// folded value.
113/// Folding operations on constants can lead to an increased memory usage
114/// whenever the input cannot be replaced but a new constant is inserted. Hence,
115/// this will currently only suggest folding when the memory impact is
116/// negligible.
117/// Takes the \p unaryOp and the constant input \p values.
118/// \returns Whether folding should be applied.
119bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
120 assert(unaryOp->getNumOperands() == 1);
121 auto inputOp = unaryOp->getOperand(0);
122
123 // If the input is a splat, we don't care for the number of users
124 if (isa<SplatElementsAttr>(values)) {
125 return true;
126 }
127
128 // If this is the only use of the tensor it should be replaced as no
129 // additional memory is required
130 return inputOp.hasOneUse();
131}
132
133template <typename RangeType>
134DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
135 ShapedType outputType,
136 llvm::ArrayRef<int64_t> permValues) {
137 using ElementType = std::decay_t<decltype(*std::begin(data))>;
138
139 assert(inputType.getElementType() == outputType.getElementType());
140
141 if (inputType.getNumElements() == 0)
143
144 auto inputShape = inputType.getShape();
145
146 // The inverted permutation map and strides of the output are used to compute
147 // the contribution of a given dimension to the destination linear index in
148 // an order-independent way.
149 auto outputStrides = computeStrides(outputType.getShape());
150 auto invertedPermValues = invertPermutationVector(permValues);
151
152 auto initialValue = *std::begin(data);
153 SmallVector<ElementType> outputValues(inputType.getNumElements(),
154 initialValue);
155
156 for (const auto &it : llvm::enumerate(data)) {
157 auto srcLinearIndex = it.index();
158
159 uint64_t dstLinearIndex = 0;
160 for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
161 // Compute the index into the current dimension of the source vector.
162 auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
163 srcLinearIndex /= inputShape[dim];
164
165 // Add the contribution of the current dimension to the output using the
166 // permutation map.
167 dstLinearIndex +=
168 outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
169 }
170
171 outputValues[dstLinearIndex] = it.value();
172 }
173
174 return DenseElementsAttr::get(outputType,
175 llvm::ArrayRef<ElementType>(outputValues));
176}
177
178// Try to get the values of a DenseResourceElementsAttr construct
179template <typename T>
180std::optional<ArrayRef<T>> tryGetDenseResourceValues(ElementsAttr attr) {
181 if (auto denseResource = dyn_cast<DenseResourceElementsAttr>(attr)) {
182 // Check that the resource memory blob exists
183 AsmResourceBlob *blob = denseResource.getRawHandle().getBlob();
184 if (!blob)
185 return std::nullopt;
186
187 // Check that the data are in a valid form
188 bool isSplat = false;
189 if (!DenseElementsAttr::isValidRawBuffer(attr.getShapedType(),
190 blob->getData(), isSplat)) {
191 return std::nullopt;
192 }
193
194 return blob->template getDataAs<T>();
195 }
196
197 return std::nullopt;
198}
199
200// A type specialized transposition of an ElementsAttr.
201// This implementation tries to operate on the underlying data in its raw
202// representation when possible to avoid allocating a large number of Attribute
203// objects.
204DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
205 ShapedType outputType,
206 llvm::ArrayRef<int64_t> permValues) {
207 // Handle generic ElementsAttr
208 if (auto data = attr.tryGetValues<bool>())
209 return transposeType(*data, inputType, outputType, permValues);
210
211 if (auto data = attr.tryGetValues<int8_t>())
212 return transposeType(*data, inputType, outputType, permValues);
213
214 if (auto data = attr.tryGetValues<int16_t>())
215 return transposeType(*data, inputType, outputType, permValues);
216
217 if (auto data = attr.tryGetValues<int32_t>())
218 return transposeType(*data, inputType, outputType, permValues);
219
220 if (auto data = attr.tryGetValues<int64_t>())
221 return transposeType(*data, inputType, outputType, permValues);
222
223 if (auto data = attr.tryGetValues<float>())
224 return transposeType(*data, inputType, outputType, permValues);
225
226 if (auto data = attr.tryGetValues<APFloat>())
227 return transposeType(*data, inputType, outputType, permValues);
228
229 // Handle DenseResourceElementsAttr
230 if (isa<DenseResourceElementsAttr>(attr)) {
231 auto elementTy = attr.getElementType();
232
233 if (auto data = tryGetDenseResourceValues<bool>(attr);
234 data && elementTy.isInteger(1))
235 return transposeType(*data, inputType, outputType, permValues);
236
237 if (auto data = tryGetDenseResourceValues<int8_t>(attr);
238 data && elementTy.isInteger(8))
239 return transposeType(*data, inputType, outputType, permValues);
240
241 if (auto data = tryGetDenseResourceValues<int16_t>(attr);
242 data && elementTy.isInteger(16))
243 return transposeType(*data, inputType, outputType, permValues);
244
245 if (auto data = tryGetDenseResourceValues<int32_t>(attr);
246 data && elementTy.isInteger(32))
247 return transposeType(*data, inputType, outputType, permValues);
248
249 if (auto data = tryGetDenseResourceValues<int64_t>(attr);
250 data && elementTy.isInteger(64))
251 return transposeType(*data, inputType, outputType, permValues);
252
253 if (auto data = tryGetDenseResourceValues<float>(attr);
254 data && elementTy.isF32())
255 return transposeType(*data, inputType, outputType, permValues);
256 }
257
258 return nullptr;
259}
260
261struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
263
264 LogicalResult matchAndRewrite(tosa::TransposeOp op,
265 PatternRewriter &rewriter) const override {
266 auto outputType = cast<ShapedType>(op.getType());
267 // TOSA supports quantized types.
268 if (!outputType.getElementType().isIntOrIndexOrFloat())
269 return failure();
270
271 ElementsAttr inputValues;
272 if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
273 return failure();
274 // Make sure the input is a constant that has a single user.
275 if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
276 return failure();
277
278 auto permValues = llvm::map_to_vector(
279 op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });
280
281 auto inputType = cast<ShapedType>(op.getInput1().getType());
282
283 auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
284 if (!resultAttr) {
285 return rewriter.notifyMatchFailure(
286 op, "unsupported attribute or element type");
287 }
288
289 rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
290 return success();
291 }
292};
293
294struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
295
297
298 LogicalResult matchAndRewrite(ReciprocalOp recip,
299 PatternRewriter &rewriter) const override {
300 auto inputTensor = recip.getInput1();
301
302 // Check that we can apply folding
303 auto preCondCheck =
304 notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
305 if (failed(preCondCheck)) {
306 return preCondCheck;
307 }
308
309 // Extract the tensor values
310 DenseElementsAttr inputValues;
311 matchPattern(inputTensor, m_Constant(&inputValues));
312
313 // Check whether this should be folded.
314 if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
315 return rewriter.notifyMatchFailure(
316 recip, "Currently, reciprocals will only be folded if the input "
317 "tensor has a single user");
318 }
319
320 // Create a new tensor with the updated values
321 auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
322 inputValues, &ReciprocalOp::calcOneElement,
323 cast<FloatType>(inputValues.getElementType()));
324
325 // Replace the use of the reciprocal with the transformed tensor
326 rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
327 return success();
328 }
329};
330
331/// Getting the axes position of the element which is located
332/// in the tensor at the counter index
333
335getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) {
336 int64_t remaining = index;
337 llvm::SmallVector<int64_t> position(tensorShape.size(), 0);
338 for (int64_t i = tensorShape.size() - 1; i >= 0; --i) {
339 position[i] = remaining % tensorShape[i];
340 remaining /= tensorShape[i];
341 }
342 return position;
343}
344
345/// Getting the index of the element which is located at the
346/// axes position in the tensor
347
348int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position,
349 llvm::ArrayRef<int64_t> tensorShape) {
350 int64_t index = 0;
351 int64_t multiplierTmp = 1;
352 for (int64_t i = position.size() - 1; i >= 0; --i) {
353 index += position[i] * multiplierTmp;
354 multiplierTmp *= tensorShape[i];
355 }
356 return index;
357}
358
359template <typename OperationType>
360llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
362 int64_t reductionAxis,
363 int64_t reductionIndex) {
364
365 llvm::SmallVector<int64_t> newShape(oldShape);
366 newShape[reductionAxis] = 1;
367 /// Let's calculate the position of the index
369 getPositionFromIndex(reductionIndex, newShape);
370 auto oldTensor = oldTensorAttr.getValues<llvm::APInt>();
371 /// Starting from the first positon along the reduction axis
372 position[reductionAxis] = 0;
373 int64_t indexAtOldTensor = getIndexFromPosition(position, oldShape);
374 llvm::APInt reducedValue = oldTensor[indexAtOldTensor];
375
376 for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
377 ++reductionAxisVal) {
378
379 int64_t stride = llvm::product_of(oldShape.drop_front(reductionAxis + 1));
380 int64_t index = indexAtOldTensor + stride * reductionAxisVal;
381 reducedValue =
382 OperationType::calcOneElement(reducedValue, oldTensor[index]);
383 }
384 return reducedValue;
385}
386
387template <typename OperationType>
388struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
389
390 ReduceConstantOptimization(MLIRContext *context,
391 bool aggressiveReduceConstant)
392 : OpRewritePattern<OperationType>(context),
393 aggressiveReduceConstant(aggressiveReduceConstant) {}
394
395 using OpRewritePattern<OperationType>::OpRewritePattern;
396
397 LogicalResult matchAndRewrite(OperationType op,
398 PatternRewriter &rewriter) const override {
399 Value inputOp = op.getInput();
400 auto constOp = inputOp.getDefiningOp<tosa::ConstOp>();
401
402 if (!constOp)
403 return rewriter.notifyMatchFailure(
404 op, "reduce input must be const operation");
405
406 if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
407 return rewriter.notifyMatchFailure(
408 op, "input operation has more than one user");
409
410 auto resultType = cast<ShapedType>(op.getOutput().getType());
411
412 if (!resultType.hasStaticShape())
413 return rewriter.notifyMatchFailure(op, "result type shape is not static");
414
415 auto reductionAxis = op.getAxis();
416 const auto denseElementsAttr = constOp.getValues();
417 const auto shapedOldElementsValues =
418 cast<ShapedType>(denseElementsAttr.getType());
419
420 if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
421 return rewriter.notifyMatchFailure(
422 op, "reduce input currently supported with integer type");
423
424 auto oldShape = shapedOldElementsValues.getShape();
425 auto newShape = resultType.getShape();
426
427 int64_t newNumOfElements = llvm::product_of(newShape);
428 llvm::SmallVector<APInt> newReducedTensor(newNumOfElements);
429
430 for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
431 ++reductionIndex) {
432
433 /// Let's reduce all the elements along this reduction axis
434 newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>(
435 denseElementsAttr, oldShape, reductionAxis, reductionIndex);
436 }
437
438 auto rankedTensorType = cast<RankedTensorType>(resultType);
439 auto denseAttr =
440 mlir::DenseElementsAttr::get(rankedTensorType, newReducedTensor);
441 rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
442 return success();
443 }
444 const bool aggressiveReduceConstant;
445};
446
447} // namespace
448
451 bool aggressiveReduceConstant) {
452 patterns.add<ReduceConstantOptimization<ReduceAllOp>>(
453 ctx, aggressiveReduceConstant);
454 patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(
455 ctx, aggressiveReduceConstant);
456 patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(
457 ctx, aggressiveReduceConstant);
458 patterns.add<ReduceConstantOptimization<ReduceMinOp>>(
459 ctx, aggressiveReduceConstant);
460 patterns.add<ReduceConstantOptimization<ReduceProductOp>>(
461 ctx, aggressiveReduceConstant);
462 patterns.add<ReduceConstantOptimization<ReduceSumOp>>(
463 ctx, aggressiveReduceConstant);
464}
465
468 patterns.add<TosaFoldConstantTranspose>(ctx);
469}
470
473 patterns.add<TosaFoldConstantReciprocal>(ctx);
474}
return success()
This class represents a processed binary blob of data.
Definition AsmState.h:91
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition AsmState.h:145
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
int64_t getNumElements() const
Returns the number of elements held by this attribute.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void populateTosaConstantReduction(MLIRContext *ctx, RewritePatternSet &patterns, bool aggressiveReduceConstant)
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx, RewritePatternSet &patterns)
void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx, RewritePatternSet &patterns)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...