MLIR 23.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// TOSA canonicalization patterns and folders.
11//
12//===----------------------------------------------------------------------===//
13
19#include "mlir/Dialect/Traits.h"
22#include "mlir/IR/Matchers.h"
26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28
29#include <functional>
30
31using namespace mlir;
32using namespace mlir::tosa;
33
34namespace {
35OpFoldResult foldToInputIfTypeMatches(Type typeRef, Value input) {
36 return input.getType() == typeRef ? OpFoldResult(input) : OpFoldResult{};
37}
38} // namespace
39
40//===----------------------------------------------------------------------===//
41// Operator Canonicalizers.
42//===----------------------------------------------------------------------===//
43
44//===----------------------------------------------------------------------===//
45// Tensor Data Engine Operators.
46//===----------------------------------------------------------------------===//
47
48// Check that the zero point of the tensor and padding operations are aligned.
49static bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
50 // Check that padConst is a constant value and a scalar tensor
51 DenseElementsAttr padConstAttr;
52 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
53 (padConstAttr.size() != 1)) {
54 return false;
55 }
56
57 // Check that floating point pad is zero
58 if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
59 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
60 return padConstVal == 0.0f;
61 }
62
63 // Check that the zp and padConst align for the integer (quantized) case
64 if (auto padConstIntAttr =
65 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
67 // Check that zp is a constant value and a scalar tensor
68 if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
69 return false;
70 }
71
72 // Check equality
73 int64_t zpVal = (*zpAttr.begin()).getSExtValue();
74 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
75 return zpVal == padConstVal;
76 }
77
78 // Bail-out on unsupported type
79 return false;
80}
81
82namespace {
83template <typename OpTy>
84struct PoolPadFoldAdaptor;
85
86template <>
87struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
88 using OpTy = tosa::MaxPool2dOp;
89 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
90 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
91 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
92 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
93 return false;
94 return true;
95 }
96 static bool checkPadConstCompliance(OpTy, Value padConst) {
97 // Check that padConst is a constant value and a scalar tensor
98 DenseElementsAttr padConstAttr;
99 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
100 padConstAttr.size() != 1) {
101 return false;
102 }
103
104 // Pad needs to be in the minimum value to be able to merge
105 if (auto padConstFpAttr =
106 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
107 const APFloat padConstVal = *padConstFpAttr.begin();
108 const APFloat lowestVal =
109 APFloat::getLargest(padConstVal.getSemantics(), true);
110 return padConstVal == lowestVal;
111 }
112 if (auto padConstIntAttr =
113 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
114 const APInt padConstVal = *padConstIntAttr.begin();
115 const unsigned int bitWidth = padConstVal.getBitWidth();
116 const APInt lowestVal =
117 padConstIntAttr.getElementType().isUnsignedInteger()
118 ? APInt::getZero(bitWidth)
119 : APInt::getSignedMinValue(bitWidth);
120 return padConstVal == lowestVal;
121 }
122
123 // Bail-out on unsupported type
124 return false;
125 }
126 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
127 Value padInput, ArrayRef<int64_t> newPad) {
128 rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
129 op, op.getType(), padInput, op.getKernel(), op.getStride(),
130 rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
131 }
132};
133
134template <typename OpTy>
135struct ConvPadFoldAdaptor {
136 static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
137 return true;
138 }
139 static bool checkPadConstCompliance(OpTy op, Value padConst) {
140 return checkMatchingPadConstAndZp(padConst, op.getInputZp());
141 }
142 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
143 Value padInput, ArrayRef<int64_t> newPad) {
144 rewriter.replaceOpWithNewOp<OpTy>(
145 op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
146 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
147 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
148 }
149};
150
151// Pattern attempts to fold a `tosa.pad` operator to a following tensor
152// operation like `tosa.conv2d` by merging the padding associated with the
153// pad operator directly to the implicit padding of the tensor operation.
154// This helps eliminate the explicit padding operator if unused.
155template <typename OpTy, typename AdaptorTy>
156struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
157 using OpRewritePattern<OpTy>::OpRewritePattern;
158
159 LogicalResult matchAndRewrite(OpTy tensorOp,
160 PatternRewriter &rewriter) const override {
161 // Check producer is a tosa::PadOp
162 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
163 if (!padOp)
164 return rewriter.notifyMatchFailure(tensorOp,
165 "Producer must be a tosa::PadOp.");
166
167 // Validate that tensor operation has sane padding
168 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
169 if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
170 return rewriter.notifyMatchFailure(
171 tensorOp, "Tensor operation padding shall have 4 elements.");
172
173 // Validate tosa::PadOp padding
174 DenseIntElementsAttr padOpPadding;
175 if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
176 return rewriter.notifyMatchFailure(
177 tensorOp,
178 "The `padding` input specified on the tosa::PadOp must be constant.");
179 }
180 // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
181 // C_after
182 if (padOpPadding.size() != 8)
183 return rewriter.notifyMatchFailure(tensorOp,
184 "Pad padding should have 8 elements.");
185 int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
186 int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
187 int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
188 int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
189 int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
190 int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
191 int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
192 int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
193
194 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
195 return rewriter.notifyMatchFailure(
196 tensorOp, "Folding padding in N or C dimensions is not supported.");
197
198 // Fold padding from Pad into the tensor operation
199 // 4 elements - pad_top, pad_bottom, pad_left, pad_right
200 SmallVector<int64_t> foldedPad(tensorOpPad.size());
201 foldedPad[0] = padHBefore + tensorOpPad[0];
202 foldedPad[1] = padHAfter + tensorOpPad[1];
203 foldedPad[2] = padWBefore + tensorOpPad[2];
204 foldedPad[3] = padWAfter + tensorOpPad[3];
205
206 // Check kernel related restrictions
207 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
208 return rewriter.notifyMatchFailure(
209 tensorOp, "Padding size not aligned with kernel restrictions.");
210 }
211
212 // Check padding constant restrictions
213 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
214 return rewriter.notifyMatchFailure(
215 tensorOp,
216 "Padding constant is not aligned with operator zero-point.");
217 }
218
219 // Check that padding doesn't grow more than 8K level (8192) for now
220 if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
221 return rewriter.notifyMatchFailure(
222 tensorOp, "Padding size more than the 8K level limit.");
223 }
224
225 // Create operator
226 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
227 foldedPad);
228
229 return success();
230 }
231};
232} // namespace
233
234void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
235 MLIRContext *context) {
236 results.add<
237 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
238 context);
239}
240
241void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
242 MLIRContext *context) {
243 results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
244 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
245 context);
246}
247
248struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
250
251 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
252 PatternRewriter &rewriter) const override {
253 Value input = op.getInput();
254 Value output = op.getOutput();
255 ShapedType inputType = llvm::cast<ShapedType>(input.getType());
256 ShapedType outputType = llvm::cast<ShapedType>(output.getType());
257
258 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
259 return failure();
260 }
261
262 // If the output and input shapes are 1x1, then this is a no op.
263 ArrayRef<int64_t> outputShape = outputType.getShape();
264 if (outputShape[1] != 1 || outputShape[2] != 1) {
265 return failure();
266 }
267
268 ArrayRef<int64_t> inputShape = inputType.getShape();
269 if (inputShape[1] != 1 || inputShape[2] != 1) {
270 return failure();
271 }
272
273 rewriter.replaceOp(op, input);
274 return success();
275 }
276};
277
278void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
279 MLIRContext *context) {
280 results.add<MaxPool2dIsNoOp,
281 FoldPadToTensorOp<tosa::MaxPool2dOp,
282 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
283 context);
284}
285
286//===----------------------------------------------------------------------===//
287// Data Layout / Memory Reinterpretation.
288//===----------------------------------------------------------------------===//
289
290struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
291 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
292
293 LogicalResult matchAndRewrite(tosa::ConcatOp op,
294 PatternRewriter &rewriter) const override {
295 if (op.getInput1().size() != 1)
296 return failure();
297 if (op.getInput1().front().getType() != op.getType()) {
298 rewriter
299 .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
300 op.getInput1().front())
301 .getResult();
302 return success();
303 }
304
305 rewriter.replaceOp(op, op.getInput1().front());
306 return success();
307 }
308};
309
310void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
311 MLIRContext *context) {
312 results.add<ConcatOptimization>(context);
313}
314
315LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
316 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
317 if (!notOp)
318 return failure();
319 rewriter.modifyOpInPlace(op, [&]() {
320 op.getOperation()->setOperands(
321 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
322 });
323 return success();
324}
325
327 : public OpRewritePattern<tosa::TransposeOp> {
329
330 LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
331 PatternRewriter &rewriter) const override {
332 // Input is also TransposeOp - transpose(transpose(A)).
333 auto innerTranspose =
334 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
335 if (!innerTranspose)
336 return rewriter.notifyMatchFailure(transposeOp,
337 "input must be transpose operation");
338
339 const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
340 const llvm::ArrayRef<int32_t> innerTransposePerms =
341 innerTranspose.getPerms();
342
343 if (transposePerms.size() != innerTransposePerms.size())
344 return rewriter.notifyMatchFailure(
345 transposeOp,
346 "transpose and inner transpose perms sizes must be equal");
347 if (transposePerms.empty())
348 return rewriter.notifyMatchFailure(
349 transposeOp, "transpose perms sizes must be positive");
350
351 // Consolidate transposes into one transpose.
352 SmallVector<int32_t> perms(transposePerms.size());
353 for (int i = 0, s = transposePerms.size(); i < s; ++i)
354 perms[i] = innerTransposePerms[transposePerms[i]];
355
356 rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
357 transposeOp, transposeOp.getResult().getType(),
358 innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
359
360 return success();
361 }
362};
363
364// Determines the case when tosa.transpose is a tosa.reshape operation.
365struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
367
368 LogicalResult matchAndRewrite(tosa::TransposeOp op,
369 PatternRewriter &rewriter) const override {
370 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
371 return rewriter.notifyMatchFailure(
372 op, "Src is from transpose, can compose transposes");
373
374 Value result = op.getResult();
375 for (Operation *subop : result.getUsers()) {
376 if (isa_and_nonnull<tosa::TransposeOp>(subop))
377 return rewriter.notifyMatchFailure(
378 op, "Dest is used by transpose, can compose transposes");
379 }
380
381 auto input = op.getInput1();
382 auto inputTy = llvm::cast<ShapedType>(input.getType());
383 if (!inputTy.hasRank())
384 return rewriter.notifyMatchFailure(op, "Unranked input.");
385
386 int64_t numDynDims = 0;
387 for (int i = 0; i < inputTy.getRank(); ++i)
388 if (inputTy.isDynamicDim(i))
389 numDynDims++;
390
391 if (numDynDims > 1)
392 return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
393
394 const llvm::ArrayRef<int32_t> permValues = op.getPerms();
395
396 SmallVector<int64_t> nonZeroPerms;
397 nonZeroPerms.reserve(permValues.size());
398 for (auto idx : permValues) {
399 auto sz = inputTy.getDimSize(idx);
400 if (sz != 1)
401 nonZeroPerms.push_back(idx);
402 }
403
404 for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
405 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
406 return rewriter.notifyMatchFailure(op,
407 "Transpose changes memory layout.");
408
409 SmallVector<int64_t> newShape;
410 newShape.reserve(inputTy.getRank());
411 for (int i = 0, s = inputTy.getRank(); i < s; ++i)
412 newShape.push_back(inputTy.getDimSize(permValues[i]));
413
414 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
415 op, op.getType(), op.getInput1(),
416 getTosaConstShape(rewriter, op.getLoc(), newShape));
417 return success();
418 }
419};
420
421void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
422 MLIRContext *context) {
423 results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
424}
425
426struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
428
429 LogicalResult matchAndRewrite(tosa::ClampOp op,
430 PatternRewriter &rewriter) const override {
431 Value input = op.getInput();
432 auto inputType = llvm::cast<ShapedType>(op.getInput().getType());
433 auto inputElementType = inputType.getElementType();
434
435 if (isa<FloatType>(inputElementType)) {
436 // Unlike integer types, floating point types can represent infinity.
437 const auto minClamp =
438 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
439 const auto maxClamp =
440 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
441 const bool isMin = minClamp.isNegInfinity();
442 const bool isMax = maxClamp.isInfinity();
443
444 if (isMin && isMax) {
445 rewriter.replaceOp(op, input);
446 return success();
447 }
448 return failure();
449 }
450
451 // i1 types are boolean in TOSA
452 const bool isBoolean = inputElementType.isInteger(1);
453 if (inputElementType.isUnsignedInteger() || isBoolean) {
454 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
455 .getValue()
456 .getZExtValue();
457 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
458 .getValue()
459 .getZExtValue();
460
461 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
462 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
463 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
464
465 if (minClamp <= intMin && maxClamp >= intMax) {
466 rewriter.replaceOp(op, input);
467 return success();
468 }
469 return failure();
470 }
471
472 if (llvm::isa<IntegerType>(inputElementType)) {
473 const int64_t minClamp =
474 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
475 const int64_t maxClamp =
476 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
477
478 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
479 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
480 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
481
482 if (minClamp <= intMin && maxClamp >= intMax) {
483 rewriter.replaceOp(op, input);
484 return success();
485 }
486 return failure();
487 }
488
489 return failure();
490 }
491};
492
493// Attempts the following transformation:
494//
495// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
496// tensor X the following identity holds:
497//
498// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
499//
500// subject to the following valid NaN propagation semantics:
501// --------------------------------------------
502// | OUTER CLAMP | INNER CLAMP | RESULT MODE |
503// |-------------|--------------|-------------|
504// | PROPAGATE | PROPAGATE | PROPAGATE |
505// | PROPAGATE | IGNORE | IGNORE |
506// | IGNORE | PROPAGATE | INVALID |
507// | IGNORE | IGNORE | IGNORE |
508// |------------------------------------------|
509
510struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
511 using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
512
513 // Helper structure to describe the range of a clamp operation.
514 template <typename T>
515 struct ClampRange {
516 ClampRange(const T &start, const T &end) : start(start), end(end) {}
519
520 // Helper function to determine if two Clamp ranges intersect.
521 bool intersects(const ClampRange<T> &otherRange) {
522 return start < otherRange.end && otherRange.start < end;
523 }
524 };
525
526 LogicalResult matchAndRewrite(tosa::ClampOp op,
527 PatternRewriter &rewriter) const override {
528 Value input = op.getInput();
529
530 // Check the input to the CLAMP op is itself a CLAMP.
531 auto clampOp = input.getDefiningOp<tosa::ClampOp>();
532 if (!clampOp)
533 return failure();
534
535 // Check we have a valid NaN propagation combination.
536 const auto opNanMode = op.getNanMode();
537 const auto clampNanMode = clampOp.getNanMode();
538 if (opNanMode == NanPropagationMode::IGNORE &&
539 clampNanMode == NanPropagationMode::PROPAGATE)
540 return failure();
541
542 auto maxValAttr = op.getMaxValAttr();
543 auto minValAttr = op.getMinValAttr();
544 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
545 auto clampOpMinValAttr = clampOp.getMinValAttr();
546
547 auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
548 if (auto quantType =
549 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
550 inputEType = getStorageElementTypeFromQuantized(quantType);
551 }
552
553 Attribute newMinValAttr, newMaxValAttr;
554 if (mlir::isa<FloatType>(inputEType)) {
555 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
556 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
557 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
558 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
559
560 // Check we have intersecting ranges.
561 const auto opMinFloat = floatMinValAttr.getValue();
562 const auto opMaxFloat = floatMaxValAttr.getValue();
563 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
564 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
565 ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
566 ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
567 clampOpMaxFloat);
568 if (!opRangeFloatRange.intersects(clampRangeFloatRange))
569 return failure();
570
571 // Run the transformation.
572 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
573 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
574 newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
575 newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
576 } else {
577 assert(mlir::isa<IntegerType>(inputEType));
578 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
579 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
580 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
581 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
582
583 if (inputEType.isUnsignedInteger()) {
584 // Check we have intersecting ranges.
585 const auto opMinInt = intMinValAttr.getUInt();
586 const auto opMaxInt = intMaxValAttr.getUInt();
587 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
588 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
589 ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
590 ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
591 clampOpMaxInt);
592 if (!opRangeIntRange.intersects(clampRangeIntRange))
593 return failure();
594
595 // Run the transformation.
596 auto newMinVal = std::max(opMinInt, clampOpMinInt);
597 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
598 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
599 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
600 } else {
601 // Check we have intersecting ranges.
602 const auto opMinInt = intMinValAttr.getInt();
603 const auto opMaxInt = intMaxValAttr.getInt();
604 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
605 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
606 ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
607 ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
608 clampOpMaxInt);
609 if (!opRangeIntRange.intersects(clampRangeIntRange))
610 return failure();
611
612 // Run the transformation.
613 auto newMinVal = std::max(opMinInt, clampOpMinInt);
614 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
615 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
616 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
617 }
618 }
619
620 auto newMode = (opNanMode != clampNanMode)
621 ? tosa::NanPropagationMode::IGNORE
622 : opNanMode;
623
624 auto newModeAttr =
625 NanPropagationModeAttr::get(rewriter.getContext(), newMode);
626
627 rewriter.replaceOpWithNewOp<tosa::ClampOp>(
628 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
629 newModeAttr);
630 return success();
631 }
632};
633
634void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
635 MLIRContext *context) {
636 results.add<ClampIsNoOp>(context);
637 results.add<ClampClampOptimization>(context);
638}
639
640struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
641 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
642
643 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
644 PatternRewriter &rewriter) const override {
645 Value sliceInput = sliceOp.getInput1();
646 auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
647 if (!concatOp)
648 return rewriter.notifyMatchFailure(
649 sliceOp, "slice input must be concat operation");
650
651 OperandRange inputs = concatOp.getInput1();
652 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
653 if (!concatType || !concatType.hasStaticShape())
654 return rewriter.notifyMatchFailure(
655 sliceOp, "slice input must be a static ranked tensor");
656 int32_t axis = concatOp.getAxis();
657
658 DenseElementsAttr startElems;
659 DenseElementsAttr sizeElems;
660
661 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
662 return rewriter.notifyMatchFailure(
663 sliceOp, "start of slice must be a static ranked shape");
664
665 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
666 return rewriter.notifyMatchFailure(
667 sliceOp, "size of slice must be a static ranked shape");
668
669 llvm::SmallVector<int64_t> sliceStarts =
670 llvm::to_vector(startElems.getValues<int64_t>());
671 llvm::SmallVector<int64_t> sliceSizes =
672 llvm::to_vector(sizeElems.getValues<int64_t>());
673
674 // Validate slice on the concatenated axis. Slicing along this
675 // axis should span only one of the inputs to the concatenate
676 // operation.
677 std::optional<Value> replaceWithSlice;
678 for (auto input : inputs) {
679 auto inputType = dyn_cast<RankedTensorType>(input.getType());
680 if (!inputType || !inputType.hasStaticShape())
681 return rewriter.notifyMatchFailure(
682 sliceOp, "concat input must be a static ranked tensor");
683
684 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
685 inputType.getDimSize(axis)) {
686 auto start_op =
687 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
688 auto size_op =
689 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
690 replaceWithSlice =
691 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
692 input, start_op, size_op)
693 .getResult();
694 break;
695 }
696 sliceStarts[axis] -= inputType.getDimSize(axis);
697 }
698
699 if (!replaceWithSlice)
700 return rewriter.notifyMatchFailure(
701 sliceOp, "corresponding concat input not found for slice");
702
703 rewriter.replaceOp(sliceOp, replaceWithSlice.value());
704 return success();
705 }
706};
707
708struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
709 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
710
711 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
712 PatternRewriter &rewriter) const override {
713 Value sliceInput = sliceOp.getInput1();
714
715 // Check if producer is a PadOp
716 auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
717 if (!padOp)
718 return rewriter.notifyMatchFailure(sliceOp,
719 "slice input must be a pad operation");
720
721 // Check PadOp has a single consumer
722 if (!padOp->hasOneUse())
723 return rewriter.notifyMatchFailure(sliceOp,
724 "pad shall have a single consumer");
725
726 // Check input is statically ranked
727 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
728 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
729 if (!inputTy || !padTy || !inputTy.hasRank())
730 return rewriter.notifyMatchFailure(sliceOp,
731 "slice input must be a ranked tensor");
732
733 // Validate and extract tosa::PadOp padding
734 DenseIntElementsAttr paddingElems;
735 if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
736 return rewriter.notifyMatchFailure(
737 sliceOp,
738 "`padding` input specified on the tosa::PadOp must be constant.");
739 }
740 llvm::SmallVector<int64_t> padPaddings =
741 llvm::to_vector(paddingElems.getValues<int64_t>());
742
743 // Extract slice parameters
744 DenseElementsAttr startElems;
745 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
746 return rewriter.notifyMatchFailure(
747 sliceOp, "start of slice must be a static ranked shape");
748 llvm::SmallVector<int64_t> sliceStarts =
749 llvm::to_vector(startElems.getValues<int64_t>());
750
751 DenseElementsAttr sizeElems;
752 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
753 return rewriter.notifyMatchFailure(
754 sliceOp, "size of slice must be a static ranked shape");
755 llvm::SmallVector<int64_t> sliceSizes =
756 llvm::to_vector(sizeElems.getValues<int64_t>());
757
758 // Check if dynamic dimensions are sliced
759 const int64_t rank = inputTy.getRank();
760 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
761 const bool isDimDynamic = inputTy.isDynamicDim(i);
762 const bool isDimSliced =
763 (sliceStarts[i] != 0) || (sliceSizes[i] != kInferableDimSize);
764
765 return isDimDynamic && isDimSliced;
766 })) {
767 return rewriter.notifyMatchFailure(
768 sliceOp, "axis that are sliced shall be statically known.");
769 }
770
771 // Update the parameters
772 llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
773 llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
774 llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
775 bool updated = false;
776
777 for (int64_t i = 0; i < rank; ++i) {
778 const int64_t padLo = padPaddings[i * 2];
779 const int64_t padHi = padPaddings[i * 2 + 1];
780 const int64_t sliceStart = sliceStarts[i];
781 const int64_t sliceSize = sliceSizes[i];
782 const int64_t sliceEnd = sliceStart + sliceSize;
783
784 // If dimension is dynamic pass-through
785 if (inputTy.isDynamicDim(i)) {
786 newPadPaddings[i * 2] = padLo;
787 newPadPaddings[i * 2 + 1] = padHi;
788 newSliceStarts[i] = sliceStart;
789 continue;
790 }
791
792 // Handle static dimensions
793 const int64_t dimSize = inputTy.getShape()[i];
794 const int64_t dimTotal = padLo + dimSize + padHi;
795
796 // Check slice within bounds
797 if (sliceStart < 0 || sliceEnd > dimTotal)
798 return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
799
800 // Compute updated slice start parameter
801 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
802 newSliceStarts[i] = newSliceStart;
803 updated |= newSliceStart != sliceStart;
804
805 // Compute updated pad parameters
806 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
807 const int64_t newPadHi =
808 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
809 newPadPaddings[i * 2] = newPadLo;
810 newPadPaddings[i * 2 + 1] = newPadHi;
811 updated |= (newPadLo != padLo) || (newPadHi != padHi);
812
813 // Calculate new pad output shape
814 newPadShape[i] =
815 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
816 }
817
818 // Check that we actually need to proceed with the rewrite
819 if (!updated)
820 return rewriter.notifyMatchFailure(
821 sliceOp, "terminate condition; nothing to rewrite");
822
823 // Create a PadOp with updated padding
824 auto newPaddingsOp =
825 getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
826 auto newPadTy =
827 RankedTensorType::get(newPadShape, inputTy.getElementType());
828 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
829 padOp.getInput1(), newPaddingsOp,
830 padOp.getPadConst());
831
832 // Update SliceOp and point to new PadOp
833 auto newStartOp =
834 getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
835 rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
836 newPadOp.getResult(), newStartOp,
837 sliceOp.getSize());
838
839 return success();
840 }
841};
842
843// Update size operand of tosa.slice if size has dynamic dims but corresponding
844// output dim is static
846 : public OpRewritePattern<tosa::SliceOp> {
847 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
848
849 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
850 PatternRewriter &rewriter) const override {
851 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
852 if (!resultType.hasRank())
853 return rewriter.notifyMatchFailure(sliceOp, "output must be ranked");
854
855 ElementsAttr sizeElems;
856 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
857 return rewriter.notifyMatchFailure(
858 sliceOp, "size of slice must be a static ranked shape");
859 }
860
861 llvm::SmallVector<int64_t> sliceSizes =
862 llvm::to_vector(sizeElems.getValues<int64_t>());
863
864 bool replaceSliceSize{false};
865 // if size op has kInferableDimSize indicating dynamic shape but
866 // corresponding dim on the output is statically known, update size to match
867 // with known output dim shape
868 for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
869 if (size == kInferableDimSize && !resultType.isDynamicDim(index)) {
870 sliceSizes[index] = resultType.getDimSize(index);
871 replaceSliceSize = true;
872 }
873 }
874
875 if (!replaceSliceSize) {
876 return rewriter.notifyMatchFailure(
877 sliceOp, "no dimension of size of slice is dynamic that resolves "
878 "to static output shape");
879 }
880
881 auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
882 auto newSliceOp =
883 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
884 sliceOp.getInput1(), sliceOp.getStart(), size_op);
885
886 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
887 return success();
888 }
889};
890
891void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
892 MLIRContext *context) {
893 results.add<ConcatSliceOptimization, PadSliceOptimization,
894 SliceDynamicSizeCanonicalization>(context);
895}
896
897struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
898 using OpRewritePattern<tosa::CastOp>::OpRewritePattern;
899
900 LogicalResult matchAndRewrite(tosa::CastOp castOp,
901 PatternRewriter &rewriter) const override {
902 const Value castInput = castOp.getInput();
903 auto innerCastOp = castInput.getDefiningOp<tosa::CastOp>();
904 if (!innerCastOp)
905 return rewriter.notifyMatchFailure(castOp,
906 "input must be cast operation");
907
908 const Value innerCastInput = innerCastOp.getInput();
909
910 const auto innerInputType =
911 llvm::cast<ShapedType>(innerCastInput.getType());
912 const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
913 const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
914
915 const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType,
916 outerOutputType};
917 if (llvm::any_of(types, [](const ShapedType type) {
918 return !type.getElementType().isInteger();
919 }))
920 return rewriter.notifyMatchFailure(castOp,
921 "only integer types are supported");
922
923 // Check inner cast is non-narrowing
924 const unsigned innerInputBitWidth = innerInputType.getElementTypeBitWidth();
925 if (innerInputBitWidth > innerOutputType.getElementTypeBitWidth())
926 return rewriter.notifyMatchFailure(castOp,
927 "inner cast operation is narrowing");
928
929 // Check outer cast is non-narrowing from the inner cast input
930 if (innerInputBitWidth > outerOutputType.getElementTypeBitWidth())
931 return rewriter.notifyMatchFailure(castOp,
932 "outer cast operation is narrowing");
933
934 rewriter.replaceOpWithNewOp<tosa::CastOp>(castOp, outerOutputType,
935 innerCastInput);
936
937 return success();
938 }
939};
940
941void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
942 MLIRContext *context) {
943 results.add<NonNarrowingCastsOptimization>(context);
944}
945
947 : public OpRewritePattern<tosa::CastToBlockScaledOp> {
948 using OpRewritePattern<tosa::CastToBlockScaledOp>::OpRewritePattern;
949
950 LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp,
951 PatternRewriter &rewriter) const override {
952 const Value castToBlockScaledInput = castToBlockScaledOp.getInputData();
953 auto castFromBlockScaledOp =
954 castToBlockScaledInput.getDefiningOp<tosa::CastFromBlockScaledOp>();
955 if (!castFromBlockScaledOp)
956 return rewriter.notifyMatchFailure(
957 castToBlockScaledOp,
958 "input must be cast_from_block_scaled operation");
959
960 const Value innerData = castFromBlockScaledOp.getInputData();
961 const Value innerScale = castFromBlockScaledOp.getInputScale();
962 const auto innerDataTy = llvm::cast<ShapedType>(innerData.getType());
963 const auto innerScaleTy = llvm::cast<ShapedType>(innerScale.getType());
964
965 const Value outerData = castToBlockScaledOp.getOutputData();
966 const Value outerScale = castToBlockScaledOp.getOutputScale();
967 const auto outerDataTy = llvm::cast<ShapedType>(outerData.getType());
968 const auto outerScaleTy = llvm::cast<ShapedType>(outerScale.getType());
969
970 if (innerDataTy != outerDataTy || innerScaleTy != outerScaleTy) {
971 return rewriter.notifyMatchFailure(
972 castToBlockScaledOp,
973 "inputs types to cast_from_block_scaled operation must match output "
974 "types to cast_to_block_scaled");
975 }
976
977 if (castFromBlockScaledOp.getBlockSize() !=
978 castToBlockScaledOp.getBlockSize()) {
979 return rewriter.notifyMatchFailure(
980 castToBlockScaledOp, "block sizes for cast_from_block_scaled and "
981 "cast_to_block_scaled must match");
982 }
983
984 rewriter.replaceOp(castToBlockScaledOp, {innerData, innerScale});
985
986 return success();
987 }
988};
989
990void CastToBlockScaledOp::getCanonicalizationPatterns(
991 RewritePatternSet &results, MLIRContext *context) {
992 results.add<CancellingBlockScaledCastsOptimization>(context);
993}
994
995//===----------------------------------------------------------------------===//
996// Operator Folders.
997//===----------------------------------------------------------------------===//
998
999template <typename Folder>
1000static DenseElementsAttr
1002 bool foldDenseValues = false) {
1003 if (!lhs || !rhs)
1004 return {};
1005
1006 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1007 return {};
1008
1009 const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
1010 const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
1011 if (lETy != rETy)
1012 return {};
1013
1014 if (lhs.isSplat() && rhs.isSplat()) {
1015 if (isa<FloatType>(lETy)) {
1016 const APFloat l = lhs.getSplatValue<APFloat>();
1017 const APFloat r = rhs.getSplatValue<APFloat>();
1018 const auto maybeResult = Folder::fold(l, r);
1019 if (failed(maybeResult))
1020 return {};
1021 return DenseElementsAttr::get(returnTy, maybeResult.value());
1022 }
1023
1024 if (const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
1025 const APInt l = lhs.getSplatValue<APInt>();
1026 const APInt r = rhs.getSplatValue<APInt>();
1027 const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
1028 if (failed(maybeResult))
1029 return {};
1030 return DenseElementsAttr::get(returnTy, maybeResult.value());
1031 }
1032 }
1033
1034 if (foldDenseValues) {
1035 assert(lETy.isIntOrIndex() &&
1036 "Only integer types are currently supported.");
1037 SmallVector<APInt> resultValues;
1038 for (auto [l, r] :
1039 llvm::zip(lhs.getValues<APInt>(), rhs.getValues<APInt>())) {
1040 const auto maybeResult = Folder::fold(l, r, false);
1041 if (failed(maybeResult))
1042 return {};
1043 resultValues.push_back(maybeResult.value());
1044 }
1045 return DenseElementsAttr::get(returnTy, resultValues);
1046 }
1047
1048 return {};
1049}
1050
1051template <typename Folder>
1052static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy,
1053 bool foldDenseValues = false) {
1054 if (!val)
1055 return {};
1056
1057 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1058 return {};
1059
1060 const auto vETy = llvm::cast<ShapedType>(val.getType()).getElementType();
1061
1062 if (val.isSplat()) {
1063 if (const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
1064 const APInt v = val.getSplatValue<APInt>();
1065 const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
1066 if (failed(maybeResult))
1067 return {};
1068 return DenseElementsAttr::get(returnTy, maybeResult.value());
1069 }
1070 }
1071
1072 if (foldDenseValues) {
1073 mlir::Type elemTy = val.getElementType();
1074 if (elemTy.isIntOrIndex()) {
1075 SmallVector<APInt> resultValues;
1076 for (auto const &v : val.getValues<APInt>()) {
1077 const auto maybeResult = Folder::fold(v, false);
1078 if (failed(maybeResult))
1079 return {};
1080 resultValues.push_back(maybeResult.value());
1081 }
1082 return DenseElementsAttr::get(returnTy, resultValues);
1083 }
1084 }
1085
1086 // Folding arbitrarily sized tensor operations is not supported
1087 return {};
1088}
1089
1090static FailureOr<int64_t> getSingleI64From1ElementTensor(Value v) {
1091 DenseIntElementsAttr dense{};
1092 if (!matchPattern(v, m_Constant(&dense)))
1093 return failure();
1094
1095 assert(dense.isSplat());
1096 APInt a = dense.getSplatValue<APInt>();
1097 return a.getSExtValue();
1098}
1099
1101 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1102 const bool isUnsigned) {
1103 bool overflow;
1104 const APInt result =
1105 isUnsigned ? lhs.uadd_ov(rhs, overflow) : lhs.sadd_ov(rhs, overflow);
1106 if (overflow)
1107 return failure();
1108 return result;
1109 }
1110
1111 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1112 return lhs + rhs;
1113 }
1114};
1115
1117 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1118 const bool isUnsigned) {
1119 bool overflow;
1120 const APInt result =
1121 isUnsigned ? lhs.usub_ov(rhs, overflow) : lhs.ssub_ov(rhs, overflow);
1122 if (overflow)
1123 return failure();
1124 return result;
1125 }
1126
1127 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1128 return lhs - rhs;
1129 }
1130};
1131
1133 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1134 const bool isUnsigned) {
1135
1136 const unsigned originalWidth = lhs.getBitWidth();
1137
1138 // Check same type
1139 if (lhs.getBitWidth() != rhs.getBitWidth()) {
1140 return failure();
1141 }
1142
1143 // If either is `0`
1144 if (lhs == 0 || rhs == 0)
1145 return APInt::getZero(originalWidth);
1146
1147 bool overflow = false;
1148 APInt const result =
1149 isUnsigned ? lhs.umul_ov(rhs, overflow) : lhs.smul_ov(rhs, overflow);
1150
1151 if (overflow)
1152 return failure();
1153
1154 return result.trunc(originalWidth);
1155 }
1156
1157 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1158 return lhs * rhs;
1159 }
1160};
1161
1162static bool signsDiffer(const APInt &a, const APInt &b) {
1163 return a.isNegative() != b.isNegative();
1164}
1165
1166template <bool Ceil>
1168 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1169 bool isUnsigned) {
1170 if (lhs.getBitWidth() != rhs.getBitWidth())
1171 return failure();
1172 if (rhs.isZero())
1173 return failure();
1174
1175 if (isUnsigned) {
1176 APInt q{};
1177 APInt r{};
1178 APInt::udivrem(lhs, rhs, q, r);
1179 if (!r.isZero() && Ceil) {
1180 return q + 1;
1181 }
1182 return q;
1183 }
1184
1185 // Signed: start from trunc-toward-zero, then adjust to ceil.
1186 bool overflow{false};
1187 APInt const q = lhs.sdiv_ov(rhs, overflow);
1188 if (overflow)
1189 return failure();
1190 APInt const r = lhs.srem(rhs);
1191
1192 if (Ceil && !r.isZero() && !signsDiffer(lhs, rhs)) {
1193 // Same sign => exact quotient is positive; trunc is below ceil =>
1194 // increment q.
1195 return q + 1;
1196 }
1197 return q;
1198 }
1199
1200 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1201 return lhs / rhs;
1202 }
1203};
1204
1206 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1207 bool isUnsigned) {
1208 if (lhs.getBitWidth() != rhs.getBitWidth())
1209 return failure();
1210 if (lhs.isNegative() || (!rhs.isStrictlyPositive()))
1211 return failure();
1212
1213 if (isUnsigned) {
1214 return lhs.urem(rhs);
1215 }
1216
1217 return lhs.srem(rhs);
1218 }
1219
1220 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1221 auto t = lhs;
1222 auto const r = t.mod(rhs);
1223 if (llvm::APFloatBase::opStatus::opOK == r) {
1224 return t;
1225 }
1226 return failure();
1227 }
1228};
1229
1231 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1232 bool isUnsigned) {
1233 if (lhs.getBitWidth() != rhs.getBitWidth())
1234 return failure();
1235 return lhs.getSExtValue() >= rhs.getSExtValue() ? lhs : rhs;
1236 }
1237
1238 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1239 return lhs >= rhs ? lhs : rhs;
1240 }
1241};
1242
1244 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1245 bool isUnsigned) {
1246 if (lhs.getBitWidth() != rhs.getBitWidth())
1247 return failure();
1248 return lhs.getSExtValue() <= rhs.getSExtValue() ? lhs : rhs;
1249 }
1250
1251 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1252 return lhs <= rhs ? lhs : rhs;
1253 }
1254};
1255
1257 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1258 auto const numBits = value.getBitWidth();
1259 if (isUnsigned) {
1260 auto const zextv = value.getZExtValue();
1261 if (zextv >= numBits)
1262 return failure();
1263 return APInt::getOneBitSet(numBits, zextv);
1264 }
1265 auto const sextv = value.getSExtValue();
1266 if (sextv < 0 || sextv >= numBits || (value.isNegative()))
1267 return failure();
1268 return APInt::getOneBitSet(numBits, sextv);
1269 }
1270};
1271
1273 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1274 if (!value.isStrictlyPositive())
1275 return failure();
1276 return APInt(/*numBits=*/value.getBitWidth(), value.ceilLogBase2());
1277 }
1278};
1279
1281 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1282 if (!value.isStrictlyPositive())
1283 return failure();
1284 return APInt(/*numBits=*/value.getBitWidth(), value.logBase2());
1285 }
1286};
1287
1289 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1290 const bool isUnsigned) {
1291 return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs));
1292 }
1293
1294 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1295 return APInt(1, lhs > rhs);
1296 }
1297};
1298
1300 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1301 const bool isUnsigned) {
1302 return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs));
1303 }
1304
1305 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1306 return APInt(1, lhs >= rhs);
1307 }
1308};
1309
1311 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1312 const bool isUnsigned) {
1313 return APInt(1, lhs == rhs);
1314 }
1315
1316 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1317 return APInt(1, lhs == rhs);
1318 }
1319};
1320
1321static bool isSplatZero(Type elemType, DenseElementsAttr val) {
1322 if (llvm::isa<FloatType>(elemType))
1323 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
1324 if (llvm::isa<IntegerType>(elemType))
1325 return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
1326 return false;
1327}
1328
1329static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
1330 if (llvm::isa<FloatType>(elemType))
1331 return val && val.isSplat() &&
1332 val.getSplatValue<APFloat>().isExactlyValue(1.0);
1333 if (llvm::isa<IntegerType>(elemType)) {
1334 const int64_t shifted = 1LL << shift;
1335 return val && val.isSplat() &&
1336 val.getSplatValue<APInt>().getSExtValue() == shifted;
1337 }
1338 return false;
1339}
1340
1341OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1342 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1343 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1344 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1345 if (!lhsTy || !rhsTy || !resultTy)
1346 return {};
1347
1348 // Cannot create an ElementsAttr from non-int/float/index types
1349 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1350 !rhsTy.getElementType().isIntOrIndexOrFloat())
1351 return {};
1352
1353 auto resultETy = resultTy.getElementType();
1354 auto lhsAttr =
1355 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1356 auto rhsAttr =
1357 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1358
1359 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1360 lhsTy.getShape(), rhsTy.getShape());
1361 if (isBroadcastable && lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1362 return getInput1();
1363 if (isBroadcastable && rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
1364 return getInput2();
1365
1366 if (!lhsAttr || !rhsAttr)
1367 return {};
1368
1369 return binaryFolder<AddFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1370}
1371
1372OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1373 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
1374 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1375 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1376 !outputTy.hasStaticShape())
1377 return {};
1378
1379 const Type outputElementTy = getElementTypeOrSelf(outputTy);
1380 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
1381 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1382 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1383 return DenseElementsAttr::get(outputTy, zero);
1384 }
1385
1386 return {};
1387}
1388
1389OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1390 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1391 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1392 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1393 if (!lhsTy || !rhsTy || !resultTy)
1394 return {};
1395 if (lhsTy.getElementType() != rhsTy.getElementType())
1396 return {};
1397
1398 // IntDivOp inputs must be integer type, no need to check for quantized
1399 // type
1400 auto resultETy = resultTy.getElementType();
1401 auto lhsAttr =
1402 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1403 auto rhsAttr =
1404 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1405 if (lhsAttr && lhsAttr.isSplat()) {
1406 if (llvm::isa<IntegerType>(resultETy) && resultTy.hasStaticShape() &&
1407 lhsAttr.getSplatValue<APInt>().isZero())
1408 return lhsAttr.resizeSplat(resultTy);
1409 }
1410
1411 if (rhsAttr && rhsAttr.isSplat()) {
1412 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1413 lhsTy.getShape(), rhsTy.getShape());
1414 if (isBroadcastable && lhsTy == resultTy &&
1415 llvm::isa<IntegerType>(resultETy) &&
1416 rhsAttr.getSplatValue<APInt>().isOne())
1417 return getInput1();
1418 }
1419
1420 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1421 llvm::isa<IntegerType>(resultETy)) {
1422 APInt l = lhsAttr.getSplatValue<APInt>();
1423 APInt r = rhsAttr.getSplatValue<APInt>();
1424 if (!r.isZero()) {
1425 auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
1426 auto const result =
1427 DivFoldAdaptor</*Ceil*/ false>::fold(l, r, intTy.isUnsigned());
1428 if (failed(result))
1429 return {};
1430 return DenseElementsAttr::get(resultTy, result.value());
1431 }
1432 }
1433
1434 return {};
1435}
1436
1437namespace {
1438// calculate lhs * rhs >> shift according to TOSA Spec
1439// return nullopt if result is not in range of int32_t when shift > 0
1440std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
1441 unsigned bitwidth) {
1442 bool overflow = false;
1443 APInt result = lhs.sext(64).smul_ov(rhs.sext(64), overflow);
1444
1445 if (overflow)
1446 return std::nullopt;
1447
1448 if (shift > 0) {
1449 auto round = APInt(64, 1) << (shift - 1);
1450 result += round;
1451 result.ashrInPlace(shift);
1452 // REQUIRE(product >= minimum_s<i32_t>() && product <=
1453 // maximum_s<i32_t>())
1454 if (!(result.getSExtValue() >= INT32_MIN &&
1455 result.getSExtValue() <= INT32_MAX)) {
1456 // REQUIRE failed
1457 return std::nullopt;
1458 }
1459 }
1460
1461 return result.trunc(bitwidth);
1462}
1463
1464DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
1465 RankedTensorType ty, int32_t shift) {
1466 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
1467 if (llvm::isa<IntegerType>(ty.getElementType())) {
1468 APInt l = lhs.getSplatValue<APInt>();
1469 APInt r = rhs.getSplatValue<APInt>();
1470
1471 if (shift == 0) {
1472 return DenseElementsAttr::get(ty, l * r);
1473 }
1474
1475 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1476 const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
1477 if (!result)
1478 return {};
1479 return DenseElementsAttr::get(ty, result.value());
1480 }
1481
1482 if (llvm::isa<FloatType>(ty.getElementType())) {
1483 APFloat l = lhs.getSplatValue<APFloat>();
1484 APFloat r = rhs.getSplatValue<APFloat>();
1485 APFloat result = l * r;
1486 return DenseElementsAttr::get(ty, result);
1487 }
1488 }
1489
1490 return {};
1491}
1492} // namespace
1493
1494OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1495 auto lhs = getInput1();
1496 auto rhs = getInput2();
1497 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
1498 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
1499 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1500 if (!lhsTy || !rhsTy || !resultTy)
1501 return {};
1502
1503 auto resultETy = resultTy.getElementType();
1504 auto lhsAttr =
1505 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1506 auto rhsAttr =
1507 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1508
1509 // Result right shift on i32_t data type only. For simplification,
1510 // synthesize a zero shift for other data type.
1511 int32_t shift = 0;
1512 if (resultETy.isInteger(32)) {
1513 ElementsAttr shift_elem;
1514 if (getShift().getImpl()) {
1515 if (!matchPattern(getShift(), m_Constant(&shift_elem)))
1516 // cannot be folded when the shift value is unknown.
1517 return {};
1518 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1519 }
1520 }
1521
1522 if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr) &&
1523 resultTy.hasStaticShape())
1524 // constant values can only be resized if resulting type is static
1525 return lhsAttr.resizeSplat(resultTy);
1526 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr) &&
1527 resultTy.hasStaticShape())
1528 return rhsAttr.resizeSplat(resultTy);
1529
1530 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1531 lhsTy.getShape(), rhsTy.getShape());
1532 if (isBroadcastable && rhsTy == resultTy &&
1533 isSplatOne(resultETy, lhsAttr, shift))
1534 return rhs;
1535 if (isBroadcastable && lhsTy == resultTy &&
1536 isSplatOne(resultETy, rhsAttr, shift))
1537 return lhs;
1538
1539 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1540}
1541
1542OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1543 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1544 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1545 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1546 if (!lhsTy || !rhsTy || !resultTy)
1547 return {};
1548
1549 // Cannot create an ElementsAttr from non-int/float/index types
1550 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1551 !rhsTy.getElementType().isIntOrIndexOrFloat())
1552 return {};
1553
1554 auto resultETy = resultTy.getElementType();
1555 auto lhsAttr =
1556 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1557 auto rhsAttr =
1558 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1559
1560 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1561 lhsTy.getShape(), rhsTy.getShape());
1562 if (isBroadcastable && lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1563 return getInput1();
1564
1565 if (!lhsAttr || !rhsAttr)
1566 return {};
1567
1568 return binaryFolder<SubFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1569}
1570
1571OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1572 auto resultTy = llvm::cast<ShapedType>(getType());
1573 auto lhsAttr =
1574 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1575 auto rhsAttr =
1576 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1577
1578 if (!lhsAttr || !rhsAttr)
1579 return {};
1580
1581 return binaryFolder<GreaterFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1582}
1583
1584OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1585 auto resultTy = llvm::cast<ShapedType>(getType());
1586 auto lhsAttr =
1587 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1588 auto rhsAttr =
1589 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1590
1591 if (!lhsAttr || !rhsAttr)
1592 return {};
1593
1594 return binaryFolder<GreaterEqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1595}
1596
1597OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1598 auto resultTy = llvm::cast<ShapedType>(getType());
1599 auto lhsAttr =
1600 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1601 auto rhsAttr =
1602 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1603 Value lhs = getInput1();
1604 Value rhs = getInput2();
1605 auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
1606
1607 // If we are comparing an integer value to itself it is always true. We
1608 // can not do this with float due to float values.
1609 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy.hasRank() &&
1610 resultTy.hasStaticShape() && lhs == rhs) {
1611 return DenseElementsAttr::get(resultTy, true);
1612 }
1613
1614 if (!lhsAttr || !rhsAttr)
1615 return {};
1616
1617 return binaryFolder<EqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1618}
1619
1620OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1621 if (getInput().getType() == getType())
1622 return getInput();
1623
1624 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1625 if (!operand)
1626 return {};
1627
1628 auto inTy = llvm::cast<ShapedType>(getInput().getType());
1629 auto outTy = llvm::cast<ShapedType>(getType());
1630 if (!outTy.hasRank() || !outTy.hasStaticShape())
1631 return {};
1632 auto inETy = inTy.getElementType();
1633 auto outETy = outTy.getElementType();
1634
1635 if (operand.isSplat()) {
1636 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1637 bool overflow;
1638 auto splatVal = operand.getSplatValue<APFloat>();
1639 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1640 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1641 &overflow);
1642 return SplatElementsAttr::get(outTy, splatVal);
1643 }
1644
1645 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1646 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1647 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1648 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1649 llvm::RoundingMode::NearestTiesToEven);
1650 return SplatElementsAttr::get(outTy, splatVal);
1651 }
1652
1653 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1654 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1655 auto intVal = APSInt(
1656 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1657 auto floatVal = operand.getSplatValue<APFloat>();
1658 bool exact;
1659 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1660 &exact);
1661 return SplatElementsAttr::get(outTy, intVal);
1662 }
1663
1664 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1665 const auto inIntType = llvm::cast<IntegerType>(inETy);
1666 auto unsignIn = inIntType.isUnsignedInteger();
1667 bool trunc =
1668 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1669 auto intVal = operand.getSplatValue<APInt>();
1670 auto bitwidth = outETy.getIntOrFloatBitWidth();
1671
1672 // i1 types are boolean in TOSA
1673 if (outETy.isInteger(1)) {
1674 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1675 } else if (trunc) {
1676 intVal = intVal.trunc(bitwidth);
1677 } else if (unsignIn || inIntType.isInteger(1)) {
1678 intVal = intVal.zext(bitwidth);
1679 } else {
1680 intVal = intVal.sext(bitwidth);
1681 }
1682
1683 return SplatElementsAttr::get(outTy, intVal);
1684 }
1685 }
1686
1687 return {};
1688}
1689
1690OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1691
1692OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1693
1694#define REDUCE_FOLDER(OP) \
1695 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1696 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1697 if (!inputTy.hasRank()) \
1698 return {}; \
1699 if (inputTy != getType()) \
1700 return {}; \
1701 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1702 return getInput(); \
1703 return {}; \
1704 }
1705
1706REDUCE_FOLDER(ReduceAllOp)
1707REDUCE_FOLDER(ReduceAnyOp)
1708REDUCE_FOLDER(ReduceMaxOp)
1709REDUCE_FOLDER(ReduceMinOp)
1710REDUCE_FOLDER(ReduceProductOp)
1711REDUCE_FOLDER(ReduceSumOp)
1712#undef REDUCE_FOLDER
1713
1714OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1715 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1716 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1717
1718 if (!inputTy || !outputTy)
1719 return {};
1720
1721 // Fold when the input and output types are the same. This is only safe
1722 // when there is at most 1 dynamic dimension. For 2 or more dynamic
1723 // dimensions, there may still be a productive reshape.
1724 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1725 return getInput1();
1726
1727 // reshape(reshape(x)) -> reshape(x)
1728 if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1729 getInput1().getDefiningOp())) {
1730 getInput1Mutable().assign(reshapeOp.getInput1());
1731 return getResult();
1732 }
1733
1734 // Cannot create an ElementsAttr from non-int/float/index types
1735 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1736 return {};
1737
1738 // reshape(const(x)) -> const(reshape-attr(x))
1739 if (auto operand =
1740 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1741 // Constants must have static shape.
1742 if (!outputTy.hasStaticShape())
1743 return {};
1744
1745 // Okay to duplicate splat constants.
1746 if (operand.isSplat())
1747 return SplatElementsAttr::get(outputTy,
1748 operand.getSplatValue<Attribute>());
1749
1750 // Don't duplicate other constants.
1751 if (!getInput1().hasOneUse())
1752 return {};
1753
1755 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
1756 return {};
1757
1758 return operand.reshape(
1759 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1760 }
1761
1762 return {};
1763}
1764
1765OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1766 // If the pad is all zeros we can fold this operation away.
1767 if (adaptor.getPadding() && getInput1().getType() == getType()) {
1768 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1769 if (densePad && densePad.isSplat() &&
1770 densePad.getSplatValue<APInt>().isZero()) {
1771 return getInput1();
1772 }
1773 }
1774
1775 return {};
1776}
1777
1778// Fold away cases where a tosa.resize operation returns a copy
1779// of the input image.
1780OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1781 auto scaleAttr =
1782 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1783 auto offsetAttr =
1784 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1785 auto borderAttr =
1786 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1787 if (!scaleAttr || !offsetAttr || !borderAttr) {
1788 return {};
1789 }
1790
1791 auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
1792 auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
1793 auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
1794 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1795 return {};
1796 }
1797
1798 // Check unit scaling.
1799 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1800 return {};
1801 }
1802
1803 // There should be no offset.
1804 if (offset[0] != 0 || offset[1] != 0) {
1805 return {};
1806 }
1807
1808 // There should be no border.
1809 if (border[0] != 0 || border[1] != 0) {
1810 return {};
1811 }
1812
1813 return foldToInputIfTypeMatches(getType(), getInput());
1814}
1815
1816OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1817 auto operand = getInput1();
1818 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1819 auto axis = getAxis();
1820 // If the dim-length is 1, or reversing axis is unit-dim, also a no-op.
1821 const bool isSplatInput =
1822 llvm::isa_and_nonnull<SplatElementsAttr>(adaptor.getInput1());
1823 if (!operandTy.hasRank() ||
1824 (!isSplatInput && operandTy.getDimSize(axis) != 1))
1825 return {};
1826 return foldToInputIfTypeMatches(getType(), operand);
1827}
1828
1829OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1830 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1831 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1832
1833 if (!inputTy || !outputTy)
1834 return {};
1835
1836 if (inputTy == outputTy && inputTy.hasStaticShape())
1837 return getInput1();
1838
1839 // Check if this is a no-op slice (starts at 0 and size matches input)
1840
1841 DenseElementsAttr startElems;
1842 if (!matchPattern(getStart(), m_Constant(&startElems)))
1843 return {};
1844
1845 // Check if all start values are zero
1846 bool startIsZeros =
1847 llvm::all_of(startElems.getValues<APInt>(),
1848 [](const APInt &val) { return val.isZero(); });
1849
1850 if (startIsZeros) {
1851
1852 // Check if size matches input shape
1853 DenseElementsAttr sizeElems;
1854 if (!matchPattern(getSize(), m_Constant(&sizeElems)))
1855 return {};
1856
1857 auto inputShape = inputTy.getShape();
1858 auto sizeValues = sizeElems.getValues<APInt>();
1859
1860 bool sizeMatchesInput = true;
1861 for (const auto &[i, sizeVal] : llvm::enumerate(sizeValues)) {
1862 int64_t size = sizeVal.getSExtValue();
1863
1864 if (inputTy.isDynamicDim(i)) {
1865 // For dynamic dimensions, check for kInferableDimSize indicating full
1866 // dimension is sliced
1867 if (size != kInferableDimSize) {
1868 sizeMatchesInput = false;
1869 break;
1870 }
1871 } else {
1872 // For static dimensions, check that size must match exactly or be
1873 // kInferableDimSize indicating full dimension is sliced
1874 if (size != kInferableDimSize && size != inputShape[i]) {
1875 sizeMatchesInput = false;
1876 break;
1877 }
1878 }
1879 }
1880
1881 if (sizeMatchesInput)
1882 return getInput1();
1883 }
1884
1885 // The following checks require the input to be a constant
1886 if (!adaptor.getInput1())
1887 return {};
1888
1889 // Cannot create an ElementsAttr from non-int/float/index types
1890 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1891 !outputTy.getElementType().isIntOrIndexOrFloat())
1892 return {};
1893
1894 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1895 if (operand.isSplat() && outputTy.hasStaticShape()) {
1896 return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
1897 }
1898
1899 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1900 outputTy.getNumElements() == 1) {
1901 llvm::SmallVector<uint64_t> indices =
1902 llvm::to_vector(startElems.getValues<uint64_t>());
1903 if (auto values = operand.tryGetValues<Attribute>())
1904 return SplatElementsAttr::get(outputTy, (*values)[indices]);
1905 }
1906
1907 return {};
1908}
1909
1910OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1911 const Value pred = getPred();
1912 const Value onTrue = getOnTrue();
1913 const Value onFalse = getOnFalse();
1914
1915 const auto predTy = llvm::dyn_cast<RankedTensorType>(pred.getType());
1916 const auto onTrueTy = llvm::dyn_cast<RankedTensorType>(onTrue.getType());
1917 const auto onFalseTy = llvm::dyn_cast<RankedTensorType>(onFalse.getType());
1918 if (!predTy || !onTrueTy || !onFalseTy)
1919 return {};
1920
1921 const Type resultTy = getType();
1922
1923 const ArrayRef<int64_t> predShape = predTy.getShape();
1924 const ArrayRef<int64_t> onTrueShape = onTrueTy.getShape();
1925
1926 if (onTrue == onFalse && onTrueTy == resultTy &&
1927 OpTrait::util::staticallyKnownBroadcastable(predShape, onTrueShape))
1928 return onTrue;
1929
1930 auto predicate =
1931 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1932 if (!predicate)
1933 return {};
1934 if (!predicate.isSplat())
1935 return {};
1936
1937 const bool predicateValue = predicate.getSplatValue<APInt>().getBoolValue();
1938
1939 SmallVector<SmallVector<int64_t>, 3> shapes;
1940 shapes.emplace_back(predShape);
1941 shapes.emplace_back(onTrueShape);
1942 shapes.emplace_back(onFalseTy.getShape());
1943 const bool isBroadcastable =
1945
1946 if (predicateValue == true && onTrueTy == resultTy && isBroadcastable)
1947 return onTrue;
1948 if (predicateValue == false && onFalseTy == resultTy && isBroadcastable)
1949 return onFalse;
1950 return {};
1951}
1952
1953OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1954 if (getInput1().getType() == getType()) {
1955 if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1956 adaptor.getMultiples())) {
1957 if (multiples.isSplat() &&
1958 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1959 return getInput1();
1960 if (auto int_array_attr =
1961 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1962 if (llvm::all_of(int_array_attr.getValues<APInt>(),
1963 [](APInt v) { return v.getSExtValue() == 1; }))
1964 return getInput1();
1965 }
1966 }
1967 }
1968 return {};
1969}
1970
1971OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1972 auto resultTy = llvm::cast<ShapedType>(getType());
1973
1974 // Transposing splat values just means reshaping.
1975 if (auto input =
1976 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1977 if (input.isSplat() && resultTy.hasRank() && resultTy.hasStaticShape() &&
1978 input.getType().getElementType() == resultTy.getElementType())
1979 return input.reshape(resultTy);
1980 }
1981
1982 // Transpose is not the identity transpose.
1983 const llvm::ArrayRef<int32_t> perms = getPerms();
1984
1985 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1986 return {};
1987
1988 return foldToInputIfTypeMatches(getType(), getInput1());
1989}
1990
1991OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1992 // Element-wise negate(negate(x)) = x
1993 // iff all zero points are constant 0
1994 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1995 if (!definingOp) {
1996 // defining op of input1 is not a negate, cannot fold
1997 return {};
1998 }
1999
2000 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2001 failed(maybeIZp) || *maybeIZp != 0) {
2002 // input1 zero point is not constant 0, cannot fold
2003 return {};
2004 }
2005 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2006 failed(maybeOZp) || *maybeOZp != 0) {
2007 // output zero point is not constant 0, cannot fold
2008 return {};
2009 }
2010 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
2011 failed(maybeIZp) || *maybeIZp != 0) {
2012 // definingOp's input1 zero point is not constant 0, cannot fold
2013 return {};
2014 }
2015 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
2016 failed(maybeOZp) || *maybeOZp != 0) {
2017 // definingOp's output zero point is not constant 0, cannot fold
2018 return {};
2019 }
2020
2021 return foldToInputIfTypeMatches(getType(), definingOp.getInput1());
2022}
2023
2024OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
2025 auto input = getInput1();
2026 // Element-wise abs(abs(x)) = abs(x)
2027 if (input.getDefiningOp<tosa::AbsOp>())
2028 return foldToInputIfTypeMatches(getType(), input);
2029
2030 return {};
2031}
2032
2033OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
2034 // Fold consecutive concats on the same axis into a single op.
2035 // Keep track of the operands so we are able to construct a new concat
2036 // later. Conservatively assume that we double the number of operands when
2037 // folding
2038 SmallVector<Value, 8> concatOperands;
2039 concatOperands.reserve(2 * getNumOperands());
2040
2041 // Find all operands that are foldable concats
2042 bool foundFoldableConcat = false;
2043 for (Value operand : getOperands()) {
2044 concatOperands.emplace_back(operand);
2045
2046 auto producer = operand.getDefiningOp<ConcatOp>();
2047 if (!producer)
2048 continue;
2049
2050 // Not foldable if axes are not the same
2051 if (getAxis() != producer.getAxis())
2052 continue;
2053
2054 // Replace the original operand with all incoming operands
2055 foundFoldableConcat = true;
2056 concatOperands.pop_back();
2057 llvm::append_range(concatOperands, producer->getOperands());
2058 }
2059
2060 if (!foundFoldableConcat)
2061 return {};
2062
2063 getOperation()->setOperands(concatOperands);
2064 return getResult();
2065}
2066
2067OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
2068 auto input = adaptor.getInput1();
2069
2070 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
2071 // Fold splat inputs only.
2072 if (!inputAttr || !inputAttr.isSplat())
2073 return {};
2074
2075 auto shapeType = llvm::cast<ShapedType>(getType());
2076 if (!shapeType.hasRank() || !shapeType.hasStaticShape())
2077 return {};
2078 if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
2079 auto floatVal = inputAttr.getSplatValue<APFloat>();
2080 return DenseElementsAttr::get(shapeType,
2081 ReciprocalOp::calcOneElement(floatVal));
2082 }
2083
2084 return {};
2085}
2086
2087template <typename Op, typename OpFoldAdaptor>
2089 auto input1ConstShape =
2090 dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
2091 if (!input1ConstShape)
2092 return {};
2093
2094 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2095
2096 return unaryFolder<OpFoldAdaptor>(input1Attr, input1Attr.getType(),
2097 /*foldDenseValues=*/true);
2098}
2099
2100template <typename Op, typename OpFoldAdaptor>
2102 auto input1ConstShape =
2103 dyn_cast<tosa::ConstShapeOp>(op->getInput1().getDefiningOp());
2104 auto input2ConstShape =
2105 dyn_cast<tosa::ConstShapeOp>(op->getInput2().getDefiningOp());
2106 if (!input1ConstShape || !input2ConstShape)
2107 return {};
2108
2109 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2110 const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
2111
2112 return binaryFolder<OpFoldAdaptor>(input1Attr, input2Attr,
2113 input1Attr.getType(),
2114 /*foldDenseValues=*/true);
2115}
2116
2117OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
2118 const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().getType());
2119 if (!inputTy || !inputTy.hasRank())
2120 return {};
2121 const int32_t axis = getAxis();
2122 const int64_t dimSize = inputTy.getDimSize(axis);
2123 if (ShapedType::isDynamic(dimSize))
2124 return {};
2125
2126 OpBuilder builder(getContext());
2127 const auto resultAttrTy =
2128 RankedTensorType::get(/*rank=*/1, builder.getIndexType());
2129 return DenseElementsAttr::get(resultAttrTy, dimSize);
2130}
2131
2132OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op) {
2133 auto const inputs = op->getInput();
2134
2135 if (inputs.empty())
2136 return {};
2137
2138 SmallVector<APInt> concatDims;
2139 concatDims.reserve(/*max elem*/ 64);
2140 for (auto const &v : inputs) {
2141 auto vConstShape = dyn_cast<tosa::ConstShapeOp>(v.getDefiningOp());
2142 if (!vConstShape)
2143 return {};
2144
2145 const auto vAttr = cast<DenseElementsAttr>(vConstShape.getValues());
2146 assert(vAttr);
2147
2148 auto const vAttrVals = vAttr.getValues<APInt>();
2149 for (auto const &v : vAttrVals) {
2150 concatDims.push_back(v);
2151 }
2152 }
2153
2154 auto *ctx = op->getContext();
2155 assert(ctx != nullptr && "ctx is nullptr");
2156 auto const rankedTy = RankedTensorType::get(
2157 {static_cast<int64_t>(concatDims.size())}, IndexType::get(ctx));
2158
2159 return DenseElementsAttr::get(rankedTy, concatDims);
2160}
2161
2162OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op) {
2163 auto const input1 = op->getInput();
2164 auto const input2 = op->getStart();
2165 auto const input3 = op->getSize();
2166
2167 auto input1ConstShape = dyn_cast<tosa::ConstShapeOp>(input1.getDefiningOp());
2168
2169 if (!input1ConstShape)
2170 return {};
2171
2172 auto const input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2173 if (!input1Attr)
2174 return {};
2175
2176 auto const input1Vals = input1Attr.getValues<APInt>();
2177 auto const totalInput1 = input1Vals.size();
2178
2179 auto const start = getSingleI64From1ElementTensor(input2);
2180 auto const size = getSingleI64From1ElementTensor(input3);
2181
2182 if (failed(start) || failed(size))
2183 return {};
2184
2185 auto const startV = static_cast<int32_t>(start.value());
2186 auto const sizeV = static_cast<int32_t>(size.value());
2187
2188 if ((sizeV <= 0) || (startV < 0) ||
2189 (static_cast<size_t>(startV + sizeV) > totalInput1))
2190 return {};
2191
2192 SmallVector<APInt> sliceOfInput;
2193 sliceOfInput.reserve(totalInput1);
2194
2195 for (auto i = startV; i < (startV + sizeV); i++) {
2196 sliceOfInput.push_back(input1Vals[i]);
2197 }
2198
2199 auto *ctx = op->getContext();
2200 assert(ctx != nullptr && "ctx is nullptr");
2201
2202 auto const rankedTy = RankedTensorType::get(
2203 {static_cast<int64_t>(sliceOfInput.size())}, IndexType::get(ctx));
2204
2205 return DenseElementsAttr::get(rankedTy, sliceOfInput);
2206}
2207
2208OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
2210}
2211
2212OpFoldResult tosa::SubShapeOp::fold(FoldAdaptor adaptor) {
2214}
2215
2216OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
2218}
2219
2220OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
2221 return binaryFold<DivCeilShapeOp, DivFoldAdaptor</*Ceil*/ true>>(this);
2222}
2223
2224OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
2225 return binaryFold<DivFloorShapeOp, DivFoldAdaptor</*Ceil*/ false>>(this);
2226}
2227
2228OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
2230}
2231
2232OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
2234}
2235
2236OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
2238}
2239
2240OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
2242}
2243
2244OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
2246}
2247
2248OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
2250}
2251
2252OpFoldResult tosa::ConcatShapeOp::fold(FoldAdaptor adaptor) {
2253 return concatShapeFold(this);
2254}
2255
2256OpFoldResult tosa::SliceShapeOp::fold(FoldAdaptor adaptor) {
2257 return sliceShapeFold(this);
2258}
return success()
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
#define REDUCE_FOLDER(OP)
OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, bool foldDenseValues=false)
static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy, bool foldDenseValues=false)
OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op)
static FailureOr< int64_t > getSingleI64From1ElementTensor(Value v)
OpFoldResult binaryFold(Op *op)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
OpFoldResult unaryShapeFold(Op *op)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool signsDiffer(const APInt &a, const APInt &b)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
MLIRContext * getContext() const
Definition Builders.h:56
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition Traits.cpp:24
DynamicAPInt round(const Fraction &f)
Definition Fraction.h:136
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
Definition TosaOps.h:153
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp, PatternRewriter &rewriter) const override
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
LogicalResult matchAndRewrite(tosa::CastOp castOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...