MLIR 22.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// TOSA canonicalization patterns and folders.
11//
12//===----------------------------------------------------------------------===//
13
20#include "mlir/IR/Matchers.h"
24#include "llvm/ADT/APFloat.h"
25#include "llvm/ADT/APInt.h"
26
27#include <functional>
28
29using namespace mlir;
30using namespace mlir::tosa;
31
32//===----------------------------------------------------------------------===//
33// Operator Canonicalizers.
34//===----------------------------------------------------------------------===//
35
36//===----------------------------------------------------------------------===//
37// Tensor Data Engine Operators.
38//===----------------------------------------------------------------------===//
39
40// Check that the zero point of the tensor and padding operations are aligned.
41static bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
42 // Check that padConst is a constant value and a scalar tensor
43 DenseElementsAttr padConstAttr;
44 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
45 (padConstAttr.size() != 1)) {
46 return false;
47 }
48
49 // Check that floating point pad is zero
50 if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
51 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
52 return padConstVal == 0.0f;
53 }
54
55 // Check that the zp and padConst align for the integer (quantized) case
56 if (auto padConstIntAttr =
57 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
59 // Check that zp is a constant value and a scalar tensor
60 if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
61 return false;
62 }
63
64 // Check equality
65 int64_t zpVal = (*zpAttr.begin()).getSExtValue();
66 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
67 return zpVal == padConstVal;
68 }
69
70 // Bail-out on unsupported type
71 return false;
72}
73
74namespace {
75template <typename OpTy>
76struct PoolPadFoldAdaptor;
77
78template <>
79struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
80 using OpTy = tosa::MaxPool2dOp;
81 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
82 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
83 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
84 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
85 return false;
86 return true;
87 }
88 static bool checkPadConstCompliance(OpTy, Value padConst) {
89 // Check that padConst is a constant value and a scalar tensor
90 DenseElementsAttr padConstAttr;
91 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
92 padConstAttr.size() != 1) {
93 return false;
94 }
95
96 // Pad needs to be in the minimum value to be able to merge
97 if (auto padConstFpAttr =
98 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
99 const APFloat padConstVal = *padConstFpAttr.begin();
100 const APFloat lowestVal =
101 APFloat::getLargest(padConstVal.getSemantics(), true);
102 return padConstVal == lowestVal;
103 }
104 if (auto padConstIntAttr =
105 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
106 const APInt padConstVal = *padConstIntAttr.begin();
107 const unsigned int bitWidth = padConstVal.getBitWidth();
108 const APInt lowestVal =
109 padConstIntAttr.getElementType().isUnsignedInteger()
110 ? APInt::getZero(bitWidth)
111 : APInt::getSignedMinValue(bitWidth);
112 return padConstVal == lowestVal;
113 }
114
115 // Bail-out on unsupported type
116 return false;
117 }
118 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
119 Value padInput, ArrayRef<int64_t> newPad) {
120 rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
121 op, op.getType(), padInput, op.getKernel(), op.getStride(),
122 rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
123 }
124};
125
126template <typename OpTy>
127struct ConvPadFoldAdaptor {
128 static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
129 return true;
130 }
131 static bool checkPadConstCompliance(OpTy op, Value padConst) {
132 return checkMatchingPadConstAndZp(padConst, op.getInputZp());
133 }
134 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
135 Value padInput, ArrayRef<int64_t> newPad) {
136 rewriter.replaceOpWithNewOp<OpTy>(
137 op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
138 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
139 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
140 }
141};
142
143// Pattern attempts to fold a `tosa.pad` operator to a following tensor
144// operation like `tosa.conv2d` by merging the padding associated with the
145// pad operator directly to the implicit padding of the tensor operation.
146// This helps eliminate the explicit padding operator if unused.
147template <typename OpTy, typename AdaptorTy>
148struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
149 using OpRewritePattern<OpTy>::OpRewritePattern;
150
151 LogicalResult matchAndRewrite(OpTy tensorOp,
152 PatternRewriter &rewriter) const override {
153 // Check producer is a tosa::PadOp
154 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
155 if (!padOp)
156 return rewriter.notifyMatchFailure(tensorOp,
157 "Producer must be a tosa::PadOp.");
158
159 // Validate that tensor operation has sane padding
160 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
161 if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
162 return rewriter.notifyMatchFailure(
163 tensorOp, "Tensor operation padding shall have 4 elements.");
164
165 // Validate tosa::PadOp padding
166 DenseIntElementsAttr padOpPadding;
167 if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
168 return rewriter.notifyMatchFailure(
169 tensorOp,
170 "The `padding` input specified on the tosa::PadOp must be constant.");
171 }
172 // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
173 // C_after
174 if (padOpPadding.size() != 8)
175 return rewriter.notifyMatchFailure(tensorOp,
176 "Pad padding should have 8 elements.");
177 int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
178 int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
179 int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
180 int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
181 int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
182 int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
183 int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
184 int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
185
186 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
187 return rewriter.notifyMatchFailure(
188 tensorOp, "Folding padding in N or C dimensions is not supported.");
189
190 // Fold padding from Pad into the tensor operation
191 // 4 elements - pad_top, pad_bottom, pad_left, pad_right
192 SmallVector<int64_t> foldedPad(tensorOpPad.size());
193 foldedPad[0] = padHBefore + tensorOpPad[0];
194 foldedPad[1] = padHAfter + tensorOpPad[1];
195 foldedPad[2] = padWBefore + tensorOpPad[2];
196 foldedPad[3] = padWAfter + tensorOpPad[3];
197
198 // Check kernel related restrictions
199 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
200 return rewriter.notifyMatchFailure(
201 tensorOp, "Padding size not aligned with kernel restrictions.");
202 }
203
204 // Check padding constant restrictions
205 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
206 return rewriter.notifyMatchFailure(
207 tensorOp,
208 "Padding constant is not aligned with operator zero-point.");
209 }
210
211 // Check that padding doesn't grow more than 8K level (8192) for now
212 if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
213 return rewriter.notifyMatchFailure(
214 tensorOp, "Padding size more than the 8K level limit.");
215 }
216
217 // Create operator
218 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
219 foldedPad);
220
221 return success();
222 }
223};
224} // namespace
225
226void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
227 MLIRContext *context) {
228 results.add<
229 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
230 context);
231}
232
233void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
234 MLIRContext *context) {
235 results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
236 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
237 context);
238}
239
240struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
242
243 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
244 PatternRewriter &rewriter) const override {
245 Value input = op.getInput();
246 Value output = op.getOutput();
247 ShapedType inputType = llvm::cast<ShapedType>(input.getType());
248 ShapedType outputType = llvm::cast<ShapedType>(output.getType());
249
250 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
251 return failure();
252 }
253
254 // If the output and input shapes are 1x1, then this is a no op.
255 ArrayRef<int64_t> outputShape = outputType.getShape();
256 if (outputShape[1] != 1 || outputShape[2] != 1) {
257 return failure();
258 }
259
260 ArrayRef<int64_t> inputShape = inputType.getShape();
261 if (inputShape[1] != 1 || inputShape[2] != 1) {
262 return failure();
263 }
264
265 rewriter.replaceOp(op, input);
266 return success();
267 }
268};
269
270void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
271 MLIRContext *context) {
272 results.add<MaxPool2dIsNoOp,
273 FoldPadToTensorOp<tosa::MaxPool2dOp,
274 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
275 context);
276}
277
278//===----------------------------------------------------------------------===//
279// Data Layout / Memory Reinterpretation.
280//===----------------------------------------------------------------------===//
281
282struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
283 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
284
285 LogicalResult matchAndRewrite(tosa::ConcatOp op,
286 PatternRewriter &rewriter) const override {
287 if (op.getInput1().size() != 1)
288 return failure();
289 if (op.getInput1().front().getType() != op.getType()) {
290 rewriter
291 .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
292 op.getInput1().front())
293 .getResult();
294 return success();
295 }
296
297 rewriter.replaceOp(op, op.getInput1().front());
298 return success();
299 }
300};
301
302void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
303 MLIRContext *context) {
304 results.add<ConcatOptimization>(context);
305}
306
307LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
308 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
309 if (!notOp)
310 return failure();
311 rewriter.modifyOpInPlace(op, [&]() {
312 op.getOperation()->setOperands(
313 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
314 });
315 return success();
316}
317
319 : public OpRewritePattern<tosa::TransposeOp> {
321
322 LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
323 PatternRewriter &rewriter) const override {
324 // Input is also TransposeOp - transpose(transpose(A)).
325 auto innerTranspose =
326 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
327 if (!innerTranspose)
328 return rewriter.notifyMatchFailure(transposeOp,
329 "input must be transpose operation");
330
331 const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
332 const llvm::ArrayRef<int32_t> innerTransposePerms =
333 innerTranspose.getPerms();
334
335 if (transposePerms.size() != innerTransposePerms.size())
336 return rewriter.notifyMatchFailure(
337 transposeOp,
338 "transpose and inner transpose perms sizes must be equal");
339 if (transposePerms.empty())
340 return rewriter.notifyMatchFailure(
341 transposeOp, "transpose perms sizes must be positive");
342
343 // Consolidate transposes into one transpose.
344 SmallVector<int32_t> perms(transposePerms.size());
345 for (int i = 0, s = transposePerms.size(); i < s; ++i)
346 perms[i] = innerTransposePerms[transposePerms[i]];
347
348 rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
349 transposeOp, transposeOp.getResult().getType(),
350 innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
351
352 return success();
353 }
354};
355
356// Determines the case when tosa.transpose is a tosa.reshape operation.
357struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
359
360 LogicalResult matchAndRewrite(tosa::TransposeOp op,
361 PatternRewriter &rewriter) const override {
362 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
363 return rewriter.notifyMatchFailure(
364 op, "Src is from transpose, can compose transposes");
365
366 Value result = op.getResult();
367 for (Operation *subop : result.getUsers()) {
368 if (isa_and_nonnull<tosa::TransposeOp>(subop))
369 return rewriter.notifyMatchFailure(
370 op, "Dest is used by transpose, can compose transposes");
371 }
372
373 auto input = op.getInput1();
374 auto inputTy = llvm::cast<ShapedType>(input.getType());
375 if (!inputTy.hasRank())
376 return rewriter.notifyMatchFailure(op, "Unranked input.");
377
378 int64_t numDynDims = 0;
379 for (int i = 0; i < inputTy.getRank(); ++i)
380 if (inputTy.isDynamicDim(i))
381 numDynDims++;
382
383 if (numDynDims > 1)
384 return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
385
386 const llvm::ArrayRef<int32_t> permValues = op.getPerms();
387
388 SmallVector<int64_t> nonZeroPerms;
389 nonZeroPerms.reserve(permValues.size());
390 for (auto idx : permValues) {
391 auto sz = inputTy.getDimSize(idx);
392 if (sz != 1)
393 nonZeroPerms.push_back(idx);
394 }
395
396 for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
397 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
398 return rewriter.notifyMatchFailure(op,
399 "Transpose changes memory layout.");
400
401 SmallVector<int64_t> newShape;
402 newShape.reserve(inputTy.getRank());
403 for (int i = 0, s = inputTy.getRank(); i < s; ++i)
404 newShape.push_back(inputTy.getDimSize(permValues[i]));
405
406 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
407 op, op.getType(), op.getInput1(),
408 getTosaConstShape(rewriter, op.getLoc(), newShape));
409 return success();
410 }
411};
412
413void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
414 MLIRContext *context) {
415 results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
416}
417
418struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
420
421 LogicalResult matchAndRewrite(tosa::ClampOp op,
422 PatternRewriter &rewriter) const override {
423 Value input = op.getInput();
424 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
425 auto inputElementType = inputType.getElementType();
426
427 if (isa<FloatType>(inputElementType)) {
428 // Unlike integer types, floating point types can represent infinity.
429 const auto minClamp =
430 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
431 const auto maxClamp =
432 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
433 const bool isMin = minClamp.isNegInfinity();
434 const bool isMax = maxClamp.isInfinity();
435
436 if (isMin && isMax) {
437 rewriter.replaceOp(op, input);
438 return success();
439 }
440 return failure();
441 }
442
443 // i1 types are boolean in TOSA
444 const bool isBoolean = inputElementType.isInteger(1);
445 if (inputElementType.isUnsignedInteger() || isBoolean) {
446 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
447 .getValue()
448 .getZExtValue();
449 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
450 .getValue()
451 .getZExtValue();
452
453 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
454 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
455 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
456
457 if (minClamp <= intMin && maxClamp >= intMax) {
458 rewriter.replaceOp(op, input);
459 return success();
460 }
461 return failure();
462 }
463
464 if (llvm::isa<IntegerType>(inputElementType)) {
465 const int64_t minClamp =
466 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
467 const int64_t maxClamp =
468 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
469
470 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
471 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
472 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
473
474 if (minClamp <= intMin && maxClamp >= intMax) {
475 rewriter.replaceOp(op, input);
476 return success();
477 }
478 return failure();
479 }
480
481 return failure();
482 }
483};
484
485// Attempts the following transformation:
486//
487// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
488// tensor X the following identity holds:
489//
490// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
491//
492// subject to the following valid NaN propagation semantics:
493// --------------------------------------------
494// | OUTER CLAMP | INNER CLAMP | RESULT MODE |
495// |-------------|--------------|-------------|
496// | PROPAGATE | PROPAGATE | PROPAGATE |
497// | PROPAGATE | IGNORE | IGNORE |
498// | IGNORE | PROPAGATE | INVALID |
499// | IGNORE | IGNORE | IGNORE |
500// |------------------------------------------|
501
502struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
503 using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
504
505 // Helper structure to describe the range of a clamp operation.
506 template <typename T>
507 struct ClampRange {
508 ClampRange(const T &start, const T &end) : start(start), end(end) {}
511
512 // Helper function to determine if two Clamp ranges intersect.
513 bool intersects(const ClampRange<T> &otherRange) {
514 return start < otherRange.end && otherRange.start < end;
515 }
516 };
517
518 LogicalResult matchAndRewrite(tosa::ClampOp op,
519 PatternRewriter &rewriter) const override {
520 Value input = op.getInput();
521
522 // Check the input to the CLAMP op is itself a CLAMP.
523 auto clampOp = input.getDefiningOp<tosa::ClampOp>();
524 if (!clampOp)
525 return failure();
526
527 // Check we have a valid NaN propagation combination.
528 const auto opNanMode = op.getNanMode();
529 const auto clampNanMode = clampOp.getNanMode();
530 if (opNanMode == NanPropagationMode::IGNORE &&
531 clampNanMode == NanPropagationMode::PROPAGATE)
532 return failure();
533
534 auto maxValAttr = op.getMaxValAttr();
535 auto minValAttr = op.getMinValAttr();
536 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
537 auto clampOpMinValAttr = clampOp.getMinValAttr();
538
539 auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
540 if (auto quantType =
541 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
542 inputEType = quantType.getStorageType();
543 }
544
545 Attribute newMinValAttr, newMaxValAttr;
546 if (mlir::isa<FloatType>(inputEType)) {
547 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
548 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
549 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
550 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
551
552 // Check we have intersecting ranges.
553 const auto opMinFloat = floatMinValAttr.getValue();
554 const auto opMaxFloat = floatMaxValAttr.getValue();
555 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
556 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
557 ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
558 ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
559 clampOpMaxFloat);
560 if (!opRangeFloatRange.intersects(clampRangeFloatRange))
561 return failure();
562
563 // Run the transformation.
564 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
565 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
566 newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
567 newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
568 } else {
569 assert(mlir::isa<IntegerType>(inputEType));
570 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
571 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
572 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
573 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
574
575 if (inputEType.isUnsignedInteger()) {
576 // Check we have intersecting ranges.
577 const auto opMinInt = intMinValAttr.getUInt();
578 const auto opMaxInt = intMaxValAttr.getUInt();
579 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
580 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
581 ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
582 ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
583 clampOpMaxInt);
584 if (!opRangeIntRange.intersects(clampRangeIntRange))
585 return failure();
586
587 // Run the transformation.
588 auto newMinVal = std::max(opMinInt, clampOpMinInt);
589 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
590 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
591 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
592 } else {
593 // Check we have intersecting ranges.
594 const auto opMinInt = intMinValAttr.getInt();
595 const auto opMaxInt = intMaxValAttr.getInt();
596 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
597 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
598 ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
599 ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
600 clampOpMaxInt);
601 if (!opRangeIntRange.intersects(clampRangeIntRange))
602 return failure();
603
604 // Run the transformation.
605 auto newMinVal = std::max(opMinInt, clampOpMinInt);
606 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
607 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
608 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
609 }
610 }
611
612 auto newMode = (opNanMode != clampNanMode)
613 ? tosa::NanPropagationMode::IGNORE
614 : opNanMode;
615
616 auto newModeAttr =
617 NanPropagationModeAttr::get(rewriter.getContext(), newMode);
618
619 rewriter.replaceOpWithNewOp<tosa::ClampOp>(
620 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
621 newModeAttr);
622 return success();
623 }
624};
625
626void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
627 MLIRContext *context) {
628 results.add<ClampIsNoOp>(context);
629 results.add<ClampClampOptimization>(context);
630}
631
632struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
633 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
634
635 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
636 PatternRewriter &rewriter) const override {
637 Value sliceInput = sliceOp.getInput1();
638 auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
639 if (!concatOp)
640 return rewriter.notifyMatchFailure(
641 sliceOp, "slice input must be concat operation");
642
643 OperandRange inputs = concatOp.getInput1();
644 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
645 if (!concatType || !concatType.hasStaticShape())
646 return rewriter.notifyMatchFailure(
647 sliceOp, "slice input must be a static ranked tensor");
648 int32_t axis = concatOp.getAxis();
649
650 DenseElementsAttr startElems;
651 DenseElementsAttr sizeElems;
652
653 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
654 return rewriter.notifyMatchFailure(
655 sliceOp, "start of slice must be a static ranked shape");
656
657 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
658 return rewriter.notifyMatchFailure(
659 sliceOp, "size of slice must be a static ranked shape");
660
661 llvm::SmallVector<int64_t> sliceStarts =
662 llvm::to_vector(startElems.getValues<int64_t>());
663 llvm::SmallVector<int64_t> sliceSizes =
664 llvm::to_vector(sizeElems.getValues<int64_t>());
665
666 // Validate slice on the concatenated axis. Slicing along this
667 // axis should span only one of the inputs to the concatenate
668 // operation.
669 std::optional<Value> replaceWithSlice;
670 for (auto input : inputs) {
671 auto inputType = dyn_cast<RankedTensorType>(input.getType());
672 if (!inputType || !inputType.hasStaticShape())
673 return rewriter.notifyMatchFailure(
674 sliceOp, "concat input must be a static ranked tensor");
675
676 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
677 inputType.getDimSize(axis)) {
678 auto start_op =
679 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
680 auto size_op =
681 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
682 replaceWithSlice =
683 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
684 input, start_op, size_op)
685 .getResult();
686 break;
687 }
688 sliceStarts[axis] -= inputType.getDimSize(axis);
689 }
690
691 if (!replaceWithSlice)
692 return rewriter.notifyMatchFailure(
693 sliceOp, "corresponding concat input not found for slice");
694
695 rewriter.replaceOp(sliceOp, replaceWithSlice.value());
696 return success();
697 }
698};
699
700struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
701 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
702
703 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
704 PatternRewriter &rewriter) const override {
705 Value sliceInput = sliceOp.getInput1();
706
707 // Check if producer is a PadOp
708 auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
709 if (!padOp)
710 return rewriter.notifyMatchFailure(sliceOp,
711 "slice input must be a pad operation");
712
713 // Check PadOp has a single consumer
714 if (!padOp->hasOneUse())
715 return rewriter.notifyMatchFailure(sliceOp,
716 "pad shall have a single consumer");
717
718 // Check input is statically ranked
719 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
720 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
721 if (!inputTy || !padTy || !inputTy.hasRank())
722 return rewriter.notifyMatchFailure(sliceOp,
723 "slice input must be a ranked tensor");
724
725 // Validate and extract tosa::PadOp padding
726 DenseIntElementsAttr paddingElems;
727 if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
728 return rewriter.notifyMatchFailure(
729 sliceOp,
730 "`padding` input specified on the tosa::PadOp must be constant.");
731 }
732 llvm::SmallVector<int64_t> padPaddings =
733 llvm::to_vector(paddingElems.getValues<int64_t>());
734
735 // Extract slice parameters
736 DenseElementsAttr startElems;
737 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
738 return rewriter.notifyMatchFailure(
739 sliceOp, "start of slice must be a static ranked shape");
740 llvm::SmallVector<int64_t> sliceStarts =
741 llvm::to_vector(startElems.getValues<int64_t>());
742
743 DenseElementsAttr sizeElems;
744 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
745 return rewriter.notifyMatchFailure(
746 sliceOp, "size of slice must be a static ranked shape");
747 llvm::SmallVector<int64_t> sliceSizes =
748 llvm::to_vector(sizeElems.getValues<int64_t>());
749
750 // Check if dynamic dimensions are sliced
751 const int64_t rank = inputTy.getRank();
752 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
753 const bool isDimDynamic = inputTy.isDynamicDim(i);
754 const bool isDimSliced =
755 (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
756
757 return isDimDynamic && isDimSliced;
758 })) {
759 return rewriter.notifyMatchFailure(
760 sliceOp, "axis that are sliced shall be statically known.");
761 }
762
763 // Update the parameters
764 llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
765 llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
766 llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
767 bool updated = false;
768
769 for (int64_t i = 0; i < rank; ++i) {
770 const int64_t padLo = padPaddings[i * 2];
771 const int64_t padHi = padPaddings[i * 2 + 1];
772 const int64_t sliceStart = sliceStarts[i];
773 const int64_t sliceSize = sliceSizes[i];
774 const int64_t sliceEnd = sliceStart + sliceSize;
775
776 // If dimension is dynamic pass-through
777 if (inputTy.isDynamicDim(i)) {
778 newPadPaddings[i * 2] = padLo;
779 newPadPaddings[i * 2 + 1] = padHi;
780 newSliceStarts[i] = sliceStart;
781 continue;
782 }
783
784 // Handle static dimensions
785 const int64_t dimSize = inputTy.getShape()[i];
786 const int64_t dimTotal = padLo + dimSize + padHi;
787
788 // Check slice within bounds
789 if (sliceStart < 0 || sliceEnd > dimTotal)
790 return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
791
792 // Compute updated slice start parameter
793 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
794 newSliceStarts[i] = newSliceStart;
795 updated |= newSliceStart != sliceStart;
796
797 // Compute updated pad parameters
798 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
799 const int64_t newPadHi =
800 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
801 newPadPaddings[i * 2] = newPadLo;
802 newPadPaddings[i * 2 + 1] = newPadHi;
803 updated |= (newPadLo != padLo) || (newPadHi != padHi);
804
805 // Calculate new pad output shape
806 newPadShape[i] =
807 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
808 }
809
810 // Check that we actually need to proceed with the rewrite
811 if (!updated)
812 return rewriter.notifyMatchFailure(
813 sliceOp, "terminate condition; nothing to rewrite");
814
815 // Create a PadOp with updated padding
816 auto newPaddingsOp =
817 getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
818 auto newPadTy =
819 RankedTensorType::get(newPadShape, inputTy.getElementType());
820 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
821 padOp.getInput1(), newPaddingsOp,
822 padOp.getPadConst());
823
824 // Update SliceOp and point to new PadOp
825 auto newStartOp =
826 getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
827 rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
828 newPadOp.getResult(), newStartOp,
829 sliceOp.getSize());
830
831 return success();
832 }
833};
834
835// Update size operand of tosa.slice if size has dynamic dims but corresponding
836// output dim is static
838 : public OpRewritePattern<tosa::SliceOp> {
839 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
840
841 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
842 PatternRewriter &rewriter) const override {
843 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
844
845 ElementsAttr sizeElems;
846 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
847 return rewriter.notifyMatchFailure(
848 sliceOp, "size of slice must be a static ranked shape");
849 }
850
851 llvm::SmallVector<int64_t> sliceSizes =
852 llvm::to_vector(sizeElems.getValues<int64_t>());
853
854 bool replaceSliceSize{false};
855 // if size op has -1 indicating dynamic shape but corresponding dim on the
856 // output is statically known, update size to match with known output dim
857 // shape
858 for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
859 if (size == -1 && !resultType.isDynamicDim(index)) {
860 sliceSizes[index] = resultType.getDimSize(index);
861 replaceSliceSize = true;
862 }
863 }
864
865 if (!replaceSliceSize) {
866 return rewriter.notifyMatchFailure(
867 sliceOp, "no dimension of size of slice is dynamic that resolves "
868 "to static output shape");
869 }
870
871 auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
872 auto newSliceOp =
873 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
874 sliceOp.getInput1(), sliceOp.getStart(), size_op);
875
876 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
877 return success();
878 }
879};
880
881void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
882 MLIRContext *context) {
883 results.add<ConcatSliceOptimization, PadSliceOptimization,
884 SliceDynamicSizeCanonicalization>(context);
885}
886
887//===----------------------------------------------------------------------===//
888// Operator Folders.
889//===----------------------------------------------------------------------===//
890
891template <typename IntFolder, typename FloatFolder>
894 RankedTensorType returnTy) {
895 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
896 auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
897 auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
898 if (lETy != rETy)
899 return {};
900
901 if (llvm::isa<IntegerType>(lETy)) {
902 APInt l = lhs.getSplatValue<APInt>();
903 APInt r = rhs.getSplatValue<APInt>();
904 auto result = IntFolder()(l, r);
905 return DenseElementsAttr::get(returnTy, result);
906 }
907
908 if (llvm::isa<FloatType>(lETy)) {
909 APFloat l = lhs.getSplatValue<APFloat>();
910 APFloat r = rhs.getSplatValue<APFloat>();
911 auto result = FloatFolder()(l, r);
912 return DenseElementsAttr::get(returnTy, result);
913 }
914 }
915
916 return {};
917}
918
919static bool isSplatZero(Type elemType, DenseElementsAttr val) {
920 if (llvm::isa<FloatType>(elemType))
921 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
922 if (llvm::isa<IntegerType>(elemType))
923 return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
924 return false;
925}
926
927static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
928 if (llvm::isa<FloatType>(elemType))
929 return val && val.isSplat() &&
930 val.getSplatValue<APFloat>().isExactlyValue(1.0);
931 if (llvm::isa<IntegerType>(elemType)) {
932 const int64_t shifted = 1LL << shift;
933 return val && val.isSplat() &&
934 val.getSplatValue<APInt>().getSExtValue() == shifted;
935 }
936 return false;
937}
938
939OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
940 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
941 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
942 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
943 if (!lhsTy || !rhsTy || !resultTy)
944 return {};
945
946 // Cannot create an ElementsAttr from non-int/float/index types
947 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
948 !rhsTy.getElementType().isIntOrIndexOrFloat())
949 return {};
950
951 auto resultETy = resultTy.getElementType();
952 auto lhsAttr =
953 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
954 auto rhsAttr =
955 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
956
957 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
958 return getInput1();
959 if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
960 return getInput2();
961
962 if (!lhsAttr || !rhsAttr)
963 return {};
964
965 return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
966 resultTy);
967}
968
969OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
970 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
971 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
972 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
973 !outputTy.hasStaticShape())
974 return {};
975
976 const Type outputElementTy = getElementTypeOrSelf(outputTy);
977 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
978 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
979 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
980 return DenseElementsAttr::get(outputTy, zero);
981 }
982
983 return {};
984}
985
986OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
987 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
988 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
989 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
990 if (!lhsTy || !rhsTy || !resultTy)
991 return {};
992 if (lhsTy != rhsTy)
993 return {};
994
995 // IntDivOp inputs must be integer type, no need to check for quantized type
996 auto resultETy = resultTy.getElementType();
997 auto lhsAttr =
998 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
999 auto rhsAttr =
1000 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1001 if (lhsAttr && lhsAttr.isSplat()) {
1002 if (llvm::isa<IntegerType>(resultETy) &&
1003 lhsAttr.getSplatValue<APInt>().isZero())
1004 return lhsAttr;
1005 }
1006
1007 if (rhsAttr && rhsAttr.isSplat()) {
1008 if (llvm::isa<IntegerType>(resultETy) &&
1009 rhsAttr.getSplatValue<APInt>().isOne())
1010 return getInput1();
1011 }
1012
1013 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1014 llvm::isa<IntegerType>(resultETy)) {
1015 APInt l = lhsAttr.getSplatValue<APInt>();
1016 APInt r = rhsAttr.getSplatValue<APInt>();
1017 if (!r.isZero()) {
1018 APInt result = l.sdiv(r);
1019 return DenseElementsAttr::get(resultTy, result);
1020 }
1021 }
1022
1023 return {};
1024}
1025
1026namespace {
1027// calculate lhs * rhs >> shift according to TOSA Spec
1028// return nullopt if result is not in range of int32_t when shift > 0
1029std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
1030 unsigned bitwidth) {
1031 APInt result = lhs.sext(64) * rhs.sext(64);
1032
1033 if (shift > 0) {
1034 auto round = APInt(64, 1) << (shift - 1);
1035 result += round;
1036 result.ashrInPlace(shift);
1037 // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
1038 if (!(result.getSExtValue() >= INT32_MIN &&
1039 result.getSExtValue() <= INT32_MAX)) {
1040 // REQUIRE failed
1041 return std::nullopt;
1042 }
1043 }
1044
1045 return result.trunc(bitwidth);
1046}
1047
1048DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
1049 RankedTensorType ty, int32_t shift) {
1050 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
1051 if (llvm::isa<IntegerType>(ty.getElementType())) {
1052 APInt l = lhs.getSplatValue<APInt>();
1053 APInt r = rhs.getSplatValue<APInt>();
1054
1055 if (shift == 0) {
1056 return DenseElementsAttr::get(ty, l * r);
1057 }
1058
1059 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1060 const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
1061 if (!result)
1062 return {};
1063 return DenseElementsAttr::get(ty, result.value());
1064 }
1065
1066 if (llvm::isa<FloatType>(ty.getElementType())) {
1067 APFloat l = lhs.getSplatValue<APFloat>();
1068 APFloat r = rhs.getSplatValue<APFloat>();
1069 APFloat result = l * r;
1070 return DenseElementsAttr::get(ty, result);
1071 }
1072 }
1073
1074 return {};
1075}
1076} // namespace
1077
1078OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1079 auto lhs = getInput1();
1080 auto rhs = getInput2();
1081 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
1082 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
1083 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1084 if (!lhsTy || !rhsTy || !resultTy)
1085 return {};
1086
1087 auto resultETy = resultTy.getElementType();
1088 auto lhsAttr =
1089 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1090 auto rhsAttr =
1091 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1092
1093 // Result right shift on i32_t data type only. For simplification, synthesize
1094 // a zero shift for other data type.
1095 int32_t shift = 0;
1096 if (resultETy.isInteger(32)) {
1097 ElementsAttr shift_elem;
1098 if (getShift().getImpl()) {
1099 if (!matchPattern(getShift(), m_Constant(&shift_elem)))
1100 // cannot be folded when the shift value is unknown.
1101 return {};
1102 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1103 }
1104 }
1105
1106 if (rhsTy == resultTy) {
1107 if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
1108 // constant values can only be resized if resulting type is static
1109 return lhsAttr.resizeSplat(resultTy);
1110 if (isSplatOne(resultETy, lhsAttr, shift))
1111 return rhs;
1112 }
1113 if (lhsTy == resultTy) {
1114 if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
1115 return rhsAttr.resizeSplat(resultTy);
1116 if (isSplatOne(resultETy, rhsAttr, shift))
1117 return lhs;
1118 }
1119
1120 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1121}
1122
1123OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1124 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1125 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1126 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1127 if (!lhsTy || !rhsTy || !resultTy)
1128 return {};
1129
1130 // Cannot create an ElementsAttr from non-int/float/index types
1131 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1132 !rhsTy.getElementType().isIntOrIndexOrFloat())
1133 return {};
1134
1135 auto resultETy = resultTy.getElementType();
1136 auto lhsAttr =
1137 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1138 auto rhsAttr =
1139 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1140
1141 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1142 return getInput1();
1143
1144 if (!lhsAttr || !rhsAttr)
1145 return {};
1146
1147 return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
1148 resultTy);
1149}
1150
1151namespace {
1152template <typename Cmp>
1153struct ComparisonFold {
1154 ComparisonFold() = default;
1155 APInt operator()(const APInt &l, const APInt &r) {
1156 return APInt(1, Cmp()(l, r));
1157 }
1158
1159 APInt operator()(const APFloat &l, const APFloat &r) {
1160 return APInt(1, Cmp()(l, r));
1161 }
1162};
1163
1164struct APIntFoldGreater {
1165 APIntFoldGreater() = default;
1166 APInt operator()(const APInt &l, const APInt &r) {
1167 return APInt(1, l.sgt(r));
1168 }
1169};
1170
1171struct APIntFoldGreaterEqual {
1172 APIntFoldGreaterEqual() = default;
1173 APInt operator()(const APInt &l, const APInt &r) {
1174 return APInt(1, l.sge(r));
1175 }
1176};
1177} // namespace
1178
1179OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1180 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1181 auto lhsAttr =
1182 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1183 auto rhsAttr =
1184 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1185
1186 if (!lhsAttr || !rhsAttr)
1187 return {};
1188
1190 lhsAttr, rhsAttr, resultTy);
1191}
1192
1193OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1194 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1195 auto lhsAttr =
1196 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1197 auto rhsAttr =
1198 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1199
1200 if (!lhsAttr || !rhsAttr)
1201 return {};
1202
1203 return binaryFolder<APIntFoldGreaterEqual,
1204 ComparisonFold<std::greater_equal<APFloat>>>(
1205 lhsAttr, rhsAttr, resultTy);
1206}
1207
1208OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1209 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1210 auto lhsAttr =
1211 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1212 auto rhsAttr =
1213 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1214 Value lhs = getInput1();
1215 Value rhs = getInput2();
1216 auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
1217
1218 // If we are comparing an integer value to itself it is always true. We can
1219 // not do this with float due to float values.
1220 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1221 resultTy.hasStaticShape() && lhs == rhs) {
1222 return DenseElementsAttr::get(resultTy, true);
1223 }
1224
1225 if (!lhsAttr || !rhsAttr)
1226 return {};
1227
1229 ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
1230 resultTy);
1231}
1232
1233OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1234 if (getInput().getType() == getType())
1235 return getInput();
1236
1237 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1238 if (!operand)
1239 return {};
1240
1241 auto inTy = llvm::cast<ShapedType>(getInput().getType());
1242 auto outTy = llvm::cast<ShapedType>(getType());
1243 auto inETy = inTy.getElementType();
1244 auto outETy = outTy.getElementType();
1245
1246 if (operand.isSplat()) {
1247 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1248 bool overflow;
1249 auto splatVal = operand.getSplatValue<APFloat>();
1250 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1251 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1252 &overflow);
1253 return SplatElementsAttr::get(outTy, splatVal);
1254 }
1255
1256 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1257 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1258 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1259 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1260 llvm::RoundingMode::NearestTiesToEven);
1261 return SplatElementsAttr::get(outTy, splatVal);
1262 }
1263
1264 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1265 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1266 auto intVal = APSInt(
1267 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1268 auto floatVal = operand.getSplatValue<APFloat>();
1269 bool exact;
1270 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1271 &exact);
1272 return SplatElementsAttr::get(outTy, intVal);
1273 }
1274
1275 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1276 const auto inIntType = llvm::cast<IntegerType>(inETy);
1277 auto unsignIn = inIntType.isUnsignedInteger();
1278 bool trunc =
1279 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1280 auto intVal = operand.getSplatValue<APInt>();
1281 auto bitwidth = outETy.getIntOrFloatBitWidth();
1282
1283 // i1 types are boolean in TOSA
1284 if (outETy.isInteger(1)) {
1285 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1286 } else if (trunc) {
1287 intVal = intVal.trunc(bitwidth);
1288 } else if (unsignIn || inIntType.isInteger(1)) {
1289 intVal = intVal.zext(bitwidth);
1290 } else {
1291 intVal = intVal.sext(bitwidth);
1292 }
1293
1294 return SplatElementsAttr::get(outTy, intVal);
1295 }
1296 }
1297
1298 return {};
1299}
1300
1301OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1302
1303OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1304
1305#define REDUCE_FOLDER(OP) \
1306 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1307 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1308 if (!inputTy.hasRank()) \
1309 return {}; \
1310 if (inputTy != getType()) \
1311 return {}; \
1312 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1313 return getInput(); \
1314 return {}; \
1315 }
1316
1317REDUCE_FOLDER(ReduceAllOp)
1318REDUCE_FOLDER(ReduceAnyOp)
1319REDUCE_FOLDER(ReduceMaxOp)
1320REDUCE_FOLDER(ReduceMinOp)
1321REDUCE_FOLDER(ReduceProductOp)
1322REDUCE_FOLDER(ReduceSumOp)
1323#undef REDUCE_FOLDER
1324
1325OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1326 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1327 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1328
1329 if (!inputTy || !outputTy)
1330 return {};
1331
1332 // Fold when the input and output types are the same. This is only safe when
1333 // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
1334 // there may still be a productive reshape.
1335 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1336 return getInput1();
1337
1338 // reshape(reshape(x)) -> reshape(x)
1339 if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1340 getInput1().getDefiningOp())) {
1341 getInput1Mutable().assign(reshapeOp.getInput1());
1342 return getResult();
1343 }
1344
1345 // Cannot create an ElementsAttr from non-int/float/index types
1346 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1347 return {};
1348
1349 // reshape(const(x)) -> const(reshape-attr(x))
1350 if (auto operand =
1351 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1352 // Constants must have static shape.
1353 if (!outputTy.hasStaticShape())
1354 return {};
1355
1356 // Okay to duplicate splat constants.
1357 if (operand.isSplat())
1358 return SplatElementsAttr::get(outputTy,
1359 operand.getSplatValue<Attribute>());
1360
1361 // Don't duplicate other constants.
1362 if (!getInput1().hasOneUse())
1363 return {};
1364
1366 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
1367 return {};
1368
1369 return operand.reshape(
1370 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1371 }
1372
1373 return {};
1374}
1375
1376OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1377 // If the pad is all zeros we can fold this operation away.
1378 if (adaptor.getPadding() && getInput1().getType() == getType()) {
1379 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1380 if (densePad && densePad.isSplat() &&
1381 densePad.getSplatValue<APInt>().isZero()) {
1382 return getInput1();
1383 }
1384 }
1385
1386 return {};
1387}
1388
1389// Fold away cases where a tosa.resize operation returns a copy
1390// of the input image.
1391OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1392 auto scaleAttr =
1393 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1394 auto offsetAttr =
1395 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1396 auto borderAttr =
1397 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1398 if (!scaleAttr || !offsetAttr || !borderAttr) {
1399 return {};
1400 }
1401
1402 auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
1403 auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
1404 auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
1405 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1406 return {};
1407 }
1408
1409 // Check unit scaling.
1410 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1411 return {};
1412 }
1413
1414 // There should be no offset.
1415 if (offset[0] != 0 || offset[1] != 0) {
1416 return {};
1417 }
1418
1419 // There should be no border.
1420 if (border[0] != 0 || border[1] != 0) {
1421 return {};
1422 }
1423
1424 auto input = getInput();
1425 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1426 auto resultTy = llvm::cast<RankedTensorType>(getType());
1427 if (inputTy != resultTy)
1428 return {};
1429
1430 return input;
1431}
1432
1433OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1434 auto operand = getInput1();
1435 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1436 auto axis = getAxis();
1437 auto operandAttr =
1438 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1439 if (operandAttr)
1440 return operandAttr;
1441
1442 // If the dim-length is 1, tosa.reverse is a no-op.
1443 if (operandTy.hasRank() &&
1444 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1445 return operand;
1446
1447 return {};
1448}
1449
1450OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1451 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1452 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1453
1454 if (!inputTy || !outputTy)
1455 return {};
1456
1457 if (inputTy == outputTy && inputTy.hasStaticShape())
1458 return getInput1();
1459
1460 if (!adaptor.getInput1())
1461 return {};
1462
1463 // Cannot create an ElementsAttr from non-int/float/index types
1464 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1465 !outputTy.getElementType().isIntOrIndexOrFloat())
1466 return {};
1467
1468 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1469 if (operand.isSplat() && outputTy.hasStaticShape()) {
1470 return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
1471 }
1472
1473 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1474 outputTy.getNumElements() == 1) {
1475 DenseElementsAttr startElems;
1476 if (!matchPattern(getStart(), m_Constant(&startElems)))
1477 return {};
1478
1479 llvm::SmallVector<uint64_t> indices =
1480 llvm::to_vector(startElems.getValues<uint64_t>());
1481 auto value = operand.getValues<Attribute>()[indices];
1482 return SplatElementsAttr::get(outputTy, value);
1483 }
1484
1485 return {};
1486}
1487
1488OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1489 if (getOnTrue() == getOnFalse())
1490 return getOnTrue();
1491
1492 auto predicate =
1493 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1494 if (!predicate)
1495 return {};
1496
1497 if (!predicate.isSplat())
1498 return {};
1499 return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1500 : getOnFalse();
1501}
1502
1503OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1504 if (getInput1().getType() == getType()) {
1505 if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1506 adaptor.getMultiples())) {
1507 if (multiples.isSplat() &&
1508 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1509 return getInput1();
1510 if (auto int_array_attr =
1511 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1512 if (llvm::all_of(int_array_attr.getValues<APInt>(),
1513 [](APInt v) { return v.getSExtValue() == 1; }))
1514 return getInput1();
1515 }
1516 }
1517 }
1518 return {};
1519}
1520
1521OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1522 auto resultTy = llvm::cast<ShapedType>(getType());
1523
1524 // Transposing splat values just means reshaping.
1525 if (auto input =
1526 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1527 if (input.isSplat() && resultTy.hasStaticShape() &&
1528 input.getType().getElementType() == resultTy.getElementType())
1529 return input.reshape(resultTy);
1530 }
1531
1532 // Transpose is not the identity transpose.
1533 const llvm::ArrayRef<int32_t> perms = getPerms();
1534
1535 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1536 return {};
1537
1538 return getInput1();
1539}
1540
1541OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1542 // Element-wise negate(negate(x)) = x
1543 // iff all zero points are constant 0
1544 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1545 if (!definingOp) {
1546 // defining op of input1 is not a negate, cannot fold
1547 return {};
1548 }
1549
1550 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1551 failed(maybeIZp) || *maybeIZp != 0) {
1552 // input1 zero point is not constant 0, cannot fold
1553 return {};
1554 }
1555 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1556 failed(maybeOZp) || *maybeOZp != 0) {
1557 // output zero point is not constant 0, cannot fold
1558 return {};
1559 }
1560 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1561 failed(maybeIZp) || *maybeIZp != 0) {
1562 // definingOp's input1 zero point is not constant 0, cannot fold
1563 return {};
1564 }
1565 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1566 failed(maybeOZp) || *maybeOZp != 0) {
1567 // definingOp's output zero point is not constant 0, cannot fold
1568 return {};
1569 }
1570
1571 return definingOp.getInput1();
1572}
1573
1574OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1575 auto input = getInput1();
1576 // Element-wise abs(abs(x)) = abs(x)
1577 if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
1578 return input;
1579 }
1580
1581 return {};
1582}
1583
1584OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1585 // Fold consecutive concats on the same axis into a single op.
1586 // Keep track of the operands so we are able to construct a new concat
1587 // later. Conservatively assume that we double the number of operands when
1588 // folding
1589 SmallVector<Value, 8> concatOperands;
1590 concatOperands.reserve(2 * getNumOperands());
1591
1592 // Find all operands that are foldable concats
1593 bool foundFoldableConcat = false;
1594 for (Value operand : getOperands()) {
1595 concatOperands.emplace_back(operand);
1596
1597 auto producer = operand.getDefiningOp<ConcatOp>();
1598 if (!producer)
1599 continue;
1600
1601 // Not foldable if axes are not the same
1602 if (getAxis() != producer.getAxis())
1603 continue;
1604
1605 // Replace the original operand with all incoming operands
1606 foundFoldableConcat = true;
1607 concatOperands.pop_back();
1608 llvm::append_range(concatOperands, producer->getOperands());
1609 }
1610
1611 if (!foundFoldableConcat)
1612 return {};
1613
1614 getOperation()->setOperands(concatOperands);
1615 return getResult();
1616}
1617
1618OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1619 auto input = adaptor.getInput1();
1620
1621 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1622 // Fold splat inputs only.
1623 if (!inputAttr || !inputAttr.isSplat())
1624 return {};
1625
1626 auto shapeType = llvm::cast<ShapedType>(getType());
1627 if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1628 auto floatVal = inputAttr.getSplatValue<APFloat>();
1629 return DenseElementsAttr::get(shapeType,
1630 ReciprocalOp::calcOneElement(floatVal));
1631 }
1632
1633 return {};
1634}
return success()
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
lhs
#define REDUCE_FOLDER(OP)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
MLIRContext * getContext() const
Definition Builders.h:56
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
DynamicAPInt round(const Fraction &f)
Definition Fraction.h:136
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...