MLIR 23.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// TOSA canonicalization patterns and folders.
11//
12//===----------------------------------------------------------------------===//
13
19#include "mlir/Dialect/Traits.h"
22#include "mlir/IR/Matchers.h"
26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28
29#include <functional>
30
31using namespace mlir;
32using namespace mlir::tosa;
33
34//===----------------------------------------------------------------------===//
35// Operator Canonicalizers.
36//===----------------------------------------------------------------------===//
37
38//===----------------------------------------------------------------------===//
39// Tensor Data Engine Operators.
40//===----------------------------------------------------------------------===//
41
42// Check that the zero point of the tensor and padding operations are aligned.
43static bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
44 // Check that padConst is a constant value and a scalar tensor
45 DenseElementsAttr padConstAttr;
46 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
47 (padConstAttr.size() != 1)) {
48 return false;
49 }
50
51 // Check that floating point pad is zero
52 if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
53 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
54 return padConstVal == 0.0f;
55 }
56
57 // Check that the zp and padConst align for the integer (quantized) case
58 if (auto padConstIntAttr =
59 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
61 // Check that zp is a constant value and a scalar tensor
62 if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
63 return false;
64 }
65
66 // Check equality
67 int64_t zpVal = (*zpAttr.begin()).getSExtValue();
68 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
69 return zpVal == padConstVal;
70 }
71
72 // Bail-out on unsupported type
73 return false;
74}
75
76namespace {
77template <typename OpTy>
78struct PoolPadFoldAdaptor;
79
80template <>
81struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
82 using OpTy = tosa::MaxPool2dOp;
83 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
84 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
85 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
86 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
87 return false;
88 return true;
89 }
90 static bool checkPadConstCompliance(OpTy, Value padConst) {
91 // Check that padConst is a constant value and a scalar tensor
92 DenseElementsAttr padConstAttr;
93 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
94 padConstAttr.size() != 1) {
95 return false;
96 }
97
98 // Pad needs to be in the minimum value to be able to merge
99 if (auto padConstFpAttr =
100 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
101 const APFloat padConstVal = *padConstFpAttr.begin();
102 const APFloat lowestVal =
103 APFloat::getLargest(padConstVal.getSemantics(), true);
104 return padConstVal == lowestVal;
105 }
106 if (auto padConstIntAttr =
107 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
108 const APInt padConstVal = *padConstIntAttr.begin();
109 const unsigned int bitWidth = padConstVal.getBitWidth();
110 const APInt lowestVal =
111 padConstIntAttr.getElementType().isUnsignedInteger()
112 ? APInt::getZero(bitWidth)
113 : APInt::getSignedMinValue(bitWidth);
114 return padConstVal == lowestVal;
115 }
116
117 // Bail-out on unsupported type
118 return false;
119 }
120 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
121 Value padInput, ArrayRef<int64_t> newPad) {
122 rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
123 op, op.getType(), padInput, op.getKernel(), op.getStride(),
124 rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
125 }
126};
127
128template <typename OpTy>
129struct ConvPadFoldAdaptor {
130 static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
131 return true;
132 }
133 static bool checkPadConstCompliance(OpTy op, Value padConst) {
134 return checkMatchingPadConstAndZp(padConst, op.getInputZp());
135 }
136 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
137 Value padInput, ArrayRef<int64_t> newPad) {
138 rewriter.replaceOpWithNewOp<OpTy>(
139 op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
140 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
141 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
142 }
143};
144
145// Pattern attempts to fold a `tosa.pad` operator to a following tensor
146// operation like `tosa.conv2d` by merging the padding associated with the
147// pad operator directly to the implicit padding of the tensor operation.
148// This helps eliminate the explicit padding operator if unused.
149template <typename OpTy, typename AdaptorTy>
150struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
151 using OpRewritePattern<OpTy>::OpRewritePattern;
152
153 LogicalResult matchAndRewrite(OpTy tensorOp,
154 PatternRewriter &rewriter) const override {
155 // Check producer is a tosa::PadOp
156 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
157 if (!padOp)
158 return rewriter.notifyMatchFailure(tensorOp,
159 "Producer must be a tosa::PadOp.");
160
161 // Validate that tensor operation has sane padding
162 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
163 if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
164 return rewriter.notifyMatchFailure(
165 tensorOp, "Tensor operation padding shall have 4 elements.");
166
167 // Validate tosa::PadOp padding
168 DenseIntElementsAttr padOpPadding;
169 if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
170 return rewriter.notifyMatchFailure(
171 tensorOp,
172 "The `padding` input specified on the tosa::PadOp must be constant.");
173 }
174 // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
175 // C_after
176 if (padOpPadding.size() != 8)
177 return rewriter.notifyMatchFailure(tensorOp,
178 "Pad padding should have 8 elements.");
179 int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
180 int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
181 int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
182 int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
183 int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
184 int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
185 int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
186 int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
187
188 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
189 return rewriter.notifyMatchFailure(
190 tensorOp, "Folding padding in N or C dimensions is not supported.");
191
192 // Fold padding from Pad into the tensor operation
193 // 4 elements - pad_top, pad_bottom, pad_left, pad_right
194 SmallVector<int64_t> foldedPad(tensorOpPad.size());
195 foldedPad[0] = padHBefore + tensorOpPad[0];
196 foldedPad[1] = padHAfter + tensorOpPad[1];
197 foldedPad[2] = padWBefore + tensorOpPad[2];
198 foldedPad[3] = padWAfter + tensorOpPad[3];
199
200 // Check kernel related restrictions
201 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
202 return rewriter.notifyMatchFailure(
203 tensorOp, "Padding size not aligned with kernel restrictions.");
204 }
205
206 // Check padding constant restrictions
207 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
208 return rewriter.notifyMatchFailure(
209 tensorOp,
210 "Padding constant is not aligned with operator zero-point.");
211 }
212
213 // Check that padding doesn't grow more than 8K level (8192) for now
214 if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
215 return rewriter.notifyMatchFailure(
216 tensorOp, "Padding size more than the 8K level limit.");
217 }
218
219 // Create operator
220 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
221 foldedPad);
222
223 return success();
224 }
225};
226} // namespace
227
228void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
229 MLIRContext *context) {
230 results.add<
231 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
232 context);
233}
234
235void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
236 MLIRContext *context) {
237 results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
238 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
239 context);
240}
241
242struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
244
245 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
246 PatternRewriter &rewriter) const override {
247 Value input = op.getInput();
248 Value output = op.getOutput();
249 ShapedType inputType = llvm::cast<ShapedType>(input.getType());
250 ShapedType outputType = llvm::cast<ShapedType>(output.getType());
251
252 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
253 return failure();
254 }
255
256 // If the output and input shapes are 1x1, then this is a no op.
257 ArrayRef<int64_t> outputShape = outputType.getShape();
258 if (outputShape[1] != 1 || outputShape[2] != 1) {
259 return failure();
260 }
261
262 ArrayRef<int64_t> inputShape = inputType.getShape();
263 if (inputShape[1] != 1 || inputShape[2] != 1) {
264 return failure();
265 }
266
267 rewriter.replaceOp(op, input);
268 return success();
269 }
270};
271
272void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
273 MLIRContext *context) {
274 results.add<MaxPool2dIsNoOp,
275 FoldPadToTensorOp<tosa::MaxPool2dOp,
276 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
277 context);
278}
279
280//===----------------------------------------------------------------------===//
281// Data Layout / Memory Reinterpretation.
282//===----------------------------------------------------------------------===//
283
284struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
285 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
286
287 LogicalResult matchAndRewrite(tosa::ConcatOp op,
288 PatternRewriter &rewriter) const override {
289 if (op.getInput1().size() != 1)
290 return failure();
291 if (op.getInput1().front().getType() != op.getType()) {
292 rewriter
293 .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
294 op.getInput1().front())
295 .getResult();
296 return success();
297 }
298
299 rewriter.replaceOp(op, op.getInput1().front());
300 return success();
301 }
302};
303
304void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
305 MLIRContext *context) {
306 results.add<ConcatOptimization>(context);
307}
308
309LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
310 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
311 if (!notOp)
312 return failure();
313 rewriter.modifyOpInPlace(op, [&]() {
314 op.getOperation()->setOperands(
315 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
316 });
317 return success();
318}
319
321 : public OpRewritePattern<tosa::TransposeOp> {
323
324 LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
325 PatternRewriter &rewriter) const override {
326 // Input is also TransposeOp - transpose(transpose(A)).
327 auto innerTranspose =
328 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
329 if (!innerTranspose)
330 return rewriter.notifyMatchFailure(transposeOp,
331 "input must be transpose operation");
332
333 const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
334 const llvm::ArrayRef<int32_t> innerTransposePerms =
335 innerTranspose.getPerms();
336
337 if (transposePerms.size() != innerTransposePerms.size())
338 return rewriter.notifyMatchFailure(
339 transposeOp,
340 "transpose and inner transpose perms sizes must be equal");
341 if (transposePerms.empty())
342 return rewriter.notifyMatchFailure(
343 transposeOp, "transpose perms sizes must be positive");
344
345 // Consolidate transposes into one transpose.
346 SmallVector<int32_t> perms(transposePerms.size());
347 for (int i = 0, s = transposePerms.size(); i < s; ++i)
348 perms[i] = innerTransposePerms[transposePerms[i]];
349
350 rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
351 transposeOp, transposeOp.getResult().getType(),
352 innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
353
354 return success();
355 }
356};
357
358// Determines the case when tosa.transpose is a tosa.reshape operation.
359struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
361
362 LogicalResult matchAndRewrite(tosa::TransposeOp op,
363 PatternRewriter &rewriter) const override {
364 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
365 return rewriter.notifyMatchFailure(
366 op, "Src is from transpose, can compose transposes");
367
368 Value result = op.getResult();
369 for (Operation *subop : result.getUsers()) {
370 if (isa_and_nonnull<tosa::TransposeOp>(subop))
371 return rewriter.notifyMatchFailure(
372 op, "Dest is used by transpose, can compose transposes");
373 }
374
375 auto input = op.getInput1();
376 auto inputTy = llvm::cast<ShapedType>(input.getType());
377 if (!inputTy.hasRank())
378 return rewriter.notifyMatchFailure(op, "Unranked input.");
379
380 int64_t numDynDims = 0;
381 for (int i = 0; i < inputTy.getRank(); ++i)
382 if (inputTy.isDynamicDim(i))
383 numDynDims++;
384
385 if (numDynDims > 1)
386 return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
387
388 const llvm::ArrayRef<int32_t> permValues = op.getPerms();
389
390 SmallVector<int64_t> nonZeroPerms;
391 nonZeroPerms.reserve(permValues.size());
392 for (auto idx : permValues) {
393 auto sz = inputTy.getDimSize(idx);
394 if (sz != 1)
395 nonZeroPerms.push_back(idx);
396 }
397
398 for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
399 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
400 return rewriter.notifyMatchFailure(op,
401 "Transpose changes memory layout.");
402
403 SmallVector<int64_t> newShape;
404 newShape.reserve(inputTy.getRank());
405 for (int i = 0, s = inputTy.getRank(); i < s; ++i)
406 newShape.push_back(inputTy.getDimSize(permValues[i]));
407
408 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
409 op, op.getType(), op.getInput1(),
410 getTosaConstShape(rewriter, op.getLoc(), newShape));
411 return success();
412 }
413};
414
415void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
416 MLIRContext *context) {
417 results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
418}
419
420struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
422
423 LogicalResult matchAndRewrite(tosa::ClampOp op,
424 PatternRewriter &rewriter) const override {
425 Value input = op.getInput();
426 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
427 auto inputElementType = inputType.getElementType();
428
429 if (isa<FloatType>(inputElementType)) {
430 // Unlike integer types, floating point types can represent infinity.
431 const auto minClamp =
432 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
433 const auto maxClamp =
434 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
435 const bool isMin = minClamp.isNegInfinity();
436 const bool isMax = maxClamp.isInfinity();
437
438 if (isMin && isMax) {
439 rewriter.replaceOp(op, input);
440 return success();
441 }
442 return failure();
443 }
444
445 // i1 types are boolean in TOSA
446 const bool isBoolean = inputElementType.isInteger(1);
447 if (inputElementType.isUnsignedInteger() || isBoolean) {
448 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
449 .getValue()
450 .getZExtValue();
451 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
452 .getValue()
453 .getZExtValue();
454
455 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
456 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
457 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
458
459 if (minClamp <= intMin && maxClamp >= intMax) {
460 rewriter.replaceOp(op, input);
461 return success();
462 }
463 return failure();
464 }
465
466 if (llvm::isa<IntegerType>(inputElementType)) {
467 const int64_t minClamp =
468 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
469 const int64_t maxClamp =
470 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
471
472 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
473 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
474 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
475
476 if (minClamp <= intMin && maxClamp >= intMax) {
477 rewriter.replaceOp(op, input);
478 return success();
479 }
480 return failure();
481 }
482
483 return failure();
484 }
485};
486
487// Attempts the following transformation:
488//
489// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
490// tensor X the following identity holds:
491//
492// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
493//
494// subject to the following valid NaN propagation semantics:
495// --------------------------------------------
496// | OUTER CLAMP | INNER CLAMP | RESULT MODE |
497// |-------------|--------------|-------------|
498// | PROPAGATE | PROPAGATE | PROPAGATE |
499// | PROPAGATE | IGNORE | IGNORE |
500// | IGNORE | PROPAGATE | INVALID |
501// | IGNORE | IGNORE | IGNORE |
502// |------------------------------------------|
503
504struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
505 using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
506
507 // Helper structure to describe the range of a clamp operation.
508 template <typename T>
509 struct ClampRange {
510 ClampRange(const T &start, const T &end) : start(start), end(end) {}
513
514 // Helper function to determine if two Clamp ranges intersect.
515 bool intersects(const ClampRange<T> &otherRange) {
516 return start < otherRange.end && otherRange.start < end;
517 }
518 };
519
520 LogicalResult matchAndRewrite(tosa::ClampOp op,
521 PatternRewriter &rewriter) const override {
522 Value input = op.getInput();
523
524 // Check the input to the CLAMP op is itself a CLAMP.
525 auto clampOp = input.getDefiningOp<tosa::ClampOp>();
526 if (!clampOp)
527 return failure();
528
529 // Check we have a valid NaN propagation combination.
530 const auto opNanMode = op.getNanMode();
531 const auto clampNanMode = clampOp.getNanMode();
532 if (opNanMode == NanPropagationMode::IGNORE &&
533 clampNanMode == NanPropagationMode::PROPAGATE)
534 return failure();
535
536 auto maxValAttr = op.getMaxValAttr();
537 auto minValAttr = op.getMinValAttr();
538 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
539 auto clampOpMinValAttr = clampOp.getMinValAttr();
540
541 auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
542 if (auto quantType =
543 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
544 inputEType = getStorageElementTypeFromQuantized(quantType);
545 }
546
547 Attribute newMinValAttr, newMaxValAttr;
548 if (mlir::isa<FloatType>(inputEType)) {
549 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
550 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
551 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
552 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
553
554 // Check we have intersecting ranges.
555 const auto opMinFloat = floatMinValAttr.getValue();
556 const auto opMaxFloat = floatMaxValAttr.getValue();
557 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
558 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
559 ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
560 ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
561 clampOpMaxFloat);
562 if (!opRangeFloatRange.intersects(clampRangeFloatRange))
563 return failure();
564
565 // Run the transformation.
566 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
567 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
568 newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
569 newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
570 } else {
571 assert(mlir::isa<IntegerType>(inputEType));
572 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
573 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
574 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
575 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
576
577 if (inputEType.isUnsignedInteger()) {
578 // Check we have intersecting ranges.
579 const auto opMinInt = intMinValAttr.getUInt();
580 const auto opMaxInt = intMaxValAttr.getUInt();
581 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
582 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
583 ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
584 ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
585 clampOpMaxInt);
586 if (!opRangeIntRange.intersects(clampRangeIntRange))
587 return failure();
588
589 // Run the transformation.
590 auto newMinVal = std::max(opMinInt, clampOpMinInt);
591 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
592 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
593 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
594 } else {
595 // Check we have intersecting ranges.
596 const auto opMinInt = intMinValAttr.getInt();
597 const auto opMaxInt = intMaxValAttr.getInt();
598 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
599 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
600 ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
601 ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
602 clampOpMaxInt);
603 if (!opRangeIntRange.intersects(clampRangeIntRange))
604 return failure();
605
606 // Run the transformation.
607 auto newMinVal = std::max(opMinInt, clampOpMinInt);
608 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
609 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
610 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
611 }
612 }
613
614 auto newMode = (opNanMode != clampNanMode)
615 ? tosa::NanPropagationMode::IGNORE
616 : opNanMode;
617
618 auto newModeAttr =
619 NanPropagationModeAttr::get(rewriter.getContext(), newMode);
620
621 rewriter.replaceOpWithNewOp<tosa::ClampOp>(
622 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
623 newModeAttr);
624 return success();
625 }
626};
627
628void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
629 MLIRContext *context) {
630 results.add<ClampIsNoOp>(context);
631 results.add<ClampClampOptimization>(context);
632}
633
634struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
635 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
636
637 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
638 PatternRewriter &rewriter) const override {
639 Value sliceInput = sliceOp.getInput1();
640 auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
641 if (!concatOp)
642 return rewriter.notifyMatchFailure(
643 sliceOp, "slice input must be concat operation");
644
645 OperandRange inputs = concatOp.getInput1();
646 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
647 if (!concatType || !concatType.hasStaticShape())
648 return rewriter.notifyMatchFailure(
649 sliceOp, "slice input must be a static ranked tensor");
650 int32_t axis = concatOp.getAxis();
651
652 DenseElementsAttr startElems;
653 DenseElementsAttr sizeElems;
654
655 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
656 return rewriter.notifyMatchFailure(
657 sliceOp, "start of slice must be a static ranked shape");
658
659 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
660 return rewriter.notifyMatchFailure(
661 sliceOp, "size of slice must be a static ranked shape");
662
663 llvm::SmallVector<int64_t> sliceStarts =
664 llvm::to_vector(startElems.getValues<int64_t>());
665 llvm::SmallVector<int64_t> sliceSizes =
666 llvm::to_vector(sizeElems.getValues<int64_t>());
667
668 // Validate slice on the concatenated axis. Slicing along this
669 // axis should span only one of the inputs to the concatenate
670 // operation.
671 std::optional<Value> replaceWithSlice;
672 for (auto input : inputs) {
673 auto inputType = dyn_cast<RankedTensorType>(input.getType());
674 if (!inputType || !inputType.hasStaticShape())
675 return rewriter.notifyMatchFailure(
676 sliceOp, "concat input must be a static ranked tensor");
677
678 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
679 inputType.getDimSize(axis)) {
680 auto start_op =
681 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
682 auto size_op =
683 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
684 replaceWithSlice =
685 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
686 input, start_op, size_op)
687 .getResult();
688 break;
689 }
690 sliceStarts[axis] -= inputType.getDimSize(axis);
691 }
692
693 if (!replaceWithSlice)
694 return rewriter.notifyMatchFailure(
695 sliceOp, "corresponding concat input not found for slice");
696
697 rewriter.replaceOp(sliceOp, replaceWithSlice.value());
698 return success();
699 }
700};
701
702struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
703 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
704
705 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
706 PatternRewriter &rewriter) const override {
707 Value sliceInput = sliceOp.getInput1();
708
709 // Check if producer is a PadOp
710 auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
711 if (!padOp)
712 return rewriter.notifyMatchFailure(sliceOp,
713 "slice input must be a pad operation");
714
715 // Check PadOp has a single consumer
716 if (!padOp->hasOneUse())
717 return rewriter.notifyMatchFailure(sliceOp,
718 "pad shall have a single consumer");
719
720 // Check input is statically ranked
721 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
722 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
723 if (!inputTy || !padTy || !inputTy.hasRank())
724 return rewriter.notifyMatchFailure(sliceOp,
725 "slice input must be a ranked tensor");
726
727 // Validate and extract tosa::PadOp padding
728 DenseIntElementsAttr paddingElems;
729 if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
730 return rewriter.notifyMatchFailure(
731 sliceOp,
732 "`padding` input specified on the tosa::PadOp must be constant.");
733 }
734 llvm::SmallVector<int64_t> padPaddings =
735 llvm::to_vector(paddingElems.getValues<int64_t>());
736
737 // Extract slice parameters
738 DenseElementsAttr startElems;
739 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
740 return rewriter.notifyMatchFailure(
741 sliceOp, "start of slice must be a static ranked shape");
742 llvm::SmallVector<int64_t> sliceStarts =
743 llvm::to_vector(startElems.getValues<int64_t>());
744
745 DenseElementsAttr sizeElems;
746 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
747 return rewriter.notifyMatchFailure(
748 sliceOp, "size of slice must be a static ranked shape");
749 llvm::SmallVector<int64_t> sliceSizes =
750 llvm::to_vector(sizeElems.getValues<int64_t>());
751
752 // Check if dynamic dimensions are sliced
753 const int64_t rank = inputTy.getRank();
754 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
755 const bool isDimDynamic = inputTy.isDynamicDim(i);
756 const bool isDimSliced =
757 (sliceStarts[i] != 0) || (sliceSizes[i] != kInferableDimSize);
758
759 return isDimDynamic && isDimSliced;
760 })) {
761 return rewriter.notifyMatchFailure(
762 sliceOp, "axis that are sliced shall be statically known.");
763 }
764
765 // Update the parameters
766 llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
767 llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
768 llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
769 bool updated = false;
770
771 for (int64_t i = 0; i < rank; ++i) {
772 const int64_t padLo = padPaddings[i * 2];
773 const int64_t padHi = padPaddings[i * 2 + 1];
774 const int64_t sliceStart = sliceStarts[i];
775 const int64_t sliceSize = sliceSizes[i];
776 const int64_t sliceEnd = sliceStart + sliceSize;
777
778 // If dimension is dynamic pass-through
779 if (inputTy.isDynamicDim(i)) {
780 newPadPaddings[i * 2] = padLo;
781 newPadPaddings[i * 2 + 1] = padHi;
782 newSliceStarts[i] = sliceStart;
783 continue;
784 }
785
786 // Handle static dimensions
787 const int64_t dimSize = inputTy.getShape()[i];
788 const int64_t dimTotal = padLo + dimSize + padHi;
789
790 // Check slice within bounds
791 if (sliceStart < 0 || sliceEnd > dimTotal)
792 return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
793
794 // Compute updated slice start parameter
795 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
796 newSliceStarts[i] = newSliceStart;
797 updated |= newSliceStart != sliceStart;
798
799 // Compute updated pad parameters
800 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
801 const int64_t newPadHi =
802 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
803 newPadPaddings[i * 2] = newPadLo;
804 newPadPaddings[i * 2 + 1] = newPadHi;
805 updated |= (newPadLo != padLo) || (newPadHi != padHi);
806
807 // Calculate new pad output shape
808 newPadShape[i] =
809 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
810 }
811
812 // Check that we actually need to proceed with the rewrite
813 if (!updated)
814 return rewriter.notifyMatchFailure(
815 sliceOp, "terminate condition; nothing to rewrite");
816
817 // Create a PadOp with updated padding
818 auto newPaddingsOp =
819 getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
820 auto newPadTy =
821 RankedTensorType::get(newPadShape, inputTy.getElementType());
822 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
823 padOp.getInput1(), newPaddingsOp,
824 padOp.getPadConst());
825
826 // Update SliceOp and point to new PadOp
827 auto newStartOp =
828 getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
829 rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
830 newPadOp.getResult(), newStartOp,
831 sliceOp.getSize());
832
833 return success();
834 }
835};
836
837// Update size operand of tosa.slice if size has dynamic dims but corresponding
838// output dim is static
840 : public OpRewritePattern<tosa::SliceOp> {
841 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
842
843 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
844 PatternRewriter &rewriter) const override {
845 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
846
847 ElementsAttr sizeElems;
848 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
849 return rewriter.notifyMatchFailure(
850 sliceOp, "size of slice must be a static ranked shape");
851 }
852
853 llvm::SmallVector<int64_t> sliceSizes =
854 llvm::to_vector(sizeElems.getValues<int64_t>());
855
856 bool replaceSliceSize{false};
857 // if size op has kInferableDimSize indicating dynamic shape but
858 // corresponding dim on the output is statically known, update size to match
859 // with known output dim shape
860 for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
861 if (size == kInferableDimSize && !resultType.isDynamicDim(index)) {
862 sliceSizes[index] = resultType.getDimSize(index);
863 replaceSliceSize = true;
864 }
865 }
866
867 if (!replaceSliceSize) {
868 return rewriter.notifyMatchFailure(
869 sliceOp, "no dimension of size of slice is dynamic that resolves "
870 "to static output shape");
871 }
872
873 auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
874 auto newSliceOp =
875 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
876 sliceOp.getInput1(), sliceOp.getStart(), size_op);
877
878 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
879 return success();
880 }
881};
882
883void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
884 MLIRContext *context) {
885 results.add<ConcatSliceOptimization, PadSliceOptimization,
886 SliceDynamicSizeCanonicalization>(context);
887}
888
889struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
890 using OpRewritePattern<tosa::CastOp>::OpRewritePattern;
891
892 LogicalResult matchAndRewrite(tosa::CastOp castOp,
893 PatternRewriter &rewriter) const override {
894 const Value castInput = castOp.getInput();
895 auto innerCastOp = castInput.getDefiningOp<tosa::CastOp>();
896 if (!innerCastOp)
897 return rewriter.notifyMatchFailure(castOp,
898 "input must be cast operation");
899
900 const Value innerCastInput = innerCastOp.getInput();
901
902 const auto innerInputType =
903 llvm::cast<ShapedType>(innerCastInput.getType());
904 const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
905 const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
906
907 const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType,
908 outerOutputType};
909 if (llvm::any_of(types, [](const ShapedType type) {
910 return !type.getElementType().isInteger();
911 }))
912 return rewriter.notifyMatchFailure(castOp,
913 "only integer types are supported");
914
915 // Check inner cast is non-narrowing
916 const unsigned innerInputBitWidth = innerInputType.getElementTypeBitWidth();
917 if (innerInputBitWidth > innerOutputType.getElementTypeBitWidth())
918 return rewriter.notifyMatchFailure(castOp,
919 "inner cast operation is narrowing");
920
921 // Check outer cast is non-narrowing from the inner cast input
922 if (innerInputBitWidth > outerOutputType.getElementTypeBitWidth())
923 return rewriter.notifyMatchFailure(castOp,
924 "outer cast operation is narrowing");
925
926 rewriter.replaceOpWithNewOp<tosa::CastOp>(castOp, outerOutputType,
927 innerCastInput);
928
929 return success();
930 }
931};
932
933void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
934 MLIRContext *context) {
935 results.add<NonNarrowingCastsOptimization>(context);
936}
937
938//===----------------------------------------------------------------------===//
939// Operator Folders.
940//===----------------------------------------------------------------------===//
941
942template <typename Folder>
943static DenseElementsAttr
945 bool foldDenseValues = false) {
946 if (!lhs || !rhs)
947 return {};
948
949 const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
950 const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
951 if (lETy != rETy)
952 return {};
953
954 if (lhs.isSplat() && rhs.isSplat()) {
955 if (isa<FloatType>(lETy)) {
956 const APFloat l = lhs.getSplatValue<APFloat>();
957 const APFloat r = rhs.getSplatValue<APFloat>();
958 const auto maybeResult = Folder::fold(l, r);
959 if (failed(maybeResult))
960 return {};
961 return DenseElementsAttr::get(returnTy, maybeResult.value());
962 }
963
964 if (const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
965 const APInt l = lhs.getSplatValue<APInt>();
966 const APInt r = rhs.getSplatValue<APInt>();
967 const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
968 if (failed(maybeResult))
969 return {};
970 return DenseElementsAttr::get(returnTy, maybeResult.value());
971 }
972 }
973
974 if (foldDenseValues) {
975 assert(lETy.isIntOrIndex() &&
976 "Only integer types are currently supported.");
977 SmallVector<APInt> resultValues;
978 for (auto [l, r] :
979 llvm::zip(lhs.getValues<APInt>(), rhs.getValues<APInt>())) {
980 const auto maybeResult = Folder::fold(l, r, false);
981 if (failed(maybeResult))
982 return {};
983 resultValues.push_back(maybeResult.value());
984 }
985 return DenseElementsAttr::get(returnTy, resultValues);
986 }
987
988 return {};
989}
990
991template <typename Folder>
992static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy,
993 bool foldDenseValues = false) {
994 if (!val)
995 return {};
996
997 const auto vETy = llvm::cast<ShapedType>(val.getType()).getElementType();
998
999 if (val.isSplat()) {
1000 if (const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
1001 const APInt v = val.getSplatValue<APInt>();
1002 const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
1003 if (failed(maybeResult))
1004 return {};
1005 return DenseElementsAttr::get(returnTy, maybeResult.value());
1006 }
1007 }
1008
1009 if (foldDenseValues) {
1010 mlir::Type elemTy = val.getElementType();
1011 if (elemTy.isIntOrIndex()) {
1012 SmallVector<APInt> resultValues;
1013 for (auto const &v : val.getValues<APInt>()) {
1014 const auto maybeResult = Folder::fold(v, false);
1015 if (failed(maybeResult))
1016 return {};
1017 resultValues.push_back(maybeResult.value());
1018 }
1019 return DenseElementsAttr::get(returnTy, resultValues);
1020 }
1021 }
1022
1023 // Folding arbitrarily sized tensor operations is not supported
1024 return {};
1025}
1026
1028 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1029 const bool isUnsigned) {
1030 bool overflow;
1031 const APInt result =
1032 isUnsigned ? lhs.uadd_ov(rhs, overflow) : lhs.sadd_ov(rhs, overflow);
1033 if (overflow)
1034 return failure();
1035 return result;
1036 }
1037
1038 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1039 return lhs + rhs;
1040 }
1041};
1042
1044 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1045 const bool isUnsigned) {
1046 bool overflow;
1047 const APInt result =
1048 isUnsigned ? lhs.usub_ov(rhs, overflow) : lhs.ssub_ov(rhs, overflow);
1049 if (overflow)
1050 return failure();
1051 return result;
1052 }
1053
1054 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1055 return lhs - rhs;
1056 }
1057};
1058
1060 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1061 const bool isUnsigned) {
1062
1063 const unsigned originalWidth = lhs.getBitWidth();
1064
1065 // Check same type
1066 if (lhs.getBitWidth() != rhs.getBitWidth()) {
1067 return failure();
1068 }
1069
1070 // If either is `0`
1071 if (lhs == 0 || rhs == 0)
1072 return APInt::getZero(originalWidth);
1073
1074 bool overflow = false;
1075 APInt const result =
1076 isUnsigned ? lhs.umul_ov(rhs, overflow) : lhs.smul_ov(rhs, overflow);
1077
1078 if (overflow)
1079 return failure();
1080
1081 return result.trunc(originalWidth);
1082 }
1083
1084 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1085 return lhs * rhs;
1086 }
1087};
1088
1089static bool signsDiffer(const APInt &a, const APInt &b) {
1090 return a.isNegative() != b.isNegative();
1091}
1092
1093template <bool Ceil>
1095 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1096 bool isUnsigned) {
1097 if (lhs.getBitWidth() != rhs.getBitWidth())
1098 return failure();
1099 if (rhs.isZero())
1100 return failure();
1101
1102 if (isUnsigned) {
1103 APInt q{};
1104 APInt r{};
1105 APInt::udivrem(lhs, rhs, q, r);
1106 if (!r.isZero() && Ceil) {
1107 return q + 1;
1108 }
1109 return q;
1110 }
1111
1112 // Signed: start from trunc-toward-zero, then adjust to ceil.
1113 bool overflow{false};
1114 APInt const q = lhs.sdiv_ov(rhs, overflow);
1115 if (overflow)
1116 return failure();
1117 APInt const r = lhs.srem(rhs);
1118
1119 if (Ceil && !r.isZero() && !signsDiffer(lhs, rhs)) {
1120 // Same sign => exact quotient is positive; trunc is below ceil =>
1121 // increment q.
1122 return q + 1;
1123 }
1124 return q;
1125 }
1126
1127 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1128 return lhs / rhs;
1129 }
1130};
1131
1133 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1134 bool isUnsigned) {
1135 if (lhs.getBitWidth() != rhs.getBitWidth())
1136 return failure();
1137 if (lhs.isNegative() || (!rhs.isStrictlyPositive()))
1138 return failure();
1139
1140 if (isUnsigned) {
1141 return lhs.urem(rhs);
1142 }
1143
1144 return lhs.srem(rhs);
1145 }
1146
1147 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1148 auto t = lhs;
1149 auto const r = t.mod(rhs);
1150 if (llvm::APFloatBase::opStatus::opOK == r) {
1151 return t;
1152 }
1153 return failure();
1154 }
1155};
1156
1158 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1159 bool isUnsigned) {
1160 if (lhs.getBitWidth() != rhs.getBitWidth())
1161 return failure();
1162 return lhs.getSExtValue() >= rhs.getSExtValue() ? lhs : rhs;
1163 }
1164
1165 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1166 return lhs >= rhs ? lhs : rhs;
1167 }
1168};
1169
1171 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1172 bool isUnsigned) {
1173 if (lhs.getBitWidth() != rhs.getBitWidth())
1174 return failure();
1175 return lhs.getSExtValue() <= rhs.getSExtValue() ? lhs : rhs;
1176 }
1177
1178 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1179 return lhs <= rhs ? lhs : rhs;
1180 }
1181};
1182
1184 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1185 auto const numBits = value.getBitWidth();
1186 if (isUnsigned) {
1187 auto const zextv = value.getZExtValue();
1188 if (zextv >= numBits)
1189 return failure();
1190 return APInt::getOneBitSet(numBits, zextv);
1191 }
1192 auto const sextv = value.getSExtValue();
1193 if (sextv < 0 || sextv >= numBits || (value.isNegative()))
1194 return failure();
1195 return APInt::getOneBitSet(numBits, sextv);
1196 }
1197};
1198
1200 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1201 if (!value.isStrictlyPositive())
1202 return failure();
1203 return APInt(/*numBits=*/value.getBitWidth(), value.ceilLogBase2());
1204 }
1205};
1206
1208 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1209 if (!value.isStrictlyPositive())
1210 return failure();
1211 return APInt(/*numBits=*/value.getBitWidth(), value.logBase2());
1212 }
1213};
1214
1216 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1217 const bool isUnsigned) {
1218 return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs));
1219 }
1220
1221 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1222 return APInt(1, lhs > rhs);
1223 }
1224};
1225
1227 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1228 const bool isUnsigned) {
1229 return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs));
1230 }
1231
1232 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1233 return APInt(1, lhs >= rhs);
1234 }
1235};
1236
1238 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1239 const bool isUnsigned) {
1240 return APInt(1, lhs == rhs);
1241 }
1242
1243 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1244 return APInt(1, lhs == rhs);
1245 }
1246};
1247
1248static bool isSplatZero(Type elemType, DenseElementsAttr val) {
1249 if (llvm::isa<FloatType>(elemType))
1250 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
1251 if (llvm::isa<IntegerType>(elemType))
1252 return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
1253 return false;
1254}
1255
1256static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
1257 if (llvm::isa<FloatType>(elemType))
1258 return val && val.isSplat() &&
1259 val.getSplatValue<APFloat>().isExactlyValue(1.0);
1260 if (llvm::isa<IntegerType>(elemType)) {
1261 const int64_t shifted = 1LL << shift;
1262 return val && val.isSplat() &&
1263 val.getSplatValue<APInt>().getSExtValue() == shifted;
1264 }
1265 return false;
1266}
1267
1268OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1269 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1270 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1271 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1272 if (!lhsTy || !rhsTy || !resultTy)
1273 return {};
1274
1275 // Cannot create an ElementsAttr from non-int/float/index types
1276 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1277 !rhsTy.getElementType().isIntOrIndexOrFloat())
1278 return {};
1279
1280 auto resultETy = resultTy.getElementType();
1281 auto lhsAttr =
1282 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1283 auto rhsAttr =
1284 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1285
1286 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1287 lhsTy.getShape(), rhsTy.getShape());
1288 if (isBroadcastable && lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1289 return getInput1();
1290 if (isBroadcastable && rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
1291 return getInput2();
1292
1293 if (!lhsAttr || !rhsAttr)
1294 return {};
1295
1296 return binaryFolder<AddFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1297}
1298
1299OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1300 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
1301 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1302 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1303 !outputTy.hasStaticShape())
1304 return {};
1305
1306 const Type outputElementTy = getElementTypeOrSelf(outputTy);
1307 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
1308 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1309 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1310 return DenseElementsAttr::get(outputTy, zero);
1311 }
1312
1313 return {};
1314}
1315
1316OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1317 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1318 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1319 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1320 if (!lhsTy || !rhsTy || !resultTy)
1321 return {};
1322 if (lhsTy.getElementType() != rhsTy.getElementType())
1323 return {};
1324
1325 // IntDivOp inputs must be integer type, no need to check for quantized
1326 // type
1327 auto resultETy = resultTy.getElementType();
1328 auto lhsAttr =
1329 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1330 auto rhsAttr =
1331 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1332 if (lhsAttr && lhsAttr.isSplat()) {
1333 if (llvm::isa<IntegerType>(resultETy) && resultTy.hasStaticShape() &&
1334 lhsAttr.getSplatValue<APInt>().isZero())
1335 return lhsAttr.resizeSplat(resultTy);
1336 }
1337
1338 if (rhsAttr && rhsAttr.isSplat()) {
1339 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1340 lhsTy.getShape(), rhsTy.getShape());
1341 if (isBroadcastable && lhsTy == resultTy &&
1342 llvm::isa<IntegerType>(resultETy) &&
1343 rhsAttr.getSplatValue<APInt>().isOne())
1344 return getInput1();
1345 }
1346
1347 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1348 llvm::isa<IntegerType>(resultETy)) {
1349 APInt l = lhsAttr.getSplatValue<APInt>();
1350 APInt r = rhsAttr.getSplatValue<APInt>();
1351 if (!r.isZero()) {
1352 auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
1353 auto const result =
1354 DivFoldAdaptor</*Ceil*/ false>::fold(l, r, intTy.isUnsigned());
1355 if (failed(result))
1356 return {};
1357 return DenseElementsAttr::get(resultTy, result.value());
1358 }
1359 }
1360
1361 return {};
1362}
1363
1364namespace {
1365// calculate lhs * rhs >> shift according to TOSA Spec
1366// return nullopt if result is not in range of int32_t when shift > 0
1367std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
1368 unsigned bitwidth) {
1369 bool overflow = false;
1370 APInt result = lhs.sext(64).smul_ov(rhs.sext(64), overflow);
1371
1372 if (overflow)
1373 return std::nullopt;
1374
1375 if (shift > 0) {
1376 auto round = APInt(64, 1) << (shift - 1);
1377 result += round;
1378 result.ashrInPlace(shift);
1379 // REQUIRE(product >= minimum_s<i32_t>() && product <=
1380 // maximum_s<i32_t>())
1381 if (!(result.getSExtValue() >= INT32_MIN &&
1382 result.getSExtValue() <= INT32_MAX)) {
1383 // REQUIRE failed
1384 return std::nullopt;
1385 }
1386 }
1387
1388 return result.trunc(bitwidth);
1389}
1390
1391DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
1392 RankedTensorType ty, int32_t shift) {
1393 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
1394 if (llvm::isa<IntegerType>(ty.getElementType())) {
1395 APInt l = lhs.getSplatValue<APInt>();
1396 APInt r = rhs.getSplatValue<APInt>();
1397
1398 if (shift == 0) {
1399 return DenseElementsAttr::get(ty, l * r);
1400 }
1401
1402 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1403 const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
1404 if (!result)
1405 return {};
1406 return DenseElementsAttr::get(ty, result.value());
1407 }
1408
1409 if (llvm::isa<FloatType>(ty.getElementType())) {
1410 APFloat l = lhs.getSplatValue<APFloat>();
1411 APFloat r = rhs.getSplatValue<APFloat>();
1412 APFloat result = l * r;
1413 return DenseElementsAttr::get(ty, result);
1414 }
1415 }
1416
1417 return {};
1418}
1419} // namespace
1420
1421OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1422 auto lhs = getInput1();
1423 auto rhs = getInput2();
1424 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
1425 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
1426 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1427 if (!lhsTy || !rhsTy || !resultTy)
1428 return {};
1429
1430 auto resultETy = resultTy.getElementType();
1431 auto lhsAttr =
1432 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1433 auto rhsAttr =
1434 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1435
1436 // Result right shift on i32_t data type only. For simplification,
1437 // synthesize a zero shift for other data type.
1438 int32_t shift = 0;
1439 if (resultETy.isInteger(32)) {
1440 ElementsAttr shift_elem;
1441 if (getShift().getImpl()) {
1442 if (!matchPattern(getShift(), m_Constant(&shift_elem)))
1443 // cannot be folded when the shift value is unknown.
1444 return {};
1445 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1446 }
1447 }
1448
1449 if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr) &&
1450 resultTy.hasStaticShape())
1451 // constant values can only be resized if resulting type is static
1452 return lhsAttr.resizeSplat(resultTy);
1453 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr) &&
1454 resultTy.hasStaticShape())
1455 return rhsAttr.resizeSplat(resultTy);
1456
1457 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1458 lhsTy.getShape(), rhsTy.getShape());
1459 if (isBroadcastable && rhsTy == resultTy &&
1460 isSplatOne(resultETy, lhsAttr, shift))
1461 return rhs;
1462 if (isBroadcastable && lhsTy == resultTy &&
1463 isSplatOne(resultETy, rhsAttr, shift))
1464 return lhs;
1465
1466 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1467}
1468
1469OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1470 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1471 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1472 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1473 if (!lhsTy || !rhsTy || !resultTy)
1474 return {};
1475
1476 // Cannot create an ElementsAttr from non-int/float/index types
1477 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1478 !rhsTy.getElementType().isIntOrIndexOrFloat())
1479 return {};
1480
1481 auto resultETy = resultTy.getElementType();
1482 auto lhsAttr =
1483 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1484 auto rhsAttr =
1485 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1486
1487 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1488 lhsTy.getShape(), rhsTy.getShape());
1489 if (isBroadcastable && lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1490 return getInput1();
1491
1492 if (!lhsAttr || !rhsAttr)
1493 return {};
1494
1495 return binaryFolder<SubFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1496}
1497
1498OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1499 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1500 auto lhsAttr =
1501 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1502 auto rhsAttr =
1503 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1504
1505 if (!lhsAttr || !rhsAttr)
1506 return {};
1507
1508 return binaryFolder<GreaterFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1509}
1510
1511OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1512 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1513 auto lhsAttr =
1514 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1515 auto rhsAttr =
1516 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1517
1518 if (!lhsAttr || !rhsAttr)
1519 return {};
1520
1521 return binaryFolder<GreaterEqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1522}
1523
1524OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1525 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1526 auto lhsAttr =
1527 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1528 auto rhsAttr =
1529 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1530 Value lhs = getInput1();
1531 Value rhs = getInput2();
1532 auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
1533
1534 // If we are comparing an integer value to itself it is always true. We
1535 // can not do this with float due to float values.
1536 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1537 resultTy.hasStaticShape() && lhs == rhs) {
1538 return DenseElementsAttr::get(resultTy, true);
1539 }
1540
1541 if (!lhsAttr || !rhsAttr)
1542 return {};
1543
1544 return binaryFolder<EqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1545}
1546
1547OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1548 if (getInput().getType() == getType())
1549 return getInput();
1550
1551 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1552 if (!operand)
1553 return {};
1554
1555 auto inTy = llvm::cast<ShapedType>(getInput().getType());
1556 auto outTy = llvm::cast<ShapedType>(getType());
1557 auto inETy = inTy.getElementType();
1558 auto outETy = outTy.getElementType();
1559
1560 if (operand.isSplat()) {
1561 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1562 bool overflow;
1563 auto splatVal = operand.getSplatValue<APFloat>();
1564 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1565 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1566 &overflow);
1567 return SplatElementsAttr::get(outTy, splatVal);
1568 }
1569
1570 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1571 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1572 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1573 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1574 llvm::RoundingMode::NearestTiesToEven);
1575 return SplatElementsAttr::get(outTy, splatVal);
1576 }
1577
1578 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1579 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1580 auto intVal = APSInt(
1581 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1582 auto floatVal = operand.getSplatValue<APFloat>();
1583 bool exact;
1584 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1585 &exact);
1586 return SplatElementsAttr::get(outTy, intVal);
1587 }
1588
1589 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1590 const auto inIntType = llvm::cast<IntegerType>(inETy);
1591 auto unsignIn = inIntType.isUnsignedInteger();
1592 bool trunc =
1593 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1594 auto intVal = operand.getSplatValue<APInt>();
1595 auto bitwidth = outETy.getIntOrFloatBitWidth();
1596
1597 // i1 types are boolean in TOSA
1598 if (outETy.isInteger(1)) {
1599 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1600 } else if (trunc) {
1601 intVal = intVal.trunc(bitwidth);
1602 } else if (unsignIn || inIntType.isInteger(1)) {
1603 intVal = intVal.zext(bitwidth);
1604 } else {
1605 intVal = intVal.sext(bitwidth);
1606 }
1607
1608 return SplatElementsAttr::get(outTy, intVal);
1609 }
1610 }
1611
1612 return {};
1613}
1614
1615OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1616
1617OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1618
1619#define REDUCE_FOLDER(OP) \
1620 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1621 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1622 if (!inputTy.hasRank()) \
1623 return {}; \
1624 if (inputTy != getType()) \
1625 return {}; \
1626 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1627 return getInput(); \
1628 return {}; \
1629 }
1630
1631REDUCE_FOLDER(ReduceAllOp)
1632REDUCE_FOLDER(ReduceAnyOp)
1633REDUCE_FOLDER(ReduceMaxOp)
1634REDUCE_FOLDER(ReduceMinOp)
1635REDUCE_FOLDER(ReduceProductOp)
1636REDUCE_FOLDER(ReduceSumOp)
1637#undef REDUCE_FOLDER
1638
1639OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1640 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1641 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1642
1643 if (!inputTy || !outputTy)
1644 return {};
1645
1646 // Fold when the input and output types are the same. This is only safe
1647 // when there is at most 1 dynamic dimension. For 2 or more dynamic
1648 // dimensions, there may still be a productive reshape.
1649 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1650 return getInput1();
1651
1652 // reshape(reshape(x)) -> reshape(x)
1653 if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1654 getInput1().getDefiningOp())) {
1655 getInput1Mutable().assign(reshapeOp.getInput1());
1656 return getResult();
1657 }
1658
1659 // Cannot create an ElementsAttr from non-int/float/index types
1660 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1661 return {};
1662
1663 // reshape(const(x)) -> const(reshape-attr(x))
1664 if (auto operand =
1665 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1666 // Constants must have static shape.
1667 if (!outputTy.hasStaticShape())
1668 return {};
1669
1670 // Okay to duplicate splat constants.
1671 if (operand.isSplat())
1672 return SplatElementsAttr::get(outputTy,
1673 operand.getSplatValue<Attribute>());
1674
1675 // Don't duplicate other constants.
1676 if (!getInput1().hasOneUse())
1677 return {};
1678
1680 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
1681 return {};
1682
1683 return operand.reshape(
1684 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1685 }
1686
1687 return {};
1688}
1689
1690OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1691 // If the pad is all zeros we can fold this operation away.
1692 if (adaptor.getPadding() && getInput1().getType() == getType()) {
1693 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1694 if (densePad && densePad.isSplat() &&
1695 densePad.getSplatValue<APInt>().isZero()) {
1696 return getInput1();
1697 }
1698 }
1699
1700 return {};
1701}
1702
1703// Fold away cases where a tosa.resize operation returns a copy
1704// of the input image.
1705OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1706 auto scaleAttr =
1707 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1708 auto offsetAttr =
1709 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1710 auto borderAttr =
1711 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1712 if (!scaleAttr || !offsetAttr || !borderAttr) {
1713 return {};
1714 }
1715
1716 auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
1717 auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
1718 auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
1719 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1720 return {};
1721 }
1722
1723 // Check unit scaling.
1724 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1725 return {};
1726 }
1727
1728 // There should be no offset.
1729 if (offset[0] != 0 || offset[1] != 0) {
1730 return {};
1731 }
1732
1733 // There should be no border.
1734 if (border[0] != 0 || border[1] != 0) {
1735 return {};
1736 }
1737
1738 auto input = getInput();
1739 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1740 auto resultTy = llvm::cast<RankedTensorType>(getType());
1741 if (inputTy != resultTy)
1742 return {};
1743
1744 return input;
1745}
1746
1747OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1748 auto operand = getInput1();
1749 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1750 auto axis = getAxis();
1751 auto operandAttr =
1752 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1753 if (operandAttr)
1754 return operandAttr;
1755
1756 // If the dim-length is 1, tosa.reverse is a no-op.
1757 if (operandTy.hasRank() &&
1758 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1759 return operand;
1760
1761 return {};
1762}
1763
1764OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1765 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1766 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1767
1768 if (!inputTy || !outputTy)
1769 return {};
1770
1771 if (inputTy == outputTy && inputTy.hasStaticShape())
1772 return getInput1();
1773
1774 // Check if this is a no-op slice (starts at 0 and size matches input)
1775
1776 DenseElementsAttr startElems;
1777 if (!matchPattern(getStart(), m_Constant(&startElems)))
1778 return {};
1779
1780 // Check if all start values are zero
1781 bool startIsZeros =
1782 llvm::all_of(startElems.getValues<APInt>(),
1783 [](const APInt &val) { return val.isZero(); });
1784
1785 if (startIsZeros) {
1786
1787 // Check if size matches input shape
1788 DenseElementsAttr sizeElems;
1789 if (!matchPattern(getSize(), m_Constant(&sizeElems)))
1790 return {};
1791
1792 auto inputShape = inputTy.getShape();
1793 auto sizeValues = sizeElems.getValues<APInt>();
1794
1795 bool sizeMatchesInput = true;
1796 for (const auto &[i, sizeVal] : llvm::enumerate(sizeValues)) {
1797 int64_t size = sizeVal.getSExtValue();
1798
1799 if (inputTy.isDynamicDim(i)) {
1800 // For dynamic dimensions, check for kInferableDimSize indicating full
1801 // dimension is sliced
1802 if (size != kInferableDimSize) {
1803 sizeMatchesInput = false;
1804 break;
1805 }
1806 } else {
1807 // For static dimensions, check that size must match exactly or be
1808 // kInferableDimSize indicating full dimension is sliced
1809 if (size != kInferableDimSize && size != inputShape[i]) {
1810 sizeMatchesInput = false;
1811 break;
1812 }
1813 }
1814 }
1815
1816 if (sizeMatchesInput)
1817 return getInput1();
1818 }
1819
1820 // The following checks require the input to be a constant
1821 if (!adaptor.getInput1())
1822 return {};
1823
1824 // Cannot create an ElementsAttr from non-int/float/index types
1825 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1826 !outputTy.getElementType().isIntOrIndexOrFloat())
1827 return {};
1828
1829 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1830 if (operand.isSplat() && outputTy.hasStaticShape()) {
1831 return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
1832 }
1833
1834 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1835 outputTy.getNumElements() == 1) {
1836 llvm::SmallVector<uint64_t> indices =
1837 llvm::to_vector(startElems.getValues<uint64_t>());
1838 auto value = operand.getValues<Attribute>()[indices];
1839 return SplatElementsAttr::get(outputTy, value);
1840 }
1841
1842 return {};
1843}
1844
1845OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1846 const Value pred = getPred();
1847 const Value onTrue = getOnTrue();
1848 const Value onFalse = getOnFalse();
1849
1850 const auto predTy = llvm::dyn_cast<RankedTensorType>(pred.getType());
1851 const auto onTrueTy = llvm::dyn_cast<RankedTensorType>(onTrue.getType());
1852 const auto onFalseTy = llvm::dyn_cast<RankedTensorType>(onFalse.getType());
1853 if (!predTy || !onTrueTy || !onFalseTy)
1854 return {};
1855
1856 const Type resultTy = getType();
1857
1858 const ArrayRef<int64_t> predShape = predTy.getShape();
1859 const ArrayRef<int64_t> onTrueShape = onTrueTy.getShape();
1860
1861 if (onTrue == onFalse && onTrueTy == resultTy &&
1862 OpTrait::util::staticallyKnownBroadcastable(predShape, onTrueShape))
1863 return onTrue;
1864
1865 auto predicate =
1866 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1867 if (!predicate)
1868 return {};
1869 if (!predicate.isSplat())
1870 return {};
1871
1872 const bool predicateValue = predicate.getSplatValue<APInt>().getBoolValue();
1873
1874 SmallVector<SmallVector<int64_t>, 3> shapes;
1875 shapes.emplace_back(predShape);
1876 shapes.emplace_back(onTrueShape);
1877 shapes.emplace_back(onFalseTy.getShape());
1878 const bool isBroadcastable =
1880
1881 if (predicateValue == true && onTrueTy == resultTy && isBroadcastable)
1882 return onTrue;
1883 if (predicateValue == false && onFalseTy == resultTy && isBroadcastable)
1884 return onFalse;
1885 return {};
1886}
1887
1888OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1889 if (getInput1().getType() == getType()) {
1890 if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1891 adaptor.getMultiples())) {
1892 if (multiples.isSplat() &&
1893 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1894 return getInput1();
1895 if (auto int_array_attr =
1896 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1897 if (llvm::all_of(int_array_attr.getValues<APInt>(),
1898 [](APInt v) { return v.getSExtValue() == 1; }))
1899 return getInput1();
1900 }
1901 }
1902 }
1903 return {};
1904}
1905
1906OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1907 auto resultTy = llvm::cast<ShapedType>(getType());
1908
1909 // Transposing splat values just means reshaping.
1910 if (auto input =
1911 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1912 if (input.isSplat() && resultTy.hasStaticShape() &&
1913 input.getType().getElementType() == resultTy.getElementType())
1914 return input.reshape(resultTy);
1915 }
1916
1917 // Transpose is not the identity transpose.
1918 const llvm::ArrayRef<int32_t> perms = getPerms();
1919
1920 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1921 return {};
1922
1923 return getInput1();
1924}
1925
1926OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1927 // Element-wise negate(negate(x)) = x
1928 // iff all zero points are constant 0
1929 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1930 if (!definingOp) {
1931 // defining op of input1 is not a negate, cannot fold
1932 return {};
1933 }
1934
1935 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1936 failed(maybeIZp) || *maybeIZp != 0) {
1937 // input1 zero point is not constant 0, cannot fold
1938 return {};
1939 }
1940 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1941 failed(maybeOZp) || *maybeOZp != 0) {
1942 // output zero point is not constant 0, cannot fold
1943 return {};
1944 }
1945 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1946 failed(maybeIZp) || *maybeIZp != 0) {
1947 // definingOp's input1 zero point is not constant 0, cannot fold
1948 return {};
1949 }
1950 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1951 failed(maybeOZp) || *maybeOZp != 0) {
1952 // definingOp's output zero point is not constant 0, cannot fold
1953 return {};
1954 }
1955
1956 return definingOp.getInput1();
1957}
1958
1959OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1960 auto input = getInput1();
1961 // Element-wise abs(abs(x)) = abs(x)
1962 if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
1963 return input;
1964 }
1965
1966 return {};
1967}
1968
1969OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1970 // Fold consecutive concats on the same axis into a single op.
1971 // Keep track of the operands so we are able to construct a new concat
1972 // later. Conservatively assume that we double the number of operands when
1973 // folding
1974 SmallVector<Value, 8> concatOperands;
1975 concatOperands.reserve(2 * getNumOperands());
1976
1977 // Find all operands that are foldable concats
1978 bool foundFoldableConcat = false;
1979 for (Value operand : getOperands()) {
1980 concatOperands.emplace_back(operand);
1981
1982 auto producer = operand.getDefiningOp<ConcatOp>();
1983 if (!producer)
1984 continue;
1985
1986 // Not foldable if axes are not the same
1987 if (getAxis() != producer.getAxis())
1988 continue;
1989
1990 // Replace the original operand with all incoming operands
1991 foundFoldableConcat = true;
1992 concatOperands.pop_back();
1993 llvm::append_range(concatOperands, producer->getOperands());
1994 }
1995
1996 if (!foundFoldableConcat)
1997 return {};
1998
1999 getOperation()->setOperands(concatOperands);
2000 return getResult();
2001}
2002
2003OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
2004 auto input = adaptor.getInput1();
2005
2006 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
2007 // Fold splat inputs only.
2008 if (!inputAttr || !inputAttr.isSplat())
2009 return {};
2010
2011 auto shapeType = llvm::cast<ShapedType>(getType());
2012 if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
2013 auto floatVal = inputAttr.getSplatValue<APFloat>();
2014 return DenseElementsAttr::get(shapeType,
2015 ReciprocalOp::calcOneElement(floatVal));
2016 }
2017
2018 return {};
2019}
2020
2021template <typename Op, typename OpFoldAdaptor>
2023 auto input1ConstShape =
2024 dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
2025 if (!input1ConstShape)
2026 return {};
2027
2028 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2029
2030 return unaryFolder<OpFoldAdaptor>(input1Attr, input1Attr.getType(),
2031 /*foldDenseValues=*/true);
2032}
2033
2034template <typename Op, typename OpFoldAdaptor>
2036 auto input1ConstShape =
2037 dyn_cast<tosa::ConstShapeOp>(op->getInput1().getDefiningOp());
2038 auto input2ConstShape =
2039 dyn_cast<tosa::ConstShapeOp>(op->getInput2().getDefiningOp());
2040 if (!input1ConstShape || !input2ConstShape)
2041 return {};
2042
2043 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2044 const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
2045
2046 return binaryFolder<OpFoldAdaptor>(input1Attr, input2Attr,
2047 input1Attr.getType(),
2048 /*foldDenseValues=*/true);
2049}
2050
2051OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
2052 const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().getType());
2053 if (!inputTy || !inputTy.hasRank())
2054 return {};
2055 const int32_t axis = getAxis();
2056 const int64_t dimSize = inputTy.getDimSize(axis);
2057 if (ShapedType::isDynamic(dimSize))
2058 return {};
2059
2060 OpBuilder builder(getContext());
2061 const auto resultAttrTy =
2062 RankedTensorType::get(/*rank=*/1, builder.getIndexType());
2063 return DenseElementsAttr::get(resultAttrTy, dimSize);
2064}
2065
2066OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
2068}
2069
2070OpFoldResult tosa::SubShapeOp::fold(FoldAdaptor adaptor) {
2072}
2073
2074OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
2076}
2077
2078OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
2079 return binaryFold<DivCeilShapeOp, DivFoldAdaptor</*Ceil*/ true>>(this);
2080}
2081
2082OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
2083 return binaryFold<DivFloorShapeOp, DivFoldAdaptor</*Ceil*/ false>>(this);
2084}
2085
2086OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
2088}
2089
2090OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
2092}
2093
2094OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
2096}
2097
2098OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
2100}
2101
2102OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
2104}
2105
2106OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
2108}
return success()
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
#define REDUCE_FOLDER(OP)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, bool foldDenseValues=false)
static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy, bool foldDenseValues=false)
OpFoldResult binaryFold(Op *op)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
OpFoldResult unaryShapeFold(Op *op)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool signsDiffer(const APInt &a, const APInt &b)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
MLIRContext * getContext() const
Definition Builders.h:56
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition Traits.cpp:24
DynamicAPInt round(const Fraction &f)
Definition Fraction.h:136
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
Definition TosaOps.h:153
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
LogicalResult matchAndRewrite(tosa::CastOp castOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...