MLIR 23.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// TOSA canonicalization patterns and folders.
11//
12//===----------------------------------------------------------------------===//
13
21#include "mlir/IR/Matchers.h"
25#include "llvm/ADT/APFloat.h"
26#include "llvm/ADT/APInt.h"
27
28#include <functional>
29
30using namespace mlir;
31using namespace mlir::tosa;
32
33//===----------------------------------------------------------------------===//
34// Operator Canonicalizers.
35//===----------------------------------------------------------------------===//
36
37//===----------------------------------------------------------------------===//
38// Tensor Data Engine Operators.
39//===----------------------------------------------------------------------===//
40
41// Check that the zero point of the tensor and padding operations are aligned.
42static bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
43 // Check that padConst is a constant value and a scalar tensor
44 DenseElementsAttr padConstAttr;
45 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
46 (padConstAttr.size() != 1)) {
47 return false;
48 }
49
50 // Check that floating point pad is zero
51 if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
52 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
53 return padConstVal == 0.0f;
54 }
55
56 // Check that the zp and padConst align for the integer (quantized) case
57 if (auto padConstIntAttr =
58 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
60 // Check that zp is a constant value and a scalar tensor
61 if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
62 return false;
63 }
64
65 // Check equality
66 int64_t zpVal = (*zpAttr.begin()).getSExtValue();
67 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
68 return zpVal == padConstVal;
69 }
70
71 // Bail-out on unsupported type
72 return false;
73}
74
75namespace {
76template <typename OpTy>
77struct PoolPadFoldAdaptor;
78
79template <>
80struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
81 using OpTy = tosa::MaxPool2dOp;
82 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
83 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
84 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
85 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
86 return false;
87 return true;
88 }
89 static bool checkPadConstCompliance(OpTy, Value padConst) {
90 // Check that padConst is a constant value and a scalar tensor
91 DenseElementsAttr padConstAttr;
92 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
93 padConstAttr.size() != 1) {
94 return false;
95 }
96
97 // Pad needs to be in the minimum value to be able to merge
98 if (auto padConstFpAttr =
99 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
100 const APFloat padConstVal = *padConstFpAttr.begin();
101 const APFloat lowestVal =
102 APFloat::getLargest(padConstVal.getSemantics(), true);
103 return padConstVal == lowestVal;
104 }
105 if (auto padConstIntAttr =
106 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
107 const APInt padConstVal = *padConstIntAttr.begin();
108 const unsigned int bitWidth = padConstVal.getBitWidth();
109 const APInt lowestVal =
110 padConstIntAttr.getElementType().isUnsignedInteger()
111 ? APInt::getZero(bitWidth)
112 : APInt::getSignedMinValue(bitWidth);
113 return padConstVal == lowestVal;
114 }
115
116 // Bail-out on unsupported type
117 return false;
118 }
119 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
120 Value padInput, ArrayRef<int64_t> newPad) {
121 rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
122 op, op.getType(), padInput, op.getKernel(), op.getStride(),
123 rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
124 }
125};
126
127template <typename OpTy>
128struct ConvPadFoldAdaptor {
129 static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
130 return true;
131 }
132 static bool checkPadConstCompliance(OpTy op, Value padConst) {
133 return checkMatchingPadConstAndZp(padConst, op.getInputZp());
134 }
135 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
136 Value padInput, ArrayRef<int64_t> newPad) {
137 rewriter.replaceOpWithNewOp<OpTy>(
138 op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
139 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
140 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
141 }
142};
143
144// Pattern attempts to fold a `tosa.pad` operator to a following tensor
145// operation like `tosa.conv2d` by merging the padding associated with the
146// pad operator directly to the implicit padding of the tensor operation.
147// This helps eliminate the explicit padding operator if unused.
148template <typename OpTy, typename AdaptorTy>
149struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
150 using OpRewritePattern<OpTy>::OpRewritePattern;
151
152 LogicalResult matchAndRewrite(OpTy tensorOp,
153 PatternRewriter &rewriter) const override {
154 // Check producer is a tosa::PadOp
155 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
156 if (!padOp)
157 return rewriter.notifyMatchFailure(tensorOp,
158 "Producer must be a tosa::PadOp.");
159
160 // Validate that tensor operation has sane padding
161 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
162 if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
163 return rewriter.notifyMatchFailure(
164 tensorOp, "Tensor operation padding shall have 4 elements.");
165
166 // Validate tosa::PadOp padding
167 DenseIntElementsAttr padOpPadding;
168 if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
169 return rewriter.notifyMatchFailure(
170 tensorOp,
171 "The `padding` input specified on the tosa::PadOp must be constant.");
172 }
173 // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
174 // C_after
175 if (padOpPadding.size() != 8)
176 return rewriter.notifyMatchFailure(tensorOp,
177 "Pad padding should have 8 elements.");
178 int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
179 int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
180 int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
181 int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
182 int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
183 int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
184 int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
185 int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
186
187 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
188 return rewriter.notifyMatchFailure(
189 tensorOp, "Folding padding in N or C dimensions is not supported.");
190
191 // Fold padding from Pad into the tensor operation
192 // 4 elements - pad_top, pad_bottom, pad_left, pad_right
193 SmallVector<int64_t> foldedPad(tensorOpPad.size());
194 foldedPad[0] = padHBefore + tensorOpPad[0];
195 foldedPad[1] = padHAfter + tensorOpPad[1];
196 foldedPad[2] = padWBefore + tensorOpPad[2];
197 foldedPad[3] = padWAfter + tensorOpPad[3];
198
199 // Check kernel related restrictions
200 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
201 return rewriter.notifyMatchFailure(
202 tensorOp, "Padding size not aligned with kernel restrictions.");
203 }
204
205 // Check padding constant restrictions
206 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
207 return rewriter.notifyMatchFailure(
208 tensorOp,
209 "Padding constant is not aligned with operator zero-point.");
210 }
211
212 // Check that padding doesn't grow more than 8K level (8192) for now
213 if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
214 return rewriter.notifyMatchFailure(
215 tensorOp, "Padding size more than the 8K level limit.");
216 }
217
218 // Create operator
219 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
220 foldedPad);
221
222 return success();
223 }
224};
225} // namespace
226
227void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
228 MLIRContext *context) {
229 results.add<
230 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
231 context);
232}
233
234void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
235 MLIRContext *context) {
236 results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
237 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
238 context);
239}
240
241struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
243
244 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
245 PatternRewriter &rewriter) const override {
246 Value input = op.getInput();
247 Value output = op.getOutput();
248 ShapedType inputType = llvm::cast<ShapedType>(input.getType());
249 ShapedType outputType = llvm::cast<ShapedType>(output.getType());
250
251 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
252 return failure();
253 }
254
255 // If the output and input shapes are 1x1, then this is a no op.
256 ArrayRef<int64_t> outputShape = outputType.getShape();
257 if (outputShape[1] != 1 || outputShape[2] != 1) {
258 return failure();
259 }
260
261 ArrayRef<int64_t> inputShape = inputType.getShape();
262 if (inputShape[1] != 1 || inputShape[2] != 1) {
263 return failure();
264 }
265
266 rewriter.replaceOp(op, input);
267 return success();
268 }
269};
270
271void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
272 MLIRContext *context) {
273 results.add<MaxPool2dIsNoOp,
274 FoldPadToTensorOp<tosa::MaxPool2dOp,
275 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
276 context);
277}
278
279//===----------------------------------------------------------------------===//
280// Data Layout / Memory Reinterpretation.
281//===----------------------------------------------------------------------===//
282
283struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
284 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
285
286 LogicalResult matchAndRewrite(tosa::ConcatOp op,
287 PatternRewriter &rewriter) const override {
288 if (op.getInput1().size() != 1)
289 return failure();
290 if (op.getInput1().front().getType() != op.getType()) {
291 rewriter
292 .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
293 op.getInput1().front())
294 .getResult();
295 return success();
296 }
297
298 rewriter.replaceOp(op, op.getInput1().front());
299 return success();
300 }
301};
302
303void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
304 MLIRContext *context) {
305 results.add<ConcatOptimization>(context);
306}
307
308LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
309 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
310 if (!notOp)
311 return failure();
312 rewriter.modifyOpInPlace(op, [&]() {
313 op.getOperation()->setOperands(
314 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
315 });
316 return success();
317}
318
320 : public OpRewritePattern<tosa::TransposeOp> {
322
323 LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
324 PatternRewriter &rewriter) const override {
325 // Input is also TransposeOp - transpose(transpose(A)).
326 auto innerTranspose =
327 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
328 if (!innerTranspose)
329 return rewriter.notifyMatchFailure(transposeOp,
330 "input must be transpose operation");
331
332 const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
333 const llvm::ArrayRef<int32_t> innerTransposePerms =
334 innerTranspose.getPerms();
335
336 if (transposePerms.size() != innerTransposePerms.size())
337 return rewriter.notifyMatchFailure(
338 transposeOp,
339 "transpose and inner transpose perms sizes must be equal");
340 if (transposePerms.empty())
341 return rewriter.notifyMatchFailure(
342 transposeOp, "transpose perms sizes must be positive");
343
344 // Consolidate transposes into one transpose.
345 SmallVector<int32_t> perms(transposePerms.size());
346 for (int i = 0, s = transposePerms.size(); i < s; ++i)
347 perms[i] = innerTransposePerms[transposePerms[i]];
348
349 rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
350 transposeOp, transposeOp.getResult().getType(),
351 innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
352
353 return success();
354 }
355};
356
357// Determines the case when tosa.transpose is a tosa.reshape operation.
358struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
360
361 LogicalResult matchAndRewrite(tosa::TransposeOp op,
362 PatternRewriter &rewriter) const override {
363 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
364 return rewriter.notifyMatchFailure(
365 op, "Src is from transpose, can compose transposes");
366
367 Value result = op.getResult();
368 for (Operation *subop : result.getUsers()) {
369 if (isa_and_nonnull<tosa::TransposeOp>(subop))
370 return rewriter.notifyMatchFailure(
371 op, "Dest is used by transpose, can compose transposes");
372 }
373
374 auto input = op.getInput1();
375 auto inputTy = llvm::cast<ShapedType>(input.getType());
376 if (!inputTy.hasRank())
377 return rewriter.notifyMatchFailure(op, "Unranked input.");
378
379 int64_t numDynDims = 0;
380 for (int i = 0; i < inputTy.getRank(); ++i)
381 if (inputTy.isDynamicDim(i))
382 numDynDims++;
383
384 if (numDynDims > 1)
385 return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
386
387 const llvm::ArrayRef<int32_t> permValues = op.getPerms();
388
389 SmallVector<int64_t> nonZeroPerms;
390 nonZeroPerms.reserve(permValues.size());
391 for (auto idx : permValues) {
392 auto sz = inputTy.getDimSize(idx);
393 if (sz != 1)
394 nonZeroPerms.push_back(idx);
395 }
396
397 for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
398 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
399 return rewriter.notifyMatchFailure(op,
400 "Transpose changes memory layout.");
401
402 SmallVector<int64_t> newShape;
403 newShape.reserve(inputTy.getRank());
404 for (int i = 0, s = inputTy.getRank(); i < s; ++i)
405 newShape.push_back(inputTy.getDimSize(permValues[i]));
406
407 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
408 op, op.getType(), op.getInput1(),
409 getTosaConstShape(rewriter, op.getLoc(), newShape));
410 return success();
411 }
412};
413
414void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
415 MLIRContext *context) {
416 results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
417}
418
419struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
421
422 LogicalResult matchAndRewrite(tosa::ClampOp op,
423 PatternRewriter &rewriter) const override {
424 Value input = op.getInput();
425 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
426 auto inputElementType = inputType.getElementType();
427
428 if (isa<FloatType>(inputElementType)) {
429 // Unlike integer types, floating point types can represent infinity.
430 const auto minClamp =
431 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
432 const auto maxClamp =
433 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
434 const bool isMin = minClamp.isNegInfinity();
435 const bool isMax = maxClamp.isInfinity();
436
437 if (isMin && isMax) {
438 rewriter.replaceOp(op, input);
439 return success();
440 }
441 return failure();
442 }
443
444 // i1 types are boolean in TOSA
445 const bool isBoolean = inputElementType.isInteger(1);
446 if (inputElementType.isUnsignedInteger() || isBoolean) {
447 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
448 .getValue()
449 .getZExtValue();
450 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
451 .getValue()
452 .getZExtValue();
453
454 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
455 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
456 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
457
458 if (minClamp <= intMin && maxClamp >= intMax) {
459 rewriter.replaceOp(op, input);
460 return success();
461 }
462 return failure();
463 }
464
465 if (llvm::isa<IntegerType>(inputElementType)) {
466 const int64_t minClamp =
467 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
468 const int64_t maxClamp =
469 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
470
471 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
472 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
473 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
474
475 if (minClamp <= intMin && maxClamp >= intMax) {
476 rewriter.replaceOp(op, input);
477 return success();
478 }
479 return failure();
480 }
481
482 return failure();
483 }
484};
485
486// Attempts the following transformation:
487//
488// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
489// tensor X the following identity holds:
490//
491// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
492//
493// subject to the following valid NaN propagation semantics:
494// --------------------------------------------
495// | OUTER CLAMP | INNER CLAMP | RESULT MODE |
496// |-------------|--------------|-------------|
497// | PROPAGATE | PROPAGATE | PROPAGATE |
498// | PROPAGATE | IGNORE | IGNORE |
499// | IGNORE | PROPAGATE | INVALID |
500// | IGNORE | IGNORE | IGNORE |
501// |------------------------------------------|
502
503struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
504 using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
505
506 // Helper structure to describe the range of a clamp operation.
507 template <typename T>
508 struct ClampRange {
509 ClampRange(const T &start, const T &end) : start(start), end(end) {}
512
513 // Helper function to determine if two Clamp ranges intersect.
514 bool intersects(const ClampRange<T> &otherRange) {
515 return start < otherRange.end && otherRange.start < end;
516 }
517 };
518
519 LogicalResult matchAndRewrite(tosa::ClampOp op,
520 PatternRewriter &rewriter) const override {
521 Value input = op.getInput();
522
523 // Check the input to the CLAMP op is itself a CLAMP.
524 auto clampOp = input.getDefiningOp<tosa::ClampOp>();
525 if (!clampOp)
526 return failure();
527
528 // Check we have a valid NaN propagation combination.
529 const auto opNanMode = op.getNanMode();
530 const auto clampNanMode = clampOp.getNanMode();
531 if (opNanMode == NanPropagationMode::IGNORE &&
532 clampNanMode == NanPropagationMode::PROPAGATE)
533 return failure();
534
535 auto maxValAttr = op.getMaxValAttr();
536 auto minValAttr = op.getMinValAttr();
537 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
538 auto clampOpMinValAttr = clampOp.getMinValAttr();
539
540 auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
541 if (auto quantType =
542 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
543 inputEType = getStorageElementTypeFromQuantized(quantType);
544 }
545
546 Attribute newMinValAttr, newMaxValAttr;
547 if (mlir::isa<FloatType>(inputEType)) {
548 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
549 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
550 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
551 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
552
553 // Check we have intersecting ranges.
554 const auto opMinFloat = floatMinValAttr.getValue();
555 const auto opMaxFloat = floatMaxValAttr.getValue();
556 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
557 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
558 ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
559 ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
560 clampOpMaxFloat);
561 if (!opRangeFloatRange.intersects(clampRangeFloatRange))
562 return failure();
563
564 // Run the transformation.
565 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
566 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
567 newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
568 newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
569 } else {
570 assert(mlir::isa<IntegerType>(inputEType));
571 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
572 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
573 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
574 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
575
576 if (inputEType.isUnsignedInteger()) {
577 // Check we have intersecting ranges.
578 const auto opMinInt = intMinValAttr.getUInt();
579 const auto opMaxInt = intMaxValAttr.getUInt();
580 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
581 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
582 ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
583 ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
584 clampOpMaxInt);
585 if (!opRangeIntRange.intersects(clampRangeIntRange))
586 return failure();
587
588 // Run the transformation.
589 auto newMinVal = std::max(opMinInt, clampOpMinInt);
590 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
591 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
592 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
593 } else {
594 // Check we have intersecting ranges.
595 const auto opMinInt = intMinValAttr.getInt();
596 const auto opMaxInt = intMaxValAttr.getInt();
597 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
598 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
599 ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
600 ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
601 clampOpMaxInt);
602 if (!opRangeIntRange.intersects(clampRangeIntRange))
603 return failure();
604
605 // Run the transformation.
606 auto newMinVal = std::max(opMinInt, clampOpMinInt);
607 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
608 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
609 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
610 }
611 }
612
613 auto newMode = (opNanMode != clampNanMode)
614 ? tosa::NanPropagationMode::IGNORE
615 : opNanMode;
616
617 auto newModeAttr =
618 NanPropagationModeAttr::get(rewriter.getContext(), newMode);
619
620 rewriter.replaceOpWithNewOp<tosa::ClampOp>(
621 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
622 newModeAttr);
623 return success();
624 }
625};
626
627void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
628 MLIRContext *context) {
629 results.add<ClampIsNoOp>(context);
630 results.add<ClampClampOptimization>(context);
631}
632
633struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
634 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
635
636 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
637 PatternRewriter &rewriter) const override {
638 Value sliceInput = sliceOp.getInput1();
639 auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
640 if (!concatOp)
641 return rewriter.notifyMatchFailure(
642 sliceOp, "slice input must be concat operation");
643
644 OperandRange inputs = concatOp.getInput1();
645 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
646 if (!concatType || !concatType.hasStaticShape())
647 return rewriter.notifyMatchFailure(
648 sliceOp, "slice input must be a static ranked tensor");
649 int32_t axis = concatOp.getAxis();
650
651 DenseElementsAttr startElems;
652 DenseElementsAttr sizeElems;
653
654 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
655 return rewriter.notifyMatchFailure(
656 sliceOp, "start of slice must be a static ranked shape");
657
658 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
659 return rewriter.notifyMatchFailure(
660 sliceOp, "size of slice must be a static ranked shape");
661
662 llvm::SmallVector<int64_t> sliceStarts =
663 llvm::to_vector(startElems.getValues<int64_t>());
664 llvm::SmallVector<int64_t> sliceSizes =
665 llvm::to_vector(sizeElems.getValues<int64_t>());
666
667 // Validate slice on the concatenated axis. Slicing along this
668 // axis should span only one of the inputs to the concatenate
669 // operation.
670 std::optional<Value> replaceWithSlice;
671 for (auto input : inputs) {
672 auto inputType = dyn_cast<RankedTensorType>(input.getType());
673 if (!inputType || !inputType.hasStaticShape())
674 return rewriter.notifyMatchFailure(
675 sliceOp, "concat input must be a static ranked tensor");
676
677 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
678 inputType.getDimSize(axis)) {
679 auto start_op =
680 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
681 auto size_op =
682 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
683 replaceWithSlice =
684 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
685 input, start_op, size_op)
686 .getResult();
687 break;
688 }
689 sliceStarts[axis] -= inputType.getDimSize(axis);
690 }
691
692 if (!replaceWithSlice)
693 return rewriter.notifyMatchFailure(
694 sliceOp, "corresponding concat input not found for slice");
695
696 rewriter.replaceOp(sliceOp, replaceWithSlice.value());
697 return success();
698 }
699};
700
701struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
702 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
703
704 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
705 PatternRewriter &rewriter) const override {
706 Value sliceInput = sliceOp.getInput1();
707
708 // Check if producer is a PadOp
709 auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
710 if (!padOp)
711 return rewriter.notifyMatchFailure(sliceOp,
712 "slice input must be a pad operation");
713
714 // Check PadOp has a single consumer
715 if (!padOp->hasOneUse())
716 return rewriter.notifyMatchFailure(sliceOp,
717 "pad shall have a single consumer");
718
719 // Check input is statically ranked
720 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
721 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
722 if (!inputTy || !padTy || !inputTy.hasRank())
723 return rewriter.notifyMatchFailure(sliceOp,
724 "slice input must be a ranked tensor");
725
726 // Validate and extract tosa::PadOp padding
727 DenseIntElementsAttr paddingElems;
728 if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
729 return rewriter.notifyMatchFailure(
730 sliceOp,
731 "`padding` input specified on the tosa::PadOp must be constant.");
732 }
733 llvm::SmallVector<int64_t> padPaddings =
734 llvm::to_vector(paddingElems.getValues<int64_t>());
735
736 // Extract slice parameters
737 DenseElementsAttr startElems;
738 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
739 return rewriter.notifyMatchFailure(
740 sliceOp, "start of slice must be a static ranked shape");
741 llvm::SmallVector<int64_t> sliceStarts =
742 llvm::to_vector(startElems.getValues<int64_t>());
743
744 DenseElementsAttr sizeElems;
745 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
746 return rewriter.notifyMatchFailure(
747 sliceOp, "size of slice must be a static ranked shape");
748 llvm::SmallVector<int64_t> sliceSizes =
749 llvm::to_vector(sizeElems.getValues<int64_t>());
750
751 // Check if dynamic dimensions are sliced
752 const int64_t rank = inputTy.getRank();
753 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
754 const bool isDimDynamic = inputTy.isDynamicDim(i);
755 const bool isDimSliced =
756 (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
757
758 return isDimDynamic && isDimSliced;
759 })) {
760 return rewriter.notifyMatchFailure(
761 sliceOp, "axis that are sliced shall be statically known.");
762 }
763
764 // Update the parameters
765 llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
766 llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
767 llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
768 bool updated = false;
769
770 for (int64_t i = 0; i < rank; ++i) {
771 const int64_t padLo = padPaddings[i * 2];
772 const int64_t padHi = padPaddings[i * 2 + 1];
773 const int64_t sliceStart = sliceStarts[i];
774 const int64_t sliceSize = sliceSizes[i];
775 const int64_t sliceEnd = sliceStart + sliceSize;
776
777 // If dimension is dynamic pass-through
778 if (inputTy.isDynamicDim(i)) {
779 newPadPaddings[i * 2] = padLo;
780 newPadPaddings[i * 2 + 1] = padHi;
781 newSliceStarts[i] = sliceStart;
782 continue;
783 }
784
785 // Handle static dimensions
786 const int64_t dimSize = inputTy.getShape()[i];
787 const int64_t dimTotal = padLo + dimSize + padHi;
788
789 // Check slice within bounds
790 if (sliceStart < 0 || sliceEnd > dimTotal)
791 return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
792
793 // Compute updated slice start parameter
794 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
795 newSliceStarts[i] = newSliceStart;
796 updated |= newSliceStart != sliceStart;
797
798 // Compute updated pad parameters
799 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
800 const int64_t newPadHi =
801 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
802 newPadPaddings[i * 2] = newPadLo;
803 newPadPaddings[i * 2 + 1] = newPadHi;
804 updated |= (newPadLo != padLo) || (newPadHi != padHi);
805
806 // Calculate new pad output shape
807 newPadShape[i] =
808 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
809 }
810
811 // Check that we actually need to proceed with the rewrite
812 if (!updated)
813 return rewriter.notifyMatchFailure(
814 sliceOp, "terminate condition; nothing to rewrite");
815
816 // Create a PadOp with updated padding
817 auto newPaddingsOp =
818 getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
819 auto newPadTy =
820 RankedTensorType::get(newPadShape, inputTy.getElementType());
821 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
822 padOp.getInput1(), newPaddingsOp,
823 padOp.getPadConst());
824
825 // Update SliceOp and point to new PadOp
826 auto newStartOp =
827 getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
828 rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
829 newPadOp.getResult(), newStartOp,
830 sliceOp.getSize());
831
832 return success();
833 }
834};
835
836// Update size operand of tosa.slice if size has dynamic dims but corresponding
837// output dim is static
839 : public OpRewritePattern<tosa::SliceOp> {
840 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
841
842 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
843 PatternRewriter &rewriter) const override {
844 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
845
846 ElementsAttr sizeElems;
847 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
848 return rewriter.notifyMatchFailure(
849 sliceOp, "size of slice must be a static ranked shape");
850 }
851
852 llvm::SmallVector<int64_t> sliceSizes =
853 llvm::to_vector(sizeElems.getValues<int64_t>());
854
855 bool replaceSliceSize{false};
856 // if size op has -1 indicating dynamic shape but corresponding dim on the
857 // output is statically known, update size to match with known output dim
858 // shape
859 for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
860 if (size == -1 && !resultType.isDynamicDim(index)) {
861 sliceSizes[index] = resultType.getDimSize(index);
862 replaceSliceSize = true;
863 }
864 }
865
866 if (!replaceSliceSize) {
867 return rewriter.notifyMatchFailure(
868 sliceOp, "no dimension of size of slice is dynamic that resolves "
869 "to static output shape");
870 }
871
872 auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
873 auto newSliceOp =
874 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
875 sliceOp.getInput1(), sliceOp.getStart(), size_op);
876
877 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
878 return success();
879 }
880};
881
882void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
883 MLIRContext *context) {
884 results.add<ConcatSliceOptimization, PadSliceOptimization,
885 SliceDynamicSizeCanonicalization>(context);
886}
887
888struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
889 using OpRewritePattern<tosa::CastOp>::OpRewritePattern;
890
891 LogicalResult matchAndRewrite(tosa::CastOp castOp,
892 PatternRewriter &rewriter) const override {
893 const Value castInput = castOp.getInput();
894 auto innerCastOp = castInput.getDefiningOp<tosa::CastOp>();
895 if (!innerCastOp)
896 return rewriter.notifyMatchFailure(castOp,
897 "input must be cast operation");
898
899 const Value innerCastInput = innerCastOp.getInput();
900
901 const auto innerInputType =
902 llvm::cast<ShapedType>(innerCastInput.getType());
903 const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
904 const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
905
906 const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType,
907 outerOutputType};
908 if (llvm::any_of(types, [](const ShapedType type) {
909 return !type.getElementType().isInteger();
910 }))
911 return rewriter.notifyMatchFailure(castOp,
912 "only integer types are supported");
913
914 // Check inner cast is non-narrowing
915 const unsigned innerInputBitWidth = innerInputType.getElementTypeBitWidth();
916 if (innerInputBitWidth > innerOutputType.getElementTypeBitWidth())
917 return rewriter.notifyMatchFailure(castOp,
918 "inner cast operation is narrowing");
919
920 // Check outer cast is non-narrowing from the inner cast input
921 if (innerInputBitWidth > outerOutputType.getElementTypeBitWidth())
922 return rewriter.notifyMatchFailure(castOp,
923 "outer cast operation is narrowing");
924
925 rewriter.replaceOpWithNewOp<tosa::CastOp>(castOp, outerOutputType,
926 innerCastInput);
927
928 return success();
929 }
930};
931
932void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
933 MLIRContext *context) {
934 results.add<NonNarrowingCastsOptimization>(context);
935}
936
937//===----------------------------------------------------------------------===//
938// Operator Folders.
939//===----------------------------------------------------------------------===//
940
941template <typename Folder>
942static DenseElementsAttr
944 bool foldDenseValues = false) {
945 if (!lhs || !rhs)
946 return {};
947
948 const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
949 const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
950 if (lETy != rETy)
951 return {};
952
953 if (lhs.isSplat() && rhs.isSplat()) {
954 if (isa<FloatType>(lETy)) {
955 const APFloat l = lhs.getSplatValue<APFloat>();
956 const APFloat r = rhs.getSplatValue<APFloat>();
957 const auto maybeResult = Folder::fold(l, r);
958 if (failed(maybeResult))
959 return {};
960 return DenseElementsAttr::get(returnTy, maybeResult.value());
961 }
962
963 if (const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
964 const APInt l = lhs.getSplatValue<APInt>();
965 const APInt r = rhs.getSplatValue<APInt>();
966 const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
967 if (failed(maybeResult))
968 return {};
969 return DenseElementsAttr::get(returnTy, maybeResult.value());
970 }
971 }
972
973 if (foldDenseValues) {
974 assert(lETy.isIntOrIndex() &&
975 "Only integer types are currently supported.");
976 SmallVector<APInt> resultValues;
977 for (auto [l, r] :
978 llvm::zip(lhs.getValues<APInt>(), rhs.getValues<APInt>())) {
979 const auto maybeResult = Folder::fold(l, r, false);
980 if (failed(maybeResult))
981 return {};
982 resultValues.push_back(maybeResult.value());
983 }
984 return DenseElementsAttr::get(returnTy, resultValues);
985 }
986
987 return {};
988}
989
990template <typename Folder>
991static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy,
992 bool foldDenseValues = false) {
993 if (!val)
994 return {};
995
996 const auto vETy = llvm::cast<ShapedType>(val.getType()).getElementType();
997
998 if (val.isSplat()) {
999 if (const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
1000 const APInt v = val.getSplatValue<APInt>();
1001 const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
1002 if (failed(maybeResult))
1003 return {};
1004 return DenseElementsAttr::get(returnTy, maybeResult.value());
1005 }
1006 }
1007
1008 if (foldDenseValues) {
1009 mlir::Type elemTy = val.getElementType();
1010 if (elemTy.isIntOrIndex()) {
1011 SmallVector<APInt> resultValues;
1012 for (auto const &v : val.getValues<APInt>()) {
1013 const auto maybeResult = Folder::fold(v, false);
1014 if (failed(maybeResult))
1015 return {};
1016 resultValues.push_back(maybeResult.value());
1017 }
1018 return DenseElementsAttr::get(returnTy, resultValues);
1019 }
1020 }
1021
1022 // Folding arbitrarily sized tensor operations is not supported
1023 return {};
1024}
1025
1027 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1028 const bool isUnsigned) {
1029 bool overflow;
1030 const APInt result =
1031 isUnsigned ? lhs.uadd_ov(rhs, overflow) : lhs.sadd_ov(rhs, overflow);
1032 if (overflow)
1033 return failure();
1034 return result;
1035 }
1036
1037 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1038 return lhs + rhs;
1039 }
1040};
1041
1043 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1044 const bool isUnsigned) {
1045 bool overflow;
1046 const APInt result =
1047 isUnsigned ? lhs.usub_ov(rhs, overflow) : lhs.ssub_ov(rhs, overflow);
1048 if (overflow)
1049 return failure();
1050 return result;
1051 }
1052
1053 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1054 return lhs - rhs;
1055 }
1056};
1057
1059 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1060 const bool isUnsigned) {
1061
1062 const unsigned originalWidth = lhs.getBitWidth();
1063
1064 // Check same type
1065 if (lhs.getBitWidth() != rhs.getBitWidth()) {
1066 return failure();
1067 }
1068
1069 // If either is `0`
1070 if (lhs == 0 || rhs == 0)
1071 return APInt::getZero(originalWidth);
1072
1073 bool overflow = false;
1074 APInt const result =
1075 isUnsigned ? lhs.umul_ov(rhs, overflow) : lhs.smul_ov(rhs, overflow);
1076
1077 if (overflow)
1078 return failure();
1079
1080 return result.trunc(originalWidth);
1081 }
1082
1083 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1084 return lhs * rhs;
1085 }
1086};
1087
1088static bool signsDiffer(const APInt &a, const APInt &b) {
1089 return a.isNegative() != b.isNegative();
1090}
1091
1092template <bool Ceil>
1094 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1095 bool isUnsigned) {
1096 if (lhs.getBitWidth() != rhs.getBitWidth())
1097 return failure();
1098 if (rhs.isZero())
1099 return failure();
1100
1101 if (isUnsigned) {
1102 APInt q{};
1103 APInt r{};
1104 APInt::udivrem(lhs, rhs, q, r);
1105 if (!r.isZero() && Ceil) {
1106 return q + 1;
1107 }
1108 return q;
1109 }
1110
1111 // Signed: start from trunc-toward-zero, then adjust to ceil.
1112 bool overflow{false};
1113 APInt const q = lhs.sdiv_ov(rhs, overflow);
1114 if (overflow)
1115 return failure();
1116 APInt const r = lhs.srem(rhs);
1117
1118 if (Ceil && !r.isZero() && !signsDiffer(lhs, rhs)) {
1119 // Same sign => exact quotient is positive; trunc is below ceil =>
1120 // increment q.
1121 return q + 1;
1122 }
1123 return q;
1124 }
1125
1126 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1127 return lhs / rhs;
1128 }
1129};
1130
1132 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1133 bool isUnsigned) {
1134 if (lhs.getBitWidth() != rhs.getBitWidth())
1135 return failure();
1136 if (lhs.isNegative() || (!rhs.isStrictlyPositive()))
1137 return failure();
1138
1139 if (isUnsigned) {
1140 return lhs.urem(rhs);
1141 }
1142
1143 return lhs.srem(rhs);
1144 }
1145
1146 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1147 auto t = lhs;
1148 auto const r = t.mod(rhs);
1149 if (llvm::APFloatBase::opStatus::opOK == r) {
1150 return t;
1151 }
1152 return failure();
1153 }
1154};
1155
1157 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1158 bool isUnsigned) {
1159 if (lhs.getBitWidth() != rhs.getBitWidth())
1160 return failure();
1161 return lhs.getSExtValue() >= rhs.getSExtValue() ? lhs : rhs;
1162 }
1163
1164 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1165 return lhs >= rhs ? lhs : rhs;
1166 }
1167};
1168
1170 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1171 bool isUnsigned) {
1172 if (lhs.getBitWidth() != rhs.getBitWidth())
1173 return failure();
1174 return lhs.getSExtValue() <= rhs.getSExtValue() ? lhs : rhs;
1175 }
1176
1177 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1178 return lhs <= rhs ? lhs : rhs;
1179 }
1180};
1181
1183 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1184 auto const numBits = value.getBitWidth();
1185 if (isUnsigned) {
1186 auto const zextv = value.getZExtValue();
1187 if (zextv >= numBits)
1188 return failure();
1189 return APInt::getOneBitSet(numBits, zextv);
1190 }
1191 auto const sextv = value.getSExtValue();
1192 if (sextv < 0 || sextv >= numBits || (value.isNegative()))
1193 return failure();
1194 return APInt::getOneBitSet(numBits, sextv);
1195 }
1196};
1197
1199 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1200 if (!value.isStrictlyPositive())
1201 return failure();
1202 return APInt(/*numBits=*/value.getBitWidth(), value.ceilLogBase2());
1203 }
1204};
1205
1207 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1208 if (!value.isStrictlyPositive())
1209 return failure();
1210 return APInt(/*numBits=*/value.getBitWidth(), value.logBase2());
1211 }
1212};
1213
1215 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1216 const bool isUnsigned) {
1217 return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs));
1218 }
1219
1220 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1221 return APInt(1, lhs > rhs);
1222 }
1223};
1224
1226 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1227 const bool isUnsigned) {
1228 return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs));
1229 }
1230
1231 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1232 return APInt(1, lhs >= rhs);
1233 }
1234};
1235
1237 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1238 const bool isUnsigned) {
1239 return APInt(1, lhs == rhs);
1240 }
1241
1242 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1243 return APInt(1, lhs == rhs);
1244 }
1245};
1246
1247static bool isSplatZero(Type elemType, DenseElementsAttr val) {
1248 if (llvm::isa<FloatType>(elemType))
1249 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
1250 if (llvm::isa<IntegerType>(elemType))
1251 return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
1252 return false;
1253}
1254
1255static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
1256 if (llvm::isa<FloatType>(elemType))
1257 return val && val.isSplat() &&
1258 val.getSplatValue<APFloat>().isExactlyValue(1.0);
1259 if (llvm::isa<IntegerType>(elemType)) {
1260 const int64_t shifted = 1LL << shift;
1261 return val && val.isSplat() &&
1262 val.getSplatValue<APInt>().getSExtValue() == shifted;
1263 }
1264 return false;
1265}
1266
1267OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1268 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1269 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1270 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1271 if (!lhsTy || !rhsTy || !resultTy)
1272 return {};
1273
1274 // Cannot create an ElementsAttr from non-int/float/index types
1275 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1276 !rhsTy.getElementType().isIntOrIndexOrFloat())
1277 return {};
1278
1279 auto resultETy = resultTy.getElementType();
1280 auto lhsAttr =
1281 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1282 auto rhsAttr =
1283 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1284
1285 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1286 return getInput1();
1287 if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
1288 return getInput2();
1289
1290 if (!lhsAttr || !rhsAttr)
1291 return {};
1292
1293 return binaryFolder<AddFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1294}
1295
1296OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1297 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
1298 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1299 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1300 !outputTy.hasStaticShape())
1301 return {};
1302
1303 const Type outputElementTy = getElementTypeOrSelf(outputTy);
1304 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
1305 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1306 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1307 return DenseElementsAttr::get(outputTy, zero);
1308 }
1309
1310 return {};
1311}
1312
1313OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1314 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1315 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1316 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1317 if (!lhsTy || !rhsTy || !resultTy)
1318 return {};
1319 if (lhsTy != rhsTy)
1320 return {};
1321
1322 // IntDivOp inputs must be integer type, no need to check for quantized
1323 // type
1324 auto resultETy = resultTy.getElementType();
1325 auto lhsAttr =
1326 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1327 auto rhsAttr =
1328 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1329 if (lhsAttr && lhsAttr.isSplat()) {
1330 if (llvm::isa<IntegerType>(resultETy) &&
1331 lhsAttr.getSplatValue<APInt>().isZero())
1332 return lhsAttr;
1333 }
1334
1335 if (rhsAttr && rhsAttr.isSplat()) {
1336 if (llvm::isa<IntegerType>(resultETy) &&
1337 rhsAttr.getSplatValue<APInt>().isOne())
1338 return getInput1();
1339 }
1340
1341 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1342 llvm::isa<IntegerType>(resultETy)) {
1343 APInt l = lhsAttr.getSplatValue<APInt>();
1344 APInt r = rhsAttr.getSplatValue<APInt>();
1345 if (!r.isZero()) {
1346 auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
1347 auto const result =
1348 DivFoldAdaptor</*Ceil*/ false>::fold(l, r, intTy.isUnsigned());
1349 if (failed(result))
1350 return {};
1351 return DenseElementsAttr::get(resultTy, result.value());
1352 }
1353 }
1354
1355 return {};
1356}
1357
1358namespace {
1359// calculate lhs * rhs >> shift according to TOSA Spec
1360// return nullopt if result is not in range of int32_t when shift > 0
1361std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
1362 unsigned bitwidth) {
1363 bool overflow = false;
1364 APInt result = lhs.sext(64).smul_ov(rhs.sext(64), overflow);
1365
1366 if (overflow)
1367 return std::nullopt;
1368
1369 if (shift > 0) {
1370 auto round = APInt(64, 1) << (shift - 1);
1371 result += round;
1372 result.ashrInPlace(shift);
1373 // REQUIRE(product >= minimum_s<i32_t>() && product <=
1374 // maximum_s<i32_t>())
1375 if (!(result.getSExtValue() >= INT32_MIN &&
1376 result.getSExtValue() <= INT32_MAX)) {
1377 // REQUIRE failed
1378 return std::nullopt;
1379 }
1380 }
1381
1382 return result.trunc(bitwidth);
1383}
1384
1385DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
1386 RankedTensorType ty, int32_t shift) {
1387 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
1388 if (llvm::isa<IntegerType>(ty.getElementType())) {
1389 APInt l = lhs.getSplatValue<APInt>();
1390 APInt r = rhs.getSplatValue<APInt>();
1391
1392 if (shift == 0) {
1393 return DenseElementsAttr::get(ty, l * r);
1394 }
1395
1396 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1397 const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
1398 if (!result)
1399 return {};
1400 return DenseElementsAttr::get(ty, result.value());
1401 }
1402
1403 if (llvm::isa<FloatType>(ty.getElementType())) {
1404 APFloat l = lhs.getSplatValue<APFloat>();
1405 APFloat r = rhs.getSplatValue<APFloat>();
1406 APFloat result = l * r;
1407 return DenseElementsAttr::get(ty, result);
1408 }
1409 }
1410
1411 return {};
1412}
1413} // namespace
1414
1415OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1416 auto lhs = getInput1();
1417 auto rhs = getInput2();
1418 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
1419 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
1420 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1421 if (!lhsTy || !rhsTy || !resultTy)
1422 return {};
1423
1424 auto resultETy = resultTy.getElementType();
1425 auto lhsAttr =
1426 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1427 auto rhsAttr =
1428 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1429
1430 // Result right shift on i32_t data type only. For simplification,
1431 // synthesize a zero shift for other data type.
1432 int32_t shift = 0;
1433 if (resultETy.isInteger(32)) {
1434 ElementsAttr shift_elem;
1435 if (getShift().getImpl()) {
1436 if (!matchPattern(getShift(), m_Constant(&shift_elem)))
1437 // cannot be folded when the shift value is unknown.
1438 return {};
1439 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1440 }
1441 }
1442
1443 if (rhsTy == resultTy) {
1444 if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
1445 // constant values can only be resized if resulting type is static
1446 return lhsAttr.resizeSplat(resultTy);
1447 if (isSplatOne(resultETy, lhsAttr, shift))
1448 return rhs;
1449 }
1450 if (lhsTy == resultTy) {
1451 if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
1452 return rhsAttr.resizeSplat(resultTy);
1453 if (isSplatOne(resultETy, rhsAttr, shift))
1454 return lhs;
1455 }
1456
1457 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1458}
1459
1460OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1461 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1462 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1463 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1464 if (!lhsTy || !rhsTy || !resultTy)
1465 return {};
1466
1467 // Cannot create an ElementsAttr from non-int/float/index types
1468 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1469 !rhsTy.getElementType().isIntOrIndexOrFloat())
1470 return {};
1471
1472 auto resultETy = resultTy.getElementType();
1473 auto lhsAttr =
1474 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1475 auto rhsAttr =
1476 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1477
1478 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1479 return getInput1();
1480
1481 if (!lhsAttr || !rhsAttr)
1482 return {};
1483
1484 return binaryFolder<SubFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1485}
1486
1487OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1488 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1489 auto lhsAttr =
1490 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1491 auto rhsAttr =
1492 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1493
1494 if (!lhsAttr || !rhsAttr)
1495 return {};
1496
1497 return binaryFolder<GreaterFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1498}
1499
1500OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1501 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1502 auto lhsAttr =
1503 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1504 auto rhsAttr =
1505 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1506
1507 if (!lhsAttr || !rhsAttr)
1508 return {};
1509
1510 return binaryFolder<GreaterEqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1511}
1512
1513OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1514 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1515 auto lhsAttr =
1516 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1517 auto rhsAttr =
1518 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1519 Value lhs = getInput1();
1520 Value rhs = getInput2();
1521 auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
1522
1523 // If we are comparing an integer value to itself it is always true. We
1524 // can not do this with float due to float values.
1525 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1526 resultTy.hasStaticShape() && lhs == rhs) {
1527 return DenseElementsAttr::get(resultTy, true);
1528 }
1529
1530 if (!lhsAttr || !rhsAttr)
1531 return {};
1532
1533 return binaryFolder<EqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1534}
1535
1536OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1537 if (getInput().getType() == getType())
1538 return getInput();
1539
1540 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1541 if (!operand)
1542 return {};
1543
1544 auto inTy = llvm::cast<ShapedType>(getInput().getType());
1545 auto outTy = llvm::cast<ShapedType>(getType());
1546 auto inETy = inTy.getElementType();
1547 auto outETy = outTy.getElementType();
1548
1549 if (operand.isSplat()) {
1550 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1551 bool overflow;
1552 auto splatVal = operand.getSplatValue<APFloat>();
1553 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1554 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1555 &overflow);
1556 return SplatElementsAttr::get(outTy, splatVal);
1557 }
1558
1559 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1560 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1561 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1562 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1563 llvm::RoundingMode::NearestTiesToEven);
1564 return SplatElementsAttr::get(outTy, splatVal);
1565 }
1566
1567 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1568 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1569 auto intVal = APSInt(
1570 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1571 auto floatVal = operand.getSplatValue<APFloat>();
1572 bool exact;
1573 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1574 &exact);
1575 return SplatElementsAttr::get(outTy, intVal);
1576 }
1577
1578 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1579 const auto inIntType = llvm::cast<IntegerType>(inETy);
1580 auto unsignIn = inIntType.isUnsignedInteger();
1581 bool trunc =
1582 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1583 auto intVal = operand.getSplatValue<APInt>();
1584 auto bitwidth = outETy.getIntOrFloatBitWidth();
1585
1586 // i1 types are boolean in TOSA
1587 if (outETy.isInteger(1)) {
1588 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1589 } else if (trunc) {
1590 intVal = intVal.trunc(bitwidth);
1591 } else if (unsignIn || inIntType.isInteger(1)) {
1592 intVal = intVal.zext(bitwidth);
1593 } else {
1594 intVal = intVal.sext(bitwidth);
1595 }
1596
1597 return SplatElementsAttr::get(outTy, intVal);
1598 }
1599 }
1600
1601 return {};
1602}
1603
1604OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1605
1606OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1607
1608#define REDUCE_FOLDER(OP) \
1609 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1610 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1611 if (!inputTy.hasRank()) \
1612 return {}; \
1613 if (inputTy != getType()) \
1614 return {}; \
1615 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1616 return getInput(); \
1617 return {}; \
1618 }
1619
1620REDUCE_FOLDER(ReduceAllOp)
1621REDUCE_FOLDER(ReduceAnyOp)
1622REDUCE_FOLDER(ReduceMaxOp)
1623REDUCE_FOLDER(ReduceMinOp)
1624REDUCE_FOLDER(ReduceProductOp)
1625REDUCE_FOLDER(ReduceSumOp)
1626#undef REDUCE_FOLDER
1627
1628OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1629 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1630 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1631
1632 if (!inputTy || !outputTy)
1633 return {};
1634
1635 // Fold when the input and output types are the same. This is only safe
1636 // when there is at most 1 dynamic dimension. For 2 or more dynamic
1637 // dimensions, there may still be a productive reshape.
1638 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1639 return getInput1();
1640
1641 // reshape(reshape(x)) -> reshape(x)
1642 if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1643 getInput1().getDefiningOp())) {
1644 getInput1Mutable().assign(reshapeOp.getInput1());
1645 return getResult();
1646 }
1647
1648 // Cannot create an ElementsAttr from non-int/float/index types
1649 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1650 return {};
1651
1652 // reshape(const(x)) -> const(reshape-attr(x))
1653 if (auto operand =
1654 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1655 // Constants must have static shape.
1656 if (!outputTy.hasStaticShape())
1657 return {};
1658
1659 // Okay to duplicate splat constants.
1660 if (operand.isSplat())
1661 return SplatElementsAttr::get(outputTy,
1662 operand.getSplatValue<Attribute>());
1663
1664 // Don't duplicate other constants.
1665 if (!getInput1().hasOneUse())
1666 return {};
1667
1669 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
1670 return {};
1671
1672 return operand.reshape(
1673 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1674 }
1675
1676 return {};
1677}
1678
1679OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1680 // If the pad is all zeros we can fold this operation away.
1681 if (adaptor.getPadding() && getInput1().getType() == getType()) {
1682 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1683 if (densePad && densePad.isSplat() &&
1684 densePad.getSplatValue<APInt>().isZero()) {
1685 return getInput1();
1686 }
1687 }
1688
1689 return {};
1690}
1691
1692// Fold away cases where a tosa.resize operation returns a copy
1693// of the input image.
1694OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1695 auto scaleAttr =
1696 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1697 auto offsetAttr =
1698 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1699 auto borderAttr =
1700 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1701 if (!scaleAttr || !offsetAttr || !borderAttr) {
1702 return {};
1703 }
1704
1705 auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
1706 auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
1707 auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
1708 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1709 return {};
1710 }
1711
1712 // Check unit scaling.
1713 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1714 return {};
1715 }
1716
1717 // There should be no offset.
1718 if (offset[0] != 0 || offset[1] != 0) {
1719 return {};
1720 }
1721
1722 // There should be no border.
1723 if (border[0] != 0 || border[1] != 0) {
1724 return {};
1725 }
1726
1727 auto input = getInput();
1728 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1729 auto resultTy = llvm::cast<RankedTensorType>(getType());
1730 if (inputTy != resultTy)
1731 return {};
1732
1733 return input;
1734}
1735
1736OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1737 auto operand = getInput1();
1738 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1739 auto axis = getAxis();
1740 auto operandAttr =
1741 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1742 if (operandAttr)
1743 return operandAttr;
1744
1745 // If the dim-length is 1, tosa.reverse is a no-op.
1746 if (operandTy.hasRank() &&
1747 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1748 return operand;
1749
1750 return {};
1751}
1752
1753OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1754 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1755 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1756
1757 if (!inputTy || !outputTy)
1758 return {};
1759
1760 if (inputTy == outputTy && inputTy.hasStaticShape())
1761 return getInput1();
1762
1763 if (!adaptor.getInput1())
1764 return {};
1765
1766 // Cannot create an ElementsAttr from non-int/float/index types
1767 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1768 !outputTy.getElementType().isIntOrIndexOrFloat())
1769 return {};
1770
1771 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1772 if (operand.isSplat() && outputTy.hasStaticShape()) {
1773 return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
1774 }
1775
1776 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1777 outputTy.getNumElements() == 1) {
1778 DenseElementsAttr startElems;
1779 if (!matchPattern(getStart(), m_Constant(&startElems)))
1780 return {};
1781
1782 llvm::SmallVector<uint64_t> indices =
1783 llvm::to_vector(startElems.getValues<uint64_t>());
1784 auto value = operand.getValues<Attribute>()[indices];
1785 return SplatElementsAttr::get(outputTy, value);
1786 }
1787
1788 return {};
1789}
1790
1791static bool
1793 const auto isDynamic = [](Type ty) {
1794 const auto shapedTy = llvm::dyn_cast<ShapedType>(ty);
1795 return !shapedTy || !shapedTy.hasStaticShape();
1796 };
1797
1798 return llvm::any_of(operandTypes, isDynamic) ||
1799 failed(verifyCompatibleShapes(operandTypes));
1800}
1801
1802OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1803 // Select allows operand shapes to be broadcast to the output shape. For
1804 // now, don't support folding when we cannot prove no broadcasting is
1805 // involved.
1806 if (mayRequireBroadcast(getOperandTypes()))
1807 return {};
1808
1809 if (getOnTrue() == getOnFalse())
1810 return getOnTrue();
1811
1812 auto predicate =
1813 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1814 if (!predicate)
1815 return {};
1816
1817 if (!predicate.isSplat())
1818 return {};
1819 return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1820 : getOnFalse();
1821}
1822
1823OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1824 if (getInput1().getType() == getType()) {
1825 if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1826 adaptor.getMultiples())) {
1827 if (multiples.isSplat() &&
1828 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1829 return getInput1();
1830 if (auto int_array_attr =
1831 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1832 if (llvm::all_of(int_array_attr.getValues<APInt>(),
1833 [](APInt v) { return v.getSExtValue() == 1; }))
1834 return getInput1();
1835 }
1836 }
1837 }
1838 return {};
1839}
1840
1841OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1842 auto resultTy = llvm::cast<ShapedType>(getType());
1843
1844 // Transposing splat values just means reshaping.
1845 if (auto input =
1846 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1847 if (input.isSplat() && resultTy.hasStaticShape() &&
1848 input.getType().getElementType() == resultTy.getElementType())
1849 return input.reshape(resultTy);
1850 }
1851
1852 // Transpose is not the identity transpose.
1853 const llvm::ArrayRef<int32_t> perms = getPerms();
1854
1855 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1856 return {};
1857
1858 return getInput1();
1859}
1860
1861OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1862 // Element-wise negate(negate(x)) = x
1863 // iff all zero points are constant 0
1864 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1865 if (!definingOp) {
1866 // defining op of input1 is not a negate, cannot fold
1867 return {};
1868 }
1869
1870 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1871 failed(maybeIZp) || *maybeIZp != 0) {
1872 // input1 zero point is not constant 0, cannot fold
1873 return {};
1874 }
1875 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1876 failed(maybeOZp) || *maybeOZp != 0) {
1877 // output zero point is not constant 0, cannot fold
1878 return {};
1879 }
1880 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1881 failed(maybeIZp) || *maybeIZp != 0) {
1882 // definingOp's input1 zero point is not constant 0, cannot fold
1883 return {};
1884 }
1885 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1886 failed(maybeOZp) || *maybeOZp != 0) {
1887 // definingOp's output zero point is not constant 0, cannot fold
1888 return {};
1889 }
1890
1891 return definingOp.getInput1();
1892}
1893
1894OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1895 auto input = getInput1();
1896 // Element-wise abs(abs(x)) = abs(x)
1897 if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
1898 return input;
1899 }
1900
1901 return {};
1902}
1903
1904OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1905 // Fold consecutive concats on the same axis into a single op.
1906 // Keep track of the operands so we are able to construct a new concat
1907 // later. Conservatively assume that we double the number of operands when
1908 // folding
1909 SmallVector<Value, 8> concatOperands;
1910 concatOperands.reserve(2 * getNumOperands());
1911
1912 // Find all operands that are foldable concats
1913 bool foundFoldableConcat = false;
1914 for (Value operand : getOperands()) {
1915 concatOperands.emplace_back(operand);
1916
1917 auto producer = operand.getDefiningOp<ConcatOp>();
1918 if (!producer)
1919 continue;
1920
1921 // Not foldable if axes are not the same
1922 if (getAxis() != producer.getAxis())
1923 continue;
1924
1925 // Replace the original operand with all incoming operands
1926 foundFoldableConcat = true;
1927 concatOperands.pop_back();
1928 llvm::append_range(concatOperands, producer->getOperands());
1929 }
1930
1931 if (!foundFoldableConcat)
1932 return {};
1933
1934 getOperation()->setOperands(concatOperands);
1935 return getResult();
1936}
1937
1938OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1939 auto input = adaptor.getInput1();
1940
1941 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1942 // Fold splat inputs only.
1943 if (!inputAttr || !inputAttr.isSplat())
1944 return {};
1945
1946 auto shapeType = llvm::cast<ShapedType>(getType());
1947 if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1948 auto floatVal = inputAttr.getSplatValue<APFloat>();
1949 return DenseElementsAttr::get(shapeType,
1950 ReciprocalOp::calcOneElement(floatVal));
1951 }
1952
1953 return {};
1954}
1955
1956template <typename Op, typename OpFoldAdaptor>
1958 auto input1ConstShape =
1959 dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
1960 if (!input1ConstShape)
1961 return {};
1962
1963 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
1964
1965 return unaryFolder<OpFoldAdaptor>(input1Attr, input1Attr.getType(),
1966 /*foldDenseValues=*/true);
1967}
1968
1969template <typename Op, typename OpFoldAdaptor>
1971 auto input1ConstShape =
1972 dyn_cast<tosa::ConstShapeOp>(op->getInput1().getDefiningOp());
1973 auto input2ConstShape =
1974 dyn_cast<tosa::ConstShapeOp>(op->getInput2().getDefiningOp());
1975 if (!input1ConstShape || !input2ConstShape)
1976 return {};
1977
1978 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
1979 const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
1980
1981 return binaryFolder<OpFoldAdaptor>(input1Attr, input2Attr,
1982 input1Attr.getType(),
1983 /*foldDenseValues=*/true);
1984}
1985
1986OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
1987 const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().getType());
1988 if (!inputTy || !inputTy.hasRank())
1989 return {};
1990 const int32_t axis = getAxis();
1991 const int64_t dimSize = inputTy.getDimSize(axis);
1992 if (ShapedType::isDynamic(dimSize))
1993 return {};
1994
1995 OpBuilder builder(getContext());
1996 const auto resultAttrTy =
1997 RankedTensorType::get(/*rank=*/1, builder.getIndexType());
1998 return DenseElementsAttr::get(resultAttrTy, dimSize);
1999}
2000
2001OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
2003}
2004
2005OpFoldResult tosa::SubShapeOp::fold(FoldAdaptor adaptor) {
2007}
2008
2009OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
2011}
2012
2013OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
2014 return binaryFold<DivCeilShapeOp, DivFoldAdaptor</*Ceil*/ true>>(this);
2015}
2016
2017OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
2018 return binaryFold<DivFloorShapeOp, DivFoldAdaptor</*Ceil*/ false>>(this);
2019}
2020
2021OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
2023}
2024
2025OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
2027}
2028
2029OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
2031}
2032
2033OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
2035}
2036
2037OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
2039}
2040
2041OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
2043}
return success()
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
#define REDUCE_FOLDER(OP)
static bool mayRequireBroadcast(ValueTypeRange< mlir::OperandRange > operandTypes)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, bool foldDenseValues=false)
static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy, bool foldDenseValues=false)
OpFoldResult binaryFold(Op *op)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
OpFoldResult unaryShapeFold(Op *op)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool signsDiffer(const APInt &a, const APInt &b)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
MLIRContext * getContext() const
Definition Builders.h:56
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class implements iteration on the types of a given range of values.
Definition TypeRange.h:135
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
DynamicAPInt round(const Fraction &f)
Definition Fraction.h:136
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
LogicalResult matchAndRewrite(tosa::CastOp castOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...