MLIR 23.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// TOSA canonicalization patterns and folders.
11//
12//===----------------------------------------------------------------------===//
13
21#include "mlir/IR/Matchers.h"
25#include "llvm/ADT/APFloat.h"
26#include "llvm/ADT/APInt.h"
27
28#include <functional>
29
30using namespace mlir;
31using namespace mlir::tosa;
32
33//===----------------------------------------------------------------------===//
34// Operator Canonicalizers.
35//===----------------------------------------------------------------------===//
36
37//===----------------------------------------------------------------------===//
38// Tensor Data Engine Operators.
39//===----------------------------------------------------------------------===//
40
41// Check that the zero point of the tensor and padding operations are aligned.
42static bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
43 // Check that padConst is a constant value and a scalar tensor
44 DenseElementsAttr padConstAttr;
45 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
46 (padConstAttr.size() != 1)) {
47 return false;
48 }
49
50 // Check that floating point pad is zero
51 if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
52 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
53 return padConstVal == 0.0f;
54 }
55
56 // Check that the zp and padConst align for the integer (quantized) case
57 if (auto padConstIntAttr =
58 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
60 // Check that zp is a constant value and a scalar tensor
61 if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
62 return false;
63 }
64
65 // Check equality
66 int64_t zpVal = (*zpAttr.begin()).getSExtValue();
67 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
68 return zpVal == padConstVal;
69 }
70
71 // Bail-out on unsupported type
72 return false;
73}
74
75namespace {
76template <typename OpTy>
77struct PoolPadFoldAdaptor;
78
79template <>
80struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
81 using OpTy = tosa::MaxPool2dOp;
82 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
83 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
84 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
85 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
86 return false;
87 return true;
88 }
89 static bool checkPadConstCompliance(OpTy, Value padConst) {
90 // Check that padConst is a constant value and a scalar tensor
91 DenseElementsAttr padConstAttr;
92 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
93 padConstAttr.size() != 1) {
94 return false;
95 }
96
97 // Pad needs to be in the minimum value to be able to merge
98 if (auto padConstFpAttr =
99 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
100 const APFloat padConstVal = *padConstFpAttr.begin();
101 const APFloat lowestVal =
102 APFloat::getLargest(padConstVal.getSemantics(), true);
103 return padConstVal == lowestVal;
104 }
105 if (auto padConstIntAttr =
106 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
107 const APInt padConstVal = *padConstIntAttr.begin();
108 const unsigned int bitWidth = padConstVal.getBitWidth();
109 const APInt lowestVal =
110 padConstIntAttr.getElementType().isUnsignedInteger()
111 ? APInt::getZero(bitWidth)
112 : APInt::getSignedMinValue(bitWidth);
113 return padConstVal == lowestVal;
114 }
115
116 // Bail-out on unsupported type
117 return false;
118 }
119 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
120 Value padInput, ArrayRef<int64_t> newPad) {
121 rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
122 op, op.getType(), padInput, op.getKernel(), op.getStride(),
123 rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
124 }
125};
126
127template <typename OpTy>
128struct ConvPadFoldAdaptor {
129 static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
130 return true;
131 }
132 static bool checkPadConstCompliance(OpTy op, Value padConst) {
133 return checkMatchingPadConstAndZp(padConst, op.getInputZp());
134 }
135 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
136 Value padInput, ArrayRef<int64_t> newPad) {
137 rewriter.replaceOpWithNewOp<OpTy>(
138 op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
139 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
140 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
141 }
142};
143
144// Pattern attempts to fold a `tosa.pad` operator to a following tensor
145// operation like `tosa.conv2d` by merging the padding associated with the
146// pad operator directly to the implicit padding of the tensor operation.
147// This helps eliminate the explicit padding operator if unused.
148template <typename OpTy, typename AdaptorTy>
149struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
150 using OpRewritePattern<OpTy>::OpRewritePattern;
151
152 LogicalResult matchAndRewrite(OpTy tensorOp,
153 PatternRewriter &rewriter) const override {
154 // Check producer is a tosa::PadOp
155 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
156 if (!padOp)
157 return rewriter.notifyMatchFailure(tensorOp,
158 "Producer must be a tosa::PadOp.");
159
160 // Validate that tensor operation has sane padding
161 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
162 if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
163 return rewriter.notifyMatchFailure(
164 tensorOp, "Tensor operation padding shall have 4 elements.");
165
166 // Validate tosa::PadOp padding
167 DenseIntElementsAttr padOpPadding;
168 if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
169 return rewriter.notifyMatchFailure(
170 tensorOp,
171 "The `padding` input specified on the tosa::PadOp must be constant.");
172 }
173 // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
174 // C_after
175 if (padOpPadding.size() != 8)
176 return rewriter.notifyMatchFailure(tensorOp,
177 "Pad padding should have 8 elements.");
178 int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
179 int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
180 int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
181 int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
182 int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
183 int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
184 int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
185 int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
186
187 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
188 return rewriter.notifyMatchFailure(
189 tensorOp, "Folding padding in N or C dimensions is not supported.");
190
191 // Fold padding from Pad into the tensor operation
192 // 4 elements - pad_top, pad_bottom, pad_left, pad_right
193 SmallVector<int64_t> foldedPad(tensorOpPad.size());
194 foldedPad[0] = padHBefore + tensorOpPad[0];
195 foldedPad[1] = padHAfter + tensorOpPad[1];
196 foldedPad[2] = padWBefore + tensorOpPad[2];
197 foldedPad[3] = padWAfter + tensorOpPad[3];
198
199 // Check kernel related restrictions
200 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
201 return rewriter.notifyMatchFailure(
202 tensorOp, "Padding size not aligned with kernel restrictions.");
203 }
204
205 // Check padding constant restrictions
206 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
207 return rewriter.notifyMatchFailure(
208 tensorOp,
209 "Padding constant is not aligned with operator zero-point.");
210 }
211
212 // Check that padding doesn't grow more than 8K level (8192) for now
213 if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
214 return rewriter.notifyMatchFailure(
215 tensorOp, "Padding size more than the 8K level limit.");
216 }
217
218 // Create operator
219 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
220 foldedPad);
221
222 return success();
223 }
224};
225} // namespace
226
227void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
228 MLIRContext *context) {
229 results.add<
230 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
231 context);
232}
233
234void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
235 MLIRContext *context) {
236 results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
237 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
238 context);
239}
240
241struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
243
244 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
245 PatternRewriter &rewriter) const override {
246 Value input = op.getInput();
247 Value output = op.getOutput();
248 ShapedType inputType = llvm::cast<ShapedType>(input.getType());
249 ShapedType outputType = llvm::cast<ShapedType>(output.getType());
250
251 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
252 return failure();
253 }
254
255 // If the output and input shapes are 1x1, then this is a no op.
256 ArrayRef<int64_t> outputShape = outputType.getShape();
257 if (outputShape[1] != 1 || outputShape[2] != 1) {
258 return failure();
259 }
260
261 ArrayRef<int64_t> inputShape = inputType.getShape();
262 if (inputShape[1] != 1 || inputShape[2] != 1) {
263 return failure();
264 }
265
266 rewriter.replaceOp(op, input);
267 return success();
268 }
269};
270
271void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
272 MLIRContext *context) {
273 results.add<MaxPool2dIsNoOp,
274 FoldPadToTensorOp<tosa::MaxPool2dOp,
275 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
276 context);
277}
278
279//===----------------------------------------------------------------------===//
280// Data Layout / Memory Reinterpretation.
281//===----------------------------------------------------------------------===//
282
283struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
284 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
285
286 LogicalResult matchAndRewrite(tosa::ConcatOp op,
287 PatternRewriter &rewriter) const override {
288 if (op.getInput1().size() != 1)
289 return failure();
290 if (op.getInput1().front().getType() != op.getType()) {
291 rewriter
292 .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
293 op.getInput1().front())
294 .getResult();
295 return success();
296 }
297
298 rewriter.replaceOp(op, op.getInput1().front());
299 return success();
300 }
301};
302
303void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
304 MLIRContext *context) {
305 results.add<ConcatOptimization>(context);
306}
307
308LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
309 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
310 if (!notOp)
311 return failure();
312 rewriter.modifyOpInPlace(op, [&]() {
313 op.getOperation()->setOperands(
314 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
315 });
316 return success();
317}
318
320 : public OpRewritePattern<tosa::TransposeOp> {
322
323 LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
324 PatternRewriter &rewriter) const override {
325 // Input is also TransposeOp - transpose(transpose(A)).
326 auto innerTranspose =
327 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
328 if (!innerTranspose)
329 return rewriter.notifyMatchFailure(transposeOp,
330 "input must be transpose operation");
331
332 const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
333 const llvm::ArrayRef<int32_t> innerTransposePerms =
334 innerTranspose.getPerms();
335
336 if (transposePerms.size() != innerTransposePerms.size())
337 return rewriter.notifyMatchFailure(
338 transposeOp,
339 "transpose and inner transpose perms sizes must be equal");
340 if (transposePerms.empty())
341 return rewriter.notifyMatchFailure(
342 transposeOp, "transpose perms sizes must be positive");
343
344 // Consolidate transposes into one transpose.
345 SmallVector<int32_t> perms(transposePerms.size());
346 for (int i = 0, s = transposePerms.size(); i < s; ++i)
347 perms[i] = innerTransposePerms[transposePerms[i]];
348
349 rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
350 transposeOp, transposeOp.getResult().getType(),
351 innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
352
353 return success();
354 }
355};
356
357// Determines the case when tosa.transpose is a tosa.reshape operation.
358struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
360
361 LogicalResult matchAndRewrite(tosa::TransposeOp op,
362 PatternRewriter &rewriter) const override {
363 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
364 return rewriter.notifyMatchFailure(
365 op, "Src is from transpose, can compose transposes");
366
367 Value result = op.getResult();
368 for (Operation *subop : result.getUsers()) {
369 if (isa_and_nonnull<tosa::TransposeOp>(subop))
370 return rewriter.notifyMatchFailure(
371 op, "Dest is used by transpose, can compose transposes");
372 }
373
374 auto input = op.getInput1();
375 auto inputTy = llvm::cast<ShapedType>(input.getType());
376 if (!inputTy.hasRank())
377 return rewriter.notifyMatchFailure(op, "Unranked input.");
378
379 int64_t numDynDims = 0;
380 for (int i = 0; i < inputTy.getRank(); ++i)
381 if (inputTy.isDynamicDim(i))
382 numDynDims++;
383
384 if (numDynDims > 1)
385 return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
386
387 const llvm::ArrayRef<int32_t> permValues = op.getPerms();
388
389 SmallVector<int64_t> nonZeroPerms;
390 nonZeroPerms.reserve(permValues.size());
391 for (auto idx : permValues) {
392 auto sz = inputTy.getDimSize(idx);
393 if (sz != 1)
394 nonZeroPerms.push_back(idx);
395 }
396
397 for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
398 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
399 return rewriter.notifyMatchFailure(op,
400 "Transpose changes memory layout.");
401
402 SmallVector<int64_t> newShape;
403 newShape.reserve(inputTy.getRank());
404 for (int i = 0, s = inputTy.getRank(); i < s; ++i)
405 newShape.push_back(inputTy.getDimSize(permValues[i]));
406
407 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
408 op, op.getType(), op.getInput1(),
409 getTosaConstShape(rewriter, op.getLoc(), newShape));
410 return success();
411 }
412};
413
414void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
415 MLIRContext *context) {
416 results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
417}
418
419struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
421
422 LogicalResult matchAndRewrite(tosa::ClampOp op,
423 PatternRewriter &rewriter) const override {
424 Value input = op.getInput();
425 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
426 auto inputElementType = inputType.getElementType();
427
428 if (isa<FloatType>(inputElementType)) {
429 // Unlike integer types, floating point types can represent infinity.
430 const auto minClamp =
431 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
432 const auto maxClamp =
433 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
434 const bool isMin = minClamp.isNegInfinity();
435 const bool isMax = maxClamp.isInfinity();
436
437 if (isMin && isMax) {
438 rewriter.replaceOp(op, input);
439 return success();
440 }
441 return failure();
442 }
443
444 // i1 types are boolean in TOSA
445 const bool isBoolean = inputElementType.isInteger(1);
446 if (inputElementType.isUnsignedInteger() || isBoolean) {
447 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
448 .getValue()
449 .getZExtValue();
450 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
451 .getValue()
452 .getZExtValue();
453
454 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
455 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
456 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
457
458 if (minClamp <= intMin && maxClamp >= intMax) {
459 rewriter.replaceOp(op, input);
460 return success();
461 }
462 return failure();
463 }
464
465 if (llvm::isa<IntegerType>(inputElementType)) {
466 const int64_t minClamp =
467 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
468 const int64_t maxClamp =
469 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
470
471 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
472 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
473 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
474
475 if (minClamp <= intMin && maxClamp >= intMax) {
476 rewriter.replaceOp(op, input);
477 return success();
478 }
479 return failure();
480 }
481
482 return failure();
483 }
484};
485
486// Attempts the following transformation:
487//
488// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
489// tensor X the following identity holds:
490//
491// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
492//
493// subject to the following valid NaN propagation semantics:
494// --------------------------------------------
495// | OUTER CLAMP | INNER CLAMP | RESULT MODE |
496// |-------------|--------------|-------------|
497// | PROPAGATE | PROPAGATE | PROPAGATE |
498// | PROPAGATE | IGNORE | IGNORE |
499// | IGNORE | PROPAGATE | INVALID |
500// | IGNORE | IGNORE | IGNORE |
501// |------------------------------------------|
502
503struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
504 using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
505
506 // Helper structure to describe the range of a clamp operation.
507 template <typename T>
508 struct ClampRange {
509 ClampRange(const T &start, const T &end) : start(start), end(end) {}
512
513 // Helper function to determine if two Clamp ranges intersect.
514 bool intersects(const ClampRange<T> &otherRange) {
515 return start < otherRange.end && otherRange.start < end;
516 }
517 };
518
519 LogicalResult matchAndRewrite(tosa::ClampOp op,
520 PatternRewriter &rewriter) const override {
521 Value input = op.getInput();
522
523 // Check the input to the CLAMP op is itself a CLAMP.
524 auto clampOp = input.getDefiningOp<tosa::ClampOp>();
525 if (!clampOp)
526 return failure();
527
528 // Check we have a valid NaN propagation combination.
529 const auto opNanMode = op.getNanMode();
530 const auto clampNanMode = clampOp.getNanMode();
531 if (opNanMode == NanPropagationMode::IGNORE &&
532 clampNanMode == NanPropagationMode::PROPAGATE)
533 return failure();
534
535 auto maxValAttr = op.getMaxValAttr();
536 auto minValAttr = op.getMinValAttr();
537 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
538 auto clampOpMinValAttr = clampOp.getMinValAttr();
539
540 auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
541 if (auto quantType =
542 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
543 inputEType = getStorageElementTypeFromQuantized(quantType);
544 }
545
546 Attribute newMinValAttr, newMaxValAttr;
547 if (mlir::isa<FloatType>(inputEType)) {
548 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
549 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
550 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
551 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
552
553 // Check we have intersecting ranges.
554 const auto opMinFloat = floatMinValAttr.getValue();
555 const auto opMaxFloat = floatMaxValAttr.getValue();
556 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
557 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
558 ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
559 ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
560 clampOpMaxFloat);
561 if (!opRangeFloatRange.intersects(clampRangeFloatRange))
562 return failure();
563
564 // Run the transformation.
565 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
566 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
567 newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
568 newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
569 } else {
570 assert(mlir::isa<IntegerType>(inputEType));
571 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
572 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
573 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
574 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
575
576 if (inputEType.isUnsignedInteger()) {
577 // Check we have intersecting ranges.
578 const auto opMinInt = intMinValAttr.getUInt();
579 const auto opMaxInt = intMaxValAttr.getUInt();
580 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
581 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
582 ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
583 ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
584 clampOpMaxInt);
585 if (!opRangeIntRange.intersects(clampRangeIntRange))
586 return failure();
587
588 // Run the transformation.
589 auto newMinVal = std::max(opMinInt, clampOpMinInt);
590 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
591 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
592 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
593 } else {
594 // Check we have intersecting ranges.
595 const auto opMinInt = intMinValAttr.getInt();
596 const auto opMaxInt = intMaxValAttr.getInt();
597 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
598 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
599 ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
600 ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
601 clampOpMaxInt);
602 if (!opRangeIntRange.intersects(clampRangeIntRange))
603 return failure();
604
605 // Run the transformation.
606 auto newMinVal = std::max(opMinInt, clampOpMinInt);
607 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
608 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
609 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
610 }
611 }
612
613 auto newMode = (opNanMode != clampNanMode)
614 ? tosa::NanPropagationMode::IGNORE
615 : opNanMode;
616
617 auto newModeAttr =
618 NanPropagationModeAttr::get(rewriter.getContext(), newMode);
619
620 rewriter.replaceOpWithNewOp<tosa::ClampOp>(
621 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
622 newModeAttr);
623 return success();
624 }
625};
626
627void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
628 MLIRContext *context) {
629 results.add<ClampIsNoOp>(context);
630 results.add<ClampClampOptimization>(context);
631}
632
633struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
634 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
635
636 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
637 PatternRewriter &rewriter) const override {
638 Value sliceInput = sliceOp.getInput1();
639 auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
640 if (!concatOp)
641 return rewriter.notifyMatchFailure(
642 sliceOp, "slice input must be concat operation");
643
644 OperandRange inputs = concatOp.getInput1();
645 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
646 if (!concatType || !concatType.hasStaticShape())
647 return rewriter.notifyMatchFailure(
648 sliceOp, "slice input must be a static ranked tensor");
649 int32_t axis = concatOp.getAxis();
650
651 DenseElementsAttr startElems;
652 DenseElementsAttr sizeElems;
653
654 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
655 return rewriter.notifyMatchFailure(
656 sliceOp, "start of slice must be a static ranked shape");
657
658 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
659 return rewriter.notifyMatchFailure(
660 sliceOp, "size of slice must be a static ranked shape");
661
662 llvm::SmallVector<int64_t> sliceStarts =
663 llvm::to_vector(startElems.getValues<int64_t>());
664 llvm::SmallVector<int64_t> sliceSizes =
665 llvm::to_vector(sizeElems.getValues<int64_t>());
666
667 // Validate slice on the concatenated axis. Slicing along this
668 // axis should span only one of the inputs to the concatenate
669 // operation.
670 std::optional<Value> replaceWithSlice;
671 for (auto input : inputs) {
672 auto inputType = dyn_cast<RankedTensorType>(input.getType());
673 if (!inputType || !inputType.hasStaticShape())
674 return rewriter.notifyMatchFailure(
675 sliceOp, "concat input must be a static ranked tensor");
676
677 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
678 inputType.getDimSize(axis)) {
679 auto start_op =
680 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
681 auto size_op =
682 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
683 replaceWithSlice =
684 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
685 input, start_op, size_op)
686 .getResult();
687 break;
688 }
689 sliceStarts[axis] -= inputType.getDimSize(axis);
690 }
691
692 if (!replaceWithSlice)
693 return rewriter.notifyMatchFailure(
694 sliceOp, "corresponding concat input not found for slice");
695
696 rewriter.replaceOp(sliceOp, replaceWithSlice.value());
697 return success();
698 }
699};
700
701struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
702 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
703
704 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
705 PatternRewriter &rewriter) const override {
706 Value sliceInput = sliceOp.getInput1();
707
708 // Check if producer is a PadOp
709 auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
710 if (!padOp)
711 return rewriter.notifyMatchFailure(sliceOp,
712 "slice input must be a pad operation");
713
714 // Check PadOp has a single consumer
715 if (!padOp->hasOneUse())
716 return rewriter.notifyMatchFailure(sliceOp,
717 "pad shall have a single consumer");
718
719 // Check input is statically ranked
720 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
721 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
722 if (!inputTy || !padTy || !inputTy.hasRank())
723 return rewriter.notifyMatchFailure(sliceOp,
724 "slice input must be a ranked tensor");
725
726 // Validate and extract tosa::PadOp padding
727 DenseIntElementsAttr paddingElems;
728 if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
729 return rewriter.notifyMatchFailure(
730 sliceOp,
731 "`padding` input specified on the tosa::PadOp must be constant.");
732 }
733 llvm::SmallVector<int64_t> padPaddings =
734 llvm::to_vector(paddingElems.getValues<int64_t>());
735
736 // Extract slice parameters
737 DenseElementsAttr startElems;
738 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
739 return rewriter.notifyMatchFailure(
740 sliceOp, "start of slice must be a static ranked shape");
741 llvm::SmallVector<int64_t> sliceStarts =
742 llvm::to_vector(startElems.getValues<int64_t>());
743
744 DenseElementsAttr sizeElems;
745 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
746 return rewriter.notifyMatchFailure(
747 sliceOp, "size of slice must be a static ranked shape");
748 llvm::SmallVector<int64_t> sliceSizes =
749 llvm::to_vector(sizeElems.getValues<int64_t>());
750
751 // Check if dynamic dimensions are sliced
752 const int64_t rank = inputTy.getRank();
753 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
754 const bool isDimDynamic = inputTy.isDynamicDim(i);
755 const bool isDimSliced =
756 (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
757
758 return isDimDynamic && isDimSliced;
759 })) {
760 return rewriter.notifyMatchFailure(
761 sliceOp, "axis that are sliced shall be statically known.");
762 }
763
764 // Update the parameters
765 llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
766 llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
767 llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
768 bool updated = false;
769
770 for (int64_t i = 0; i < rank; ++i) {
771 const int64_t padLo = padPaddings[i * 2];
772 const int64_t padHi = padPaddings[i * 2 + 1];
773 const int64_t sliceStart = sliceStarts[i];
774 const int64_t sliceSize = sliceSizes[i];
775 const int64_t sliceEnd = sliceStart + sliceSize;
776
777 // If dimension is dynamic pass-through
778 if (inputTy.isDynamicDim(i)) {
779 newPadPaddings[i * 2] = padLo;
780 newPadPaddings[i * 2 + 1] = padHi;
781 newSliceStarts[i] = sliceStart;
782 continue;
783 }
784
785 // Handle static dimensions
786 const int64_t dimSize = inputTy.getShape()[i];
787 const int64_t dimTotal = padLo + dimSize + padHi;
788
789 // Check slice within bounds
790 if (sliceStart < 0 || sliceEnd > dimTotal)
791 return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
792
793 // Compute updated slice start parameter
794 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
795 newSliceStarts[i] = newSliceStart;
796 updated |= newSliceStart != sliceStart;
797
798 // Compute updated pad parameters
799 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
800 const int64_t newPadHi =
801 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
802 newPadPaddings[i * 2] = newPadLo;
803 newPadPaddings[i * 2 + 1] = newPadHi;
804 updated |= (newPadLo != padLo) || (newPadHi != padHi);
805
806 // Calculate new pad output shape
807 newPadShape[i] =
808 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
809 }
810
811 // Check that we actually need to proceed with the rewrite
812 if (!updated)
813 return rewriter.notifyMatchFailure(
814 sliceOp, "terminate condition; nothing to rewrite");
815
816 // Create a PadOp with updated padding
817 auto newPaddingsOp =
818 getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
819 auto newPadTy =
820 RankedTensorType::get(newPadShape, inputTy.getElementType());
821 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
822 padOp.getInput1(), newPaddingsOp,
823 padOp.getPadConst());
824
825 // Update SliceOp and point to new PadOp
826 auto newStartOp =
827 getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
828 rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
829 newPadOp.getResult(), newStartOp,
830 sliceOp.getSize());
831
832 return success();
833 }
834};
835
836// Update size operand of tosa.slice if size has dynamic dims but corresponding
837// output dim is static
839 : public OpRewritePattern<tosa::SliceOp> {
840 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
841
842 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
843 PatternRewriter &rewriter) const override {
844 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
845
846 ElementsAttr sizeElems;
847 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
848 return rewriter.notifyMatchFailure(
849 sliceOp, "size of slice must be a static ranked shape");
850 }
851
852 llvm::SmallVector<int64_t> sliceSizes =
853 llvm::to_vector(sizeElems.getValues<int64_t>());
854
855 bool replaceSliceSize{false};
856 // if size op has -1 indicating dynamic shape but corresponding dim on the
857 // output is statically known, update size to match with known output dim
858 // shape
859 for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
860 if (size == -1 && !resultType.isDynamicDim(index)) {
861 sliceSizes[index] = resultType.getDimSize(index);
862 replaceSliceSize = true;
863 }
864 }
865
866 if (!replaceSliceSize) {
867 return rewriter.notifyMatchFailure(
868 sliceOp, "no dimension of size of slice is dynamic that resolves "
869 "to static output shape");
870 }
871
872 auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
873 auto newSliceOp =
874 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
875 sliceOp.getInput1(), sliceOp.getStart(), size_op);
876
877 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
878 return success();
879 }
880};
881
882void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
883 MLIRContext *context) {
884 results.add<ConcatSliceOptimization, PadSliceOptimization,
885 SliceDynamicSizeCanonicalization>(context);
886}
887
888struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
889 using OpRewritePattern<tosa::CastOp>::OpRewritePattern;
890
891 LogicalResult matchAndRewrite(tosa::CastOp castOp,
892 PatternRewriter &rewriter) const override {
893 const Value castInput = castOp.getInput();
894 auto innerCastOp = castInput.getDefiningOp<tosa::CastOp>();
895 if (!innerCastOp)
896 return rewriter.notifyMatchFailure(castOp,
897 "input must be cast operation");
898
899 const Value innerCastInput = innerCastOp.getInput();
900
901 const auto innerInputType =
902 llvm::cast<ShapedType>(innerCastInput.getType());
903 const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
904 const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
905
906 const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType,
907 outerOutputType};
908 if (llvm::any_of(types, [](const ShapedType type) {
909 return !type.getElementType().isInteger();
910 }))
911 return rewriter.notifyMatchFailure(castOp,
912 "only integer types are supported");
913
914 // Check inner cast is non-narrowing
915 const unsigned innerInputBitWidth = innerInputType.getElementTypeBitWidth();
916 if (innerInputBitWidth > innerOutputType.getElementTypeBitWidth())
917 return rewriter.notifyMatchFailure(castOp,
918 "inner cast operation is narrowing");
919
920 // Check outer cast is non-narrowing from the inner cast input
921 if (innerInputBitWidth > outerOutputType.getElementTypeBitWidth())
922 return rewriter.notifyMatchFailure(castOp,
923 "outer cast operation is narrowing");
924
925 rewriter.replaceOpWithNewOp<tosa::CastOp>(castOp, outerOutputType,
926 innerCastInput);
927
928 return success();
929 }
930};
931
932void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
933 MLIRContext *context) {
934 results.add<NonNarrowingCastsOptimization>(context);
935}
936
937//===----------------------------------------------------------------------===//
938// Operator Folders.
939//===----------------------------------------------------------------------===//
940
941template <typename Folder>
942static DenseElementsAttr
944 bool foldDenseValues = false) {
945 if (!lhs || !rhs)
946 return {};
947
948 const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
949 const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
950 if (lETy != rETy)
951 return {};
952
953 if (lhs.isSplat() && rhs.isSplat()) {
954 if (isa<FloatType>(lETy)) {
955 const APFloat l = lhs.getSplatValue<APFloat>();
956 const APFloat r = rhs.getSplatValue<APFloat>();
957 const auto maybeResult = Folder::fold(l, r);
958 if (failed(maybeResult))
959 return {};
960 return DenseElementsAttr::get(returnTy, maybeResult.value());
961 }
962
963 if (const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
964 const APInt l = lhs.getSplatValue<APInt>();
965 const APInt r = rhs.getSplatValue<APInt>();
966 const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
967 if (failed(maybeResult))
968 return {};
969 return DenseElementsAttr::get(returnTy, maybeResult.value());
970 }
971 }
972
973 if (foldDenseValues) {
974 assert(lETy.isIntOrIndex() &&
975 "Only integer types are currently supported.");
976 SmallVector<APInt> resultValues;
977 for (auto [l, r] :
978 llvm::zip(lhs.getValues<APInt>(), rhs.getValues<APInt>())) {
979 const auto maybeResult = Folder::fold(l, r, false);
980 if (failed(maybeResult))
981 return {};
982 resultValues.push_back(maybeResult.value());
983 }
984 return DenseElementsAttr::get(returnTy, resultValues);
985 }
986
987 return {};
988}
990 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
991 const bool isUnsigned) {
992 bool overflow;
993 const APInt result =
994 isUnsigned ? lhs.uadd_ov(rhs, overflow) : lhs.sadd_ov(rhs, overflow);
995 if (overflow)
996 return failure();
997 return result;
998 }
999
1000 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1001 return lhs + rhs;
1002 }
1003};
1004
1006 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1007 const bool isUnsigned) {
1008 bool overflow;
1009 const APInt result =
1010 isUnsigned ? lhs.usub_ov(rhs, overflow) : lhs.ssub_ov(rhs, overflow);
1011 if (overflow)
1012 return failure();
1013 return result;
1014 }
1015
1016 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1017 return lhs - rhs;
1018 }
1019};
1020
1022 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1023 const bool isUnsigned) {
1024 return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs));
1025 }
1026
1027 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1028 return APInt(1, lhs > rhs);
1029 }
1030};
1031
1033 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1034 const bool isUnsigned) {
1035 return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs));
1036 }
1037
1038 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1039 return APInt(1, lhs >= rhs);
1040 }
1041};
1042
1044 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1045 const bool isUnsigned) {
1046 return APInt(1, lhs == rhs);
1047 }
1048
1049 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1050 return APInt(1, lhs == rhs);
1051 }
1052};
1053
1054static bool isSplatZero(Type elemType, DenseElementsAttr val) {
1055 if (llvm::isa<FloatType>(elemType))
1056 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
1057 if (llvm::isa<IntegerType>(elemType))
1058 return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
1059 return false;
1060}
1061
1062static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
1063 if (llvm::isa<FloatType>(elemType))
1064 return val && val.isSplat() &&
1065 val.getSplatValue<APFloat>().isExactlyValue(1.0);
1066 if (llvm::isa<IntegerType>(elemType)) {
1067 const int64_t shifted = 1LL << shift;
1068 return val && val.isSplat() &&
1069 val.getSplatValue<APInt>().getSExtValue() == shifted;
1070 }
1071 return false;
1072}
1073
1074OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1075 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1076 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1077 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1078 if (!lhsTy || !rhsTy || !resultTy)
1079 return {};
1080
1081 // Cannot create an ElementsAttr from non-int/float/index types
1082 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1083 !rhsTy.getElementType().isIntOrIndexOrFloat())
1084 return {};
1085
1086 auto resultETy = resultTy.getElementType();
1087 auto lhsAttr =
1088 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1089 auto rhsAttr =
1090 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1091
1092 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1093 return getInput1();
1094 if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
1095 return getInput2();
1096
1097 if (!lhsAttr || !rhsAttr)
1098 return {};
1099
1100 return binaryFolder<FoldAddAdaptor>(lhsAttr, rhsAttr, resultTy);
1101}
1102
1103OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1104 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
1105 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1106 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1107 !outputTy.hasStaticShape())
1108 return {};
1109
1110 const Type outputElementTy = getElementTypeOrSelf(outputTy);
1111 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
1112 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1113 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1114 return DenseElementsAttr::get(outputTy, zero);
1115 }
1116
1117 return {};
1118}
1119
1120OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1121 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1122 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1123 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1124 if (!lhsTy || !rhsTy || !resultTy)
1125 return {};
1126 if (lhsTy != rhsTy)
1127 return {};
1128
1129 // IntDivOp inputs must be integer type, no need to check for quantized type
1130 auto resultETy = resultTy.getElementType();
1131 auto lhsAttr =
1132 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1133 auto rhsAttr =
1134 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1135 if (lhsAttr && lhsAttr.isSplat()) {
1136 if (llvm::isa<IntegerType>(resultETy) &&
1137 lhsAttr.getSplatValue<APInt>().isZero())
1138 return lhsAttr;
1139 }
1140
1141 if (rhsAttr && rhsAttr.isSplat()) {
1142 if (llvm::isa<IntegerType>(resultETy) &&
1143 rhsAttr.getSplatValue<APInt>().isOne())
1144 return getInput1();
1145 }
1146
1147 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1148 llvm::isa<IntegerType>(resultETy)) {
1149 APInt l = lhsAttr.getSplatValue<APInt>();
1150 APInt r = rhsAttr.getSplatValue<APInt>();
1151 if (!r.isZero()) {
1152 APInt result = l.sdiv(r);
1153 return DenseElementsAttr::get(resultTy, result);
1154 }
1155 }
1156
1157 return {};
1158}
1159
1160namespace {
1161// calculate lhs * rhs >> shift according to TOSA Spec
1162// return nullopt if result is not in range of int32_t when shift > 0
1163std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
1164 unsigned bitwidth) {
1165 APInt result = lhs.sext(64) * rhs.sext(64);
1166
1167 if (shift > 0) {
1168 auto round = APInt(64, 1) << (shift - 1);
1169 result += round;
1170 result.ashrInPlace(shift);
1171 // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
1172 if (!(result.getSExtValue() >= INT32_MIN &&
1173 result.getSExtValue() <= INT32_MAX)) {
1174 // REQUIRE failed
1175 return std::nullopt;
1176 }
1177 }
1178
1179 return result.trunc(bitwidth);
1180}
1181
1182DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
1183 RankedTensorType ty, int32_t shift) {
1184 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
1185 if (llvm::isa<IntegerType>(ty.getElementType())) {
1186 APInt l = lhs.getSplatValue<APInt>();
1187 APInt r = rhs.getSplatValue<APInt>();
1188
1189 if (shift == 0) {
1190 return DenseElementsAttr::get(ty, l * r);
1191 }
1192
1193 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1194 const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
1195 if (!result)
1196 return {};
1197 return DenseElementsAttr::get(ty, result.value());
1198 }
1199
1200 if (llvm::isa<FloatType>(ty.getElementType())) {
1201 APFloat l = lhs.getSplatValue<APFloat>();
1202 APFloat r = rhs.getSplatValue<APFloat>();
1203 APFloat result = l * r;
1204 return DenseElementsAttr::get(ty, result);
1205 }
1206 }
1207
1208 return {};
1209}
1210} // namespace
1211
1212OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1213 auto lhs = getInput1();
1214 auto rhs = getInput2();
1215 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
1216 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
1217 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1218 if (!lhsTy || !rhsTy || !resultTy)
1219 return {};
1220
1221 auto resultETy = resultTy.getElementType();
1222 auto lhsAttr =
1223 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1224 auto rhsAttr =
1225 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1226
1227 // Result right shift on i32_t data type only. For simplification, synthesize
1228 // a zero shift for other data type.
1229 int32_t shift = 0;
1230 if (resultETy.isInteger(32)) {
1231 ElementsAttr shift_elem;
1232 if (getShift().getImpl()) {
1233 if (!matchPattern(getShift(), m_Constant(&shift_elem)))
1234 // cannot be folded when the shift value is unknown.
1235 return {};
1236 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1237 }
1238 }
1239
1240 if (rhsTy == resultTy) {
1241 if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
1242 // constant values can only be resized if resulting type is static
1243 return lhsAttr.resizeSplat(resultTy);
1244 if (isSplatOne(resultETy, lhsAttr, shift))
1245 return rhs;
1246 }
1247 if (lhsTy == resultTy) {
1248 if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
1249 return rhsAttr.resizeSplat(resultTy);
1250 if (isSplatOne(resultETy, rhsAttr, shift))
1251 return lhs;
1252 }
1253
1254 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1255}
1256
1257OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1258 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1259 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1260 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1261 if (!lhsTy || !rhsTy || !resultTy)
1262 return {};
1263
1264 // Cannot create an ElementsAttr from non-int/float/index types
1265 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1266 !rhsTy.getElementType().isIntOrIndexOrFloat())
1267 return {};
1268
1269 auto resultETy = resultTy.getElementType();
1270 auto lhsAttr =
1271 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1272 auto rhsAttr =
1273 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1274
1275 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1276 return getInput1();
1277
1278 if (!lhsAttr || !rhsAttr)
1279 return {};
1280
1281 return binaryFolder<FoldSubAdaptor>(lhsAttr, rhsAttr, resultTy);
1282}
1283
1284OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1285 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1286 auto lhsAttr =
1287 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1288 auto rhsAttr =
1289 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1290
1291 if (!lhsAttr || !rhsAttr)
1292 return {};
1293
1294 return binaryFolder<FoldGreaterAdaptor>(lhsAttr, rhsAttr, resultTy);
1295}
1296
1297OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1298 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1299 auto lhsAttr =
1300 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1301 auto rhsAttr =
1302 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1303
1304 if (!lhsAttr || !rhsAttr)
1305 return {};
1306
1307 return binaryFolder<FoldGreaterEqualAdaptor>(lhsAttr, rhsAttr, resultTy);
1308}
1309
1310OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1311 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1312 auto lhsAttr =
1313 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1314 auto rhsAttr =
1315 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1316 Value lhs = getInput1();
1317 Value rhs = getInput2();
1318 auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
1319
1320 // If we are comparing an integer value to itself it is always true. We can
1321 // not do this with float due to float values.
1322 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1323 resultTy.hasStaticShape() && lhs == rhs) {
1324 return DenseElementsAttr::get(resultTy, true);
1325 }
1326
1327 if (!lhsAttr || !rhsAttr)
1328 return {};
1329
1330 return binaryFolder<FoldEqualAdaptor>(lhsAttr, rhsAttr, resultTy);
1331}
1332
1333OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1334 if (getInput().getType() == getType())
1335 return getInput();
1336
1337 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1338 if (!operand)
1339 return {};
1340
1341 auto inTy = llvm::cast<ShapedType>(getInput().getType());
1342 auto outTy = llvm::cast<ShapedType>(getType());
1343 auto inETy = inTy.getElementType();
1344 auto outETy = outTy.getElementType();
1345
1346 if (operand.isSplat()) {
1347 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1348 bool overflow;
1349 auto splatVal = operand.getSplatValue<APFloat>();
1350 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1351 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1352 &overflow);
1353 return SplatElementsAttr::get(outTy, splatVal);
1354 }
1355
1356 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1357 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1358 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1359 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1360 llvm::RoundingMode::NearestTiesToEven);
1361 return SplatElementsAttr::get(outTy, splatVal);
1362 }
1363
1364 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1365 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1366 auto intVal = APSInt(
1367 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1368 auto floatVal = operand.getSplatValue<APFloat>();
1369 bool exact;
1370 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1371 &exact);
1372 return SplatElementsAttr::get(outTy, intVal);
1373 }
1374
1375 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1376 const auto inIntType = llvm::cast<IntegerType>(inETy);
1377 auto unsignIn = inIntType.isUnsignedInteger();
1378 bool trunc =
1379 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1380 auto intVal = operand.getSplatValue<APInt>();
1381 auto bitwidth = outETy.getIntOrFloatBitWidth();
1382
1383 // i1 types are boolean in TOSA
1384 if (outETy.isInteger(1)) {
1385 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1386 } else if (trunc) {
1387 intVal = intVal.trunc(bitwidth);
1388 } else if (unsignIn || inIntType.isInteger(1)) {
1389 intVal = intVal.zext(bitwidth);
1390 } else {
1391 intVal = intVal.sext(bitwidth);
1392 }
1393
1394 return SplatElementsAttr::get(outTy, intVal);
1395 }
1396 }
1397
1398 return {};
1399}
1400
1401OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1402
1403OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1404
1405#define REDUCE_FOLDER(OP) \
1406 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1407 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1408 if (!inputTy.hasRank()) \
1409 return {}; \
1410 if (inputTy != getType()) \
1411 return {}; \
1412 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1413 return getInput(); \
1414 return {}; \
1415 }
1416
1417REDUCE_FOLDER(ReduceAllOp)
1418REDUCE_FOLDER(ReduceAnyOp)
1419REDUCE_FOLDER(ReduceMaxOp)
1420REDUCE_FOLDER(ReduceMinOp)
1421REDUCE_FOLDER(ReduceProductOp)
1422REDUCE_FOLDER(ReduceSumOp)
1423#undef REDUCE_FOLDER
1424
1425OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1426 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1427 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1428
1429 if (!inputTy || !outputTy)
1430 return {};
1431
1432 // Fold when the input and output types are the same. This is only safe when
1433 // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
1434 // there may still be a productive reshape.
1435 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1436 return getInput1();
1437
1438 // reshape(reshape(x)) -> reshape(x)
1439 if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1440 getInput1().getDefiningOp())) {
1441 getInput1Mutable().assign(reshapeOp.getInput1());
1442 return getResult();
1443 }
1444
1445 // Cannot create an ElementsAttr from non-int/float/index types
1446 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1447 return {};
1448
1449 // reshape(const(x)) -> const(reshape-attr(x))
1450 if (auto operand =
1451 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1452 // Constants must have static shape.
1453 if (!outputTy.hasStaticShape())
1454 return {};
1455
1456 // Okay to duplicate splat constants.
1457 if (operand.isSplat())
1458 return SplatElementsAttr::get(outputTy,
1459 operand.getSplatValue<Attribute>());
1460
1461 // Don't duplicate other constants.
1462 if (!getInput1().hasOneUse())
1463 return {};
1464
1466 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
1467 return {};
1468
1469 return operand.reshape(
1470 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1471 }
1472
1473 return {};
1474}
1475
1476OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1477 // If the pad is all zeros we can fold this operation away.
1478 if (adaptor.getPadding() && getInput1().getType() == getType()) {
1479 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1480 if (densePad && densePad.isSplat() &&
1481 densePad.getSplatValue<APInt>().isZero()) {
1482 return getInput1();
1483 }
1484 }
1485
1486 return {};
1487}
1488
1489// Fold away cases where a tosa.resize operation returns a copy
1490// of the input image.
1491OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1492 auto scaleAttr =
1493 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1494 auto offsetAttr =
1495 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1496 auto borderAttr =
1497 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1498 if (!scaleAttr || !offsetAttr || !borderAttr) {
1499 return {};
1500 }
1501
1502 auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
1503 auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
1504 auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
1505 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1506 return {};
1507 }
1508
1509 // Check unit scaling.
1510 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1511 return {};
1512 }
1513
1514 // There should be no offset.
1515 if (offset[0] != 0 || offset[1] != 0) {
1516 return {};
1517 }
1518
1519 // There should be no border.
1520 if (border[0] != 0 || border[1] != 0) {
1521 return {};
1522 }
1523
1524 auto input = getInput();
1525 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1526 auto resultTy = llvm::cast<RankedTensorType>(getType());
1527 if (inputTy != resultTy)
1528 return {};
1529
1530 return input;
1531}
1532
1533OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1534 auto operand = getInput1();
1535 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1536 auto axis = getAxis();
1537 auto operandAttr =
1538 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1539 if (operandAttr)
1540 return operandAttr;
1541
1542 // If the dim-length is 1, tosa.reverse is a no-op.
1543 if (operandTy.hasRank() &&
1544 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1545 return operand;
1546
1547 return {};
1548}
1549
1550OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1551 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1552 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1553
1554 if (!inputTy || !outputTy)
1555 return {};
1556
1557 if (inputTy == outputTy && inputTy.hasStaticShape())
1558 return getInput1();
1559
1560 if (!adaptor.getInput1())
1561 return {};
1562
1563 // Cannot create an ElementsAttr from non-int/float/index types
1564 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1565 !outputTy.getElementType().isIntOrIndexOrFloat())
1566 return {};
1567
1568 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1569 if (operand.isSplat() && outputTy.hasStaticShape()) {
1570 return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
1571 }
1572
1573 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1574 outputTy.getNumElements() == 1) {
1575 DenseElementsAttr startElems;
1576 if (!matchPattern(getStart(), m_Constant(&startElems)))
1577 return {};
1578
1579 llvm::SmallVector<uint64_t> indices =
1580 llvm::to_vector(startElems.getValues<uint64_t>());
1581 auto value = operand.getValues<Attribute>()[indices];
1582 return SplatElementsAttr::get(outputTy, value);
1583 }
1584
1585 return {};
1586}
1587
1588static bool
1590 const auto isDynamic = [](Type ty) {
1591 const auto shapedTy = llvm::dyn_cast<ShapedType>(ty);
1592 return !shapedTy || !shapedTy.hasStaticShape();
1593 };
1594
1595 return llvm::any_of(operandTypes, isDynamic) ||
1596 failed(verifyCompatibleShapes(operandTypes));
1597}
1598
1599OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1600 // Select allows operand shapes to be broadcast to the output shape. For
1601 // now, don't support folding when we cannot prove no broadcasting is
1602 // involved.
1603 if (mayRequireBroadcast(getOperandTypes()))
1604 return {};
1605
1606 if (getOnTrue() == getOnFalse())
1607 return getOnTrue();
1608
1609 auto predicate =
1610 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1611 if (!predicate)
1612 return {};
1613
1614 if (!predicate.isSplat())
1615 return {};
1616 return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1617 : getOnFalse();
1618}
1619
1620OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1621 if (getInput1().getType() == getType()) {
1622 if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1623 adaptor.getMultiples())) {
1624 if (multiples.isSplat() &&
1625 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1626 return getInput1();
1627 if (auto int_array_attr =
1628 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1629 if (llvm::all_of(int_array_attr.getValues<APInt>(),
1630 [](APInt v) { return v.getSExtValue() == 1; }))
1631 return getInput1();
1632 }
1633 }
1634 }
1635 return {};
1636}
1637
1638OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1639 auto resultTy = llvm::cast<ShapedType>(getType());
1640
1641 // Transposing splat values just means reshaping.
1642 if (auto input =
1643 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1644 if (input.isSplat() && resultTy.hasStaticShape() &&
1645 input.getType().getElementType() == resultTy.getElementType())
1646 return input.reshape(resultTy);
1647 }
1648
1649 // Transpose is not the identity transpose.
1650 const llvm::ArrayRef<int32_t> perms = getPerms();
1651
1652 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1653 return {};
1654
1655 return getInput1();
1656}
1657
1658OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1659 // Element-wise negate(negate(x)) = x
1660 // iff all zero points are constant 0
1661 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1662 if (!definingOp) {
1663 // defining op of input1 is not a negate, cannot fold
1664 return {};
1665 }
1666
1667 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1668 failed(maybeIZp) || *maybeIZp != 0) {
1669 // input1 zero point is not constant 0, cannot fold
1670 return {};
1671 }
1672 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1673 failed(maybeOZp) || *maybeOZp != 0) {
1674 // output zero point is not constant 0, cannot fold
1675 return {};
1676 }
1677 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1678 failed(maybeIZp) || *maybeIZp != 0) {
1679 // definingOp's input1 zero point is not constant 0, cannot fold
1680 return {};
1681 }
1682 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1683 failed(maybeOZp) || *maybeOZp != 0) {
1684 // definingOp's output zero point is not constant 0, cannot fold
1685 return {};
1686 }
1687
1688 return definingOp.getInput1();
1689}
1690
1691OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1692 auto input = getInput1();
1693 // Element-wise abs(abs(x)) = abs(x)
1694 if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
1695 return input;
1696 }
1697
1698 return {};
1699}
1700
1701OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1702 // Fold consecutive concats on the same axis into a single op.
1703 // Keep track of the operands so we are able to construct a new concat
1704 // later. Conservatively assume that we double the number of operands when
1705 // folding
1706 SmallVector<Value, 8> concatOperands;
1707 concatOperands.reserve(2 * getNumOperands());
1708
1709 // Find all operands that are foldable concats
1710 bool foundFoldableConcat = false;
1711 for (Value operand : getOperands()) {
1712 concatOperands.emplace_back(operand);
1713
1714 auto producer = operand.getDefiningOp<ConcatOp>();
1715 if (!producer)
1716 continue;
1717
1718 // Not foldable if axes are not the same
1719 if (getAxis() != producer.getAxis())
1720 continue;
1721
1722 // Replace the original operand with all incoming operands
1723 foundFoldableConcat = true;
1724 concatOperands.pop_back();
1725 llvm::append_range(concatOperands, producer->getOperands());
1726 }
1727
1728 if (!foundFoldableConcat)
1729 return {};
1730
1731 getOperation()->setOperands(concatOperands);
1732 return getResult();
1733}
1734
1735OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1736 auto input = adaptor.getInput1();
1737
1738 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1739 // Fold splat inputs only.
1740 if (!inputAttr || !inputAttr.isSplat())
1741 return {};
1742
1743 auto shapeType = llvm::cast<ShapedType>(getType());
1744 if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1745 auto floatVal = inputAttr.getSplatValue<APFloat>();
1746 return DenseElementsAttr::get(shapeType,
1747 ReciprocalOp::calcOneElement(floatVal));
1748 }
1749
1750 return {};
1751}
1752
1753OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
1754 auto input1ConstShape =
1755 dyn_cast<tosa::ConstShapeOp>(getInput1().getDefiningOp());
1756 auto input2ConstShape =
1757 dyn_cast<tosa::ConstShapeOp>(getInput2().getDefiningOp());
1758 if (!input1ConstShape || !input2ConstShape)
1759 return {};
1760
1761 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
1762 const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
1763
1765 input1Attr, input2Attr, input1Attr.getType(), /*foldDenseValues=*/true);
1766}
1767
1768OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
1769 const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().getType());
1770 if (!inputTy || !inputTy.hasRank())
1771 return {};
1772 const int32_t axis = getAxis();
1773 const int64_t dimSize = inputTy.getDimSize(axis);
1774 if (ShapedType::isDynamic(dimSize))
1775 return {};
1776
1777 OpBuilder builder(getContext());
1778 const auto resultAttrTy =
1779 RankedTensorType::get(/*rank=*/1, builder.getIndexType());
1780 return DenseElementsAttr::get(resultAttrTy, dimSize);
1781}
return success()
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
lhs
b getContext())
#define REDUCE_FOLDER(OP)
static bool mayRequireBroadcast(ValueTypeRange< mlir::OperandRange > operandTypes)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, bool foldDenseValues=false)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
MLIRContext * getContext() const
Definition Builders.h:56
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
This class implements iteration on the types of a given range of values.
Definition TypeRange.h:135
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
DynamicAPInt round(const Fraction &f)
Definition Fraction.h:136
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::CastOp castOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...