26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
53 (padConstAttr.
size() != 1)) {
58 if (
auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
59 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
60 return padConstVal == 0.0f;
64 if (
auto padConstIntAttr =
65 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
74 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
75 return zpVal == padConstVal;
83template <
typename OpTy>
84struct PoolPadFoldAdaptor;
87struct PoolPadFoldAdaptor<
tosa::MaxPool2dOp> {
88 using OpTy = tosa::MaxPool2dOp;
89 static bool checkKernelCompliance(OpTy op,
const ArrayRef<int64_t> newPad) {
90 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
91 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
92 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
96 static bool checkPadConstCompliance(OpTy, Value padConst) {
98 DenseElementsAttr padConstAttr;
100 padConstAttr.
size() != 1) {
105 if (
auto padConstFpAttr =
106 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
107 const APFloat padConstVal = *padConstFpAttr.begin();
108 const APFloat lowestVal =
109 APFloat::getLargest(padConstVal.getSemantics(),
true);
110 return padConstVal == lowestVal;
112 if (
auto padConstIntAttr =
113 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
114 const APInt padConstVal = *padConstIntAttr.begin();
115 const unsigned int bitWidth = padConstVal.getBitWidth();
116 const APInt lowestVal =
117 padConstIntAttr.getElementType().isUnsignedInteger()
118 ? APInt::getZero(bitWidth)
119 : APInt::getSignedMinValue(bitWidth);
120 return padConstVal == lowestVal;
126 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
127 Value padInput, ArrayRef<int64_t> newPad) {
129 op, op.getType(), padInput, op.getKernel(), op.getStride(),
134template <
typename OpTy>
135struct ConvPadFoldAdaptor {
136 static bool checkKernelCompliance(OpTy,
const ArrayRef<int64_t>) {
139 static bool checkPadConstCompliance(OpTy op, Value padConst) {
142 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
143 Value padInput, ArrayRef<int64_t> newPad) {
145 op, op.getResult().
getType(), padInput, op.getWeight(), op.getBias(),
146 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
147 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
155template <
typename OpTy,
typename AdaptorTy>
157 using OpRewritePattern<OpTy>::OpRewritePattern;
159 LogicalResult matchAndRewrite(OpTy tensorOp,
160 PatternRewriter &rewriter)
const override {
162 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
165 "Producer must be a tosa::PadOp.");
168 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
169 if (tensorOpPad.size() != 4)
171 tensorOp,
"Tensor operation padding shall have 4 elements.");
174 DenseIntElementsAttr padOpPadding;
178 "The `padding` input specified on the tosa::PadOp must be constant.");
182 if (padOpPadding.size() != 8)
184 "Pad padding should have 8 elements.");
185 int64_t padNBefore = (*(padOpPadding.
begin() + 0)).getLimitedValue();
186 int64_t padNAfter = (*(padOpPadding.
begin() + 1)).getLimitedValue();
187 int64_t padHBefore = (*(padOpPadding.
begin() + 2)).getLimitedValue();
188 int64_t padHAfter = (*(padOpPadding.
begin() + 3)).getLimitedValue();
189 int64_t padWBefore = (*(padOpPadding.
begin() + 4)).getLimitedValue();
190 int64_t padWAfter = (*(padOpPadding.
begin() + 5)).getLimitedValue();
191 int64_t padCBefore = (*(padOpPadding.
begin() + 6)).getLimitedValue();
192 int64_t padCAfter = (*(padOpPadding.
begin() + 7)).getLimitedValue();
194 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
196 tensorOp,
"Folding padding in N or C dimensions is not supported.");
200 SmallVector<int64_t> foldedPad(tensorOpPad.size());
201 foldedPad[0] = padHBefore + tensorOpPad[0];
202 foldedPad[1] = padHAfter + tensorOpPad[1];
203 foldedPad[2] = padWBefore + tensorOpPad[2];
204 foldedPad[3] = padWAfter + tensorOpPad[3];
207 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
209 tensorOp,
"Padding size not aligned with kernel restrictions.");
213 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
216 "Padding constant is not aligned with operator zero-point.");
220 if (llvm::any_of(foldedPad, [](int64_t padVal) {
return padVal > 8192; })) {
222 tensorOp,
"Padding size more than the 8K level limit.");
226 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
237 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
243 results.
add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
244 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
253 Value input = op.getInput();
254 Value output = op.getOutput();
255 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
256 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
258 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
264 if (outputShape[1] != 1 || outputShape[2] != 1) {
269 if (inputShape[1] != 1 || inputShape[2] != 1) {
281 FoldPadToTensorOp<tosa::MaxPool2dOp,
282 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
295 if (op.getInput1().size() != 1)
297 if (op.getInput1().front().getType() != op.getType()) {
300 op.getInput1().front())
305 rewriter.
replaceOp(op, op.getInput1().front());
315LogicalResult SelectOp::canonicalize(SelectOp op,
PatternRewriter &rewriter) {
316 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
320 op.getOperation()->setOperands(
321 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
333 auto innerTranspose =
334 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
337 "input must be transpose operation");
341 innerTranspose.getPerms();
343 if (transposePerms.size() != innerTransposePerms.size())
346 "transpose and inner transpose perms sizes must be equal");
347 if (transposePerms.empty())
349 transposeOp,
"transpose perms sizes must be positive");
353 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
354 perms[i] = innerTransposePerms[transposePerms[i]];
357 transposeOp, transposeOp.getResult().
getType(),
370 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
372 op,
"Src is from transpose, can compose transposes");
376 if (isa_and_nonnull<tosa::TransposeOp>(subop))
378 op,
"Dest is used by transpose, can compose transposes");
381 auto input = op.getInput1();
382 auto inputTy = llvm::cast<ShapedType>(input.
getType());
383 if (!inputTy.hasRank())
387 for (
int i = 0; i < inputTy.getRank(); ++i)
388 if (inputTy.isDynamicDim(i))
397 nonZeroPerms.reserve(permValues.size());
398 for (
auto idx : permValues) {
399 auto sz = inputTy.getDimSize(idx);
401 nonZeroPerms.push_back(idx);
404 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
405 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
407 "Transpose changes memory layout.");
410 newShape.reserve(inputTy.getRank());
411 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
412 newShape.push_back(inputTy.getDimSize(permValues[i]));
415 op, op.getType(), op.getInput1(),
423 results.
add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
431 Value input = op.getInput();
432 auto inputType = llvm::cast<ShapedType>(op.getInput().getType());
433 auto inputElementType = inputType.getElementType();
435 if (isa<FloatType>(inputElementType)) {
437 const auto minClamp =
438 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
439 const auto maxClamp =
440 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
441 const bool isMin = minClamp.isNegInfinity();
442 const bool isMax = maxClamp.isInfinity();
444 if (isMin && isMax) {
452 const bool isBoolean = inputElementType.isInteger(1);
453 if (inputElementType.isUnsignedInteger() || isBoolean) {
454 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
457 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
461 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
462 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
463 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
465 if (minClamp <= intMin && maxClamp >= intMax) {
472 if (llvm::isa<IntegerType>(inputElementType)) {
474 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
476 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
478 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
479 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
480 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
482 if (minClamp <= intMin && maxClamp >= intMax) {
514 template <
typename T>
528 Value input = op.getInput();
536 const auto opNanMode = op.getNanMode();
537 const auto clampNanMode = clampOp.getNanMode();
538 if (opNanMode == NanPropagationMode::IGNORE &&
539 clampNanMode == NanPropagationMode::PROPAGATE)
542 auto maxValAttr = op.getMaxValAttr();
543 auto minValAttr = op.getMinValAttr();
544 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
545 auto clampOpMinValAttr = clampOp.getMinValAttr();
547 auto inputEType = llvm::cast<ShapedType>(input.
getType()).getElementType();
549 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
554 if (mlir::isa<FloatType>(inputEType)) {
555 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
556 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
557 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
558 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
561 const auto opMinFloat = floatMinValAttr.getValue();
562 const auto opMaxFloat = floatMaxValAttr.getValue();
563 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
564 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
568 if (!opRangeFloatRange.
intersects(clampRangeFloatRange))
572 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
573 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
574 newMinValAttr = rewriter.
getFloatAttr(inputEType, newMinVal);
575 newMaxValAttr = rewriter.
getFloatAttr(inputEType, newMaxVal);
577 assert(mlir::isa<IntegerType>(inputEType));
578 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
579 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
580 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
581 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
583 if (inputEType.isUnsignedInteger()) {
585 const auto opMinInt = intMinValAttr.getUInt();
586 const auto opMaxInt = intMaxValAttr.getUInt();
587 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
588 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
592 if (!opRangeIntRange.
intersects(clampRangeIntRange))
596 auto newMinVal = std::max(opMinInt, clampOpMinInt);
597 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
602 const auto opMinInt = intMinValAttr.getInt();
603 const auto opMaxInt = intMaxValAttr.getInt();
604 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
605 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
609 if (!opRangeIntRange.
intersects(clampRangeIntRange))
613 auto newMinVal = std::max(opMinInt, clampOpMinInt);
614 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
620 auto newMode = (opNanMode != clampNanMode)
621 ? tosa::NanPropagationMode::IGNORE
625 NanPropagationModeAttr::get(rewriter.
getContext(), newMode);
628 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
634void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
635 MLIRContext *context) {
636 results.
add<ClampIsNoOp>(context);
637 results.
add<ClampClampOptimization>(context);
645 Value sliceInput = sliceOp.getInput1();
649 sliceOp,
"slice input must be concat operation");
652 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
653 if (!concatType || !concatType.hasStaticShape())
655 sliceOp,
"slice input must be a static ranked tensor");
656 int32_t axis = concatOp.getAxis();
663 sliceOp,
"start of slice must be a static ranked shape");
667 sliceOp,
"size of slice must be a static ranked shape");
677 std::optional<Value> replaceWithSlice;
678 for (
auto input : inputs) {
679 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
680 if (!inputType || !inputType.hasStaticShape())
682 sliceOp,
"concat input must be a static ranked tensor");
684 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
685 inputType.getDimSize(axis)) {
691 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
692 input, start_op, size_op)
696 sliceStarts[axis] -= inputType.getDimSize(axis);
699 if (!replaceWithSlice)
701 sliceOp,
"corresponding concat input not found for slice");
703 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
713 Value sliceInput = sliceOp.getInput1();
719 "slice input must be a pad operation");
722 if (!padOp->hasOneUse())
724 "pad shall have a single consumer");
727 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
728 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
729 if (!inputTy || !padTy || !inputTy.hasRank())
731 "slice input must be a ranked tensor");
738 "`padding` input specified on the tosa::PadOp must be constant.");
741 llvm::to_vector(paddingElems.getValues<
int64_t>());
747 sliceOp,
"start of slice must be a static ranked shape");
754 sliceOp,
"size of slice must be a static ranked shape");
759 const int64_t rank = inputTy.getRank();
760 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](
int64_t i) {
761 const bool isDimDynamic = inputTy.isDynamicDim(i);
762 const bool isDimSliced =
765 return isDimDynamic && isDimSliced;
768 sliceOp,
"axis that are sliced shall be statically known.");
775 bool updated =
false;
777 for (
int64_t i = 0; i < rank; ++i) {
778 const int64_t padLo = padPaddings[i * 2];
779 const int64_t padHi = padPaddings[i * 2 + 1];
780 const int64_t sliceStart = sliceStarts[i];
781 const int64_t sliceSize = sliceSizes[i];
782 const int64_t sliceEnd = sliceStart + sliceSize;
785 if (inputTy.isDynamicDim(i)) {
786 newPadPaddings[i * 2] = padLo;
787 newPadPaddings[i * 2 + 1] = padHi;
788 newSliceStarts[i] = sliceStart;
793 const int64_t dimSize = inputTy.getShape()[i];
794 const int64_t dimTotal = padLo + dimSize + padHi;
797 if (sliceStart < 0 || sliceEnd > dimTotal)
801 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
802 newSliceStarts[i] = newSliceStart;
803 updated |= newSliceStart != sliceStart;
806 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
808 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
809 newPadPaddings[i * 2] = newPadLo;
810 newPadPaddings[i * 2 + 1] = newPadHi;
811 updated |= (newPadLo != padLo) || (newPadHi != padHi);
815 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
821 sliceOp,
"terminate condition; nothing to rewrite");
827 RankedTensorType::get(newPadShape, inputTy.getElementType());
828 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
829 padOp.getInput1(), newPaddingsOp,
830 padOp.getPadConst());
836 newPadOp.getResult(), newStartOp,
851 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
852 if (!resultType.hasRank())
855 ElementsAttr sizeElems;
858 sliceOp,
"size of slice must be a static ranked shape");
862 llvm::to_vector(sizeElems.getValues<
int64_t>());
864 bool replaceSliceSize{
false};
868 for (
const auto &[
index, size] : llvm::enumerate(sliceSizes)) {
870 sliceSizes[
index] = resultType.getDimSize(
index);
871 replaceSliceSize =
true;
875 if (!replaceSliceSize) {
877 sliceOp,
"no dimension of size of slice is dynamic that resolves "
878 "to static output shape");
883 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
884 sliceOp.getInput1(), sliceOp.getStart(), size_op);
886 rewriter.
replaceOp(sliceOp, newSliceOp.getResult());
891void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
892 MLIRContext *context) {
893 results.
add<ConcatSliceOptimization, PadSliceOptimization,
894 SliceDynamicSizeCanonicalization>(context);
902 const Value castInput = castOp.getInput();
906 "input must be cast operation");
908 const Value innerCastInput = innerCastOp.getInput();
910 const auto innerInputType =
911 llvm::cast<ShapedType>(innerCastInput.
getType());
912 const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
913 const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
917 if (llvm::any_of(types, [](
const ShapedType type) {
918 return !type.getElementType().isInteger();
921 "only integer types are supported");
924 const unsigned innerInputBitWidth = innerInputType.getElementTypeBitWidth();
925 if (innerInputBitWidth > innerOutputType.getElementTypeBitWidth())
927 "inner cast operation is narrowing");
930 if (innerInputBitWidth > outerOutputType.getElementTypeBitWidth())
932 "outer cast operation is narrowing");
941void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
942 MLIRContext *context) {
943 results.
add<NonNarrowingCastsOptimization>(context);
952 const Value castToBlockScaledInput = castToBlockScaledOp.getInputData();
953 auto castFromBlockScaledOp =
954 castToBlockScaledInput.
getDefiningOp<tosa::CastFromBlockScaledOp>();
955 if (!castFromBlockScaledOp)
958 "input must be cast_from_block_scaled operation");
960 const Value innerData = castFromBlockScaledOp.getInputData();
961 const Value innerScale = castFromBlockScaledOp.getInputScale();
962 const auto innerDataTy = llvm::cast<ShapedType>(innerData.
getType());
963 const auto innerScaleTy = llvm::cast<ShapedType>(innerScale.
getType());
965 const Value outerData = castToBlockScaledOp.getOutputData();
966 const Value outerScale = castToBlockScaledOp.getOutputScale();
967 const auto outerDataTy = llvm::cast<ShapedType>(outerData.
getType());
968 const auto outerScaleTy = llvm::cast<ShapedType>(outerScale.
getType());
970 if (innerDataTy != outerDataTy || innerScaleTy != outerScaleTy) {
973 "inputs types to cast_from_block_scaled operation must match output "
974 "types to cast_to_block_scaled");
977 if (castFromBlockScaledOp.getBlockSize() !=
978 castToBlockScaledOp.getBlockSize()) {
980 castToBlockScaledOp,
"block sizes for cast_from_block_scaled and "
981 "cast_to_block_scaled must match");
984 rewriter.
replaceOp(castToBlockScaledOp, {innerData, innerScale});
990void CastToBlockScaledOp::getCanonicalizationPatterns(
991 RewritePatternSet &results, MLIRContext *context) {
992 results.
add<CancellingBlockScaledCastsOptimization>(context);
999template <
typename Folder>
1000static DenseElementsAttr
1002 bool foldDenseValues =
false) {
1006 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1010 const auto rETy = llvm::cast<ShapedType>(
rhs.getType()).getElementType();
1014 if (
lhs.isSplat() &&
rhs.isSplat()) {
1015 if (isa<FloatType>(lETy)) {
1016 const APFloat l =
lhs.getSplatValue<APFloat>();
1017 const APFloat r =
rhs.getSplatValue<APFloat>();
1018 const auto maybeResult = Folder::fold(l, r);
1019 if (failed(maybeResult))
1024 if (
const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
1025 const APInt l =
lhs.getSplatValue<APInt>();
1026 const APInt r =
rhs.getSplatValue<APInt>();
1027 const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
1028 if (failed(maybeResult))
1034 if (foldDenseValues) {
1035 assert(lETy.isIntOrIndex() &&
1036 "Only integer types are currently supported.");
1039 llvm::zip(
lhs.getValues<APInt>(),
rhs.getValues<APInt>())) {
1040 const auto maybeResult = Folder::fold(l, r,
false);
1041 if (failed(maybeResult))
1043 resultValues.push_back(maybeResult.value());
1051template <
typename Folder>
1053 bool foldDenseValues =
false) {
1057 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1063 if (
const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
1065 const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
1066 if (failed(maybeResult))
1072 if (foldDenseValues) {
1076 for (
auto const &v : val.
getValues<APInt>()) {
1077 const auto maybeResult = Folder::fold(v,
false);
1078 if (failed(maybeResult))
1080 resultValues.push_back(maybeResult.value());
1095 assert(dense.isSplat());
1096 APInt a = dense.getSplatValue<APInt>();
1097 return a.getSExtValue();
1102 const bool isUnsigned) {
1105 isUnsigned ?
lhs.uadd_ov(
rhs, overflow) :
lhs.sadd_ov(
rhs, overflow);
1111 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1118 const bool isUnsigned) {
1121 isUnsigned ?
lhs.usub_ov(
rhs, overflow) :
lhs.ssub_ov(
rhs, overflow);
1127 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1134 const bool isUnsigned) {
1136 const unsigned originalWidth =
lhs.getBitWidth();
1139 if (
lhs.getBitWidth() !=
rhs.getBitWidth()) {
1144 if (
lhs == 0 ||
rhs == 0)
1145 return APInt::getZero(originalWidth);
1147 bool overflow =
false;
1149 isUnsigned ?
lhs.umul_ov(
rhs, overflow) :
lhs.smul_ov(
rhs, overflow);
1154 return result.trunc(originalWidth);
1157 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1163 return a.isNegative() !=
b.isNegative();
1170 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1178 APInt::udivrem(
lhs,
rhs, q, r);
1179 if (!r.isZero() && Ceil) {
1186 bool overflow{
false};
1187 APInt
const q =
lhs.sdiv_ov(
rhs, overflow);
1190 APInt
const r =
lhs.srem(
rhs);
1200 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1208 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1210 if (
lhs.isNegative() || (!
rhs.isStrictlyPositive()))
1220 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1222 auto const r = t.mod(
rhs);
1223 if (llvm::APFloatBase::opStatus::opOK == r) {
1233 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1235 return lhs.getSExtValue() >=
rhs.getSExtValue() ?
lhs :
rhs;
1238 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1246 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1248 return lhs.getSExtValue() <=
rhs.getSExtValue() ?
lhs :
rhs;
1251 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1257 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1258 auto const numBits = value.getBitWidth();
1260 auto const zextv = value.getZExtValue();
1261 if (zextv >= numBits)
1263 return APInt::getOneBitSet(numBits, zextv);
1265 auto const sextv = value.getSExtValue();
1266 if (sextv < 0 || sextv >= numBits || (value.isNegative()))
1268 return APInt::getOneBitSet(numBits, sextv);
1273 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1274 if (!value.isStrictlyPositive())
1276 return APInt(value.getBitWidth(), value.ceilLogBase2());
1281 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1282 if (!value.isStrictlyPositive())
1284 return APInt(value.getBitWidth(), value.logBase2());
1290 const bool isUnsigned) {
1291 return isUnsigned ? APInt(1,
lhs.ugt(
rhs)) : APInt(1,
lhs.sgt(
rhs));
1294 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1295 return APInt(1,
lhs >
rhs);
1301 const bool isUnsigned) {
1302 return isUnsigned ? APInt(1,
lhs.uge(
rhs)) : APInt(1,
lhs.sge(
rhs));
1305 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1306 return APInt(1,
lhs >=
rhs);
1312 const bool isUnsigned) {
1313 return APInt(1,
lhs ==
rhs);
1316 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1317 return APInt(1,
lhs ==
rhs);
1322 if (llvm::isa<FloatType>(elemType))
1324 if (llvm::isa<IntegerType>(elemType))
1330 if (llvm::isa<FloatType>(elemType))
1331 return val && val.
isSplat() &&
1333 if (llvm::isa<IntegerType>(elemType)) {
1334 const int64_t shifted = 1LL << shift;
1335 return val && val.
isSplat() &&
1341OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1342 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1343 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1344 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1345 if (!lhsTy || !rhsTy || !resultTy)
1349 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1350 !rhsTy.getElementType().isIntOrIndexOrFloat())
1353 auto resultETy = resultTy.getElementType();
1355 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1357 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1360 lhsTy.getShape(), rhsTy.getShape());
1361 if (isBroadcastable && lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1363 if (isBroadcastable && rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
1366 if (!lhsAttr || !rhsAttr)
1372OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1373 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
1374 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1375 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1376 !outputTy.hasStaticShape())
1380 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.
isInteger()) {
1381 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1382 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1389OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1390 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1391 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1392 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1393 if (!lhsTy || !rhsTy || !resultTy)
1395 if (lhsTy.getElementType() != rhsTy.getElementType())
1400 auto resultETy = resultTy.getElementType();
1402 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1404 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1405 if (lhsAttr && lhsAttr.isSplat()) {
1406 if (llvm::isa<IntegerType>(resultETy) && resultTy.hasStaticShape() &&
1407 lhsAttr.getSplatValue<APInt>().isZero())
1408 return lhsAttr.resizeSplat(resultTy);
1411 if (rhsAttr && rhsAttr.isSplat()) {
1413 lhsTy.getShape(), rhsTy.getShape());
1414 if (isBroadcastable && lhsTy == resultTy &&
1415 llvm::isa<IntegerType>(resultETy) &&
1416 rhsAttr.getSplatValue<APInt>().isOne())
1420 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1421 llvm::isa<IntegerType>(resultETy)) {
1422 APInt l = lhsAttr.getSplatValue<APInt>();
1423 APInt r = rhsAttr.getSplatValue<APInt>();
1425 auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
1427 DivFoldAdaptor<
false>::fold(l, r, intTy.isUnsigned());
1440std::optional<APInt> mulInt(APInt
lhs, APInt
rhs, int32_t shift,
1441 unsigned bitwidth) {
1442 bool overflow =
false;
1443 APInt
result =
lhs.sext(64).smul_ov(
rhs.sext(64), overflow);
1446 return std::nullopt;
1449 auto round = APInt(64, 1) << (shift - 1);
1451 result.ashrInPlace(shift);
1454 if (!(
result.getSExtValue() >= INT32_MIN &&
1455 result.getSExtValue() <= INT32_MAX)) {
1457 return std::nullopt;
1461 return result.trunc(bitwidth);
1464DenseElementsAttr mulBinaryFolder(DenseElementsAttr
lhs, DenseElementsAttr
rhs,
1465 RankedTensorType ty, int32_t shift) {
1467 if (llvm::isa<IntegerType>(ty.getElementType())) {
1468 APInt l =
lhs.getSplatValue<APInt>();
1469 APInt r =
rhs.getSplatValue<APInt>();
1475 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1476 const std::optional<APInt>
result = mulInt(l, r, shift, bitwidth);
1482 if (llvm::isa<FloatType>(ty.getElementType())) {
1483 APFloat l =
lhs.getSplatValue<APFloat>();
1484 APFloat r =
rhs.getSplatValue<APFloat>();
1494OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1495 auto lhs = getInput1();
1496 auto rhs = getInput2();
1497 auto lhsTy = llvm::dyn_cast<RankedTensorType>(
lhs.getType());
1498 auto rhsTy = llvm::dyn_cast<RankedTensorType>(
rhs.getType());
1499 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1500 if (!lhsTy || !rhsTy || !resultTy)
1503 auto resultETy = resultTy.getElementType();
1505 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1507 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1512 if (resultETy.isInteger(32)) {
1513 ElementsAttr shift_elem;
1514 if (getShift().getImpl()) {
1518 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1522 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr) &&
1523 resultTy.hasStaticShape())
1525 return lhsAttr.resizeSplat(resultTy);
1526 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr) &&
1527 resultTy.hasStaticShape())
1528 return rhsAttr.resizeSplat(resultTy);
1531 lhsTy.getShape(), rhsTy.getShape());
1532 if (isBroadcastable && rhsTy == resultTy &&
1535 if (isBroadcastable && lhsTy == resultTy &&
1539 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1542OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1543 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1544 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1545 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1546 if (!lhsTy || !rhsTy || !resultTy)
1550 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1551 !rhsTy.getElementType().isIntOrIndexOrFloat())
1554 auto resultETy = resultTy.getElementType();
1556 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1558 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1561 lhsTy.getShape(), rhsTy.getShape());
1562 if (isBroadcastable && lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1565 if (!lhsAttr || !rhsAttr)
1571OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1572 auto resultTy = llvm::cast<ShapedType>(
getType());
1574 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1576 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1578 if (!lhsAttr || !rhsAttr)
1584OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1585 auto resultTy = llvm::cast<ShapedType>(
getType());
1587 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1589 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1591 if (!lhsAttr || !rhsAttr)
1597OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1598 auto resultTy = llvm::cast<ShapedType>(
getType());
1600 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1602 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1603 Value
lhs = getInput1();
1604 Value
rhs = getInput2();
1605 auto lhsTy = llvm::cast<ShapedType>(
lhs.getType());
1609 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy.hasRank() &&
1610 resultTy.hasStaticShape() &&
lhs ==
rhs) {
1614 if (!lhsAttr || !rhsAttr)
1620OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1624 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1628 auto inTy = llvm::cast<ShapedType>(getInput().
getType());
1629 auto outTy = llvm::cast<ShapedType>(
getType());
1630 if (!outTy.hasRank() || !outTy.hasStaticShape())
1632 auto inETy = inTy.getElementType();
1633 auto outETy = outTy.getElementType();
1635 if (operand.isSplat()) {
1636 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1638 auto splatVal = operand.getSplatValue<APFloat>();
1639 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1640 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1645 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1646 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1647 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1648 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1649 llvm::RoundingMode::NearestTiesToEven);
1653 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1654 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1655 auto intVal = APSInt(
1656 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1657 auto floatVal = operand.getSplatValue<APFloat>();
1659 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1664 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1665 const auto inIntType = llvm::cast<IntegerType>(inETy);
1666 auto unsignIn = inIntType.isUnsignedInteger();
1668 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1669 auto intVal = operand.getSplatValue<APInt>();
1670 auto bitwidth = outETy.getIntOrFloatBitWidth();
1673 if (outETy.isInteger(1)) {
1674 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1676 intVal = intVal.trunc(bitwidth);
1677 }
else if (unsignIn || inIntType.isInteger(1)) {
1678 intVal = intVal.zext(bitwidth);
1680 intVal = intVal.sext(bitwidth);
1690OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1692OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1694#define REDUCE_FOLDER(OP) \
1695 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1696 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1697 if (!inputTy.hasRank()) \
1699 if (inputTy != getType()) \
1701 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1702 return getInput(); \
1715 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1716 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1718 if (!inputTy || !outputTy)
1724 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1728 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1729 getInput1().getDefiningOp())) {
1730 getInput1Mutable().assign(reshapeOp.getInput1());
1735 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1740 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1742 if (!outputTy.hasStaticShape())
1746 if (operand.isSplat())
1751 if (!getInput1().hasOneUse())
1758 return operand.reshape(
1759 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1765OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1767 if (adaptor.getPadding() && getInput1().
getType() ==
getType()) {
1768 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1769 if (densePad && densePad.isSplat() &&
1770 densePad.getSplatValue<APInt>().isZero()) {
1780OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1782 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1784 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1786 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1787 if (!scaleAttr || !offsetAttr || !borderAttr) {
1794 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1799 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1804 if (offset[0] != 0 || offset[1] != 0) {
1809 if (border[0] != 0 || border[1] != 0) {
1813 return foldToInputIfTypeMatches(
getType(), getInput());
1816OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1817 auto operand = getInput1();
1818 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1819 auto axis = getAxis();
1821 const bool isSplatInput =
1822 llvm::isa_and_nonnull<SplatElementsAttr>(adaptor.getInput1());
1823 if (!operandTy.hasRank() ||
1824 (!isSplatInput && operandTy.getDimSize(axis) != 1))
1826 return foldToInputIfTypeMatches(
getType(), operand);
1829OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1830 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1831 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1833 if (!inputTy || !outputTy)
1836 if (inputTy == outputTy && inputTy.hasStaticShape())
1841 DenseElementsAttr startElems;
1847 llvm::all_of(startElems.
getValues<APInt>(),
1848 [](
const APInt &val) { return val.isZero(); });
1853 DenseElementsAttr sizeElems;
1857 auto inputShape = inputTy.getShape();
1858 auto sizeValues = sizeElems.
getValues<APInt>();
1860 bool sizeMatchesInput =
true;
1861 for (
const auto &[i, sizeVal] : llvm::enumerate(sizeValues)) {
1862 int64_t size = sizeVal.getSExtValue();
1864 if (inputTy.isDynamicDim(i)) {
1868 sizeMatchesInput =
false;
1875 sizeMatchesInput =
false;
1881 if (sizeMatchesInput)
1886 if (!adaptor.getInput1())
1890 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1891 !outputTy.getElementType().isIntOrIndexOrFloat())
1894 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1895 if (operand.isSplat() && outputTy.hasStaticShape()) {
1899 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1900 outputTy.getNumElements() == 1) {
1901 llvm::SmallVector<uint64_t>
indices =
1902 llvm::to_vector(startElems.
getValues<uint64_t>());
1903 if (
auto values = operand.tryGetValues<Attribute>())
1910OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1911 const Value pred = getPred();
1912 const Value onTrue = getOnTrue();
1913 const Value onFalse = getOnFalse();
1915 const auto predTy = llvm::dyn_cast<RankedTensorType>(pred.
getType());
1916 const auto onTrueTy = llvm::dyn_cast<RankedTensorType>(onTrue.
getType());
1917 const auto onFalseTy = llvm::dyn_cast<RankedTensorType>(onFalse.
getType());
1918 if (!predTy || !onTrueTy || !onFalseTy)
1921 const Type resultTy =
getType();
1923 const ArrayRef<int64_t> predShape = predTy.getShape();
1924 const ArrayRef<int64_t> onTrueShape = onTrueTy.getShape();
1926 if (onTrue == onFalse && onTrueTy == resultTy &&
1931 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1934 if (!predicate.isSplat())
1937 const bool predicateValue = predicate.getSplatValue<APInt>().getBoolValue();
1939 SmallVector<SmallVector<int64_t>, 3> shapes;
1940 shapes.emplace_back(predShape);
1941 shapes.emplace_back(onTrueShape);
1942 shapes.emplace_back(onFalseTy.getShape());
1943 const bool isBroadcastable =
1946 if (predicateValue ==
true && onTrueTy == resultTy && isBroadcastable)
1948 if (predicateValue ==
false && onFalseTy == resultTy && isBroadcastable)
1953OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1955 if (
auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1956 adaptor.getMultiples())) {
1957 if (multiples.isSplat() &&
1958 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1960 if (
auto int_array_attr =
1961 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1962 if (llvm::all_of(int_array_attr.getValues<APInt>(),
1963 [](APInt v) { return v.getSExtValue() == 1; }))
1971OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1972 auto resultTy = llvm::cast<ShapedType>(
getType());
1976 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1977 if (input.isSplat() && resultTy.hasRank() && resultTy.hasStaticShape() &&
1978 input.
getType().getElementType() == resultTy.getElementType())
1979 return input.reshape(resultTy);
1983 const llvm::ArrayRef<int32_t> perms = getPerms();
1985 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1988 return foldToInputIfTypeMatches(
getType(), getInput1());
1991OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1994 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
2000 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2001 failed(maybeIZp) || *maybeIZp != 0) {
2005 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2006 failed(maybeOZp) || *maybeOZp != 0) {
2010 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
2011 failed(maybeIZp) || *maybeIZp != 0) {
2015 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
2016 failed(maybeOZp) || *maybeOZp != 0) {
2021 return foldToInputIfTypeMatches(
getType(), definingOp.getInput1());
2024OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
2025 auto input = getInput1();
2028 return foldToInputIfTypeMatches(
getType(), input);
2033OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
2038 SmallVector<Value, 8> concatOperands;
2039 concatOperands.reserve(2 * getNumOperands());
2042 bool foundFoldableConcat =
false;
2043 for (Value operand : getOperands()) {
2044 concatOperands.emplace_back(operand);
2046 auto producer = operand.getDefiningOp<ConcatOp>();
2051 if (getAxis() != producer.getAxis())
2055 foundFoldableConcat =
true;
2056 concatOperands.pop_back();
2057 llvm::append_range(concatOperands, producer->getOperands());
2060 if (!foundFoldableConcat)
2063 getOperation()->setOperands(concatOperands);
2067OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
2068 auto input = adaptor.getInput1();
2070 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
2072 if (!inputAttr || !inputAttr.isSplat())
2075 auto shapeType = llvm::cast<ShapedType>(
getType());
2076 if (!shapeType.hasRank() || !shapeType.hasStaticShape())
2078 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
2079 auto floatVal = inputAttr.getSplatValue<APFloat>();
2081 ReciprocalOp::calcOneElement(floatVal));
2087template <
typename Op,
typename OpFoldAdaptor>
2089 auto input1ConstShape =
2090 dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
2091 if (!input1ConstShape)
2094 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2100template <
typename Op,
typename OpFoldAdaptor>
2102 auto input1ConstShape =
2103 dyn_cast<tosa::ConstShapeOp>(op->getInput1().getDefiningOp());
2104 auto input2ConstShape =
2105 dyn_cast<tosa::ConstShapeOp>(op->getInput2().getDefiningOp());
2106 if (!input1ConstShape || !input2ConstShape)
2109 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2110 const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
2113 input1Attr.getType(),
2117OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
2118 const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().
getType());
2119 if (!inputTy || !inputTy.hasRank())
2121 const int32_t axis = getAxis();
2122 const int64_t dimSize = inputTy.getDimSize(axis);
2123 if (ShapedType::isDynamic(dimSize))
2127 const auto resultAttrTy =
2128 RankedTensorType::get(1, builder.getIndexType());
2133 auto const inputs = op->getInput();
2139 concatDims.reserve( 64);
2140 for (
auto const &v : inputs) {
2141 auto vConstShape = dyn_cast<tosa::ConstShapeOp>(v.getDefiningOp());
2145 const auto vAttr = cast<DenseElementsAttr>(vConstShape.getValues());
2148 auto const vAttrVals = vAttr.getValues<APInt>();
2149 for (
auto const &v : vAttrVals) {
2150 concatDims.push_back(v);
2154 auto *ctx = op->getContext();
2155 assert(ctx !=
nullptr &&
"ctx is nullptr");
2156 auto const rankedTy = RankedTensorType::get(
2157 {
static_cast<int64_t>(concatDims.size())}, IndexType::get(ctx));
2163 auto const input1 = op->getInput();
2164 auto const input2 = op->getStart();
2165 auto const input3 = op->getSize();
2167 auto input1ConstShape = dyn_cast<tosa::ConstShapeOp>(input1.getDefiningOp());
2169 if (!input1ConstShape)
2172 auto const input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2176 auto const input1Vals = input1Attr.getValues<APInt>();
2177 auto const totalInput1 = input1Vals.size();
2182 if (failed(start) || failed(size))
2185 auto const startV =
static_cast<int32_t
>(start.value());
2186 auto const sizeV =
static_cast<int32_t
>(size.value());
2188 if ((sizeV <= 0) || (startV < 0) ||
2189 (
static_cast<size_t>(startV + sizeV) > totalInput1))
2193 sliceOfInput.reserve(totalInput1);
2195 for (
auto i = startV; i < (startV + sizeV); i++) {
2196 sliceOfInput.push_back(input1Vals[i]);
2199 auto *ctx = op->getContext();
2200 assert(ctx !=
nullptr &&
"ctx is nullptr");
2202 auto const rankedTy = RankedTensorType::get(
2203 {
static_cast<int64_t>(sliceOfInput.size())}, IndexType::get(ctx));
2208OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
2212OpFoldResult tosa::SubShapeOp::fold(FoldAdaptor adaptor) {
2216OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
2220OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
2221 return binaryFold<DivCeilShapeOp, DivFoldAdaptor<
true>>(
this);
2224OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
2225 return binaryFold<DivFloorShapeOp, DivFoldAdaptor<
false>>(
this);
2228OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
2232OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
2236OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
2240OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
2244OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
2248OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
2252OpFoldResult tosa::ConcatShapeOp::fold(FoldAdaptor adaptor) {
2256OpFoldResult tosa::SliceShapeOp::fold(FoldAdaptor adaptor) {
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
#define REDUCE_FOLDER(OP)
OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, bool foldDenseValues=false)
static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy, bool foldDenseValues=false)
OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op)
static FailureOr< int64_t > getSingleI64From1ElementTensor(Value v)
OpFoldResult binaryFold(Op *op)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
OpFoldResult unaryShapeFold(Op *op)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool signsDiffer(const APInt &a, const APInt &b)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
DynamicAPInt round(const Fraction &f)
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp, PatternRewriter &rewriter) const override
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
LogicalResult matchAndRewrite(tosa::CastOp castOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...