26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
47 (padConstAttr.
size() != 1)) {
52 if (
auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
53 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
54 return padConstVal == 0.0f;
58 if (
auto padConstIntAttr =
59 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
68 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
69 return zpVal == padConstVal;
77template <
typename OpTy>
78struct PoolPadFoldAdaptor;
81struct PoolPadFoldAdaptor<
tosa::MaxPool2dOp> {
82 using OpTy = tosa::MaxPool2dOp;
83 static bool checkKernelCompliance(OpTy op,
const ArrayRef<int64_t> newPad) {
84 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
85 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
86 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
90 static bool checkPadConstCompliance(OpTy, Value padConst) {
92 DenseElementsAttr padConstAttr;
94 padConstAttr.
size() != 1) {
99 if (
auto padConstFpAttr =
100 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
101 const APFloat padConstVal = *padConstFpAttr.begin();
102 const APFloat lowestVal =
103 APFloat::getLargest(padConstVal.getSemantics(),
true);
104 return padConstVal == lowestVal;
106 if (
auto padConstIntAttr =
107 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
108 const APInt padConstVal = *padConstIntAttr.begin();
109 const unsigned int bitWidth = padConstVal.getBitWidth();
110 const APInt lowestVal =
111 padConstIntAttr.getElementType().isUnsignedInteger()
112 ? APInt::getZero(bitWidth)
113 : APInt::getSignedMinValue(bitWidth);
114 return padConstVal == lowestVal;
120 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
121 Value padInput, ArrayRef<int64_t> newPad) {
123 op, op.getType(), padInput, op.getKernel(), op.getStride(),
128template <
typename OpTy>
129struct ConvPadFoldAdaptor {
130 static bool checkKernelCompliance(OpTy,
const ArrayRef<int64_t>) {
133 static bool checkPadConstCompliance(OpTy op, Value padConst) {
136 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
137 Value padInput, ArrayRef<int64_t> newPad) {
139 op, op.getResult().
getType(), padInput, op.getWeight(), op.getBias(),
140 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
141 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
149template <
typename OpTy,
typename AdaptorTy>
151 using OpRewritePattern<OpTy>::OpRewritePattern;
153 LogicalResult matchAndRewrite(OpTy tensorOp,
154 PatternRewriter &rewriter)
const override {
156 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
159 "Producer must be a tosa::PadOp.");
162 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
163 if (tensorOpPad.size() != 4)
165 tensorOp,
"Tensor operation padding shall have 4 elements.");
168 DenseIntElementsAttr padOpPadding;
172 "The `padding` input specified on the tosa::PadOp must be constant.");
176 if (padOpPadding.size() != 8)
178 "Pad padding should have 8 elements.");
179 int64_t padNBefore = (*(padOpPadding.
begin() + 0)).getLimitedValue();
180 int64_t padNAfter = (*(padOpPadding.
begin() + 1)).getLimitedValue();
181 int64_t padHBefore = (*(padOpPadding.
begin() + 2)).getLimitedValue();
182 int64_t padHAfter = (*(padOpPadding.
begin() + 3)).getLimitedValue();
183 int64_t padWBefore = (*(padOpPadding.
begin() + 4)).getLimitedValue();
184 int64_t padWAfter = (*(padOpPadding.
begin() + 5)).getLimitedValue();
185 int64_t padCBefore = (*(padOpPadding.
begin() + 6)).getLimitedValue();
186 int64_t padCAfter = (*(padOpPadding.
begin() + 7)).getLimitedValue();
188 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
190 tensorOp,
"Folding padding in N or C dimensions is not supported.");
194 SmallVector<int64_t> foldedPad(tensorOpPad.size());
195 foldedPad[0] = padHBefore + tensorOpPad[0];
196 foldedPad[1] = padHAfter + tensorOpPad[1];
197 foldedPad[2] = padWBefore + tensorOpPad[2];
198 foldedPad[3] = padWAfter + tensorOpPad[3];
201 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
203 tensorOp,
"Padding size not aligned with kernel restrictions.");
207 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
210 "Padding constant is not aligned with operator zero-point.");
214 if (llvm::any_of(foldedPad, [](int64_t padVal) {
return padVal > 8192; })) {
216 tensorOp,
"Padding size more than the 8K level limit.");
220 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
231 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
237 results.
add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
238 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
247 Value input = op.getInput();
248 Value output = op.getOutput();
249 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
250 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
252 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
258 if (outputShape[1] != 1 || outputShape[2] != 1) {
263 if (inputShape[1] != 1 || inputShape[2] != 1) {
275 FoldPadToTensorOp<tosa::MaxPool2dOp,
276 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
289 if (op.getInput1().size() != 1)
291 if (op.getInput1().front().getType() != op.getType()) {
294 op.getInput1().front())
299 rewriter.
replaceOp(op, op.getInput1().front());
309LogicalResult SelectOp::canonicalize(SelectOp op,
PatternRewriter &rewriter) {
310 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
314 op.getOperation()->setOperands(
315 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
327 auto innerTranspose =
328 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
331 "input must be transpose operation");
335 innerTranspose.getPerms();
337 if (transposePerms.size() != innerTransposePerms.size())
340 "transpose and inner transpose perms sizes must be equal");
341 if (transposePerms.empty())
343 transposeOp,
"transpose perms sizes must be positive");
347 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
348 perms[i] = innerTransposePerms[transposePerms[i]];
351 transposeOp, transposeOp.getResult().
getType(),
364 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
366 op,
"Src is from transpose, can compose transposes");
370 if (isa_and_nonnull<tosa::TransposeOp>(subop))
372 op,
"Dest is used by transpose, can compose transposes");
375 auto input = op.getInput1();
376 auto inputTy = llvm::cast<ShapedType>(input.getType());
377 if (!inputTy.hasRank())
381 for (
int i = 0; i < inputTy.getRank(); ++i)
382 if (inputTy.isDynamicDim(i))
391 nonZeroPerms.reserve(permValues.size());
392 for (
auto idx : permValues) {
393 auto sz = inputTy.getDimSize(idx);
395 nonZeroPerms.push_back(idx);
398 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
399 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
401 "Transpose changes memory layout.");
404 newShape.reserve(inputTy.getRank());
405 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
406 newShape.push_back(inputTy.getDimSize(permValues[i]));
409 op, op.getType(), op.getInput1(),
417 results.
add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
425 Value input = op.getInput();
426 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
427 auto inputElementType = inputType.getElementType();
429 if (isa<FloatType>(inputElementType)) {
431 const auto minClamp =
432 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
433 const auto maxClamp =
434 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
435 const bool isMin = minClamp.isNegInfinity();
436 const bool isMax = maxClamp.isInfinity();
438 if (isMin && isMax) {
446 const bool isBoolean = inputElementType.isInteger(1);
447 if (inputElementType.isUnsignedInteger() || isBoolean) {
448 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
451 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
455 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
456 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
457 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
459 if (minClamp <= intMin && maxClamp >= intMax) {
466 if (llvm::isa<IntegerType>(inputElementType)) {
468 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
470 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
472 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
473 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
474 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
476 if (minClamp <= intMin && maxClamp >= intMax) {
508 template <
typename T>
522 Value input = op.getInput();
530 const auto opNanMode = op.getNanMode();
531 const auto clampNanMode = clampOp.getNanMode();
532 if (opNanMode == NanPropagationMode::IGNORE &&
533 clampNanMode == NanPropagationMode::PROPAGATE)
536 auto maxValAttr = op.getMaxValAttr();
537 auto minValAttr = op.getMinValAttr();
538 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
539 auto clampOpMinValAttr = clampOp.getMinValAttr();
541 auto inputEType = llvm::cast<ShapedType>(input.
getType()).getElementType();
543 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
548 if (mlir::isa<FloatType>(inputEType)) {
549 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
550 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
551 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
552 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
555 const auto opMinFloat = floatMinValAttr.getValue();
556 const auto opMaxFloat = floatMaxValAttr.getValue();
557 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
558 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
562 if (!opRangeFloatRange.
intersects(clampRangeFloatRange))
566 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
567 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
568 newMinValAttr = rewriter.
getFloatAttr(inputEType, newMinVal);
569 newMaxValAttr = rewriter.
getFloatAttr(inputEType, newMaxVal);
571 assert(mlir::isa<IntegerType>(inputEType));
572 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
573 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
574 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
575 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
577 if (inputEType.isUnsignedInteger()) {
579 const auto opMinInt = intMinValAttr.getUInt();
580 const auto opMaxInt = intMaxValAttr.getUInt();
581 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
582 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
586 if (!opRangeIntRange.
intersects(clampRangeIntRange))
590 auto newMinVal = std::max(opMinInt, clampOpMinInt);
591 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
596 const auto opMinInt = intMinValAttr.getInt();
597 const auto opMaxInt = intMaxValAttr.getInt();
598 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
599 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
603 if (!opRangeIntRange.
intersects(clampRangeIntRange))
607 auto newMinVal = std::max(opMinInt, clampOpMinInt);
608 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
614 auto newMode = (opNanMode != clampNanMode)
615 ? tosa::NanPropagationMode::IGNORE
619 NanPropagationModeAttr::get(rewriter.
getContext(), newMode);
622 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
628void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
629 MLIRContext *context) {
630 results.
add<ClampIsNoOp>(context);
631 results.
add<ClampClampOptimization>(context);
639 Value sliceInput = sliceOp.getInput1();
643 sliceOp,
"slice input must be concat operation");
646 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
647 if (!concatType || !concatType.hasStaticShape())
649 sliceOp,
"slice input must be a static ranked tensor");
650 int32_t axis = concatOp.getAxis();
657 sliceOp,
"start of slice must be a static ranked shape");
661 sliceOp,
"size of slice must be a static ranked shape");
671 std::optional<Value> replaceWithSlice;
672 for (
auto input : inputs) {
673 auto inputType = dyn_cast<RankedTensorType>(input.getType());
674 if (!inputType || !inputType.hasStaticShape())
676 sliceOp,
"concat input must be a static ranked tensor");
678 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
679 inputType.getDimSize(axis)) {
685 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
686 input, start_op, size_op)
690 sliceStarts[axis] -= inputType.getDimSize(axis);
693 if (!replaceWithSlice)
695 sliceOp,
"corresponding concat input not found for slice");
697 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
707 Value sliceInput = sliceOp.getInput1();
713 "slice input must be a pad operation");
716 if (!padOp->hasOneUse())
718 "pad shall have a single consumer");
721 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
722 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
723 if (!inputTy || !padTy || !inputTy.hasRank())
725 "slice input must be a ranked tensor");
732 "`padding` input specified on the tosa::PadOp must be constant.");
735 llvm::to_vector(paddingElems.getValues<
int64_t>());
741 sliceOp,
"start of slice must be a static ranked shape");
748 sliceOp,
"size of slice must be a static ranked shape");
753 const int64_t rank = inputTy.getRank();
754 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](
int64_t i) {
755 const bool isDimDynamic = inputTy.isDynamicDim(i);
756 const bool isDimSliced =
759 return isDimDynamic && isDimSliced;
762 sliceOp,
"axis that are sliced shall be statically known.");
769 bool updated =
false;
771 for (
int64_t i = 0; i < rank; ++i) {
772 const int64_t padLo = padPaddings[i * 2];
773 const int64_t padHi = padPaddings[i * 2 + 1];
774 const int64_t sliceStart = sliceStarts[i];
775 const int64_t sliceSize = sliceSizes[i];
776 const int64_t sliceEnd = sliceStart + sliceSize;
779 if (inputTy.isDynamicDim(i)) {
780 newPadPaddings[i * 2] = padLo;
781 newPadPaddings[i * 2 + 1] = padHi;
782 newSliceStarts[i] = sliceStart;
787 const int64_t dimSize = inputTy.getShape()[i];
788 const int64_t dimTotal = padLo + dimSize + padHi;
791 if (sliceStart < 0 || sliceEnd > dimTotal)
795 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
796 newSliceStarts[i] = newSliceStart;
797 updated |= newSliceStart != sliceStart;
800 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
802 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
803 newPadPaddings[i * 2] = newPadLo;
804 newPadPaddings[i * 2 + 1] = newPadHi;
805 updated |= (newPadLo != padLo) || (newPadHi != padHi);
809 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
815 sliceOp,
"terminate condition; nothing to rewrite");
821 RankedTensorType::get(newPadShape, inputTy.getElementType());
822 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
823 padOp.getInput1(), newPaddingsOp,
824 padOp.getPadConst());
830 newPadOp.getResult(), newStartOp,
845 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
847 ElementsAttr sizeElems;
850 sliceOp,
"size of slice must be a static ranked shape");
854 llvm::to_vector(sizeElems.getValues<
int64_t>());
856 bool replaceSliceSize{
false};
860 for (
const auto &[
index, size] : llvm::enumerate(sliceSizes)) {
862 sliceSizes[
index] = resultType.getDimSize(
index);
863 replaceSliceSize =
true;
867 if (!replaceSliceSize) {
869 sliceOp,
"no dimension of size of slice is dynamic that resolves "
870 "to static output shape");
875 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
876 sliceOp.getInput1(), sliceOp.getStart(), size_op);
878 rewriter.
replaceOp(sliceOp, newSliceOp.getResult());
883void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
884 MLIRContext *context) {
885 results.
add<ConcatSliceOptimization, PadSliceOptimization,
886 SliceDynamicSizeCanonicalization>(context);
894 const Value castInput = castOp.getInput();
898 "input must be cast operation");
900 const Value innerCastInput = innerCastOp.getInput();
902 const auto innerInputType =
903 llvm::cast<ShapedType>(innerCastInput.
getType());
904 const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
905 const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
909 if (llvm::any_of(types, [](
const ShapedType type) {
910 return !type.getElementType().isInteger();
913 "only integer types are supported");
916 const unsigned innerInputBitWidth = innerInputType.getElementTypeBitWidth();
917 if (innerInputBitWidth > innerOutputType.getElementTypeBitWidth())
919 "inner cast operation is narrowing");
922 if (innerInputBitWidth > outerOutputType.getElementTypeBitWidth())
924 "outer cast operation is narrowing");
933void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
934 MLIRContext *context) {
935 results.
add<NonNarrowingCastsOptimization>(context);
942template <
typename Folder>
943static DenseElementsAttr
945 bool foldDenseValues =
false) {
950 const auto rETy = llvm::cast<ShapedType>(
rhs.getType()).getElementType();
954 if (
lhs.isSplat() &&
rhs.isSplat()) {
955 if (isa<FloatType>(lETy)) {
956 const APFloat l =
lhs.getSplatValue<APFloat>();
957 const APFloat r =
rhs.getSplatValue<APFloat>();
958 const auto maybeResult = Folder::fold(l, r);
959 if (failed(maybeResult))
964 if (
const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
965 const APInt l =
lhs.getSplatValue<APInt>();
966 const APInt r =
rhs.getSplatValue<APInt>();
967 const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
968 if (failed(maybeResult))
974 if (foldDenseValues) {
975 assert(lETy.isIntOrIndex() &&
976 "Only integer types are currently supported.");
979 llvm::zip(
lhs.getValues<APInt>(),
rhs.getValues<APInt>())) {
980 const auto maybeResult = Folder::fold(l, r,
false);
981 if (failed(maybeResult))
983 resultValues.push_back(maybeResult.value());
991template <
typename Folder>
993 bool foldDenseValues =
false) {
1000 if (
const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
1002 const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
1003 if (failed(maybeResult))
1009 if (foldDenseValues) {
1013 for (
auto const &v : val.
getValues<APInt>()) {
1014 const auto maybeResult = Folder::fold(v,
false);
1015 if (failed(maybeResult))
1017 resultValues.push_back(maybeResult.value());
1029 const bool isUnsigned) {
1032 isUnsigned ?
lhs.uadd_ov(
rhs, overflow) :
lhs.sadd_ov(
rhs, overflow);
1038 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1045 const bool isUnsigned) {
1048 isUnsigned ?
lhs.usub_ov(
rhs, overflow) :
lhs.ssub_ov(
rhs, overflow);
1054 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1061 const bool isUnsigned) {
1063 const unsigned originalWidth =
lhs.getBitWidth();
1066 if (
lhs.getBitWidth() !=
rhs.getBitWidth()) {
1071 if (
lhs == 0 ||
rhs == 0)
1072 return APInt::getZero(originalWidth);
1074 bool overflow =
false;
1076 isUnsigned ?
lhs.umul_ov(
rhs, overflow) :
lhs.smul_ov(
rhs, overflow);
1081 return result.trunc(originalWidth);
1084 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1090 return a.isNegative() !=
b.isNegative();
1097 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1105 APInt::udivrem(
lhs,
rhs, q, r);
1106 if (!r.isZero() && Ceil) {
1113 bool overflow{
false};
1114 APInt
const q =
lhs.sdiv_ov(
rhs, overflow);
1117 APInt
const r =
lhs.srem(
rhs);
1127 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1135 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1137 if (
lhs.isNegative() || (!
rhs.isStrictlyPositive()))
1147 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1149 auto const r = t.mod(
rhs);
1150 if (llvm::APFloatBase::opStatus::opOK == r) {
1160 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1162 return lhs.getSExtValue() >=
rhs.getSExtValue() ?
lhs :
rhs;
1165 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1173 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1175 return lhs.getSExtValue() <=
rhs.getSExtValue() ?
lhs :
rhs;
1178 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1184 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1185 auto const numBits = value.getBitWidth();
1187 auto const zextv = value.getZExtValue();
1188 if (zextv >= numBits)
1190 return APInt::getOneBitSet(numBits, zextv);
1192 auto const sextv = value.getSExtValue();
1193 if (sextv < 0 || sextv >= numBits || (value.isNegative()))
1195 return APInt::getOneBitSet(numBits, sextv);
1200 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1201 if (!value.isStrictlyPositive())
1203 return APInt(value.getBitWidth(), value.ceilLogBase2());
1208 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1209 if (!value.isStrictlyPositive())
1211 return APInt(value.getBitWidth(), value.logBase2());
1217 const bool isUnsigned) {
1218 return isUnsigned ? APInt(1,
lhs.ugt(
rhs)) : APInt(1,
lhs.sgt(
rhs));
1221 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1222 return APInt(1,
lhs >
rhs);
1228 const bool isUnsigned) {
1229 return isUnsigned ? APInt(1,
lhs.uge(
rhs)) : APInt(1,
lhs.sge(
rhs));
1232 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1233 return APInt(1,
lhs >=
rhs);
1239 const bool isUnsigned) {
1240 return APInt(1,
lhs ==
rhs);
1243 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1244 return APInt(1,
lhs ==
rhs);
1249 if (llvm::isa<FloatType>(elemType))
1251 if (llvm::isa<IntegerType>(elemType))
1257 if (llvm::isa<FloatType>(elemType))
1258 return val && val.
isSplat() &&
1260 if (llvm::isa<IntegerType>(elemType)) {
1261 const int64_t shifted = 1LL << shift;
1262 return val && val.
isSplat() &&
1268OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1269 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1270 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1271 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1272 if (!lhsTy || !rhsTy || !resultTy)
1276 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1277 !rhsTy.getElementType().isIntOrIndexOrFloat())
1280 auto resultETy = resultTy.getElementType();
1282 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1284 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1287 lhsTy.getShape(), rhsTy.getShape());
1288 if (isBroadcastable && lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1290 if (isBroadcastable && rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
1293 if (!lhsAttr || !rhsAttr)
1299OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1300 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
1301 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1302 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1303 !outputTy.hasStaticShape())
1307 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.
isInteger()) {
1308 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1309 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1316OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1317 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1318 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1319 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1320 if (!lhsTy || !rhsTy || !resultTy)
1322 if (lhsTy.getElementType() != rhsTy.getElementType())
1327 auto resultETy = resultTy.getElementType();
1329 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1331 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1332 if (lhsAttr && lhsAttr.isSplat()) {
1333 if (llvm::isa<IntegerType>(resultETy) && resultTy.hasStaticShape() &&
1334 lhsAttr.getSplatValue<APInt>().isZero())
1335 return lhsAttr.resizeSplat(resultTy);
1338 if (rhsAttr && rhsAttr.isSplat()) {
1340 lhsTy.getShape(), rhsTy.getShape());
1341 if (isBroadcastable && lhsTy == resultTy &&
1342 llvm::isa<IntegerType>(resultETy) &&
1343 rhsAttr.getSplatValue<APInt>().isOne())
1347 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1348 llvm::isa<IntegerType>(resultETy)) {
1349 APInt l = lhsAttr.getSplatValue<APInt>();
1350 APInt r = rhsAttr.getSplatValue<APInt>();
1352 auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
1354 DivFoldAdaptor<
false>::fold(l, r, intTy.isUnsigned());
1367std::optional<APInt> mulInt(APInt
lhs, APInt
rhs, int32_t shift,
1368 unsigned bitwidth) {
1369 bool overflow =
false;
1370 APInt
result =
lhs.sext(64).smul_ov(
rhs.sext(64), overflow);
1373 return std::nullopt;
1376 auto round = APInt(64, 1) << (shift - 1);
1378 result.ashrInPlace(shift);
1381 if (!(
result.getSExtValue() >= INT32_MIN &&
1382 result.getSExtValue() <= INT32_MAX)) {
1384 return std::nullopt;
1388 return result.trunc(bitwidth);
1391DenseElementsAttr mulBinaryFolder(DenseElementsAttr
lhs, DenseElementsAttr
rhs,
1392 RankedTensorType ty, int32_t shift) {
1394 if (llvm::isa<IntegerType>(ty.getElementType())) {
1395 APInt l =
lhs.getSplatValue<APInt>();
1396 APInt r =
rhs.getSplatValue<APInt>();
1402 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1403 const std::optional<APInt>
result = mulInt(l, r, shift, bitwidth);
1409 if (llvm::isa<FloatType>(ty.getElementType())) {
1410 APFloat l =
lhs.getSplatValue<APFloat>();
1411 APFloat r =
rhs.getSplatValue<APFloat>();
1421OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1422 auto lhs = getInput1();
1423 auto rhs = getInput2();
1424 auto lhsTy = llvm::dyn_cast<RankedTensorType>(
lhs.getType());
1425 auto rhsTy = llvm::dyn_cast<RankedTensorType>(
rhs.getType());
1426 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1427 if (!lhsTy || !rhsTy || !resultTy)
1430 auto resultETy = resultTy.getElementType();
1432 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1434 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1439 if (resultETy.isInteger(32)) {
1440 ElementsAttr shift_elem;
1441 if (getShift().getImpl()) {
1445 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1449 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr) &&
1450 resultTy.hasStaticShape())
1452 return lhsAttr.resizeSplat(resultTy);
1453 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr) &&
1454 resultTy.hasStaticShape())
1455 return rhsAttr.resizeSplat(resultTy);
1458 lhsTy.getShape(), rhsTy.getShape());
1459 if (isBroadcastable && rhsTy == resultTy &&
1462 if (isBroadcastable && lhsTy == resultTy &&
1466 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1469OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1470 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1471 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1472 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1473 if (!lhsTy || !rhsTy || !resultTy)
1477 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1478 !rhsTy.getElementType().isIntOrIndexOrFloat())
1481 auto resultETy = resultTy.getElementType();
1483 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1485 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1488 lhsTy.getShape(), rhsTy.getShape());
1489 if (isBroadcastable && lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1492 if (!lhsAttr || !rhsAttr)
1498OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1499 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1501 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1503 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1505 if (!lhsAttr || !rhsAttr)
1511OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1512 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1514 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1516 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1518 if (!lhsAttr || !rhsAttr)
1524OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1525 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1527 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1529 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1530 Value
lhs = getInput1();
1531 Value
rhs = getInput2();
1532 auto lhsTy = llvm::cast<ShapedType>(
lhs.getType());
1536 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1537 resultTy.hasStaticShape() &&
lhs ==
rhs) {
1541 if (!lhsAttr || !rhsAttr)
1547OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1551 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1555 auto inTy = llvm::cast<ShapedType>(getInput().
getType());
1556 auto outTy = llvm::cast<ShapedType>(
getType());
1557 auto inETy = inTy.getElementType();
1558 auto outETy = outTy.getElementType();
1560 if (operand.isSplat()) {
1561 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1563 auto splatVal = operand.getSplatValue<APFloat>();
1564 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1565 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1570 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1571 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1572 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1573 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1574 llvm::RoundingMode::NearestTiesToEven);
1578 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1579 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1580 auto intVal = APSInt(
1581 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1582 auto floatVal = operand.getSplatValue<APFloat>();
1584 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1589 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1590 const auto inIntType = llvm::cast<IntegerType>(inETy);
1591 auto unsignIn = inIntType.isUnsignedInteger();
1593 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1594 auto intVal = operand.getSplatValue<APInt>();
1595 auto bitwidth = outETy.getIntOrFloatBitWidth();
1598 if (outETy.isInteger(1)) {
1599 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1601 intVal = intVal.trunc(bitwidth);
1602 }
else if (unsignIn || inIntType.isInteger(1)) {
1603 intVal = intVal.zext(bitwidth);
1605 intVal = intVal.sext(bitwidth);
1615OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1617OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1619#define REDUCE_FOLDER(OP) \
1620 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1621 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1622 if (!inputTy.hasRank()) \
1624 if (inputTy != getType()) \
1626 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1627 return getInput(); \
1640 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1641 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1643 if (!inputTy || !outputTy)
1649 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1653 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1654 getInput1().getDefiningOp())) {
1655 getInput1Mutable().assign(reshapeOp.getInput1());
1660 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1665 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1667 if (!outputTy.hasStaticShape())
1671 if (operand.isSplat())
1676 if (!getInput1().hasOneUse())
1683 return operand.reshape(
1684 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1690OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1692 if (adaptor.getPadding() && getInput1().
getType() ==
getType()) {
1693 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1694 if (densePad && densePad.isSplat() &&
1695 densePad.getSplatValue<APInt>().isZero()) {
1705OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1707 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1709 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1711 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1712 if (!scaleAttr || !offsetAttr || !borderAttr) {
1719 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1724 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1729 if (offset[0] != 0 || offset[1] != 0) {
1734 if (border[0] != 0 || border[1] != 0) {
1738 auto input = getInput();
1739 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1740 auto resultTy = llvm::cast<RankedTensorType>(
getType());
1741 if (inputTy != resultTy)
1747OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1748 auto operand = getInput1();
1749 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1750 auto axis = getAxis();
1752 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1757 if (operandTy.hasRank() &&
1758 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1764OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1765 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1766 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1768 if (!inputTy || !outputTy)
1771 if (inputTy == outputTy && inputTy.hasStaticShape())
1776 DenseElementsAttr startElems;
1782 llvm::all_of(startElems.
getValues<APInt>(),
1783 [](
const APInt &val) { return val.isZero(); });
1788 DenseElementsAttr sizeElems;
1792 auto inputShape = inputTy.getShape();
1793 auto sizeValues = sizeElems.
getValues<APInt>();
1795 bool sizeMatchesInput =
true;
1796 for (
const auto &[i, sizeVal] : llvm::enumerate(sizeValues)) {
1797 int64_t size = sizeVal.getSExtValue();
1799 if (inputTy.isDynamicDim(i)) {
1803 sizeMatchesInput =
false;
1810 sizeMatchesInput =
false;
1816 if (sizeMatchesInput)
1821 if (!adaptor.getInput1())
1825 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1826 !outputTy.getElementType().isIntOrIndexOrFloat())
1829 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1830 if (operand.isSplat() && outputTy.hasStaticShape()) {
1834 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1835 outputTy.getNumElements() == 1) {
1836 llvm::SmallVector<uint64_t>
indices =
1837 llvm::to_vector(startElems.
getValues<uint64_t>());
1838 auto value = operand.getValues<Attribute>()[
indices];
1845OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1846 const Value pred = getPred();
1847 const Value onTrue = getOnTrue();
1848 const Value onFalse = getOnFalse();
1850 const auto predTy = llvm::dyn_cast<RankedTensorType>(pred.
getType());
1851 const auto onTrueTy = llvm::dyn_cast<RankedTensorType>(onTrue.
getType());
1852 const auto onFalseTy = llvm::dyn_cast<RankedTensorType>(onFalse.
getType());
1853 if (!predTy || !onTrueTy || !onFalseTy)
1856 const Type resultTy =
getType();
1858 const ArrayRef<int64_t> predShape = predTy.getShape();
1859 const ArrayRef<int64_t> onTrueShape = onTrueTy.getShape();
1861 if (onTrue == onFalse && onTrueTy == resultTy &&
1866 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1869 if (!predicate.isSplat())
1872 const bool predicateValue = predicate.getSplatValue<APInt>().getBoolValue();
1874 SmallVector<SmallVector<int64_t>, 3> shapes;
1875 shapes.emplace_back(predShape);
1876 shapes.emplace_back(onTrueShape);
1877 shapes.emplace_back(onFalseTy.getShape());
1878 const bool isBroadcastable =
1881 if (predicateValue ==
true && onTrueTy == resultTy && isBroadcastable)
1883 if (predicateValue ==
false && onFalseTy == resultTy && isBroadcastable)
1888OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1890 if (
auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1891 adaptor.getMultiples())) {
1892 if (multiples.isSplat() &&
1893 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1895 if (
auto int_array_attr =
1896 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1897 if (llvm::all_of(int_array_attr.getValues<APInt>(),
1898 [](APInt v) { return v.getSExtValue() == 1; }))
1906OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1907 auto resultTy = llvm::cast<ShapedType>(
getType());
1911 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1912 if (input.isSplat() && resultTy.hasStaticShape() &&
1913 input.getType().getElementType() == resultTy.getElementType())
1914 return input.reshape(resultTy);
1918 const llvm::ArrayRef<int32_t> perms = getPerms();
1920 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1926OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1929 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1935 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1936 failed(maybeIZp) || *maybeIZp != 0) {
1940 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1941 failed(maybeOZp) || *maybeOZp != 0) {
1945 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1946 failed(maybeIZp) || *maybeIZp != 0) {
1950 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1951 failed(maybeOZp) || *maybeOZp != 0) {
1956 return definingOp.getInput1();
1959OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1960 auto input = getInput1();
1962 if (
auto op = input.getDefiningOp<tosa::AbsOp>()) {
1969OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1974 SmallVector<Value, 8> concatOperands;
1975 concatOperands.reserve(2 * getNumOperands());
1978 bool foundFoldableConcat =
false;
1979 for (Value operand : getOperands()) {
1980 concatOperands.emplace_back(operand);
1982 auto producer = operand.getDefiningOp<ConcatOp>();
1987 if (getAxis() != producer.getAxis())
1991 foundFoldableConcat =
true;
1992 concatOperands.pop_back();
1993 llvm::append_range(concatOperands, producer->getOperands());
1996 if (!foundFoldableConcat)
1999 getOperation()->setOperands(concatOperands);
2003OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
2004 auto input = adaptor.getInput1();
2006 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
2008 if (!inputAttr || !inputAttr.isSplat())
2011 auto shapeType = llvm::cast<ShapedType>(
getType());
2012 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
2013 auto floatVal = inputAttr.getSplatValue<APFloat>();
2015 ReciprocalOp::calcOneElement(floatVal));
2021template <
typename Op,
typename OpFoldAdaptor>
2023 auto input1ConstShape =
2024 dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
2025 if (!input1ConstShape)
2028 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2034template <
typename Op,
typename OpFoldAdaptor>
2036 auto input1ConstShape =
2037 dyn_cast<tosa::ConstShapeOp>(op->getInput1().getDefiningOp());
2038 auto input2ConstShape =
2039 dyn_cast<tosa::ConstShapeOp>(op->getInput2().getDefiningOp());
2040 if (!input1ConstShape || !input2ConstShape)
2043 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2044 const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
2047 input1Attr.getType(),
2051OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
2052 const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().
getType());
2053 if (!inputTy || !inputTy.hasRank())
2055 const int32_t axis = getAxis();
2056 const int64_t dimSize = inputTy.getDimSize(axis);
2057 if (ShapedType::isDynamic(dimSize))
2061 const auto resultAttrTy =
2062 RankedTensorType::get(1, builder.getIndexType());
2066OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
2070OpFoldResult tosa::SubShapeOp::fold(FoldAdaptor adaptor) {
2074OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
2078OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
2079 return binaryFold<DivCeilShapeOp, DivFoldAdaptor<
true>>(
this);
2082OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
2083 return binaryFold<DivFloorShapeOp, DivFoldAdaptor<
false>>(
this);
2086OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
2090OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
2094OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
2098OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
2102OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
2106OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
#define REDUCE_FOLDER(OP)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, bool foldDenseValues=false)
static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy, bool foldDenseValues=false)
OpFoldResult binaryFold(Op *op)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
OpFoldResult unaryShapeFold(Op *op)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool signsDiffer(const APInt &a, const APInt &b)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
DynamicAPInt round(const Fraction &f)
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
LogicalResult matchAndRewrite(tosa::CastOp castOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...