25#include "llvm/ADT/APFloat.h"
26#include "llvm/ADT/APInt.h"
46 (padConstAttr.
size() != 1)) {
51 if (
auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
52 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
53 return padConstVal == 0.0f;
57 if (
auto padConstIntAttr =
58 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
67 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
68 return zpVal == padConstVal;
76template <
typename OpTy>
77struct PoolPadFoldAdaptor;
80struct PoolPadFoldAdaptor<
tosa::MaxPool2dOp> {
81 using OpTy = tosa::MaxPool2dOp;
82 static bool checkKernelCompliance(OpTy op,
const ArrayRef<int64_t> newPad) {
83 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
84 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
85 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
89 static bool checkPadConstCompliance(OpTy, Value padConst) {
91 DenseElementsAttr padConstAttr;
93 padConstAttr.
size() != 1) {
98 if (
auto padConstFpAttr =
99 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
100 const APFloat padConstVal = *padConstFpAttr.begin();
101 const APFloat lowestVal =
102 APFloat::getLargest(padConstVal.getSemantics(),
true);
103 return padConstVal == lowestVal;
105 if (
auto padConstIntAttr =
106 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
107 const APInt padConstVal = *padConstIntAttr.begin();
108 const unsigned int bitWidth = padConstVal.getBitWidth();
109 const APInt lowestVal =
110 padConstIntAttr.getElementType().isUnsignedInteger()
111 ? APInt::getZero(bitWidth)
112 : APInt::getSignedMinValue(bitWidth);
113 return padConstVal == lowestVal;
119 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
120 Value padInput, ArrayRef<int64_t> newPad) {
122 op, op.getType(), padInput, op.getKernel(), op.getStride(),
127template <
typename OpTy>
128struct ConvPadFoldAdaptor {
129 static bool checkKernelCompliance(OpTy,
const ArrayRef<int64_t>) {
132 static bool checkPadConstCompliance(OpTy op, Value padConst) {
135 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
136 Value padInput, ArrayRef<int64_t> newPad) {
138 op, op.getResult().
getType(), padInput, op.getWeight(), op.getBias(),
139 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
140 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
148template <
typename OpTy,
typename AdaptorTy>
150 using OpRewritePattern<OpTy>::OpRewritePattern;
152 LogicalResult matchAndRewrite(OpTy tensorOp,
153 PatternRewriter &rewriter)
const override {
155 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
158 "Producer must be a tosa::PadOp.");
161 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
162 if (tensorOpPad.size() != 4)
164 tensorOp,
"Tensor operation padding shall have 4 elements.");
167 DenseIntElementsAttr padOpPadding;
171 "The `padding` input specified on the tosa::PadOp must be constant.");
175 if (padOpPadding.size() != 8)
177 "Pad padding should have 8 elements.");
178 int64_t padNBefore = (*(padOpPadding.
begin() + 0)).getLimitedValue();
179 int64_t padNAfter = (*(padOpPadding.
begin() + 1)).getLimitedValue();
180 int64_t padHBefore = (*(padOpPadding.
begin() + 2)).getLimitedValue();
181 int64_t padHAfter = (*(padOpPadding.
begin() + 3)).getLimitedValue();
182 int64_t padWBefore = (*(padOpPadding.
begin() + 4)).getLimitedValue();
183 int64_t padWAfter = (*(padOpPadding.
begin() + 5)).getLimitedValue();
184 int64_t padCBefore = (*(padOpPadding.
begin() + 6)).getLimitedValue();
185 int64_t padCAfter = (*(padOpPadding.
begin() + 7)).getLimitedValue();
187 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
189 tensorOp,
"Folding padding in N or C dimensions is not supported.");
193 SmallVector<int64_t> foldedPad(tensorOpPad.size());
194 foldedPad[0] = padHBefore + tensorOpPad[0];
195 foldedPad[1] = padHAfter + tensorOpPad[1];
196 foldedPad[2] = padWBefore + tensorOpPad[2];
197 foldedPad[3] = padWAfter + tensorOpPad[3];
200 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
202 tensorOp,
"Padding size not aligned with kernel restrictions.");
206 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
209 "Padding constant is not aligned with operator zero-point.");
213 if (llvm::any_of(foldedPad, [](int64_t padVal) {
return padVal > 8192; })) {
215 tensorOp,
"Padding size more than the 8K level limit.");
219 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
230 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
236 results.
add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
237 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
246 Value input = op.getInput();
247 Value output = op.getOutput();
248 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
249 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
251 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
257 if (outputShape[1] != 1 || outputShape[2] != 1) {
262 if (inputShape[1] != 1 || inputShape[2] != 1) {
274 FoldPadToTensorOp<tosa::MaxPool2dOp,
275 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
288 if (op.getInput1().size() != 1)
290 if (op.getInput1().front().getType() != op.getType()) {
293 op.getInput1().front())
298 rewriter.
replaceOp(op, op.getInput1().front());
308LogicalResult SelectOp::canonicalize(SelectOp op,
PatternRewriter &rewriter) {
309 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
313 op.getOperation()->setOperands(
314 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
326 auto innerTranspose =
327 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
330 "input must be transpose operation");
334 innerTranspose.getPerms();
336 if (transposePerms.size() != innerTransposePerms.size())
339 "transpose and inner transpose perms sizes must be equal");
340 if (transposePerms.empty())
342 transposeOp,
"transpose perms sizes must be positive");
346 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
347 perms[i] = innerTransposePerms[transposePerms[i]];
350 transposeOp, transposeOp.getResult().
getType(),
363 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
365 op,
"Src is from transpose, can compose transposes");
369 if (isa_and_nonnull<tosa::TransposeOp>(subop))
371 op,
"Dest is used by transpose, can compose transposes");
374 auto input = op.getInput1();
375 auto inputTy = llvm::cast<ShapedType>(input.getType());
376 if (!inputTy.hasRank())
380 for (
int i = 0; i < inputTy.getRank(); ++i)
381 if (inputTy.isDynamicDim(i))
390 nonZeroPerms.reserve(permValues.size());
391 for (
auto idx : permValues) {
392 auto sz = inputTy.getDimSize(idx);
394 nonZeroPerms.push_back(idx);
397 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
398 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
400 "Transpose changes memory layout.");
403 newShape.reserve(inputTy.getRank());
404 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
405 newShape.push_back(inputTy.getDimSize(permValues[i]));
408 op, op.getType(), op.getInput1(),
416 results.
add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
424 Value input = op.getInput();
425 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
426 auto inputElementType = inputType.getElementType();
428 if (isa<FloatType>(inputElementType)) {
430 const auto minClamp =
431 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
432 const auto maxClamp =
433 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
434 const bool isMin = minClamp.isNegInfinity();
435 const bool isMax = maxClamp.isInfinity();
437 if (isMin && isMax) {
445 const bool isBoolean = inputElementType.isInteger(1);
446 if (inputElementType.isUnsignedInteger() || isBoolean) {
447 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
450 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
454 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
455 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
456 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
458 if (minClamp <= intMin && maxClamp >= intMax) {
465 if (llvm::isa<IntegerType>(inputElementType)) {
467 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
469 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
471 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
472 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
473 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
475 if (minClamp <= intMin && maxClamp >= intMax) {
507 template <
typename T>
521 Value input = op.getInput();
529 const auto opNanMode = op.getNanMode();
530 const auto clampNanMode = clampOp.getNanMode();
531 if (opNanMode == NanPropagationMode::IGNORE &&
532 clampNanMode == NanPropagationMode::PROPAGATE)
535 auto maxValAttr = op.getMaxValAttr();
536 auto minValAttr = op.getMinValAttr();
537 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
538 auto clampOpMinValAttr = clampOp.getMinValAttr();
540 auto inputEType = llvm::cast<ShapedType>(input.
getType()).getElementType();
542 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
547 if (mlir::isa<FloatType>(inputEType)) {
548 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
549 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
550 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
551 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
554 const auto opMinFloat = floatMinValAttr.getValue();
555 const auto opMaxFloat = floatMaxValAttr.getValue();
556 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
557 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
561 if (!opRangeFloatRange.
intersects(clampRangeFloatRange))
565 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
566 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
567 newMinValAttr = rewriter.
getFloatAttr(inputEType, newMinVal);
568 newMaxValAttr = rewriter.
getFloatAttr(inputEType, newMaxVal);
570 assert(mlir::isa<IntegerType>(inputEType));
571 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
572 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
573 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
574 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
576 if (inputEType.isUnsignedInteger()) {
578 const auto opMinInt = intMinValAttr.getUInt();
579 const auto opMaxInt = intMaxValAttr.getUInt();
580 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
581 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
585 if (!opRangeIntRange.
intersects(clampRangeIntRange))
589 auto newMinVal = std::max(opMinInt, clampOpMinInt);
590 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
595 const auto opMinInt = intMinValAttr.getInt();
596 const auto opMaxInt = intMaxValAttr.getInt();
597 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
598 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
602 if (!opRangeIntRange.
intersects(clampRangeIntRange))
606 auto newMinVal = std::max(opMinInt, clampOpMinInt);
607 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
613 auto newMode = (opNanMode != clampNanMode)
614 ? tosa::NanPropagationMode::IGNORE
618 NanPropagationModeAttr::get(rewriter.
getContext(), newMode);
621 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
627void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
628 MLIRContext *context) {
629 results.
add<ClampIsNoOp>(context);
630 results.
add<ClampClampOptimization>(context);
638 Value sliceInput = sliceOp.getInput1();
642 sliceOp,
"slice input must be concat operation");
645 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
646 if (!concatType || !concatType.hasStaticShape())
648 sliceOp,
"slice input must be a static ranked tensor");
649 int32_t axis = concatOp.getAxis();
656 sliceOp,
"start of slice must be a static ranked shape");
660 sliceOp,
"size of slice must be a static ranked shape");
670 std::optional<Value> replaceWithSlice;
671 for (
auto input : inputs) {
672 auto inputType = dyn_cast<RankedTensorType>(input.getType());
673 if (!inputType || !inputType.hasStaticShape())
675 sliceOp,
"concat input must be a static ranked tensor");
677 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
678 inputType.getDimSize(axis)) {
684 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
685 input, start_op, size_op)
689 sliceStarts[axis] -= inputType.getDimSize(axis);
692 if (!replaceWithSlice)
694 sliceOp,
"corresponding concat input not found for slice");
696 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
706 Value sliceInput = sliceOp.getInput1();
712 "slice input must be a pad operation");
715 if (!padOp->hasOneUse())
717 "pad shall have a single consumer");
720 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
721 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
722 if (!inputTy || !padTy || !inputTy.hasRank())
724 "slice input must be a ranked tensor");
731 "`padding` input specified on the tosa::PadOp must be constant.");
734 llvm::to_vector(paddingElems.getValues<
int64_t>());
740 sliceOp,
"start of slice must be a static ranked shape");
747 sliceOp,
"size of slice must be a static ranked shape");
752 const int64_t rank = inputTy.getRank();
753 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](
int64_t i) {
754 const bool isDimDynamic = inputTy.isDynamicDim(i);
755 const bool isDimSliced =
756 (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
758 return isDimDynamic && isDimSliced;
761 sliceOp,
"axis that are sliced shall be statically known.");
768 bool updated =
false;
770 for (
int64_t i = 0; i < rank; ++i) {
771 const int64_t padLo = padPaddings[i * 2];
772 const int64_t padHi = padPaddings[i * 2 + 1];
773 const int64_t sliceStart = sliceStarts[i];
774 const int64_t sliceSize = sliceSizes[i];
775 const int64_t sliceEnd = sliceStart + sliceSize;
778 if (inputTy.isDynamicDim(i)) {
779 newPadPaddings[i * 2] = padLo;
780 newPadPaddings[i * 2 + 1] = padHi;
781 newSliceStarts[i] = sliceStart;
786 const int64_t dimSize = inputTy.getShape()[i];
787 const int64_t dimTotal = padLo + dimSize + padHi;
790 if (sliceStart < 0 || sliceEnd > dimTotal)
794 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
795 newSliceStarts[i] = newSliceStart;
796 updated |= newSliceStart != sliceStart;
799 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
801 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
802 newPadPaddings[i * 2] = newPadLo;
803 newPadPaddings[i * 2 + 1] = newPadHi;
804 updated |= (newPadLo != padLo) || (newPadHi != padHi);
808 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
814 sliceOp,
"terminate condition; nothing to rewrite");
820 RankedTensorType::get(newPadShape, inputTy.getElementType());
821 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
822 padOp.getInput1(), newPaddingsOp,
823 padOp.getPadConst());
829 newPadOp.getResult(), newStartOp,
844 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
846 ElementsAttr sizeElems;
849 sliceOp,
"size of slice must be a static ranked shape");
853 llvm::to_vector(sizeElems.getValues<
int64_t>());
855 bool replaceSliceSize{
false};
859 for (
const auto &[
index, size] : llvm::enumerate(sliceSizes)) {
860 if (size == -1 && !resultType.isDynamicDim(
index)) {
861 sliceSizes[
index] = resultType.getDimSize(
index);
862 replaceSliceSize =
true;
866 if (!replaceSliceSize) {
868 sliceOp,
"no dimension of size of slice is dynamic that resolves "
869 "to static output shape");
874 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
875 sliceOp.getInput1(), sliceOp.getStart(), size_op);
877 rewriter.
replaceOp(sliceOp, newSliceOp.getResult());
882void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
883 MLIRContext *context) {
884 results.
add<ConcatSliceOptimization, PadSliceOptimization,
885 SliceDynamicSizeCanonicalization>(context);
892template <
typename IntFolder,
typename FloatFolder>
895 RankedTensorType returnTy) {
897 auto lETy = llvm::cast<ShapedType>(
lhs.getType()).getElementType();
898 auto rETy = llvm::cast<ShapedType>(
rhs.getType()).getElementType();
902 if (llvm::isa<IntegerType>(lETy)) {
903 APInt l =
lhs.getSplatValue<APInt>();
904 APInt r =
rhs.getSplatValue<APInt>();
905 auto result = IntFolder()(l, r);
909 if (llvm::isa<FloatType>(lETy)) {
910 APFloat l =
lhs.getSplatValue<APFloat>();
911 APFloat r =
rhs.getSplatValue<APFloat>();
912 auto result = FloatFolder()(l, r);
921 if (llvm::isa<FloatType>(elemType))
923 if (llvm::isa<IntegerType>(elemType))
929 if (llvm::isa<FloatType>(elemType))
932 if (llvm::isa<IntegerType>(elemType)) {
933 const int64_t shifted = 1LL << shift;
940OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
941 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
942 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
943 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
944 if (!lhsTy || !rhsTy || !resultTy)
948 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
949 !rhsTy.getElementType().isIntOrIndexOrFloat())
952 auto resultETy = resultTy.getElementType();
954 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
956 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
958 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
960 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
963 if (!lhsAttr || !rhsAttr)
970OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
971 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
972 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
973 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
974 !outputTy.hasStaticShape())
978 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.
isInteger()) {
979 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
980 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
987OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
988 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
989 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
990 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
991 if (!lhsTy || !rhsTy || !resultTy)
997 auto resultETy = resultTy.getElementType();
999 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1001 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1002 if (lhsAttr && lhsAttr.isSplat()) {
1003 if (llvm::isa<IntegerType>(resultETy) &&
1004 lhsAttr.getSplatValue<APInt>().isZero())
1008 if (rhsAttr && rhsAttr.isSplat()) {
1009 if (llvm::isa<IntegerType>(resultETy) &&
1010 rhsAttr.getSplatValue<APInt>().isOne())
1014 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1015 llvm::isa<IntegerType>(resultETy)) {
1016 APInt l = lhsAttr.getSplatValue<APInt>();
1017 APInt r = rhsAttr.getSplatValue<APInt>();
1019 APInt
result = l.sdiv(r);
1030std::optional<APInt> mulInt(APInt
lhs, APInt
rhs, int32_t shift,
1031 unsigned bitwidth) {
1035 auto round = APInt(64, 1) << (shift - 1);
1037 result.ashrInPlace(shift);
1039 if (!(
result.getSExtValue() >= INT32_MIN &&
1040 result.getSExtValue() <= INT32_MAX)) {
1042 return std::nullopt;
1046 return result.trunc(bitwidth);
1049DenseElementsAttr mulBinaryFolder(DenseElementsAttr
lhs, DenseElementsAttr
rhs,
1050 RankedTensorType ty, int32_t shift) {
1052 if (llvm::isa<IntegerType>(ty.getElementType())) {
1053 APInt l =
lhs.getSplatValue<APInt>();
1054 APInt r =
rhs.getSplatValue<APInt>();
1060 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1061 const std::optional<APInt>
result = mulInt(l, r, shift, bitwidth);
1067 if (llvm::isa<FloatType>(ty.getElementType())) {
1068 APFloat l =
lhs.getSplatValue<APFloat>();
1069 APFloat r =
rhs.getSplatValue<APFloat>();
1079OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1080 auto lhs = getInput1();
1081 auto rhs = getInput2();
1082 auto lhsTy = llvm::dyn_cast<RankedTensorType>(
lhs.getType());
1083 auto rhsTy = llvm::dyn_cast<RankedTensorType>(
rhs.getType());
1084 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1085 if (!lhsTy || !rhsTy || !resultTy)
1088 auto resultETy = resultTy.getElementType();
1090 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1092 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1097 if (resultETy.isInteger(32)) {
1098 ElementsAttr shift_elem;
1099 if (getShift().getImpl()) {
1103 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1107 if (rhsTy == resultTy) {
1108 if (
isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
1110 return lhsAttr.resizeSplat(resultTy);
1114 if (lhsTy == resultTy) {
1115 if (
isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
1116 return rhsAttr.resizeSplat(resultTy);
1121 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1124OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1125 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1126 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1127 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1128 if (!lhsTy || !rhsTy || !resultTy)
1132 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1133 !rhsTy.getElementType().isIntOrIndexOrFloat())
1136 auto resultETy = resultTy.getElementType();
1138 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1140 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1142 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1145 if (!lhsAttr || !rhsAttr)
1153template <
typename Cmp>
1154struct ComparisonFold {
1155 ComparisonFold() =
default;
1156 APInt operator()(
const APInt &l,
const APInt &r) {
1157 return APInt(1, Cmp()(l, r));
1160 APInt operator()(
const APFloat &l,
const APFloat &r) {
1161 return APInt(1, Cmp()(l, r));
1165struct APIntFoldGreater {
1166 APIntFoldGreater() =
default;
1167 APInt operator()(
const APInt &l,
const APInt &r) {
1168 return APInt(1, l.sgt(r));
1172struct APIntFoldGreaterEqual {
1173 APIntFoldGreaterEqual() =
default;
1174 APInt operator()(
const APInt &l,
const APInt &r) {
1175 return APInt(1, l.sge(r));
1180OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1181 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1183 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1185 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1187 if (!lhsAttr || !rhsAttr)
1191 lhsAttr, rhsAttr, resultTy);
1194OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1195 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1197 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1199 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1201 if (!lhsAttr || !rhsAttr)
1205 ComparisonFold<std::greater_equal<APFloat>>>(
1206 lhsAttr, rhsAttr, resultTy);
1209OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1210 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1212 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1214 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1215 Value
lhs = getInput1();
1216 Value
rhs = getInput2();
1217 auto lhsTy = llvm::cast<ShapedType>(
lhs.getType());
1221 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1222 resultTy.hasStaticShape() &&
lhs ==
rhs) {
1226 if (!lhsAttr || !rhsAttr)
1230 ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
1234OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1238 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1242 auto inTy = llvm::cast<ShapedType>(getInput().
getType());
1243 auto outTy = llvm::cast<ShapedType>(
getType());
1244 auto inETy = inTy.getElementType();
1245 auto outETy = outTy.getElementType();
1247 if (operand.isSplat()) {
1248 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1250 auto splatVal = operand.getSplatValue<APFloat>();
1251 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1252 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1257 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1258 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1259 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1260 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1261 llvm::RoundingMode::NearestTiesToEven);
1265 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1266 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1267 auto intVal = APSInt(
1268 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1269 auto floatVal = operand.getSplatValue<APFloat>();
1271 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1276 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1277 const auto inIntType = llvm::cast<IntegerType>(inETy);
1278 auto unsignIn = inIntType.isUnsignedInteger();
1280 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1281 auto intVal = operand.getSplatValue<APInt>();
1282 auto bitwidth = outETy.getIntOrFloatBitWidth();
1285 if (outETy.isInteger(1)) {
1286 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1288 intVal = intVal.trunc(bitwidth);
1289 }
else if (unsignIn || inIntType.isInteger(1)) {
1290 intVal = intVal.zext(bitwidth);
1292 intVal = intVal.sext(bitwidth);
1302OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1304OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1306#define REDUCE_FOLDER(OP) \
1307 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1308 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1309 if (!inputTy.hasRank()) \
1311 if (inputTy != getType()) \
1313 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1314 return getInput(); \
1327 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1328 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1330 if (!inputTy || !outputTy)
1336 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1340 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1341 getInput1().getDefiningOp())) {
1342 getInput1Mutable().assign(reshapeOp.getInput1());
1347 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1352 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1354 if (!outputTy.hasStaticShape())
1358 if (operand.isSplat())
1363 if (!getInput1().hasOneUse())
1370 return operand.reshape(
1371 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1377OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1379 if (adaptor.getPadding() && getInput1().
getType() ==
getType()) {
1380 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1381 if (densePad && densePad.isSplat() &&
1382 densePad.getSplatValue<APInt>().isZero()) {
1392OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1394 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1396 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1398 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1399 if (!scaleAttr || !offsetAttr || !borderAttr) {
1406 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1411 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1416 if (offset[0] != 0 || offset[1] != 0) {
1421 if (border[0] != 0 || border[1] != 0) {
1425 auto input = getInput();
1426 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1427 auto resultTy = llvm::cast<RankedTensorType>(
getType());
1428 if (inputTy != resultTy)
1434OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1435 auto operand = getInput1();
1436 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1437 auto axis = getAxis();
1439 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1444 if (operandTy.hasRank() &&
1445 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1451OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1452 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1453 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1455 if (!inputTy || !outputTy)
1458 if (inputTy == outputTy && inputTy.hasStaticShape())
1461 if (!adaptor.getInput1())
1465 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1466 !outputTy.getElementType().isIntOrIndexOrFloat())
1469 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1470 if (operand.isSplat() && outputTy.hasStaticShape()) {
1474 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1475 outputTy.getNumElements() == 1) {
1476 DenseElementsAttr startElems;
1480 llvm::SmallVector<uint64_t>
indices =
1481 llvm::to_vector(startElems.
getValues<uint64_t>());
1482 auto value = operand.getValues<Attribute>()[
indices];
1491 const auto isDynamic = [](
Type ty) {
1492 const auto shapedTy = llvm::dyn_cast<ShapedType>(ty);
1493 return !shapedTy || !shapedTy.hasStaticShape();
1496 return llvm::any_of(operandTypes, isDynamic) ||
1500OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1507 if (getOnTrue() == getOnFalse())
1511 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1515 if (!predicate.isSplat())
1517 return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1521OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1523 if (
auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1524 adaptor.getMultiples())) {
1525 if (multiples.isSplat() &&
1526 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1528 if (
auto int_array_attr =
1529 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1530 if (llvm::all_of(int_array_attr.getValues<APInt>(),
1531 [](APInt v) { return v.getSExtValue() == 1; }))
1539OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1540 auto resultTy = llvm::cast<ShapedType>(
getType());
1544 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1545 if (input.isSplat() && resultTy.hasStaticShape() &&
1546 input.getType().getElementType() == resultTy.getElementType())
1547 return input.reshape(resultTy);
1551 const llvm::ArrayRef<int32_t> perms = getPerms();
1553 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1559OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1562 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1568 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1569 failed(maybeIZp) || *maybeIZp != 0) {
1573 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1574 failed(maybeOZp) || *maybeOZp != 0) {
1578 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1579 failed(maybeIZp) || *maybeIZp != 0) {
1583 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1584 failed(maybeOZp) || *maybeOZp != 0) {
1589 return definingOp.getInput1();
1592OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1593 auto input = getInput1();
1595 if (
auto op = input.getDefiningOp<tosa::AbsOp>()) {
1602OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1607 SmallVector<Value, 8> concatOperands;
1608 concatOperands.reserve(2 * getNumOperands());
1611 bool foundFoldableConcat =
false;
1612 for (Value operand : getOperands()) {
1613 concatOperands.emplace_back(operand);
1615 auto producer = operand.getDefiningOp<ConcatOp>();
1620 if (getAxis() != producer.getAxis())
1624 foundFoldableConcat =
true;
1625 concatOperands.pop_back();
1626 llvm::append_range(concatOperands, producer->getOperands());
1629 if (!foundFoldableConcat)
1632 getOperation()->setOperands(concatOperands);
1636OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1637 auto input = adaptor.getInput1();
1639 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1641 if (!inputAttr || !inputAttr.isSplat())
1644 auto shapeType = llvm::cast<ShapedType>(
getType());
1645 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1646 auto floatVal = inputAttr.getSplatValue<APFloat>();
1648 ReciprocalOp::calcOneElement(floatVal));
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
#define REDUCE_FOLDER(OP)
static bool mayRequireBroadcast(ValueTypeRange< mlir::OperandRange > operandTypes)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class implements iteration on the types of a given range of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
DynamicAPInt round(const Fraction &f)
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...