MLIR 23.0.0git
TosaCanonicalizations.cpp
Go to the documentation of this file.
1//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// TOSA canonicalization patterns and folders.
11//
12//===----------------------------------------------------------------------===//
13
19#include "mlir/Dialect/Traits.h"
22#include "mlir/IR/Matchers.h"
26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28
29#include <functional>
30
31using namespace mlir;
32using namespace mlir::tosa;
33
34namespace {
35OpFoldResult foldToInputIfTypeMatches(Type typeRef, Value input) {
36 return input.getType() == typeRef ? OpFoldResult(input) : OpFoldResult{};
37}
38} // namespace
39
40//===----------------------------------------------------------------------===//
41// Operator Canonicalizers.
42//===----------------------------------------------------------------------===//
43
44//===----------------------------------------------------------------------===//
45// Tensor Data Engine Operators.
46//===----------------------------------------------------------------------===//
47
48// Check that the zero point of the tensor and padding operations are aligned.
49static bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
50 // Check that padConst is a constant value and a scalar tensor
51 DenseElementsAttr padConstAttr;
52 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
53 (padConstAttr.size() != 1)) {
54 return false;
55 }
56
57 // Check that floating point pad is zero
58 if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
59 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
60 return padConstVal == 0.0f;
61 }
62
63 // Check that the zp and padConst align for the integer (quantized) case
64 if (auto padConstIntAttr =
65 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
67 // Check that zp is a constant value and a scalar tensor
68 if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) {
69 return false;
70 }
71
72 // Check equality
73 int64_t zpVal = (*zpAttr.begin()).getSExtValue();
74 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
75 return zpVal == padConstVal;
76 }
77
78 // Bail-out on unsupported type
79 return false;
80}
81
82namespace {
83template <typename OpTy>
84struct PoolPadFoldAdaptor;
85
86template <>
87struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
88 using OpTy = tosa::MaxPool2dOp;
89 static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
90 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
91 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
92 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
93 return false;
94 return true;
95 }
96 static bool checkPadConstCompliance(OpTy, Value padConst) {
97 // Check that padConst is a constant value and a scalar tensor
98 DenseElementsAttr padConstAttr;
99 if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
100 padConstAttr.size() != 1) {
101 return false;
102 }
103
104 // Pad needs to be in the minimum value to be able to merge
105 if (auto padConstFpAttr =
106 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
107 const APFloat padConstVal = *padConstFpAttr.begin();
108 const APFloat lowestVal =
109 APFloat::getLargest(padConstVal.getSemantics(), true);
110 return padConstVal == lowestVal;
111 }
112 if (auto padConstIntAttr =
113 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
114 const APInt padConstVal = *padConstIntAttr.begin();
115 const unsigned int bitWidth = padConstVal.getBitWidth();
116 const APInt lowestVal =
117 padConstIntAttr.getElementType().isUnsignedInteger()
118 ? APInt::getZero(bitWidth)
119 : APInt::getSignedMinValue(bitWidth);
120 return padConstVal == lowestVal;
121 }
122
123 // Bail-out on unsupported type
124 return false;
125 }
126 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
127 Value padInput, ArrayRef<int64_t> newPad) {
128 rewriter.replaceOpWithNewOp<tosa::MaxPool2dOp>(
129 op, op.getType(), padInput, op.getKernel(), op.getStride(),
130 rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode());
131 }
132};
133
134template <typename OpTy>
135struct ConvPadFoldAdaptor {
136 static bool checkKernelCompliance(OpTy, const ArrayRef<int64_t>) {
137 return true;
138 }
139 static bool checkPadConstCompliance(OpTy op, Value padConst) {
140 return checkMatchingPadConstAndZp(padConst, op.getInputZp());
141 }
142 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
143 Value padInput, ArrayRef<int64_t> newPad) {
144 rewriter.replaceOpWithNewOp<OpTy>(
145 op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
146 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
147 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
148 }
149};
150
151// Pattern attempts to fold a `tosa.pad` operator to a following tensor
152// operation like `tosa.conv2d` by merging the padding associated with the
153// pad operator directly to the implicit padding of the tensor operation.
154// This helps eliminate the explicit padding operator if unused.
155template <typename OpTy, typename AdaptorTy>
156struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
157 using OpRewritePattern<OpTy>::OpRewritePattern;
158
159 LogicalResult matchAndRewrite(OpTy tensorOp,
160 PatternRewriter &rewriter) const override {
161 // Check producer is a tosa::PadOp
162 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
163 if (!padOp)
164 return rewriter.notifyMatchFailure(tensorOp,
165 "Producer must be a tosa::PadOp.");
166
167 // Validate that tensor operation has sane padding
168 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
169 if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right
170 return rewriter.notifyMatchFailure(
171 tensorOp, "Tensor operation padding shall have 4 elements.");
172
173 // Validate tosa::PadOp padding
174 DenseIntElementsAttr padOpPadding;
175 if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) {
176 return rewriter.notifyMatchFailure(
177 tensorOp,
178 "The `padding` input specified on the tosa::PadOp must be constant.");
179 }
180 // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
181 // C_after
182 if (padOpPadding.size() != 8)
183 return rewriter.notifyMatchFailure(tensorOp,
184 "Pad padding should have 8 elements.");
185 int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue();
186 int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue();
187 int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue();
188 int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue();
189 int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue();
190 int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue();
191 int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue();
192 int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue();
193
194 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
195 return rewriter.notifyMatchFailure(
196 tensorOp, "Folding padding in N or C dimensions is not supported.");
197
198 // Fold padding from Pad into the tensor operation
199 // 4 elements - pad_top, pad_bottom, pad_left, pad_right
200 SmallVector<int64_t> foldedPad(tensorOpPad.size());
201 foldedPad[0] = padHBefore + tensorOpPad[0];
202 foldedPad[1] = padHAfter + tensorOpPad[1];
203 foldedPad[2] = padWBefore + tensorOpPad[2];
204 foldedPad[3] = padWAfter + tensorOpPad[3];
205
206 // Check kernel related restrictions
207 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
208 return rewriter.notifyMatchFailure(
209 tensorOp, "Padding size not aligned with kernel restrictions.");
210 }
211
212 // Check padding constant restrictions
213 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
214 return rewriter.notifyMatchFailure(
215 tensorOp,
216 "Padding constant is not aligned with operator zero-point.");
217 }
218
219 // Check that padding doesn't grow more than 8K level (8192) for now
220 if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) {
221 return rewriter.notifyMatchFailure(
222 tensorOp, "Padding size more than the 8K level limit.");
223 }
224
225 // Create operator
226 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
227 foldedPad);
228
229 return success();
230 }
231};
232} // namespace
233
234void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
235 MLIRContext *context) {
236 results.add<
237 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
238 context);
239}
240
241void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
242 MLIRContext *context) {
243 results.add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
244 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
245 context);
246}
247
248struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
250
251 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
252 PatternRewriter &rewriter) const override {
253 Value input = op.getInput();
254 Value output = op.getOutput();
255 ShapedType inputType = llvm::cast<ShapedType>(input.getType());
256 ShapedType outputType = llvm::cast<ShapedType>(output.getType());
257
258 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
259 return failure();
260 }
261
262 // If the output and input shapes are 1x1, then this is a no op.
263 ArrayRef<int64_t> outputShape = outputType.getShape();
264 if (outputShape[1] != 1 || outputShape[2] != 1) {
265 return failure();
266 }
267
268 ArrayRef<int64_t> inputShape = inputType.getShape();
269 if (inputShape[1] != 1 || inputShape[2] != 1) {
270 return failure();
271 }
272
273 rewriter.replaceOp(op, input);
274 return success();
275 }
276};
277
278void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
279 MLIRContext *context) {
280 results.add<MaxPool2dIsNoOp,
281 FoldPadToTensorOp<tosa::MaxPool2dOp,
282 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
283 context);
284}
285
286//===----------------------------------------------------------------------===//
287// Data Layout / Memory Reinterpretation.
288//===----------------------------------------------------------------------===//
289
290struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
291 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
292
293 LogicalResult matchAndRewrite(tosa::ConcatOp op,
294 PatternRewriter &rewriter) const override {
295 if (op.getInput1().size() != 1)
296 return failure();
297 if (op.getInput1().front().getType() != op.getType()) {
298 rewriter
299 .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
300 op.getInput1().front())
301 .getResult();
302 return success();
303 }
304
305 rewriter.replaceOp(op, op.getInput1().front());
306 return success();
307 }
308};
309
310void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
311 MLIRContext *context) {
312 results.add<ConcatOptimization>(context);
313}
314
315LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
316 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
317 if (!notOp)
318 return failure();
319 rewriter.modifyOpInPlace(op, [&]() {
320 op.getOperation()->setOperands(
321 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
322 });
323 return success();
324}
325
327 : public OpRewritePattern<tosa::TransposeOp> {
329
330 LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
331 PatternRewriter &rewriter) const override {
332 // Input is also TransposeOp - transpose(transpose(A)).
333 auto innerTranspose =
334 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
335 if (!innerTranspose)
336 return rewriter.notifyMatchFailure(transposeOp,
337 "input must be transpose operation");
338
339 const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
340 const llvm::ArrayRef<int32_t> innerTransposePerms =
341 innerTranspose.getPerms();
342
343 if (transposePerms.size() != innerTransposePerms.size())
344 return rewriter.notifyMatchFailure(
345 transposeOp,
346 "transpose and inner transpose perms sizes must be equal");
347 if (transposePerms.empty())
348 return rewriter.notifyMatchFailure(
349 transposeOp, "transpose perms sizes must be positive");
350
351 // Consolidate transposes into one transpose.
352 SmallVector<int32_t> perms(transposePerms.size());
353 for (int i = 0, s = transposePerms.size(); i < s; ++i)
354 perms[i] = innerTransposePerms[transposePerms[i]];
355
356 rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
357 transposeOp, transposeOp.getResult().getType(),
358 innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
359
360 return success();
361 }
362};
363
364// Determines the case when tosa.transpose is a tosa.reshape operation.
365struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
367
368 LogicalResult matchAndRewrite(tosa::TransposeOp op,
369 PatternRewriter &rewriter) const override {
370 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
371 return rewriter.notifyMatchFailure(
372 op, "Src is from transpose, can compose transposes");
373
374 Value result = op.getResult();
375 for (Operation *subop : result.getUsers()) {
376 if (isa_and_nonnull<tosa::TransposeOp>(subop))
377 return rewriter.notifyMatchFailure(
378 op, "Dest is used by transpose, can compose transposes");
379 }
380
381 auto input = op.getInput1();
382 auto inputTy = llvm::cast<ShapedType>(input.getType());
383 if (!inputTy.hasRank())
384 return rewriter.notifyMatchFailure(op, "Unranked input.");
385
386 int64_t numDynDims = 0;
387 for (int i = 0; i < inputTy.getRank(); ++i)
388 if (inputTy.isDynamicDim(i))
389 numDynDims++;
390
391 if (numDynDims > 1)
392 return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
393
394 const llvm::ArrayRef<int32_t> permValues = op.getPerms();
395
396 SmallVector<int64_t> nonZeroPerms;
397 nonZeroPerms.reserve(permValues.size());
398 for (auto idx : permValues) {
399 auto sz = inputTy.getDimSize(idx);
400 if (sz != 1)
401 nonZeroPerms.push_back(idx);
402 }
403
404 for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
405 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
406 return rewriter.notifyMatchFailure(op,
407 "Transpose changes memory layout.");
408
409 SmallVector<int64_t> newShape;
410 newShape.reserve(inputTy.getRank());
411 for (int i = 0, s = inputTy.getRank(); i < s; ++i)
412 newShape.push_back(inputTy.getDimSize(permValues[i]));
413
414 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
415 op, op.getType(), op.getInput1(),
416 getTosaConstShape(rewriter, op.getLoc(), newShape));
417 return success();
418 }
419};
420
421void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
422 MLIRContext *context) {
423 results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
424}
425
426struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
428
429 LogicalResult matchAndRewrite(tosa::ClampOp op,
430 PatternRewriter &rewriter) const override {
431 Value input = op.getInput();
432 auto inputType = llvm::cast<ShapedType>(op.getInput().getType());
433 auto inputElementType = inputType.getElementType();
434
435 if (isa<FloatType>(inputElementType)) {
436 // Unlike integer types, floating point types can represent infinity.
437 const auto minClamp =
438 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
439 const auto maxClamp =
440 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
441 const bool isMin = minClamp.isNegInfinity();
442 const bool isMax = maxClamp.isInfinity();
443
444 if (isMin && isMax) {
445 rewriter.replaceOp(op, input);
446 return success();
447 }
448 return failure();
449 }
450
451 // i1 types are boolean in TOSA
452 const bool isBoolean = inputElementType.isInteger(1);
453 if (inputElementType.isUnsignedInteger() || isBoolean) {
454 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
455 .getValue()
456 .getZExtValue();
457 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
458 .getValue()
459 .getZExtValue();
460
461 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
462 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
463 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
464
465 if (minClamp <= intMin && maxClamp >= intMax) {
466 rewriter.replaceOp(op, input);
467 return success();
468 }
469 return failure();
470 }
471
472 if (llvm::isa<IntegerType>(inputElementType)) {
473 const int64_t minClamp =
474 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
475 const int64_t maxClamp =
476 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
477
478 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
479 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
480 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
481
482 if (minClamp <= intMin && maxClamp >= intMax) {
483 rewriter.replaceOp(op, input);
484 return success();
485 }
486 return failure();
487 }
488
489 return failure();
490 }
491};
492
493// Attempts the following transformation:
494//
495// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
496// tensor X the following identity holds:
497//
498// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
499//
500// subject to the following valid NaN propagation semantics:
501// --------------------------------------------
502// | OUTER CLAMP | INNER CLAMP | RESULT MODE |
503// |-------------|--------------|-------------|
504// | PROPAGATE | PROPAGATE | PROPAGATE |
505// | PROPAGATE | IGNORE | IGNORE |
506// | IGNORE | PROPAGATE | INVALID |
507// | IGNORE | IGNORE | IGNORE |
508// |------------------------------------------|
509
510struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
511 using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
512
513 // Helper structure to describe the range of a clamp operation.
514 template <typename T>
515 struct ClampRange {
516 ClampRange(const T &start, const T &end) : start(start), end(end) {}
519
520 // Helper function to determine if two Clamp ranges intersect.
521 bool intersects(const ClampRange<T> &otherRange) {
522 return start < otherRange.end && otherRange.start < end;
523 }
524 };
525
526 LogicalResult matchAndRewrite(tosa::ClampOp op,
527 PatternRewriter &rewriter) const override {
528 Value input = op.getInput();
529
530 // Check the input to the CLAMP op is itself a CLAMP.
531 auto clampOp = input.getDefiningOp<tosa::ClampOp>();
532 if (!clampOp)
533 return failure();
534
535 // Check we have a valid NaN propagation combination.
536 const auto opNanMode = op.getNanMode();
537 const auto clampNanMode = clampOp.getNanMode();
538 if (opNanMode == NanPropagationMode::IGNORE &&
539 clampNanMode == NanPropagationMode::PROPAGATE)
540 return failure();
541
542 auto maxValAttr = op.getMaxValAttr();
543 auto minValAttr = op.getMinValAttr();
544 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
545 auto clampOpMinValAttr = clampOp.getMinValAttr();
546
547 auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
548 if (auto quantType =
549 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
550 inputEType = getStorageElementTypeFromQuantized(quantType);
551 }
552
553 Attribute newMinValAttr, newMaxValAttr;
554 if (mlir::isa<FloatType>(inputEType)) {
555 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
556 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
557 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
558 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
559
560 // Check we have intersecting ranges.
561 const auto opMinFloat = floatMinValAttr.getValue();
562 const auto opMaxFloat = floatMaxValAttr.getValue();
563 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
564 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
565 ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
566 ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
567 clampOpMaxFloat);
568 if (!opRangeFloatRange.intersects(clampRangeFloatRange))
569 return failure();
570
571 // Run the transformation.
572 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
573 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
574 newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
575 newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
576 } else {
577 assert(mlir::isa<IntegerType>(inputEType));
578 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
579 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
580 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
581 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
582
583 if (inputEType.isUnsignedInteger()) {
584 // Check we have intersecting ranges.
585 const auto opMinInt = intMinValAttr.getUInt();
586 const auto opMaxInt = intMaxValAttr.getUInt();
587 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
588 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
589 ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
590 ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
591 clampOpMaxInt);
592 if (!opRangeIntRange.intersects(clampRangeIntRange))
593 return failure();
594
595 // Run the transformation.
596 auto newMinVal = std::max(opMinInt, clampOpMinInt);
597 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
598 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
599 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
600 } else {
601 // Check we have intersecting ranges.
602 const auto opMinInt = intMinValAttr.getInt();
603 const auto opMaxInt = intMaxValAttr.getInt();
604 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
605 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
606 ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
607 ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
608 clampOpMaxInt);
609 if (!opRangeIntRange.intersects(clampRangeIntRange))
610 return failure();
611
612 // Run the transformation.
613 auto newMinVal = std::max(opMinInt, clampOpMinInt);
614 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
615 newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
616 newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
617 }
618 }
619
620 auto newMode = (opNanMode != clampNanMode)
621 ? tosa::NanPropagationMode::IGNORE
622 : opNanMode;
623
624 auto newModeAttr =
625 NanPropagationModeAttr::get(rewriter.getContext(), newMode);
626
627 rewriter.replaceOpWithNewOp<tosa::ClampOp>(
628 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
629 newModeAttr);
630 return success();
631 }
632};
633
634void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
635 MLIRContext *context) {
636 results.add<ClampIsNoOp>(context);
637 results.add<ClampClampOptimization>(context);
638}
639
640struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
641 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
642
643 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
644 PatternRewriter &rewriter) const override {
645 Value sliceInput = sliceOp.getInput1();
646 auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
647 if (!concatOp)
648 return rewriter.notifyMatchFailure(
649 sliceOp, "slice input must be concat operation");
650
651 OperandRange inputs = concatOp.getInput1();
652 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
653 if (!concatType || !concatType.hasStaticShape())
654 return rewriter.notifyMatchFailure(
655 sliceOp, "slice input must be a static ranked tensor");
656 int32_t axis = concatOp.getAxis();
657
658 DenseElementsAttr startElems;
659 DenseElementsAttr sizeElems;
660
661 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
662 return rewriter.notifyMatchFailure(
663 sliceOp, "start of slice must be a static ranked shape");
664
665 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
666 return rewriter.notifyMatchFailure(
667 sliceOp, "size of slice must be a static ranked shape");
668
669 llvm::SmallVector<int64_t> sliceStarts =
670 llvm::to_vector(startElems.getValues<int64_t>());
671 llvm::SmallVector<int64_t> sliceSizes =
672 llvm::to_vector(sizeElems.getValues<int64_t>());
673
674 // Validate slice on the concatenated axis. Slicing along this
675 // axis should span only one of the inputs to the concatenate
676 // operation.
677 std::optional<Value> replaceWithSlice;
678 for (auto input : inputs) {
679 auto inputType = dyn_cast<RankedTensorType>(input.getType());
680 if (!inputType || !inputType.hasStaticShape())
681 return rewriter.notifyMatchFailure(
682 sliceOp, "concat input must be a static ranked tensor");
683
684 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
685 inputType.getDimSize(axis)) {
686 auto start_op =
687 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceStarts);
688 auto size_op =
689 getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
690 replaceWithSlice =
691 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
692 input, start_op, size_op)
693 .getResult();
694 break;
695 }
696 sliceStarts[axis] -= inputType.getDimSize(axis);
697 }
698
699 if (!replaceWithSlice)
700 return rewriter.notifyMatchFailure(
701 sliceOp, "corresponding concat input not found for slice");
702
703 rewriter.replaceOp(sliceOp, replaceWithSlice.value());
704 return success();
705 }
706};
707
708struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
709 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
710
711 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
712 PatternRewriter &rewriter) const override {
713 Value sliceInput = sliceOp.getInput1();
714
715 // Check if producer is a PadOp
716 auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
717 if (!padOp)
718 return rewriter.notifyMatchFailure(sliceOp,
719 "slice input must be a pad operation");
720
721 // Check PadOp has a single consumer
722 if (!padOp->hasOneUse())
723 return rewriter.notifyMatchFailure(sliceOp,
724 "pad shall have a single consumer");
725
726 // Check input is statically ranked
727 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
728 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
729 if (!inputTy || !padTy || !inputTy.hasRank())
730 return rewriter.notifyMatchFailure(sliceOp,
731 "slice input must be a ranked tensor");
732
733 // Validate and extract tosa::PadOp padding
734 DenseIntElementsAttr paddingElems;
735 if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
736 return rewriter.notifyMatchFailure(
737 sliceOp,
738 "`padding` input specified on the tosa::PadOp must be constant.");
739 }
740 llvm::SmallVector<int64_t> padPaddings =
741 llvm::to_vector(paddingElems.getValues<int64_t>());
742
743 // Extract slice parameters
744 DenseElementsAttr startElems;
745 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
746 return rewriter.notifyMatchFailure(
747 sliceOp, "start of slice must be a static ranked shape");
748 llvm::SmallVector<int64_t> sliceStarts =
749 llvm::to_vector(startElems.getValues<int64_t>());
750
751 DenseElementsAttr sizeElems;
752 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
753 return rewriter.notifyMatchFailure(
754 sliceOp, "size of slice must be a static ranked shape");
755 llvm::SmallVector<int64_t> sliceSizes =
756 llvm::to_vector(sizeElems.getValues<int64_t>());
757
758 // Check if dynamic dimensions are sliced
759 const int64_t rank = inputTy.getRank();
760 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
761 const bool isDimDynamic = inputTy.isDynamicDim(i);
762 const bool isDimSliced =
763 (sliceStarts[i] != 0) || (sliceSizes[i] != kInferableDimSize);
764
765 return isDimDynamic && isDimSliced;
766 })) {
767 return rewriter.notifyMatchFailure(
768 sliceOp, "axis that are sliced shall be statically known.");
769 }
770
771 // Update the parameters
772 llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
773 llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
774 llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
775 bool updated = false;
776
777 for (int64_t i = 0; i < rank; ++i) {
778 const int64_t padLo = padPaddings[i * 2];
779 const int64_t padHi = padPaddings[i * 2 + 1];
780 const int64_t sliceStart = sliceStarts[i];
781 const int64_t sliceSize = sliceSizes[i];
782 const int64_t sliceEnd = sliceStart + sliceSize;
783
784 // If dimension is dynamic pass-through
785 if (inputTy.isDynamicDim(i)) {
786 newPadPaddings[i * 2] = padLo;
787 newPadPaddings[i * 2 + 1] = padHi;
788 newSliceStarts[i] = sliceStart;
789 continue;
790 }
791
792 // Handle static dimensions
793 const int64_t dimSize = inputTy.getShape()[i];
794 const int64_t dimTotal = padLo + dimSize + padHi;
795
796 // Check slice within bounds
797 if (sliceStart < 0 || sliceEnd > dimTotal)
798 return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
799
800 // Compute updated slice start parameter
801 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
802 newSliceStarts[i] = newSliceStart;
803 updated |= newSliceStart != sliceStart;
804
805 // Compute updated pad parameters
806 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
807 const int64_t newPadHi =
808 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
809 newPadPaddings[i * 2] = newPadLo;
810 newPadPaddings[i * 2 + 1] = newPadHi;
811 updated |= (newPadLo != padLo) || (newPadHi != padHi);
812
813 // Calculate new pad output shape
814 newPadShape[i] =
815 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
816 }
817
818 // Check that we actually need to proceed with the rewrite
819 if (!updated)
820 return rewriter.notifyMatchFailure(
821 sliceOp, "terminate condition; nothing to rewrite");
822
823 // Create a PadOp with updated padding
824 auto newPaddingsOp =
825 getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
826 auto newPadTy =
827 RankedTensorType::get(newPadShape, inputTy.getElementType());
828 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
829 padOp.getInput1(), newPaddingsOp,
830 padOp.getPadConst());
831
832 // Update SliceOp and point to new PadOp
833 auto newStartOp =
834 getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
835 rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
836 newPadOp.getResult(), newStartOp,
837 sliceOp.getSize());
838
839 return success();
840 }
841};
842
843// Update size operand of tosa.slice if size has dynamic dims but corresponding
844// output dim is static
846 : public OpRewritePattern<tosa::SliceOp> {
847 using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
848
849 LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
850 PatternRewriter &rewriter) const override {
851 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
852 if (!resultType.hasRank())
853 return rewriter.notifyMatchFailure(sliceOp, "output must be ranked");
854
855 ElementsAttr sizeElems;
856 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
857 return rewriter.notifyMatchFailure(
858 sliceOp, "size of slice must be a static ranked shape");
859 }
860
861 llvm::SmallVector<int64_t> sliceSizes =
862 llvm::to_vector(sizeElems.getValues<int64_t>());
863
864 bool replaceSliceSize{false};
865 // if size op has kInferableDimSize indicating dynamic shape but
866 // corresponding dim on the output is statically known, update size to match
867 // with known output dim shape
868 for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
869 if (size == kInferableDimSize && !resultType.isDynamicDim(index)) {
870 sliceSizes[index] = resultType.getDimSize(index);
871 replaceSliceSize = true;
872 }
873 }
874
875 if (!replaceSliceSize) {
876 return rewriter.notifyMatchFailure(
877 sliceOp, "no dimension of size of slice is dynamic that resolves "
878 "to static output shape");
879 }
880
881 auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
882 auto newSliceOp =
883 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(),
884 sliceOp.getInput1(), sliceOp.getStart(), size_op);
885
886 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
887 return success();
888 }
889};
890
891void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
892 MLIRContext *context) {
893 results.add<ConcatSliceOptimization, PadSliceOptimization,
894 SliceDynamicSizeCanonicalization>(context);
895}
896
897struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
898 using OpRewritePattern<tosa::CastOp>::OpRewritePattern;
899
900 LogicalResult matchAndRewrite(tosa::CastOp castOp,
901 PatternRewriter &rewriter) const override {
902 const Value castInput = castOp.getInput();
903 auto innerCastOp = castInput.getDefiningOp<tosa::CastOp>();
904 if (!innerCastOp)
905 return rewriter.notifyMatchFailure(castOp,
906 "input must be cast operation");
907
908 const Value innerCastInput = innerCastOp.getInput();
909
910 const auto innerInputType =
911 llvm::cast<ShapedType>(innerCastInput.getType());
912 const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
913 const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
914
915 const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType,
916 outerOutputType};
917
918 if (llvm::any_of(types, [](const ShapedType type) {
919 const auto elemTy = type.getElementType();
920 // Support a specific set of floating point types since we need to be
921 // careful in not introducing unsupported type combinations
922 return !(elemTy.isInteger() ||
923 llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
924 Float16Type, Float32Type>(elemTy));
925 }))
926 return rewriter.notifyMatchFailure(
927 castOp, "only integer and f32, f16, bf16, f8E4M3FN, f8E5M2 types are "
928 "supported");
929
930 if (llvm::isa<Float8E5M2Type>(innerInputType.getElementType()) &&
931 llvm::isa<Float8E4M3FNType>(outerOutputType.getElementType())) {
932 return rewriter.notifyMatchFailure(
933 castOp, "avoid introducing f8E5M2 -> f8E4M3FN casts which are not "
934 "legal in TOSA");
935 }
936
937 if (llvm::isa<Float8E4M3FNType>(innerInputType.getElementType()) &&
938 llvm::isa<Float8E5M2Type>(outerOutputType.getElementType())) {
939 return rewriter.notifyMatchFailure(
940 castOp, "avoid introducing f8E4M3FN -> f8E5M2 casts which are not "
941 "legal in TOSA");
942 }
943
944 // Check that the cast we're considering for removal is non-narrowing
945 if (isNarrowingCast(innerInputType, innerOutputType))
946 return rewriter.notifyMatchFailure(castOp,
947 "inner cast operation is narrowing");
948
949 rewriter.replaceOpWithNewOp<tosa::CastOp>(castOp, outerOutputType,
950 innerCastInput);
951
952 return success();
953 }
954
955 bool supportsNaN(const llvm::fltSemantics &semantics) const {
956 return semantics.nonFiniteBehavior !=
957 llvm::fltNonfiniteBehavior::FiniteOnly;
958 }
959
960 bool supportsInf(const llvm::fltSemantics &semantics) const {
961 return semantics.nonFiniteBehavior == llvm::fltNonfiniteBehavior::IEEE754;
962 }
963
964 bool isNarrowingCast(const ShapedType inType,
965 const ShapedType outType) const {
966
967 if (inType.getElementType().isInteger() &&
968 outType.getElementType().isInteger()) {
969
970 const auto inTypeSignedness =
971 cast<IntegerType>(inType.getElementType()).getSignedness();
972 const auto outTypeSignedness =
973 cast<IntegerType>(outType.getElementType()).getSignedness();
974
975 return (inTypeSignedness != outTypeSignedness ||
976 inType.getElementTypeBitWidth() >
977 outType.getElementTypeBitWidth());
978 }
979
980 if (inType.getElementType().isFloat() &&
981 outType.getElementType().isFloat()) {
982
983 FloatType inElemTy = cast<FloatType>(inType.getElementType());
984 FloatType outElemTy = cast<FloatType>(outType.getElementType());
985 llvm::fltSemantics inTypeSemantics = inElemTy.getFloatSemantics();
986 llvm::fltSemantics outTypeSemantics = outElemTy.getFloatSemantics();
987
988 // If the list of supported types needs to be updated in the future, the
989 // check down below will need to be revised, for example to account for
990 // unsigned floating point types, or types that use negative zero as the
991 // representation for NaN.
992 [[maybe_unused]] const auto isSupported = [](Type elemType) {
993 return llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
994 Float16Type, Float32Type>(elemType);
995 };
996
997 assert(isSupported(inElemTy) &&
998 "unsupported input element type in isNarrowingCast");
999 assert(isSupported(outElemTy) &&
1000 "unsupported output element type in isNarrowingCast");
1001
1002 return (
1003 inTypeSemantics.maxExponent > outTypeSemantics.maxExponent ||
1004 inTypeSemantics.minExponent < outTypeSemantics.minExponent ||
1005 inTypeSemantics.precision > outTypeSemantics.precision ||
1006 (supportsNaN(inTypeSemantics) && !supportsNaN(outTypeSemantics)) ||
1007 (supportsInf(inTypeSemantics) && !supportsInf(outTypeSemantics)));
1008 }
1009
1010 // While some cases of int -> float casts can be non-narrowing, consider
1011 // them narrowing for the purposes of this optimization
1012 return true;
1013 }
1014};
1015
1016void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1017 MLIRContext *context) {
1018 results.add<NonNarrowingCastsOptimization>(context);
1019}
1020
1022 : public OpRewritePattern<tosa::CastToBlockScaledOp> {
1023 using OpRewritePattern<tosa::CastToBlockScaledOp>::OpRewritePattern;
1024
1025 LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp,
1026 PatternRewriter &rewriter) const override {
1027 const Value castToBlockScaledInput = castToBlockScaledOp.getInputData();
1028 auto castFromBlockScaledOp =
1029 castToBlockScaledInput.getDefiningOp<tosa::CastFromBlockScaledOp>();
1030 if (!castFromBlockScaledOp)
1031 return rewriter.notifyMatchFailure(
1032 castToBlockScaledOp,
1033 "input must be cast_from_block_scaled operation");
1034
1035 const Value innerData = castFromBlockScaledOp.getInputData();
1036 const Value innerScale = castFromBlockScaledOp.getInputScale();
1037 const auto innerDataTy = llvm::cast<ShapedType>(innerData.getType());
1038 const auto innerScaleTy = llvm::cast<ShapedType>(innerScale.getType());
1039
1040 const Value outerData = castToBlockScaledOp.getOutputData();
1041 const Value outerScale = castToBlockScaledOp.getOutputScale();
1042 const auto outerDataTy = llvm::cast<ShapedType>(outerData.getType());
1043 const auto outerScaleTy = llvm::cast<ShapedType>(outerScale.getType());
1044
1045 if (innerDataTy != outerDataTy || innerScaleTy != outerScaleTy) {
1046 return rewriter.notifyMatchFailure(
1047 castToBlockScaledOp,
1048 "inputs types to cast_from_block_scaled operation must match output "
1049 "types to cast_to_block_scaled");
1050 }
1051
1052 if (castFromBlockScaledOp.getBlockSize() !=
1053 castToBlockScaledOp.getBlockSize()) {
1054 return rewriter.notifyMatchFailure(
1055 castToBlockScaledOp, "block sizes for cast_from_block_scaled and "
1056 "cast_to_block_scaled must match");
1057 }
1058
1059 rewriter.replaceOp(castToBlockScaledOp, {innerData, innerScale});
1060
1061 return success();
1062 }
1063};
1064
1065void CastToBlockScaledOp::getCanonicalizationPatterns(
1066 RewritePatternSet &results, MLIRContext *context) {
1067 results.add<CancellingBlockScaledCastsOptimization>(context);
1068}
1069
1070//===----------------------------------------------------------------------===//
1071// Operator Folders.
1072//===----------------------------------------------------------------------===//
1073
1074template <typename Folder>
1075static DenseElementsAttr
1077 bool foldDenseValues = false) {
1078 if (!lhs || !rhs)
1079 return {};
1080
1081 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1082 return {};
1083
1084 const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
1085 const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
1086 if (lETy != rETy)
1087 return {};
1088
1089 if (lhs.isSplat() && rhs.isSplat()) {
1090 if (isa<FloatType>(lETy)) {
1091 const APFloat l = lhs.getSplatValue<APFloat>();
1092 const APFloat r = rhs.getSplatValue<APFloat>();
1093 const auto maybeResult = Folder::fold(l, r);
1094 if (failed(maybeResult))
1095 return {};
1096 return DenseElementsAttr::get(returnTy, maybeResult.value());
1097 }
1098
1099 if (const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
1100 const APInt l = lhs.getSplatValue<APInt>();
1101 const APInt r = rhs.getSplatValue<APInt>();
1102 const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
1103 if (failed(maybeResult))
1104 return {};
1105 return DenseElementsAttr::get(returnTy, maybeResult.value());
1106 }
1107 }
1108
1109 if (foldDenseValues) {
1110 assert(lETy.isIntOrIndex() &&
1111 "Only integer types are currently supported.");
1112 SmallVector<APInt> resultValues;
1113 for (auto [l, r] :
1114 llvm::zip(lhs.getValues<APInt>(), rhs.getValues<APInt>())) {
1115 const auto maybeResult = Folder::fold(l, r, false);
1116 if (failed(maybeResult))
1117 return {};
1118 resultValues.push_back(maybeResult.value());
1119 }
1120 return DenseElementsAttr::get(returnTy, resultValues);
1121 }
1122
1123 return {};
1124}
1125
1126template <typename Folder>
1127static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy,
1128 bool foldDenseValues = false) {
1129 if (!val)
1130 return {};
1131
1132 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1133 return {};
1134
1135 const auto vETy = llvm::cast<ShapedType>(val.getType()).getElementType();
1136
1137 if (val.isSplat()) {
1138 if (const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
1139 const APInt v = val.getSplatValue<APInt>();
1140 const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
1141 if (failed(maybeResult))
1142 return {};
1143 return DenseElementsAttr::get(returnTy, maybeResult.value());
1144 }
1145 }
1146
1147 if (foldDenseValues) {
1148 mlir::Type elemTy = val.getElementType();
1149 if (elemTy.isIntOrIndex()) {
1150 SmallVector<APInt> resultValues;
1151 for (auto const &v : val.getValues<APInt>()) {
1152 const auto maybeResult = Folder::fold(v, false);
1153 if (failed(maybeResult))
1154 return {};
1155 resultValues.push_back(maybeResult.value());
1156 }
1157 return DenseElementsAttr::get(returnTy, resultValues);
1158 }
1159 }
1160
1161 // Folding arbitrarily sized tensor operations is not supported
1162 return {};
1163}
1164
1165static FailureOr<int64_t> getSingleI64From1ElementTensor(Value v) {
1166 DenseIntElementsAttr dense{};
1167 if (!matchPattern(v, m_Constant(&dense)))
1168 return failure();
1169
1170 assert(dense.isSplat());
1171 APInt a = dense.getSplatValue<APInt>();
1172 return a.getSExtValue();
1173}
1174
1176 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1177 const bool isUnsigned) {
1178 bool overflow;
1179 const APInt result =
1180 isUnsigned ? lhs.uadd_ov(rhs, overflow) : lhs.sadd_ov(rhs, overflow);
1181 if (overflow)
1182 return failure();
1183 return result;
1184 }
1185
1186 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1187 return lhs + rhs;
1188 }
1189};
1190
1192 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1193 const bool isUnsigned) {
1194 bool overflow;
1195 const APInt result =
1196 isUnsigned ? lhs.usub_ov(rhs, overflow) : lhs.ssub_ov(rhs, overflow);
1197 if (overflow)
1198 return failure();
1199 return result;
1200 }
1201
1202 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1203 return lhs - rhs;
1204 }
1205};
1206
1208 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1209 const bool isUnsigned) {
1210
1211 const unsigned originalWidth = lhs.getBitWidth();
1212
1213 // Check same type
1214 if (lhs.getBitWidth() != rhs.getBitWidth()) {
1215 return failure();
1216 }
1217
1218 // If either is `0`
1219 if (lhs == 0 || rhs == 0)
1220 return APInt::getZero(originalWidth);
1221
1222 bool overflow = false;
1223 APInt const result =
1224 isUnsigned ? lhs.umul_ov(rhs, overflow) : lhs.smul_ov(rhs, overflow);
1225
1226 if (overflow)
1227 return failure();
1228
1229 return result.trunc(originalWidth);
1230 }
1231
1232 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1233 return lhs * rhs;
1234 }
1235};
1236
1237static bool signsDiffer(const APInt &a, const APInt &b) {
1238 return a.isNegative() != b.isNegative();
1239}
1240
1241template <bool Ceil>
1243 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1244 bool isUnsigned) {
1245 if (lhs.getBitWidth() != rhs.getBitWidth())
1246 return failure();
1247 if (rhs.isZero())
1248 return failure();
1249
1250 if (isUnsigned) {
1251 APInt q{};
1252 APInt r{};
1253 APInt::udivrem(lhs, rhs, q, r);
1254 if (!r.isZero() && Ceil) {
1255 return q + 1;
1256 }
1257 return q;
1258 }
1259
1260 // Signed: start from trunc-toward-zero, then adjust to ceil.
1261 bool overflow{false};
1262 APInt const q = lhs.sdiv_ov(rhs, overflow);
1263 if (overflow)
1264 return failure();
1265 APInt const r = lhs.srem(rhs);
1266
1267 if (Ceil && !r.isZero() && !signsDiffer(lhs, rhs)) {
1268 // Same sign => exact quotient is positive; trunc is below ceil =>
1269 // increment q.
1270 return q + 1;
1271 }
1272 return q;
1273 }
1274
1275 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1276 return lhs / rhs;
1277 }
1278};
1279
1281 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1282 bool isUnsigned) {
1283 if (lhs.getBitWidth() != rhs.getBitWidth())
1284 return failure();
1285 if (lhs.isNegative() || (!rhs.isStrictlyPositive()))
1286 return failure();
1287
1288 if (isUnsigned) {
1289 return lhs.urem(rhs);
1290 }
1291
1292 return lhs.srem(rhs);
1293 }
1294
1295 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1296 auto t = lhs;
1297 auto const r = t.mod(rhs);
1298 if (llvm::APFloatBase::opStatus::opOK == r) {
1299 return t;
1300 }
1301 return failure();
1302 }
1303};
1304
1306 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1307 bool isUnsigned) {
1308 if (lhs.getBitWidth() != rhs.getBitWidth())
1309 return failure();
1310 return lhs.getSExtValue() >= rhs.getSExtValue() ? lhs : rhs;
1311 }
1312
1313 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1314 return lhs >= rhs ? lhs : rhs;
1315 }
1316};
1317
1319 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1320 bool isUnsigned) {
1321 if (lhs.getBitWidth() != rhs.getBitWidth())
1322 return failure();
1323 return lhs.getSExtValue() <= rhs.getSExtValue() ? lhs : rhs;
1324 }
1325
1326 static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
1327 return lhs <= rhs ? lhs : rhs;
1328 }
1329};
1330
1332 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1333 auto const numBits = value.getBitWidth();
1334 if (isUnsigned) {
1335 auto const zextv = value.getZExtValue();
1336 if (zextv >= numBits)
1337 return failure();
1338 return APInt::getOneBitSet(numBits, zextv);
1339 }
1340 auto const sextv = value.getSExtValue();
1341 if (sextv < 0 || sextv >= numBits || (value.isNegative()))
1342 return failure();
1343 return APInt::getOneBitSet(numBits, sextv);
1344 }
1345};
1346
1348 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1349 if (!value.isStrictlyPositive())
1350 return failure();
1351 return APInt(/*numBits=*/value.getBitWidth(), value.ceilLogBase2());
1352 }
1353};
1354
1356 static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
1357 if (!value.isStrictlyPositive())
1358 return failure();
1359 return APInt(/*numBits=*/value.getBitWidth(), value.logBase2());
1360 }
1361};
1362
1364 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1365 const bool isUnsigned) {
1366 return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs));
1367 }
1368
1369 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1370 return APInt(1, lhs > rhs);
1371 }
1372};
1373
1375 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1376 const bool isUnsigned) {
1377 return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs));
1378 }
1379
1380 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1381 return APInt(1, lhs >= rhs);
1382 }
1383};
1384
1386 static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
1387 const bool isUnsigned) {
1388 return APInt(1, lhs == rhs);
1389 }
1390
1391 static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
1392 return APInt(1, lhs == rhs);
1393 }
1394};
1395
1396static bool isSplatZero(Type elemType, DenseElementsAttr val) {
1397 if (llvm::isa<FloatType>(elemType))
1398 return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
1399 if (llvm::isa<IntegerType>(elemType))
1400 return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
1401 return false;
1402}
1403
1404static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
1405 if (llvm::isa<FloatType>(elemType))
1406 return val && val.isSplat() &&
1407 val.getSplatValue<APFloat>().isExactlyValue(1.0);
1408 if (llvm::isa<IntegerType>(elemType)) {
1409 const int64_t shifted = 1LL << shift;
1410 return val && val.isSplat() &&
1411 val.getSplatValue<APInt>().getSExtValue() == shifted;
1412 }
1413 return false;
1414}
1415
1416OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1417 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1418 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1419 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1420 if (!lhsTy || !rhsTy || !resultTy)
1421 return {};
1422
1423 // Cannot create an ElementsAttr from non-int/float/index types
1424 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1425 !rhsTy.getElementType().isIntOrIndexOrFloat())
1426 return {};
1427
1428 auto resultETy = resultTy.getElementType();
1429 auto lhsAttr =
1430 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1431 auto rhsAttr =
1432 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1433
1434 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1435 lhsTy.getShape(), rhsTy.getShape());
1436 if (isBroadcastable && lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1437 return getInput1();
1438 if (isBroadcastable && rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
1439 return getInput2();
1440
1441 if (!lhsAttr || !rhsAttr)
1442 return {};
1443
1444 return binaryFolder<AddFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1445}
1446
1447OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1448 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
1449 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1450 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1451 !outputTy.hasStaticShape())
1452 return {};
1453
1454 const Type outputElementTy = getElementTypeOrSelf(outputTy);
1455 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
1456 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1457 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1458 return DenseElementsAttr::get(outputTy, zero);
1459 }
1460
1461 return {};
1462}
1463
1464OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1465 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1466 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1467 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1468 if (!lhsTy || !rhsTy || !resultTy)
1469 return {};
1470 if (lhsTy.getElementType() != rhsTy.getElementType())
1471 return {};
1472
1473 // IntDivOp inputs must be integer type, no need to check for quantized
1474 // type
1475 auto resultETy = resultTy.getElementType();
1476 auto lhsAttr =
1477 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1478 auto rhsAttr =
1479 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1480 if (lhsAttr && lhsAttr.isSplat()) {
1481 if (llvm::isa<IntegerType>(resultETy) && resultTy.hasStaticShape() &&
1482 lhsAttr.getSplatValue<APInt>().isZero())
1483 return lhsAttr.resizeSplat(resultTy);
1484 }
1485
1486 if (rhsAttr && rhsAttr.isSplat()) {
1487 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1488 lhsTy.getShape(), rhsTy.getShape());
1489 if (isBroadcastable && lhsTy == resultTy &&
1490 llvm::isa<IntegerType>(resultETy) &&
1491 rhsAttr.getSplatValue<APInt>().isOne())
1492 return getInput1();
1493 }
1494
1495 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1496 llvm::isa<IntegerType>(resultETy)) {
1497 APInt l = lhsAttr.getSplatValue<APInt>();
1498 APInt r = rhsAttr.getSplatValue<APInt>();
1499 if (!r.isZero()) {
1500 auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
1501 auto const result =
1502 DivFoldAdaptor</*Ceil*/ false>::fold(l, r, intTy.isUnsigned());
1503 if (failed(result))
1504 return {};
1505 return DenseElementsAttr::get(resultTy, result.value());
1506 }
1507 }
1508
1509 return {};
1510}
1511
1512namespace {
1513// calculate lhs * rhs >> shift according to TOSA Spec
1514// return nullopt if result is not in range of int32_t when shift > 0
1515std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
1516 unsigned bitwidth) {
1517 bool overflow = false;
1518 APInt result = lhs.sext(64).smul_ov(rhs.sext(64), overflow);
1519
1520 if (overflow)
1521 return std::nullopt;
1522
1523 if (shift > 0) {
1524 auto round = APInt(64, 1) << (shift - 1);
1525 result += round;
1526 result.ashrInPlace(shift);
1527 // REQUIRE(product >= minimum_s<i32_t>() && product <=
1528 // maximum_s<i32_t>())
1529 if (!(result.getSExtValue() >= INT32_MIN &&
1530 result.getSExtValue() <= INT32_MAX)) {
1531 // REQUIRE failed
1532 return std::nullopt;
1533 }
1534 }
1535
1536 return result.trunc(bitwidth);
1537}
1538
1539DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
1540 RankedTensorType ty, int32_t shift) {
1541 if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
1542 if (llvm::isa<IntegerType>(ty.getElementType())) {
1543 APInt l = lhs.getSplatValue<APInt>();
1544 APInt r = rhs.getSplatValue<APInt>();
1545
1546 if (shift == 0) {
1547 return DenseElementsAttr::get(ty, l * r);
1548 }
1549
1550 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1551 const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
1552 if (!result)
1553 return {};
1554 return DenseElementsAttr::get(ty, result.value());
1555 }
1556
1557 if (llvm::isa<FloatType>(ty.getElementType())) {
1558 APFloat l = lhs.getSplatValue<APFloat>();
1559 APFloat r = rhs.getSplatValue<APFloat>();
1560 APFloat result = l * r;
1561 return DenseElementsAttr::get(ty, result);
1562 }
1563 }
1564
1565 return {};
1566}
1567} // namespace
1568
1569OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1570 auto lhs = getInput1();
1571 auto rhs = getInput2();
1572 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
1573 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
1574 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1575 if (!lhsTy || !rhsTy || !resultTy)
1576 return {};
1577
1578 auto resultETy = resultTy.getElementType();
1579 auto lhsAttr =
1580 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1581 auto rhsAttr =
1582 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1583
1584 // Result right shift on i32_t data type only. For simplification,
1585 // synthesize a zero shift for other data type.
1586 int32_t shift = 0;
1587 if (resultETy.isInteger(32)) {
1588 ElementsAttr shift_elem;
1589 if (getShift().getImpl()) {
1590 if (!matchPattern(getShift(), m_Constant(&shift_elem)))
1591 // cannot be folded when the shift value is unknown.
1592 return {};
1593 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1594 }
1595 }
1596
1597 if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr) &&
1598 resultTy.hasStaticShape())
1599 // constant values can only be resized if resulting type is static
1600 return lhsAttr.resizeSplat(resultTy);
1601 if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr) &&
1602 resultTy.hasStaticShape())
1603 return rhsAttr.resizeSplat(resultTy);
1604
1605 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1606 lhsTy.getShape(), rhsTy.getShape());
1607 if (isBroadcastable && rhsTy == resultTy &&
1608 isSplatOne(resultETy, lhsAttr, shift))
1609 return rhs;
1610 if (isBroadcastable && lhsTy == resultTy &&
1611 isSplatOne(resultETy, rhsAttr, shift))
1612 return lhs;
1613
1614 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1615}
1616
1617OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1618 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1619 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
1620 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
1621 if (!lhsTy || !rhsTy || !resultTy)
1622 return {};
1623
1624 // Cannot create an ElementsAttr from non-int/float/index types
1625 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1626 !rhsTy.getElementType().isIntOrIndexOrFloat())
1627 return {};
1628
1629 auto resultETy = resultTy.getElementType();
1630 auto lhsAttr =
1631 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1632 auto rhsAttr =
1633 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1634
1635 const bool isBroadcastable = OpTrait::util::staticallyKnownBroadcastable(
1636 lhsTy.getShape(), rhsTy.getShape());
1637 if (isBroadcastable && lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
1638 return getInput1();
1639
1640 if (!lhsAttr || !rhsAttr)
1641 return {};
1642
1643 return binaryFolder<SubFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1644}
1645
1646OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1647 auto resultTy = llvm::cast<ShapedType>(getType());
1648 auto lhsAttr =
1649 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1650 auto rhsAttr =
1651 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1652
1653 if (!lhsAttr || !rhsAttr)
1654 return {};
1655
1656 return binaryFolder<GreaterFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1657}
1658
1659OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1660 auto resultTy = llvm::cast<ShapedType>(getType());
1661 auto lhsAttr =
1662 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1663 auto rhsAttr =
1664 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1665
1666 if (!lhsAttr || !rhsAttr)
1667 return {};
1668
1669 return binaryFolder<GreaterEqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1670}
1671
1672OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1673 auto resultTy = llvm::cast<ShapedType>(getType());
1674 auto lhsAttr =
1675 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1676 auto rhsAttr =
1677 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1678 Value lhs = getInput1();
1679 Value rhs = getInput2();
1680 auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
1681
1682 // If we are comparing an integer value to itself it is always true. We
1683 // can not do this with float due to float values.
1684 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy.hasRank() &&
1685 resultTy.hasStaticShape() && lhs == rhs) {
1686 return DenseElementsAttr::get(resultTy, true);
1687 }
1688
1689 if (!lhsAttr || !rhsAttr)
1690 return {};
1691
1692 return binaryFolder<EqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
1693}
1694
1695OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1696 if (getInput().getType() == getType())
1697 return getInput();
1698
1699 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1700 if (!operand)
1701 return {};
1702
1703 auto inTy = llvm::cast<ShapedType>(getInput().getType());
1704 auto outTy = llvm::cast<ShapedType>(getType());
1705 if (!outTy.hasRank() || !outTy.hasStaticShape())
1706 return {};
1707 auto inETy = inTy.getElementType();
1708 auto outETy = outTy.getElementType();
1709
1710 if (operand.isSplat()) {
1711 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1712 bool overflow;
1713 auto splatVal = operand.getSplatValue<APFloat>();
1714 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1715 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1716 &overflow);
1717 return SplatElementsAttr::get(outTy, splatVal);
1718 }
1719
1720 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1721 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1722 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1723 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1724 llvm::RoundingMode::NearestTiesToEven);
1725 return SplatElementsAttr::get(outTy, splatVal);
1726 }
1727
1728 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1729 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1730 auto intVal = APSInt(
1731 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1732 auto floatVal = operand.getSplatValue<APFloat>();
1733 bool exact;
1734 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1735 &exact);
1736 return SplatElementsAttr::get(outTy, intVal);
1737 }
1738
1739 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1740 const auto inIntType = llvm::cast<IntegerType>(inETy);
1741 auto unsignIn = inIntType.isUnsignedInteger();
1742 bool trunc =
1743 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1744 auto intVal = operand.getSplatValue<APInt>();
1745 auto bitwidth = outETy.getIntOrFloatBitWidth();
1746
1747 // i1 types are boolean in TOSA
1748 if (outETy.isInteger(1)) {
1749 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1750 } else if (trunc) {
1751 intVal = intVal.trunc(bitwidth);
1752 } else if (unsignIn || inIntType.isInteger(1)) {
1753 intVal = intVal.zext(bitwidth);
1754 } else {
1755 intVal = intVal.sext(bitwidth);
1756 }
1757
1758 return SplatElementsAttr::get(outTy, intVal);
1759 }
1760 }
1761
1762 return {};
1763}
1764
1765OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1766
1767OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValuesAttr(); }
1768
1769#define REDUCE_FOLDER(OP) \
1770 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1771 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1772 if (!inputTy.hasRank()) \
1773 return {}; \
1774 if (inputTy != getType()) \
1775 return {}; \
1776 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1777 return getInput(); \
1778 return {}; \
1779 }
1780
1781REDUCE_FOLDER(ReduceAllOp)
1782REDUCE_FOLDER(ReduceAnyOp)
1783REDUCE_FOLDER(ReduceMaxOp)
1784REDUCE_FOLDER(ReduceMinOp)
1785REDUCE_FOLDER(ReduceProductOp)
1786REDUCE_FOLDER(ReduceSumOp)
1787#undef REDUCE_FOLDER
1788
1789OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1790 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1791 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1792
1793 if (!inputTy || !outputTy)
1794 return {};
1795
1796 // Fold when the input and output types are the same. This is only safe
1797 // when there is at most 1 dynamic dimension. For 2 or more dynamic
1798 // dimensions, there may still be a productive reshape.
1799 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1800 return getInput1();
1801
1802 // reshape(reshape(x)) -> reshape(x)
1803 if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1804 getInput1().getDefiningOp())) {
1805 getInput1Mutable().assign(reshapeOp.getInput1());
1806 return getResult();
1807 }
1808
1809 // Cannot create an ElementsAttr from non-int/float/index types
1810 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1811 return {};
1812
1813 // reshape(const(x)) -> const(reshape-attr(x))
1814 if (auto operand =
1815 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1816 // Constants must have static shape.
1817 if (!outputTy.hasStaticShape())
1818 return {};
1819
1820 // Okay to duplicate splat constants.
1821 if (operand.isSplat())
1822 return SplatElementsAttr::get(outputTy,
1823 operand.getSplatValue<Attribute>());
1824
1825 // Don't duplicate other constants.
1826 if (!getInput1().hasOneUse())
1827 return {};
1828
1830 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeVec))
1831 return {};
1832
1833 return operand.reshape(
1834 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1835 }
1836
1837 return {};
1838}
1839
1840OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1841 // If the pad is all zeros we can fold this operation away.
1842 if (adaptor.getPadding() && getInput1().getType() == getType()) {
1843 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1844 if (densePad && densePad.isSplat() &&
1845 densePad.getSplatValue<APInt>().isZero()) {
1846 return getInput1();
1847 }
1848 }
1849
1850 return {};
1851}
1852
1853// Fold away cases where a tosa.resize operation returns a copy
1854// of the input image.
1855OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1856 auto scaleAttr =
1857 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1858 auto offsetAttr =
1859 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1860 auto borderAttr =
1861 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1862 if (!scaleAttr || !offsetAttr || !borderAttr) {
1863 return {};
1864 }
1865
1866 auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
1867 auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
1868 auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
1869 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1870 return {};
1871 }
1872
1873 // Check unit scaling.
1874 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1875 return {};
1876 }
1877
1878 // There should be no offset.
1879 if (offset[0] != 0 || offset[1] != 0) {
1880 return {};
1881 }
1882
1883 // There should be no border.
1884 if (border[0] != 0 || border[1] != 0) {
1885 return {};
1886 }
1887
1888 return foldToInputIfTypeMatches(getType(), getInput());
1889}
1890
1891OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1892 auto operand = getInput1();
1893 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1894 auto axis = getAxis();
1895 // If the dim-length is 1, or reversing axis is unit-dim, also a no-op.
1896 const bool isSplatInput =
1897 llvm::isa_and_nonnull<SplatElementsAttr>(adaptor.getInput1());
1898 if (!operandTy.hasRank() ||
1899 (!isSplatInput && operandTy.getDimSize(axis) != 1))
1900 return {};
1901 return foldToInputIfTypeMatches(getType(), operand);
1902}
1903
1904OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1905 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1906 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
1907
1908 if (!inputTy || !outputTy)
1909 return {};
1910
1911 if (inputTy == outputTy && inputTy.hasStaticShape())
1912 return getInput1();
1913
1914 // Check if this is a no-op slice (starts at 0 and size matches input)
1915
1916 DenseElementsAttr startElems;
1917 if (!matchPattern(getStart(), m_Constant(&startElems)))
1918 return {};
1919
1920 // Check if all start values are zero
1921 bool startIsZeros =
1922 llvm::all_of(startElems.getValues<APInt>(),
1923 [](const APInt &val) { return val.isZero(); });
1924
1925 if (startIsZeros) {
1926
1927 // Check if size matches input shape
1928 DenseElementsAttr sizeElems;
1929 if (!matchPattern(getSize(), m_Constant(&sizeElems)))
1930 return {};
1931
1932 auto inputShape = inputTy.getShape();
1933 auto sizeValues = sizeElems.getValues<APInt>();
1934
1935 bool sizeMatchesInput = true;
1936 for (const auto &[i, sizeVal] : llvm::enumerate(sizeValues)) {
1937 int64_t size = sizeVal.getSExtValue();
1938
1939 if (inputTy.isDynamicDim(i)) {
1940 // For dynamic dimensions, check for kInferableDimSize indicating full
1941 // dimension is sliced
1942 if (size != kInferableDimSize) {
1943 sizeMatchesInput = false;
1944 break;
1945 }
1946 } else {
1947 // For static dimensions, check that size must match exactly or be
1948 // kInferableDimSize indicating full dimension is sliced
1949 if (size != kInferableDimSize && size != inputShape[i]) {
1950 sizeMatchesInput = false;
1951 break;
1952 }
1953 }
1954 }
1955
1956 if (sizeMatchesInput)
1957 return getInput1();
1958 }
1959
1960 // The following checks require the input to be a constant
1961 if (!adaptor.getInput1())
1962 return {};
1963
1964 // Cannot create an ElementsAttr from non-int/float/index types
1965 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1966 !outputTy.getElementType().isIntOrIndexOrFloat())
1967 return {};
1968
1969 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1970 if (operand.isSplat() && outputTy.hasStaticShape()) {
1971 return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
1972 }
1973
1974 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1975 outputTy.getNumElements() == 1) {
1976 llvm::SmallVector<uint64_t> indices =
1977 llvm::to_vector(startElems.getValues<uint64_t>());
1978 if (auto values = operand.tryGetValues<Attribute>())
1979 return SplatElementsAttr::get(outputTy, (*values)[indices]);
1980 }
1981
1982 return {};
1983}
1984
1985OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1986 const Value pred = getPred();
1987 const Value onTrue = getOnTrue();
1988 const Value onFalse = getOnFalse();
1989
1990 const auto predTy = llvm::dyn_cast<RankedTensorType>(pred.getType());
1991 const auto onTrueTy = llvm::dyn_cast<RankedTensorType>(onTrue.getType());
1992 const auto onFalseTy = llvm::dyn_cast<RankedTensorType>(onFalse.getType());
1993 if (!predTy || !onTrueTy || !onFalseTy)
1994 return {};
1995
1996 const Type resultTy = getType();
1997
1998 const ArrayRef<int64_t> predShape = predTy.getShape();
1999 const ArrayRef<int64_t> onTrueShape = onTrueTy.getShape();
2000
2001 if (onTrue == onFalse && onTrueTy == resultTy &&
2002 OpTrait::util::staticallyKnownBroadcastable(predShape, onTrueShape))
2003 return onTrue;
2004
2005 auto predicate =
2006 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
2007 if (!predicate)
2008 return {};
2009 if (!predicate.isSplat())
2010 return {};
2011
2012 const bool predicateValue = predicate.getSplatValue<APInt>().getBoolValue();
2013
2014 SmallVector<SmallVector<int64_t>, 3> shapes;
2015 shapes.emplace_back(predShape);
2016 shapes.emplace_back(onTrueShape);
2017 shapes.emplace_back(onFalseTy.getShape());
2018 const bool isBroadcastable =
2020
2021 if (predicateValue == true && onTrueTy == resultTy && isBroadcastable)
2022 return onTrue;
2023 if (predicateValue == false && onFalseTy == resultTy && isBroadcastable)
2024 return onFalse;
2025 return {};
2026}
2027
2028OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
2029 if (getInput1().getType() == getType()) {
2030 if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
2031 adaptor.getMultiples())) {
2032 if (multiples.isSplat() &&
2033 multiples.getSplatValue<APInt>().getSExtValue() == 1)
2034 return getInput1();
2035 if (auto int_array_attr =
2036 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
2037 if (llvm::all_of(int_array_attr.getValues<APInt>(),
2038 [](APInt v) { return v.getSExtValue() == 1; }))
2039 return getInput1();
2040 }
2041 }
2042 }
2043 return {};
2044}
2045
2046OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
2047 auto resultTy = llvm::cast<ShapedType>(getType());
2048
2049 // Transposing splat values just means reshaping.
2050 if (auto input =
2051 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
2052 if (input.isSplat() && resultTy.hasRank() && resultTy.hasStaticShape() &&
2053 input.getType().getElementType() == resultTy.getElementType())
2054 return input.reshape(resultTy);
2055 }
2056
2057 // Transpose is not the identity transpose.
2058 const llvm::ArrayRef<int32_t> perms = getPerms();
2059
2060 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
2061 return {};
2062
2063 return foldToInputIfTypeMatches(getType(), getInput1());
2064}
2065
2066OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
2067 // Element-wise negate(negate(x)) = x
2068 // iff all zero points are constant 0
2069 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
2070 if (!definingOp) {
2071 // defining op of input1 is not a negate, cannot fold
2072 return {};
2073 }
2074
2075 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2076 failed(maybeIZp) || *maybeIZp != 0) {
2077 // input1 zero point is not constant 0, cannot fold
2078 return {};
2079 }
2080 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2081 failed(maybeOZp) || *maybeOZp != 0) {
2082 // output zero point is not constant 0, cannot fold
2083 return {};
2084 }
2085 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
2086 failed(maybeIZp) || *maybeIZp != 0) {
2087 // definingOp's input1 zero point is not constant 0, cannot fold
2088 return {};
2089 }
2090 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
2091 failed(maybeOZp) || *maybeOZp != 0) {
2092 // definingOp's output zero point is not constant 0, cannot fold
2093 return {};
2094 }
2095
2096 return foldToInputIfTypeMatches(getType(), definingOp.getInput1());
2097}
2098
2099OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
2100 auto input = getInput1();
2101 // Element-wise abs(abs(x)) = abs(x)
2102 if (input.getDefiningOp<tosa::AbsOp>())
2103 return foldToInputIfTypeMatches(getType(), input);
2104
2105 return {};
2106}
2107
2108OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
2109 // Fold consecutive concats on the same axis into a single op.
2110 // Keep track of the operands so we are able to construct a new concat
2111 // later. Conservatively assume that we double the number of operands when
2112 // folding
2113 SmallVector<Value, 8> concatOperands;
2114 concatOperands.reserve(2 * getNumOperands());
2115
2116 // Find all operands that are foldable concats
2117 bool foundFoldableConcat = false;
2118 for (Value operand : getOperands()) {
2119 concatOperands.emplace_back(operand);
2120
2121 auto producer = operand.getDefiningOp<ConcatOp>();
2122 if (!producer)
2123 continue;
2124
2125 // Not foldable if axes are not the same
2126 if (getAxis() != producer.getAxis())
2127 continue;
2128
2129 // Replace the original operand with all incoming operands
2130 foundFoldableConcat = true;
2131 concatOperands.pop_back();
2132 llvm::append_range(concatOperands, producer->getOperands());
2133 }
2134
2135 if (!foundFoldableConcat)
2136 return {};
2137
2138 getOperation()->setOperands(concatOperands);
2139 return getResult();
2140}
2141
2142OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
2143 auto input = adaptor.getInput1();
2144
2145 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
2146 // Fold splat inputs only.
2147 if (!inputAttr || !inputAttr.isSplat())
2148 return {};
2149
2150 auto shapeType = llvm::cast<ShapedType>(getType());
2151 if (!shapeType.hasRank() || !shapeType.hasStaticShape())
2152 return {};
2153 if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
2154 auto floatVal = inputAttr.getSplatValue<APFloat>();
2155 return DenseElementsAttr::get(shapeType,
2156 ReciprocalOp::calcOneElement(floatVal));
2157 }
2158
2159 return {};
2160}
2161
2162template <typename Op, typename OpFoldAdaptor>
2164 auto input1ConstShape =
2165 dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
2166 if (!input1ConstShape)
2167 return {};
2168
2169 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2170
2171 return unaryFolder<OpFoldAdaptor>(input1Attr, input1Attr.getType(),
2172 /*foldDenseValues=*/true);
2173}
2174
2175template <typename Op, typename OpFoldAdaptor>
2177 auto input1ConstShape =
2178 dyn_cast<tosa::ConstShapeOp>(op->getInput1().getDefiningOp());
2179 auto input2ConstShape =
2180 dyn_cast<tosa::ConstShapeOp>(op->getInput2().getDefiningOp());
2181 if (!input1ConstShape || !input2ConstShape)
2182 return {};
2183
2184 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2185 const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
2186
2187 return binaryFolder<OpFoldAdaptor>(input1Attr, input2Attr,
2188 input1Attr.getType(),
2189 /*foldDenseValues=*/true);
2190}
2191
2192OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
2193 const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().getType());
2194 if (!inputTy || !inputTy.hasRank())
2195 return {};
2196 const int32_t axis = getAxis();
2197 const int64_t dimSize = inputTy.getDimSize(axis);
2198 if (ShapedType::isDynamic(dimSize))
2199 return {};
2200
2201 OpBuilder builder(getContext());
2202 const auto resultAttrTy =
2203 RankedTensorType::get(/*rank=*/1, builder.getIndexType());
2204 return DenseElementsAttr::get(resultAttrTy, dimSize);
2205}
2206
2207OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op) {
2208 auto const inputs = op->getInput();
2209
2210 if (inputs.empty())
2211 return {};
2212
2213 SmallVector<APInt> concatDims;
2214 concatDims.reserve(/*max elem*/ 64);
2215 for (auto const &v : inputs) {
2216 auto vConstShape = dyn_cast<tosa::ConstShapeOp>(v.getDefiningOp());
2217 if (!vConstShape)
2218 return {};
2219
2220 const auto vAttr = cast<DenseElementsAttr>(vConstShape.getValues());
2221 assert(vAttr);
2222
2223 auto const vAttrVals = vAttr.getValues<APInt>();
2224 for (auto const &v : vAttrVals) {
2225 concatDims.push_back(v);
2226 }
2227 }
2228
2229 auto *ctx = op->getContext();
2230 assert(ctx != nullptr && "ctx is nullptr");
2231 auto const rankedTy = RankedTensorType::get(
2232 {static_cast<int64_t>(concatDims.size())}, IndexType::get(ctx));
2233
2234 return DenseElementsAttr::get(rankedTy, concatDims);
2235}
2236
2237OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op) {
2238 auto const input1 = op->getInput();
2239 auto const input2 = op->getStart();
2240 auto const input3 = op->getSize();
2241
2242 auto input1ConstShape = dyn_cast<tosa::ConstShapeOp>(input1.getDefiningOp());
2243
2244 if (!input1ConstShape)
2245 return {};
2246
2247 auto const input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2248 if (!input1Attr)
2249 return {};
2250
2251 auto const input1Vals = input1Attr.getValues<APInt>();
2252 auto const totalInput1 = input1Vals.size();
2253
2254 auto const start = getSingleI64From1ElementTensor(input2);
2255 auto const size = getSingleI64From1ElementTensor(input3);
2256
2257 if (failed(start) || failed(size))
2258 return {};
2259
2260 auto const startV = static_cast<int32_t>(start.value());
2261 auto const sizeV = static_cast<int32_t>(size.value());
2262
2263 if ((sizeV <= 0) || (startV < 0) ||
2264 (static_cast<size_t>(startV + sizeV) > totalInput1))
2265 return {};
2266
2267 SmallVector<APInt> sliceOfInput;
2268 sliceOfInput.reserve(totalInput1);
2269
2270 for (auto i = startV; i < (startV + sizeV); i++) {
2271 sliceOfInput.push_back(input1Vals[i]);
2272 }
2273
2274 auto *ctx = op->getContext();
2275 assert(ctx != nullptr && "ctx is nullptr");
2276
2277 auto const rankedTy = RankedTensorType::get(
2278 {static_cast<int64_t>(sliceOfInput.size())}, IndexType::get(ctx));
2279
2280 return DenseElementsAttr::get(rankedTy, sliceOfInput);
2281}
2282
2283OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
2285}
2286
2287OpFoldResult tosa::SubShapeOp::fold(FoldAdaptor adaptor) {
2289}
2290
2291OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
2293}
2294
2295OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
2296 return binaryFold<DivCeilShapeOp, DivFoldAdaptor</*Ceil*/ true>>(this);
2297}
2298
2299OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
2300 return binaryFold<DivFloorShapeOp, DivFoldAdaptor</*Ceil*/ false>>(this);
2301}
2302
2303OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
2305}
2306
2307OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
2309}
2310
2311OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
2313}
2314
2315OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
2317}
2318
2319OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
2321}
2322
2323OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
2325}
2326
2327OpFoldResult tosa::ConcatShapeOp::fold(FoldAdaptor adaptor) {
2328 return concatShapeFold(this);
2329}
2330
2331OpFoldResult tosa::SliceShapeOp::fold(FoldAdaptor adaptor) {
2332 return sliceShapeFold(this);
2333}
return success()
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
#define REDUCE_FOLDER(OP)
OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, bool foldDenseValues=false)
static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy, bool foldDenseValues=false)
OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op)
static FailureOr< int64_t > getSingleI64From1ElementTensor(Value v)
OpFoldResult binaryFold(Op *op)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
OpFoldResult unaryShapeFold(Op *op)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool signsDiffer(const APInt &a, const APInt &b)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
MLIRContext * getContext() const
Definition Builders.h:56
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition Traits.cpp:24
DynamicAPInt round(const Fraction &f)
Definition Fraction.h:136
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
Definition TosaOps.h:153
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp, PatternRewriter &rewriter) const override
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
bool isNarrowingCast(const ShapedType inType, const ShapedType outType) const
LogicalResult matchAndRewrite(tosa::CastOp castOp, PatternRewriter &rewriter) const override
bool supportsInf(const llvm::fltSemantics &semantics) const
bool supportsNaN(const llvm::fltSemantics &semantics) const
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...