MLIR 22.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// TOSA canonicalization patterns and folders.
11//
12//===----------------------------------------------------------------------===//
13
21#include "mlir/IR/Matchers.h"
25#include "llvm/ADT/APFloat.h"
26#include "llvm/ADT/APInt.h"
27
28#include <functional>
29
30using namespace mlir;
31using namespace mlir::tosa;
32
33//===----------------------------------------------------------------------===//
34// Operator Canonicalizers.
35//===----------------------------------------------------------------------===//
36
37//===----------------------------------------------------------------------===//
38// Tensor Data Engine Operators.
39//===----------------------------------------------------------------------===//
40
41// Check that the zero point of the tensor and padding operations are aligned.
42static bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
43 // Check that padConst is a constant value and a scalar tensor
44 DenseElementsAttr padConstAttr;
45 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
46 (padConstAttr.size() != 1)) {
47 return false;
48 }
49
50 // Check that floating point pad is zero
51 if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
52 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
53 return padConstVal == 0.0f;
54 }
55
56 // Check that the zp and padConst align for the integer (quantized) case
57 if (auto padConstIntAttr =
58 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
60 // Check that zp is a constant value and a scalar tensor
61 if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
62 return false;
63 }
64
65 // Check equality
66 int64_t zpVal = (*zpAttr.begin()).getSExtValue();
67 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
68 return zpVal == padConstVal;
69 }
70
71 // Bail-out on unsupported type
72 return false;
73}
74
75namespace {
76template <typename OpTy>
77struct PoolPadFoldAdaptor;
78
79template <>
80struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
81 using OpTy = tosa::MaxPool2dOp;
82 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
83 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
84 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
85 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
86 return false;
87 return true;
88 }
89 static bool checkPadConstCompliance(OpTy, Value padConst) {
90 // Check that padConst is a constant value and a scalar tensor
91 DenseElementsAttr padConstAttr;
92 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
93 padConstAttr.size() != 1) {
94 return false;
95 }
96
97 // Pad needs to be in the minimum value to be able to merge
98 if (auto padConstFpAttr =
99 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
100 const APFloat padConstVal = *padConstFpAttr.begin();
101 const APFloat lowestVal =
102 APFloat::getLargest(padConstVal.getSemantics(), true);
103 return padConstVal == lowestVal;
104 }
105 if (auto padConstIntAttr =
106 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
107 const APInt padConstVal = *padConstIntAttr.begin();
108 const unsigned int bitWidth = padConstVal.getBitWidth();
109 const APInt lowestVal =
110 padConstIntAttr.getElementType().isUnsignedInteger()
111 ? APInt::getZero(bitWidth)
112 : APInt::getSignedMinValue(bitWidth);
113 return padConstVal == lowestVal;
114 }
115
116 // Bail-out on unsupported type
117 return false;
118 }
119 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
120 Value padInput, ArrayRef<int64_t> newPad) {
121 rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
122 op, op.getType(), padInput, op.getKernel(), op.getStride(),
123 rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
124 }
125};
126
127template <typename OpTy>
128struct ConvPadFoldAdaptor {
129 static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
130 return true;
131 }
132 static bool checkPadConstCompliance(OpTy op, Value padConst) {
133 return checkMatchingPadConstAndZp(padConst, op.getInputZp());
134 }
135 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
136 Value padInput, ArrayRef<int64_t> newPad) {
137 rewriter.replaceOpWithNewOp<OpTy>(
138 op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
139 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
140 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
141 }
142};
143
144// Pattern attempts to fold a `tosa.pad` operator to a following tensor
145// operation like `tosa.conv2d` by merging the padding associated with the
146// pad operator directly to the implicit padding of the tensor operation.
147// This helps eliminate the explicit padding operator if unused.
148template <typename OpTy, typename AdaptorTy>
149struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
150 using OpRewritePattern<OpTy>::OpRewritePattern;
151
152 LogicalResult matchAndRewrite(OpTy tensorOp,
153 PatternRewriter &rewriter) const override {
154 // Check producer is a tosa::PadOp
155 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
156 if (!padOp)
157 return rewriter.notifyMatchFailure(tensorOp,
158 "Producer must be a tosa::PadOp.");
159
160 // Validate that tensor operation has sane padding
161 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
162 if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
163 return rewriter.notifyMatchFailure(
164 tensorOp, "Tensor operation padding shall have 4 elements.");
165
166 // Validate tosa::PadOp padding
167 DenseIntElementsAttr padOpPadding;
168 if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
169 return rewriter.notifyMatchFailure(
170 tensorOp,
171 "The `padding` input specified on the tosa::PadOp must be constant.");
172 }
173 // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
174 // C_after
175 if (padOpPadding.size() != 8)
176 return rewriter.notifyMatchFailure(tensorOp,
177 "Pad padding should have 8 elements.");
178 int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
179 int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
180 int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
181 int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
182 int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
183 int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
184 int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
185 int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
186
187 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
188 return rewriter.notifyMatchFailure(
189 tensorOp, "Folding padding in N or C dimensions is not supported.");
190
191 // Fold padding from Pad into the tensor operation
192 // 4 elements - pad_top, pad_bottom, pad_left, pad_right
193 SmallVector<int64_t> foldedPad(tensorOpPad.size());
194 foldedPad[0] = padHBefore + tensorOpPad[0];
195 foldedPad[1] = padHAfter + tensorOpPad[1];
196 foldedPad[2] = padWBefore + tensorOpPad[2];
197 foldedPad[3] = padWAfter + tensorOpPad[3];
198
199 // Check kernel related restrictions
200 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
201 return rewriter.notifyMatchFailure(
202 tensorOp, "Padding size not aligned with kernel restrictions.");
203 }
204
205 // Check padding constant restrictions
206 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
207 return rewriter.notifyMatchFailure(
208 tensorOp,
209 "Padding constant is not aligned with operator zero-point.");
210 }
211
212 // Check that padding doesn't grow more than 8K level (8192) for now
213 if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
214 return rewriter.notifyMatchFailure(
215 tensorOp, "Padding size more than the 8K level limit.");
216 }
217
218 // Create operator
219 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
220 foldedPad);
221
222 return success();
223 }
224};
225} // namespace
226
227void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
228 MLIRContext *context) {
229 results.add<
230 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
231 context);
232}
233
234void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
235 MLIRContext *context) {
236 results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
237 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
238 context);
239}
240
241struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
243
244 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
245 PatternRewriter &rewriter) const override {
246 Value input = op.getInput();
247 Value output = op.getOutput();
248 ShapedType inputType = llvm::cast<ShapedType>(input.getType());
249 ShapedType outputType = llvm::cast<ShapedType>(output.getType());
250
251 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
252 return failure();
253 }
254
255 // If the output and input shapes are 1x1, then this is a no op.
256 ArrayRef<int64_t> outputShape = outputType.getShape();
257 if (outputShape[1] != 1 || outputShape[2] != 1) {
258 return failure();
259 }
260
261 ArrayRef<int64_t> inputShape = inputType.getShape();
262 if (inputShape[1] != 1 || inputShape[2] != 1) {
263 return failure();
264 }
265
266 rewriter.replaceOp(op, input);
267 return success();
268 }
269};
270
271void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
272 MLIRContext *context) {
273 results.add<MaxPool2dIsNoOp,
274 FoldPadToTensorOp<tosa::MaxPool2dOp,
275 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
276 context);
277}
278
279//===----------------------------------------------------------------------===//
280// Data Layout / Memory Reinterpretation.
281//===----------------------------------------------------------------------===//
282
283struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
284 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
285
286 LogicalResult matchAndRewrite(tosa::ConcatOp op,
287 PatternRewriter &rewriter) const override {
288 if (op.getInput1().size() != 1)
289 return failure();
290 if (op.getInput1().front().getType() != op.getType()) {
291 rewriter
292 .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
293 op.getInput1().front())
294 .getResult();
295 return success();
296 }
297
298 rewriter.replaceOp(op, op.getInput1().front());
299 return success();
300 }
301};
302
303void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
304 MLIRContext *context) {
305 results.add<ConcatOptimization>(context);
306}
307
308LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
309 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
310 if (!notOp)
311 return failure();
312 rewriter.modifyOpInPlace(op, [&]() {
313 op.getOperation()->setOperands(
314 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
315 });
316 return success();
317}
318
320 : public OpRewritePattern<tosa::TransposeOp> {
322
323 LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
324 PatternRewriter &rewriter) const override {
325 // Input is also TransposeOp - transpose(transpose(A)).
326 auto innerTranspose =
327 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
328 if (!innerTranspose)
329 return rewriter.notifyMatchFailure(transposeOp,
330 "input must be transpose operation");
331
332 const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
333 const llvm::ArrayRef<int32_t> innerTransposePerms =
334 innerTranspose.getPerms();
335
336 if (transposePerms.size() != innerTransposePerms.size())
337 return rewriter.notifyMatchFailure(
338 transposeOp,
339 "transpose and inner transpose perms sizes must be equal");
340 if (transposePerms.empty())
341 return rewriter.notifyMatchFailure(
342 transposeOp, "transpose perms sizes must be positive");
343
344 // Consolidate transposes into one transpose.
345 SmallVector<int32_t> perms(transposePerms.size());
346 for (int i = 0, s = transposePerms.size(); i < s; ++i)
347 perms[i] = innerTransposePerms[transposePerms[i]];
348
349 rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
350 transposeOp, transposeOp.getResult().getType(),
351 innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
352
353 return success();
354 }
355};
356
357// Determines the case when tosa.transpose is a tosa.reshape operation.
358struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
360
361 LogicalResult matchAndRewrite(tosa::TransposeOp op,
362 PatternRewriter &rewriter) const override {
363 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
364 return rewriter.notifyMatchFailure(
365 op, "Src is from transpose, can compose transposes");
366
367 Value result = op.getResult();
368 for (Operation *subop : result.getUsers()) {
369 if (isa_and_nonnull<tosa::TransposeOp>(subop))
370 return rewriter.notifyMatchFailure(
371 op, "Dest is used by transpose, can compose transposes");
372 }
373
374 auto input = op.getInput1();
375 auto inputTy = llvm::cast<ShapedType>(input.getType());
376 if (!inputTy.hasRank())
377 return rewriter.notifyMatchFailure(op, "Unranked input.");
378
379 int64_t numDynDims = 0;
380 for (int i = 0; i < inputTy.getRank(); ++i)
381 if (inputTy.isDynamicDim(i))
382 numDynDims++;
383
384 if (numDynDims > 1)
385 return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
386
387 const llvm::ArrayRef<int32_t> permValues = op.getPerms();
388
389 SmallVector<int64_t> nonZeroPerms;
390 nonZeroPerms.reserve(permValues.size());
391 for (auto idx : permValues) {
392 auto sz = inputTy.getDimSize(idx);
393 if (sz != 1)
394 nonZeroPerms.push_back(idx);
395 }
396
397 for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
398 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
399 return rewriter.notifyMatchFailure(op,
400 "Transpose changes memory layout.");
401
402 SmallVector<int64_t> newShape;
403 newShape.reserve(inputTy.getRank());
404 for (int i = 0, s = inputTy.getRank(); i < s; ++i)
405 newShape.push_back(inputTy.getDimSize(permValues[i]));
406
407 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
408 op, op.getType(), op.getInput1(),
409 getTosaConstShape(rewriter, op.getLoc(), newShape));
410 return success();
411 }
412};
413
414void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
415 MLIRContext *context) {
416 results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
417}
418
419struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
421
422 LogicalResult matchAndRewrite(tosa::ClampOp op,
423 PatternRewriter &rewriter) const override {
424 Value input = op.getInput();
425 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
426 auto inputElementType = inputType.getElementType();
427
428 if (isa<FloatType>(inputElementType)) {
429 // Unlike integer types, floating point types can represent infinity.
430 const auto minClamp =
431 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
432 const auto maxClamp =
433 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
434 const bool isMin = minClamp.isNegInfinity();
435 const bool isMax = maxClamp.isInfinity();
436
437 if (isMin && isMax) {
438 rewriter.replaceOp(op, input);
439 return success();
440 }
441 return failure();
442 }
443
444 // i1 types are boolean in TOSA
445 const bool isBoolean = inputElementType.isInteger(1);
446 if (inputElementType.isUnsignedInteger() || isBoolean) {
447 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
448 .getValue()
449 .getZExtValue();
450 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
451 .getValue()
452 .getZExtValue();
453
454 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
455 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
456 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
457
458 if (minClamp <= intMin && maxClamp >= intMax) {
459 rewriter.replaceOp(op, input);
460 return success();
461 }
462 return failure();
463 }
464
465 if (llvm::isa<IntegerType>(inputElementType)) {
466 const int64_t minClamp =
467 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
468 const int64_t maxClamp =
469 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
470
471 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
472 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
473 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
474
475 if (minClamp <= intMin && maxClamp >= intMax) {
476 rewriter.replaceOp(op, input);
477 return success();
478 }
479 return failure();
480 }
481
482 return failure();
483 }
484};
485
486// Attempts the following transformation:
487//
488// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
489// tensor X the following identity holds:
490//
491// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
492//
493// subject to the following valid NaN propagation semantics:
494// --------------------------------------------
495// | OUTER CLAMP | INNER CLAMP | RESULT MODE |
496// |-------------|--------------|-------------|
497// | PROPAGATE | PROPAGATE | PROPAGATE |
498// | PROPAGATE | IGNORE | IGNORE |
499// | IGNORE | PROPAGATE | INVALID |
500// | IGNORE | IGNORE | IGNORE |
501// |------------------------------------------|
502
503struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
504 using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
505
506 // Helper structure to describe the range of a clamp operation.
507 template <typename T>
508 struct ClampRange {
509 ClampRange(const T &start, const T &end) : start(start), end(end) {}
512
513 // Helper function to determine if two Clamp ranges intersect.
514 bool intersects(const ClampRange<T> &otherRange) {
515 return start < otherRange.end && otherRange.start < end;
516 }
517 };
518
519 LogicalResult matchAndRewrite(tosa::ClampOp op,
520 PatternRewriter &rewriter) const override {
521 Value input = op.getInput();
522
523 // Check the input to the CLAMP op is itself a CLAMP.
524 auto clampOp = input.getDefiningOp<tosa::ClampOp>();
525 if (!clampOp)
526 return failure();
527
528 // Check we have a valid NaN propagation combination.
529 const auto opNanMode = op.getNanMode();
530 const auto clampNanMode = clampOp.getNanMode();
531 if (opNanMode == NanPropagationMode::IGNORE &&
532 clampNanMode == NanPropagationMode::PROPAGATE)
533 return failure();
534
535 auto maxValAttr = op.getMaxValAttr();
536 auto minValAttr = op.getMinValAttr();
537 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
538 auto clampOpMinValAttr = clampOp.getMinValAttr();
539
540 auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
541 if (auto quantType =
542 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
543 inputEType = getStorageElementTypeFromQuantized(quantType);
544 }
545
546 Attribute newMinValAttr, newMaxValAttr;
547 if (mlir::isa<FloatType>(inputEType)) {
548 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
549 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
550 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
551 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
552
553 // Check we have intersecting ranges.
554 const auto opMinFloat = floatMinValAttr.getValue();
555 const auto opMaxFloat = floatMaxValAttr.getValue();
556 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
557 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
558 ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
559 ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
560 clampOpMaxFloat);
561 if (!opRangeFloatRange.intersects(clampRangeFloatRange))
562 return failure();
563
564 // Run the transformation.
565 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
566 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
567 newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
568 newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
569 } else {
570 assert(mlir::isa<IntegerType>(inputEType));
571 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
572 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
573 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
574 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
575
576 if (inputEType.isUnsignedInteger()) {
577 // Check we have intersecting ranges.
578 const auto opMinInt = intMinValAttr.getUInt();
579 const auto opMaxInt = intMaxValAttr.getUInt();
580 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
581 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
582 ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
583 ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
584 clampOpMaxInt);
585 if (!opRangeIntRange.intersects(clampRangeIntRange))
586 return failure();
587
588 // Run the transformation.
589 auto newMinVal = std::max(opMinInt, clampOpMinInt);
590 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
591 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
592 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
593 } else {
594 // Check we have intersecting ranges.
595 const auto opMinInt = intMinValAttr.getInt();
596 const auto opMaxInt = intMaxValAttr.getInt();
597 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
598 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
599 ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
600 ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
601 clampOpMaxInt);
602 if (!opRangeIntRange.intersects(clampRangeIntRange))
603 return failure();
604
605 // Run the transformation.
606 auto newMinVal = std::max(opMinInt, clampOpMinInt);
607 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
608 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
609 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
610 }
611 }
612
613 auto newMode = (opNanMode != clampNanMode)
614 ? tosa::NanPropagationMode::IGNORE
615 : opNanMode;
616
617 auto newModeAttr =
618 NanPropagationModeAttr::get(rewriter.getContext(), newMode);
619
620 rewriter.replaceOpWithNewOp<tosa::ClampOp>(
621 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
622 newModeAttr);
623 return success();
624 }
625};
626
627void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
628 MLIRContext *context) {
629 results.add<ClampIsNoOp>(context);
630 results.add<ClampClampOptimization>(context);
631}
632
633struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
634 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
635
636 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
637 PatternRewriter &rewriter) const override {
638 Value sliceInput = sliceOp.getInput1();
639 auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
640 if (!concatOp)
641 return rewriter.notifyMatchFailure(
642 sliceOp, "slice input must be concat operation");
643
644 OperandRange inputs = concatOp.getInput1();
645 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
646 if (!concatType || !concatType.hasStaticShape())
647 return rewriter.notifyMatchFailure(
648 sliceOp, "slice input must be a static ranked tensor");
649 int32_t axis = concatOp.getAxis();
650
651 DenseElementsAttr startElems;
652 DenseElementsAttr sizeElems;
653
654 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
655 return rewriter.notifyMatchFailure(
656 sliceOp, "start of slice must be a static ranked shape");
657
658 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
659 return rewriter.notifyMatchFailure(
660 sliceOp, "size of slice must be a static ranked shape");
661
662 llvm::SmallVector<int64_t> sliceStarts =
663 llvm::to_vector(startElems.getValues<int64_t>());
664 llvm::SmallVector<int64_t> sliceSizes =
665 llvm::to_vector(sizeElems.getValues<int64_t>());
666
667 // Validate slice on the concatenated axis. Slicing along this
668 // axis should span only one of the inputs to the concatenate
669 // operation.
670 std::optional<Value> replaceWithSlice;
671 for (auto input : inputs) {
672 auto inputType = dyn_cast<RankedTensorType>(input.getType());
673 if (!inputType || !inputType.hasStaticShape())
674 return rewriter.notifyMatchFailure(
675 sliceOp, "concat input must be a static ranked tensor");
676
677 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
678 inputType.getDimSize(axis)) {
679 auto start_op =
680 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
681 auto size_op =
682 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
683 replaceWithSlice =
684 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
685 input, start_op, size_op)
686 .getResult();
687 break;
688 }
689 sliceStarts[axis] -= inputType.getDimSize(axis);
690 }
691
692 if (!replaceWithSlice)
693 return rewriter.notifyMatchFailure(
694 sliceOp, "corresponding concat input not found for slice");
695
696 rewriter.replaceOp(sliceOp, replaceWithSlice.value());
697 return success();
698 }
699};
700
701struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
702 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
703
704 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
705 PatternRewriter &rewriter) const override {
706 Value sliceInput = sliceOp.getInput1();
707
708 // Check if producer is a PadOp
709 auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
710 if (!padOp)
711 return rewriter.notifyMatchFailure(sliceOp,
712 "slice input must be a pad operation");
713
714 // Check PadOp has a single consumer
715 if (!padOp->hasOneUse())
716 return rewriter.notifyMatchFailure(sliceOp,
717 "pad shall have a single consumer");
718
719 // Check input is statically ranked
720 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
721 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
722 if (!inputTy || !padTy || !inputTy.hasRank())
723 return rewriter.notifyMatchFailure(sliceOp,
724 "slice input must be a ranked tensor");
725
726 // Validate and extract tosa::PadOp padding
727 DenseIntElementsAttr paddingElems;
728 if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
729 return rewriter.notifyMatchFailure(
730 sliceOp,
731 "`padding` input specified on the tosa::PadOp must be constant.");
732 }
733 llvm::SmallVector<int64_t> padPaddings =
734 llvm::to_vector(paddingElems.getValues<int64_t>());
735
736 // Extract slice parameters
737 DenseElementsAttr startElems;
738 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
739 return rewriter.notifyMatchFailure(
740 sliceOp, "start of slice must be a static ranked shape");
741 llvm::SmallVector<int64_t> sliceStarts =
742 llvm::to_vector(startElems.getValues<int64_t>());
743
744 DenseElementsAttr sizeElems;
745 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
746 return rewriter.notifyMatchFailure(
747 sliceOp, "size of slice must be a static ranked shape");
748 llvm::SmallVector<int64_t> sliceSizes =
749 llvm::to_vector(sizeElems.getValues<int64_t>());
750
751 // Check if dynamic dimensions are sliced
752 const int64_t rank = inputTy.getRank();
753 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
754 const bool isDimDynamic = inputTy.isDynamicDim(i);
755 const bool isDimSliced =
756 (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
757
758 return isDimDynamic && isDimSliced;
759 })) {
760 return rewriter.notifyMatchFailure(
761 sliceOp, "axis that are sliced shall be statically known.");
762 }
763
764 // Update the parameters
765 llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
766 llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
767 llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
768 bool updated = false;
769
770 for (int64_t i = 0; i < rank; ++i) {
771 const int64_t padLo = padPaddings[i * 2];
772 const int64_t padHi = padPaddings[i * 2 + 1];
773 const int64_t sliceStart = sliceStarts[i];
774 const int64_t sliceSize = sliceSizes[i];
775 const int64_t sliceEnd = sliceStart + sliceSize;
776
777 // If dimension is dynamic pass-through
778 if (inputTy.isDynamicDim(i)) {
779 newPadPaddings[i * 2] = padLo;
780 newPadPaddings[i * 2 + 1] = padHi;
781 newSliceStarts[i] = sliceStart;
782 continue;
783 }
784
785 // Handle static dimensions
786 const int64_t dimSize = inputTy.getShape()[i];
787 const int64_t dimTotal = padLo + dimSize + padHi;
788
789 // Check slice within bounds
790 if (sliceStart < 0 || sliceEnd > dimTotal)
791 return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
792
793 // Compute updated slice start parameter
794 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
795 newSliceStarts[i] = newSliceStart;
796 updated |= newSliceStart != sliceStart;
797
798 // Compute updated pad parameters
799 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
800 const int64_t newPadHi =
801 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
802 newPadPaddings[i * 2] = newPadLo;
803 newPadPaddings[i * 2 + 1] = newPadHi;
804 updated |= (newPadLo != padLo) || (newPadHi != padHi);
805
806 // Calculate new pad output shape
807 newPadShape[i] =
808 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
809 }
810
811 // Check that we actually need to proceed with the rewrite
812 if (!updated)
813 return rewriter.notifyMatchFailure(
814 sliceOp, "terminate condition; nothing to rewrite");
815
816 // Create a PadOp with updated padding
817 auto newPaddingsOp =
818 getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
819 auto newPadTy =
820 RankedTensorType::get(newPadShape, inputTy.getElementType());
821 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
822 padOp.getInput1(), newPaddingsOp,
823 padOp.getPadConst());
824
825 // Update SliceOp and point to new PadOp
826 auto newStartOp =
827 getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
828 rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
829 newPadOp.getResult(), newStartOp,
830 sliceOp.getSize());
831
832 return success();
833 }
834};
835
836// Update size operand of tosa.slice if size has dynamic dims but corresponding
837// output dim is static
839 : public OpRewritePattern<tosa::SliceOp> {
840 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
841
842 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
843 PatternRewriter &rewriter) const override {
844 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
845
846 ElementsAttr sizeElems;
847 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
848 return rewriter.notifyMatchFailure(
849 sliceOp, "size of slice must be a static ranked shape");
850 }
851
852 llvm::SmallVector<int64_t> sliceSizes =
853 llvm::to_vector(sizeElems.getValues<int64_t>());
854
855 bool replaceSliceSize{false};
856 // if size op has -1 indicating dynamic shape but corresponding dim on the
857 // output is statically known, update size to match with known output dim
858 // shape
859 for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
860 if (size == -1 && !resultType.isDynamicDim(index)) {
861 sliceSizes[index] = resultType.getDimSize(index);
862 replaceSliceSize = true;
863 }
864 }
865
866 if (!replaceSliceSize) {
867 return rewriter.notifyMatchFailure(
868 sliceOp, "no dimension of size of slice is dynamic that resolves "
869 "to static output shape");
870 }
871
872 auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
873 auto newSliceOp =
874 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
875 sliceOp.getInput1(), sliceOp.getStart(), size_op);
876
877 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
878 return success();
879 }
880};
881
882void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
883 MLIRContext *context) {
884 results.add<ConcatSliceOptimization, PadSliceOptimization,
885 SliceDynamicSizeCanonicalization>(context);
886}
887
888//===----------------------------------------------------------------------===//
889// Operator Folders.
890//===----------------------------------------------------------------------===//
891
892template <typename IntFolder, typename FloatFolder>
895 RankedTensorType returnTy) {
896 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
897 auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
898 auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
899 if (lETy != rETy)
900 return {};
901
902 if (llvm::isa<IntegerType>(lETy)) {
903 APInt l = lhs.getSplatValue<APInt>();
904 APInt r = rhs.getSplatValue<APInt>();
905 auto result = IntFolder()(l, r);
906 return DenseElementsAttr::get(returnTy, result);
907 }
908
909 if (llvm::isa<FloatType>(lETy)) {
910 APFloat l = lhs.getSplatValue<APFloat>();
911 APFloat r = rhs.getSplatValue<APFloat>();
912 auto result = FloatFolder()(l, r);
913 return DenseElementsAttr::get(returnTy, result);
914 }
915 }
916
917 return {};
918}
919
920static bool isSplatZero(Type elemType, DenseElementsAttr val) {
921 if (llvm::isa<FloatType>(elemType))
922 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
923 if (llvm::isa<IntegerType>(elemType))
924 return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
925 return false;
926}
927
928static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
929 if (llvm::isa<FloatType>(elemType))
930 return val && val.isSplat() &&
931 val.getSplatValue<APFloat>().isExactlyValue(1.0);
932 if (llvm::isa<IntegerType>(elemType)) {
933 const int64_t shifted = 1LL << shift;
934 return val && val.isSplat() &&
935 val.getSplatValue<APInt>().getSExtValue() == shifted;
936 }
937 return false;
938}
939
940OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
941 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
942 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
943 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
944 if (!lhsTy || !rhsTy || !resultTy)
945 return {};
946
947 // Cannot create an ElementsAttr from non-int/float/index types
948 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
949 !rhsTy.getElementType().isIntOrIndexOrFloat())
950 return {};
951
952 auto resultETy = resultTy.getElementType();
953 auto lhsAttr =
954 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
955 auto rhsAttr =
956 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
957
958 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
959 return getInput1();
960 if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
961 return getInput2();
962
963 if (!lhsAttr || !rhsAttr)
964 return {};
965
966 return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
967 resultTy);
968}
969
970OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
971 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
972 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
973 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
974 !outputTy.hasStaticShape())
975 return {};
976
977 const Type outputElementTy = getElementTypeOrSelf(outputTy);
978 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
979 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
980 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
981 return DenseElementsAttr::get(outputTy, zero);
982 }
983
984 return {};
985}
986
987OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
988 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
989 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
990 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
991 if (!lhsTy || !rhsTy || !resultTy)
992 return {};
993 if (lhsTy != rhsTy)
994 return {};
995
996 // IntDivOp inputs must be integer type, no need to check for quantized type
997 auto resultETy = resultTy.getElementType();
998 auto lhsAttr =
999 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1000 auto rhsAttr =
1001 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1002 if (lhsAttr && lhsAttr.isSplat()) {
1003 if (llvm::isa<IntegerType>(resultETy) &&
1004 lhsAttr.getSplatValue<APInt>().isZero())
1005 return lhsAttr;
1006 }
1007
1008 if (rhsAttr && rhsAttr.isSplat()) {
1009 if (llvm::isa<IntegerType>(resultETy) &&
1010 rhsAttr.getSplatValue<APInt>().isOne())
1011 return getInput1();
1012 }
1013
1014 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1015 llvm::isa<IntegerType>(resultETy)) {
1016 APInt l = lhsAttr.getSplatValue<APInt>();
1017 APInt r = rhsAttr.getSplatValue<APInt>();
1018 if (!r.isZero()) {
1019 APInt result = l.sdiv(r);
1020 return DenseElementsAttr::get(resultTy, result);
1021 }
1022 }
1023
1024 return {};
1025}
1026
1027namespace {
1028// calculate lhs * rhs >> shift according to TOSA Spec
1029// return nullopt if result is not in range of int32_t when shift > 0
1030std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
1031 unsigned bitwidth) {
1032 APInt result = lhs.sext(64) * rhs.sext(64);
1033
1034 if (shift > 0) {
1035 auto round = APInt(64, 1) << (shift - 1);
1036 result += round;
1037 result.ashrInPlace(shift);
1038 // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
1039 if (!(result.getSExtValue() >= INT32_MIN &&
1040 result.getSExtValue() <= INT32_MAX)) {
1041 // REQUIRE failed
1042 return std::nullopt;
1043 }
1044 }
1045
1046 return result.trunc(bitwidth);
1047}
1048
1049DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
1050 RankedTensorType ty, int32_t shift) {
1051 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
1052 if (llvm::isa<IntegerType>(ty.getElementType())) {
1053 APInt l = lhs.getSplatValue<APInt>();
1054 APInt r = rhs.getSplatValue<APInt>();
1055
1056 if (shift == 0) {
1057 return DenseElementsAttr::get(ty, l * r);
1058 }
1059
1060 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1061 const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
1062 if (!result)
1063 return {};
1064 return DenseElementsAttr::get(ty, result.value());
1065 }
1066
1067 if (llvm::isa<FloatType>(ty.getElementType())) {
1068 APFloat l = lhs.getSplatValue<APFloat>();
1069 APFloat r = rhs.getSplatValue<APFloat>();
1070 APFloat result = l * r;
1071 return DenseElementsAttr::get(ty, result);
1072 }
1073 }
1074
1075 return {};
1076}
1077} // namespace
1078
1079OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1080 auto lhs = getInput1();
1081 auto rhs = getInput2();
1082 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
1083 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
1084 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1085 if (!lhsTy || !rhsTy || !resultTy)
1086 return {};
1087
1088 auto resultETy = resultTy.getElementType();
1089 auto lhsAttr =
1090 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1091 auto rhsAttr =
1092 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1093
1094 // Result right shift on i32_t data type only. For simplification, synthesize
1095 // a zero shift for other data type.
1096 int32_t shift = 0;
1097 if (resultETy.isInteger(32)) {
1098 ElementsAttr shift_elem;
1099 if (getShift().getImpl()) {
1100 if (!matchPattern(getShift(), m_Constant(&shift_elem)))
1101 // cannot be folded when the shift value is unknown.
1102 return {};
1103 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1104 }
1105 }
1106
1107 if (rhsTy == resultTy) {
1108 if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
1109 // constant values can only be resized if resulting type is static
1110 return lhsAttr.resizeSplat(resultTy);
1111 if (isSplatOne(resultETy, lhsAttr, shift))
1112 return rhs;
1113 }
1114 if (lhsTy == resultTy) {
1115 if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
1116 return rhsAttr.resizeSplat(resultTy);
1117 if (isSplatOne(resultETy, rhsAttr, shift))
1118 return lhs;
1119 }
1120
1121 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1122}
1123
1124OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1125 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1126 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1127 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1128 if (!lhsTy || !rhsTy || !resultTy)
1129 return {};
1130
1131 // Cannot create an ElementsAttr from non-int/float/index types
1132 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1133 !rhsTy.getElementType().isIntOrIndexOrFloat())
1134 return {};
1135
1136 auto resultETy = resultTy.getElementType();
1137 auto lhsAttr =
1138 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1139 auto rhsAttr =
1140 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1141
1142 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1143 return getInput1();
1144
1145 if (!lhsAttr || !rhsAttr)
1146 return {};
1147
1148 return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
1149 resultTy);
1150}
1151
1152namespace {
1153template <typename Cmp>
1154struct ComparisonFold {
1155 ComparisonFold() = default;
1156 APInt operator()(const APInt &l, const APInt &r) {
1157 return APInt(1, Cmp()(l, r));
1158 }
1159
1160 APInt operator()(const APFloat &l, const APFloat &r) {
1161 return APInt(1, Cmp()(l, r));
1162 }
1163};
1164
1165struct APIntFoldGreater {
1166 APIntFoldGreater() = default;
1167 APInt operator()(const APInt &l, const APInt &r) {
1168 return APInt(1, l.sgt(r));
1169 }
1170};
1171
1172struct APIntFoldGreaterEqual {
1173 APIntFoldGreaterEqual() = default;
1174 APInt operator()(const APInt &l, const APInt &r) {
1175 return APInt(1, l.sge(r));
1176 }
1177};
1178} // namespace
1179
1180OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1181 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1182 auto lhsAttr =
1183 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1184 auto rhsAttr =
1185 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1186
1187 if (!lhsAttr || !rhsAttr)
1188 return {};
1189
1191 lhsAttr, rhsAttr, resultTy);
1192}
1193
1194OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1195 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1196 auto lhsAttr =
1197 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1198 auto rhsAttr =
1199 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1200
1201 if (!lhsAttr || !rhsAttr)
1202 return {};
1203
1204 return binaryFolder<APIntFoldGreaterEqual,
1205 ComparisonFold<std::greater_equal<APFloat>>>(
1206 lhsAttr, rhsAttr, resultTy);
1207}
1208
1209OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1210 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1211 auto lhsAttr =
1212 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1213 auto rhsAttr =
1214 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1215 Value lhs = getInput1();
1216 Value rhs = getInput2();
1217 auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
1218
1219 // If we are comparing an integer value to itself it is always true. We can
1220 // not do this with float due to float values.
1221 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1222 resultTy.hasStaticShape() && lhs == rhs) {
1223 return DenseElementsAttr::get(resultTy, true);
1224 }
1225
1226 if (!lhsAttr || !rhsAttr)
1227 return {};
1228
1230 ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
1231 resultTy);
1232}
1233
1234OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1235 if (getInput().getType() == getType())
1236 return getInput();
1237
1238 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1239 if (!operand)
1240 return {};
1241
1242 auto inTy = llvm::cast<ShapedType>(getInput().getType());
1243 auto outTy = llvm::cast<ShapedType>(getType());
1244 auto inETy = inTy.getElementType();
1245 auto outETy = outTy.getElementType();
1246
1247 if (operand.isSplat()) {
1248 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1249 bool overflow;
1250 auto splatVal = operand.getSplatValue<APFloat>();
1251 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1252 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1253 &overflow);
1254 return SplatElementsAttr::get(outTy, splatVal);
1255 }
1256
1257 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1258 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1259 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1260 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1261 llvm::RoundingMode::NearestTiesToEven);
1262 return SplatElementsAttr::get(outTy, splatVal);
1263 }
1264
1265 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1266 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1267 auto intVal = APSInt(
1268 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1269 auto floatVal = operand.getSplatValue<APFloat>();
1270 bool exact;
1271 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1272 &exact);
1273 return SplatElementsAttr::get(outTy, intVal);
1274 }
1275
1276 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1277 const auto inIntType = llvm::cast<IntegerType>(inETy);
1278 auto unsignIn = inIntType.isUnsignedInteger();
1279 bool trunc =
1280 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1281 auto intVal = operand.getSplatValue<APInt>();
1282 auto bitwidth = outETy.getIntOrFloatBitWidth();
1283
1284 // i1 types are boolean in TOSA
1285 if (outETy.isInteger(1)) {
1286 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1287 } else if (trunc) {
1288 intVal = intVal.trunc(bitwidth);
1289 } else if (unsignIn || inIntType.isInteger(1)) {
1290 intVal = intVal.zext(bitwidth);
1291 } else {
1292 intVal = intVal.sext(bitwidth);
1293 }
1294
1295 return SplatElementsAttr::get(outTy, intVal);
1296 }
1297 }
1298
1299 return {};
1300}
1301
1302OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1303
1304OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1305
1306#define REDUCE_FOLDER(OP) \
1307 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1308 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1309 if (!inputTy.hasRank()) \
1310 return {}; \
1311 if (inputTy != getType()) \
1312 return {}; \
1313 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1314 return getInput(); \
1315 return {}; \
1316 }
1317
1318REDUCE_FOLDER(ReduceAllOp)
1319REDUCE_FOLDER(ReduceAnyOp)
1320REDUCE_FOLDER(ReduceMaxOp)
1321REDUCE_FOLDER(ReduceMinOp)
1322REDUCE_FOLDER(ReduceProductOp)
1323REDUCE_FOLDER(ReduceSumOp)
1324#undef REDUCE_FOLDER
1325
1326OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1327 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1328 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1329
1330 if (!inputTy || !outputTy)
1331 return {};
1332
1333 // Fold when the input and output types are the same. This is only safe when
1334 // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
1335 // there may still be a productive reshape.
1336 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1337 return getInput1();
1338
1339 // reshape(reshape(x)) -> reshape(x)
1340 if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1341 getInput1().getDefiningOp())) {
1342 getInput1Mutable().assign(reshapeOp.getInput1());
1343 return getResult();
1344 }
1345
1346 // Cannot create an ElementsAttr from non-int/float/index types
1347 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1348 return {};
1349
1350 // reshape(const(x)) -> const(reshape-attr(x))
1351 if (auto operand =
1352 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1353 // Constants must have static shape.
1354 if (!outputTy.hasStaticShape())
1355 return {};
1356
1357 // Okay to duplicate splat constants.
1358 if (operand.isSplat())
1359 return SplatElementsAttr::get(outputTy,
1360 operand.getSplatValue<Attribute>());
1361
1362 // Don't duplicate other constants.
1363 if (!getInput1().hasOneUse())
1364 return {};
1365
1367 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
1368 return {};
1369
1370 return operand.reshape(
1371 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1372 }
1373
1374 return {};
1375}
1376
1377OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1378 // If the pad is all zeros we can fold this operation away.
1379 if (adaptor.getPadding() && getInput1().getType() == getType()) {
1380 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1381 if (densePad && densePad.isSplat() &&
1382 densePad.getSplatValue<APInt>().isZero()) {
1383 return getInput1();
1384 }
1385 }
1386
1387 return {};
1388}
1389
1390// Fold away cases where a tosa.resize operation returns a copy
1391// of the input image.
1392OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1393 auto scaleAttr =
1394 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1395 auto offsetAttr =
1396 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1397 auto borderAttr =
1398 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1399 if (!scaleAttr || !offsetAttr || !borderAttr) {
1400 return {};
1401 }
1402
1403 auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
1404 auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
1405 auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
1406 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1407 return {};
1408 }
1409
1410 // Check unit scaling.
1411 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1412 return {};
1413 }
1414
1415 // There should be no offset.
1416 if (offset[0] != 0 || offset[1] != 0) {
1417 return {};
1418 }
1419
1420 // There should be no border.
1421 if (border[0] != 0 || border[1] != 0) {
1422 return {};
1423 }
1424
1425 auto input = getInput();
1426 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1427 auto resultTy = llvm::cast<RankedTensorType>(getType());
1428 if (inputTy != resultTy)
1429 return {};
1430
1431 return input;
1432}
1433
1434OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1435 auto operand = getInput1();
1436 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1437 auto axis = getAxis();
1438 auto operandAttr =
1439 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1440 if (operandAttr)
1441 return operandAttr;
1442
1443 // If the dim-length is 1, tosa.reverse is a no-op.
1444 if (operandTy.hasRank() &&
1445 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1446 return operand;
1447
1448 return {};
1449}
1450
1451OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1452 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1453 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1454
1455 if (!inputTy || !outputTy)
1456 return {};
1457
1458 if (inputTy == outputTy && inputTy.hasStaticShape())
1459 return getInput1();
1460
1461 if (!adaptor.getInput1())
1462 return {};
1463
1464 // Cannot create an ElementsAttr from non-int/float/index types
1465 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1466 !outputTy.getElementType().isIntOrIndexOrFloat())
1467 return {};
1468
1469 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1470 if (operand.isSplat() && outputTy.hasStaticShape()) {
1471 return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
1472 }
1473
1474 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1475 outputTy.getNumElements() == 1) {
1476 DenseElementsAttr startElems;
1477 if (!matchPattern(getStart(), m_Constant(&startElems)))
1478 return {};
1479
1480 llvm::SmallVector<uint64_t> indices =
1481 llvm::to_vector(startElems.getValues<uint64_t>());
1482 auto value = operand.getValues<Attribute>()[indices];
1483 return SplatElementsAttr::get(outputTy, value);
1484 }
1485
1486 return {};
1487}
1488
1489static bool
1491 const auto isDynamic = [](Type ty) {
1492 const auto shapedTy = llvm::dyn_cast<ShapedType>(ty);
1493 return !shapedTy || !shapedTy.hasStaticShape();
1494 };
1495
1496 return llvm::any_of(operandTypes, isDynamic) ||
1497 failed(verifyCompatibleShapes(operandTypes));
1498}
1499
1500OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1501 // Select allows operand shapes to be broadcast to the output shape. For
1502 // now, don't support folding when we cannot prove no broadcasting is
1503 // involved.
1504 if (mayRequireBroadcast(getOperandTypes()))
1505 return {};
1506
1507 if (getOnTrue() == getOnFalse())
1508 return getOnTrue();
1509
1510 auto predicate =
1511 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1512 if (!predicate)
1513 return {};
1514
1515 if (!predicate.isSplat())
1516 return {};
1517 return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1518 : getOnFalse();
1519}
1520
1521OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1522 if (getInput1().getType() == getType()) {
1523 if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1524 adaptor.getMultiples())) {
1525 if (multiples.isSplat() &&
1526 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1527 return getInput1();
1528 if (auto int_array_attr =
1529 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1530 if (llvm::all_of(int_array_attr.getValues<APInt>(),
1531 [](APInt v) { return v.getSExtValue() == 1; }))
1532 return getInput1();
1533 }
1534 }
1535 }
1536 return {};
1537}
1538
1539OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1540 auto resultTy = llvm::cast<ShapedType>(getType());
1541
1542 // Transposing splat values just means reshaping.
1543 if (auto input =
1544 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1545 if (input.isSplat() && resultTy.hasStaticShape() &&
1546 input.getType().getElementType() == resultTy.getElementType())
1547 return input.reshape(resultTy);
1548 }
1549
1550 // Transpose is not the identity transpose.
1551 const llvm::ArrayRef<int32_t> perms = getPerms();
1552
1553 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1554 return {};
1555
1556 return getInput1();
1557}
1558
1559OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1560 // Element-wise negate(negate(x)) = x
1561 // iff all zero points are constant 0
1562 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1563 if (!definingOp) {
1564 // defining op of input1 is not a negate, cannot fold
1565 return {};
1566 }
1567
1568 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1569 failed(maybeIZp) || *maybeIZp != 0) {
1570 // input1 zero point is not constant 0, cannot fold
1571 return {};
1572 }
1573 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1574 failed(maybeOZp) || *maybeOZp != 0) {
1575 // output zero point is not constant 0, cannot fold
1576 return {};
1577 }
1578 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1579 failed(maybeIZp) || *maybeIZp != 0) {
1580 // definingOp's input1 zero point is not constant 0, cannot fold
1581 return {};
1582 }
1583 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1584 failed(maybeOZp) || *maybeOZp != 0) {
1585 // definingOp's output zero point is not constant 0, cannot fold
1586 return {};
1587 }
1588
1589 return definingOp.getInput1();
1590}
1591
1592OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1593 auto input = getInput1();
1594 // Element-wise abs(abs(x)) = abs(x)
1595 if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
1596 return input;
1597 }
1598
1599 return {};
1600}
1601
1602OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1603 // Fold consecutive concats on the same axis into a single op.
1604 // Keep track of the operands so we are able to construct a new concat
1605 // later. Conservatively assume that we double the number of operands when
1606 // folding
1607 SmallVector<Value, 8> concatOperands;
1608 concatOperands.reserve(2 * getNumOperands());
1609
1610 // Find all operands that are foldable concats
1611 bool foundFoldableConcat = false;
1612 for (Value operand : getOperands()) {
1613 concatOperands.emplace_back(operand);
1614
1615 auto producer = operand.getDefiningOp<ConcatOp>();
1616 if (!producer)
1617 continue;
1618
1619 // Not foldable if axes are not the same
1620 if (getAxis() != producer.getAxis())
1621 continue;
1622
1623 // Replace the original operand with all incoming operands
1624 foundFoldableConcat = true;
1625 concatOperands.pop_back();
1626 llvm::append_range(concatOperands, producer->getOperands());
1627 }
1628
1629 if (!foundFoldableConcat)
1630 return {};
1631
1632 getOperation()->setOperands(concatOperands);
1633 return getResult();
1634}
1635
1636OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1637 auto input = adaptor.getInput1();
1638
1639 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1640 // Fold splat inputs only.
1641 if (!inputAttr || !inputAttr.isSplat())
1642 return {};
1643
1644 auto shapeType = llvm::cast<ShapedType>(getType());
1645 if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1646 auto floatVal = inputAttr.getSplatValue<APFloat>();
1647 return DenseElementsAttr::get(shapeType,
1648 ReciprocalOp::calcOneElement(floatVal));
1649 }
1650
1651 return {};
1652}
return success()
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
lhs
#define REDUCE_FOLDER(OP)
static bool mayRequireBroadcast(ValueTypeRange< mlir::OperandRange > operandTypes)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
MLIRContext * getContext() const
Definition Builders.h:56
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
This class implements iteration on the types of a given range of values.
Definition TypeRange.h:135
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
DynamicAPInt round(const Fraction &f)
Definition Fraction.h:136
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...