24#include "llvm/ADT/APFloat.h"
25#include "llvm/ADT/APInt.h"
45 (padConstAttr.
size() != 1)) {
50 if (
auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
51 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
52 return padConstVal == 0.0f;
56 if (
auto padConstIntAttr =
57 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
66 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
67 return zpVal == padConstVal;
75template <
typename OpTy>
76struct PoolPadFoldAdaptor;
79struct PoolPadFoldAdaptor<
tosa::MaxPool2dOp> {
80 using OpTy = tosa::MaxPool2dOp;
81 static bool checkKernelCompliance(OpTy op,
const ArrayRef<int64_t> newPad) {
82 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
83 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
84 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
88 static bool checkPadConstCompliance(OpTy, Value padConst) {
90 DenseElementsAttr padConstAttr;
92 padConstAttr.
size() != 1) {
97 if (
auto padConstFpAttr =
98 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
99 const APFloat padConstVal = *padConstFpAttr.begin();
100 const APFloat lowestVal =
101 APFloat::getLargest(padConstVal.getSemantics(),
true);
102 return padConstVal == lowestVal;
104 if (
auto padConstIntAttr =
105 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
106 const APInt padConstVal = *padConstIntAttr.begin();
107 const unsigned int bitWidth = padConstVal.getBitWidth();
108 const APInt lowestVal =
109 padConstIntAttr.getElementType().isUnsignedInteger()
110 ? APInt::getZero(bitWidth)
111 : APInt::getSignedMinValue(bitWidth);
112 return padConstVal == lowestVal;
118 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
119 Value padInput, ArrayRef<int64_t> newPad) {
121 op, op.getType(), padInput, op.getKernel(), op.getStride(),
126template <
typename OpTy>
127struct ConvPadFoldAdaptor {
128 static bool checkKernelCompliance(OpTy,
const ArrayRef<int64_t>) {
131 static bool checkPadConstCompliance(OpTy op, Value padConst) {
134 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
135 Value padInput, ArrayRef<int64_t> newPad) {
137 op, op.getResult().
getType(), padInput, op.getWeight(), op.getBias(),
138 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
139 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
147template <
typename OpTy,
typename AdaptorTy>
149 using OpRewritePattern<OpTy>::OpRewritePattern;
151 LogicalResult matchAndRewrite(OpTy tensorOp,
152 PatternRewriter &rewriter)
const override {
154 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
157 "Producer must be a tosa::PadOp.");
160 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
161 if (tensorOpPad.size() != 4)
163 tensorOp,
"Tensor operation padding shall have 4 elements.");
166 DenseIntElementsAttr padOpPadding;
170 "The `padding` input specified on the tosa::PadOp must be constant.");
174 if (padOpPadding.size() != 8)
176 "Pad padding should have 8 elements.");
177 int64_t padNBefore = (*(padOpPadding.
begin() + 0)).getLimitedValue();
178 int64_t padNAfter = (*(padOpPadding.
begin() + 1)).getLimitedValue();
179 int64_t padHBefore = (*(padOpPadding.
begin() + 2)).getLimitedValue();
180 int64_t padHAfter = (*(padOpPadding.
begin() + 3)).getLimitedValue();
181 int64_t padWBefore = (*(padOpPadding.
begin() + 4)).getLimitedValue();
182 int64_t padWAfter = (*(padOpPadding.
begin() + 5)).getLimitedValue();
183 int64_t padCBefore = (*(padOpPadding.
begin() + 6)).getLimitedValue();
184 int64_t padCAfter = (*(padOpPadding.
begin() + 7)).getLimitedValue();
186 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
188 tensorOp,
"Folding padding in N or C dimensions is not supported.");
192 SmallVector<int64_t> foldedPad(tensorOpPad.size());
193 foldedPad[0] = padHBefore + tensorOpPad[0];
194 foldedPad[1] = padHAfter + tensorOpPad[1];
195 foldedPad[2] = padWBefore + tensorOpPad[2];
196 foldedPad[3] = padWAfter + tensorOpPad[3];
199 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
201 tensorOp,
"Padding size not aligned with kernel restrictions.");
205 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
208 "Padding constant is not aligned with operator zero-point.");
212 if (llvm::any_of(foldedPad, [](int64_t padVal) {
return padVal > 8192; })) {
214 tensorOp,
"Padding size more than the 8K level limit.");
218 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
229 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
235 results.
add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
236 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
245 Value input = op.getInput();
246 Value output = op.getOutput();
247 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
248 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
250 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
256 if (outputShape[1] != 1 || outputShape[2] != 1) {
261 if (inputShape[1] != 1 || inputShape[2] != 1) {
273 FoldPadToTensorOp<tosa::MaxPool2dOp,
274 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
287 if (op.getInput1().size() != 1)
289 if (op.getInput1().front().getType() != op.getType()) {
292 op.getInput1().front())
297 rewriter.
replaceOp(op, op.getInput1().front());
307LogicalResult SelectOp::canonicalize(SelectOp op,
PatternRewriter &rewriter) {
308 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
312 op.getOperation()->setOperands(
313 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
325 auto innerTranspose =
326 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
329 "input must be transpose operation");
333 innerTranspose.getPerms();
335 if (transposePerms.size() != innerTransposePerms.size())
338 "transpose and inner transpose perms sizes must be equal");
339 if (transposePerms.empty())
341 transposeOp,
"transpose perms sizes must be positive");
345 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
346 perms[i] = innerTransposePerms[transposePerms[i]];
349 transposeOp, transposeOp.getResult().
getType(),
362 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
364 op,
"Src is from transpose, can compose transposes");
368 if (isa_and_nonnull<tosa::TransposeOp>(subop))
370 op,
"Dest is used by transpose, can compose transposes");
373 auto input = op.getInput1();
374 auto inputTy = llvm::cast<ShapedType>(input.getType());
375 if (!inputTy.hasRank())
379 for (
int i = 0; i < inputTy.getRank(); ++i)
380 if (inputTy.isDynamicDim(i))
389 nonZeroPerms.reserve(permValues.size());
390 for (
auto idx : permValues) {
391 auto sz = inputTy.getDimSize(idx);
393 nonZeroPerms.push_back(idx);
396 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
397 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
399 "Transpose changes memory layout.");
402 newShape.reserve(inputTy.getRank());
403 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
404 newShape.push_back(inputTy.getDimSize(permValues[i]));
407 op, op.getType(), op.getInput1(),
415 results.
add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
423 Value input = op.getInput();
424 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
425 auto inputElementType = inputType.getElementType();
427 if (isa<FloatType>(inputElementType)) {
429 const auto minClamp =
430 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
431 const auto maxClamp =
432 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
433 const bool isMin = minClamp.isNegInfinity();
434 const bool isMax = maxClamp.isInfinity();
436 if (isMin && isMax) {
444 const bool isBoolean = inputElementType.isInteger(1);
445 if (inputElementType.isUnsignedInteger() || isBoolean) {
446 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
449 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
453 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
454 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
455 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
457 if (minClamp <= intMin && maxClamp >= intMax) {
464 if (llvm::isa<IntegerType>(inputElementType)) {
466 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
468 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
470 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
471 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
472 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
474 if (minClamp <= intMin && maxClamp >= intMax) {
506 template <
typename T>
520 Value input = op.getInput();
528 const auto opNanMode = op.getNanMode();
529 const auto clampNanMode = clampOp.getNanMode();
530 if (opNanMode == NanPropagationMode::IGNORE &&
531 clampNanMode == NanPropagationMode::PROPAGATE)
534 auto maxValAttr = op.getMaxValAttr();
535 auto minValAttr = op.getMinValAttr();
536 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
537 auto clampOpMinValAttr = clampOp.getMinValAttr();
539 auto inputEType = llvm::cast<ShapedType>(input.
getType()).getElementType();
541 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
542 inputEType = quantType.getStorageType();
546 if (mlir::isa<FloatType>(inputEType)) {
547 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
548 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
549 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
550 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
553 const auto opMinFloat = floatMinValAttr.getValue();
554 const auto opMaxFloat = floatMaxValAttr.getValue();
555 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
556 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
560 if (!opRangeFloatRange.
intersects(clampRangeFloatRange))
564 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
565 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
566 newMinValAttr = rewriter.
getFloatAttr(inputEType, newMinVal);
567 newMaxValAttr = rewriter.
getFloatAttr(inputEType, newMaxVal);
569 assert(mlir::isa<IntegerType>(inputEType));
570 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
571 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
572 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
573 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
575 if (inputEType.isUnsignedInteger()) {
577 const auto opMinInt = intMinValAttr.getUInt();
578 const auto opMaxInt = intMaxValAttr.getUInt();
579 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
580 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
584 if (!opRangeIntRange.
intersects(clampRangeIntRange))
588 auto newMinVal = std::max(opMinInt, clampOpMinInt);
589 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
594 const auto opMinInt = intMinValAttr.getInt();
595 const auto opMaxInt = intMaxValAttr.getInt();
596 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
597 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
601 if (!opRangeIntRange.
intersects(clampRangeIntRange))
605 auto newMinVal = std::max(opMinInt, clampOpMinInt);
606 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
612 auto newMode = (opNanMode != clampNanMode)
613 ? tosa::NanPropagationMode::IGNORE
617 NanPropagationModeAttr::get(rewriter.
getContext(), newMode);
620 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
626void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
627 MLIRContext *context) {
628 results.
add<ClampIsNoOp>(context);
629 results.
add<ClampClampOptimization>(context);
637 Value sliceInput = sliceOp.getInput1();
641 sliceOp,
"slice input must be concat operation");
644 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
645 if (!concatType || !concatType.hasStaticShape())
647 sliceOp,
"slice input must be a static ranked tensor");
648 int32_t axis = concatOp.getAxis();
655 sliceOp,
"start of slice must be a static ranked shape");
659 sliceOp,
"size of slice must be a static ranked shape");
669 std::optional<Value> replaceWithSlice;
670 for (
auto input : inputs) {
671 auto inputType = dyn_cast<RankedTensorType>(input.getType());
672 if (!inputType || !inputType.hasStaticShape())
674 sliceOp,
"concat input must be a static ranked tensor");
676 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
677 inputType.getDimSize(axis)) {
683 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
684 input, start_op, size_op)
688 sliceStarts[axis] -= inputType.getDimSize(axis);
691 if (!replaceWithSlice)
693 sliceOp,
"corresponding concat input not found for slice");
695 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
705 Value sliceInput = sliceOp.getInput1();
711 "slice input must be a pad operation");
714 if (!padOp->hasOneUse())
716 "pad shall have a single consumer");
719 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
720 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
721 if (!inputTy || !padTy || !inputTy.hasRank())
723 "slice input must be a ranked tensor");
730 "`padding` input specified on the tosa::PadOp must be constant.");
733 llvm::to_vector(paddingElems.getValues<
int64_t>());
739 sliceOp,
"start of slice must be a static ranked shape");
746 sliceOp,
"size of slice must be a static ranked shape");
751 const int64_t rank = inputTy.getRank();
752 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](
int64_t i) {
753 const bool isDimDynamic = inputTy.isDynamicDim(i);
754 const bool isDimSliced =
755 (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
757 return isDimDynamic && isDimSliced;
760 sliceOp,
"axis that are sliced shall be statically known.");
767 bool updated =
false;
769 for (
int64_t i = 0; i < rank; ++i) {
770 const int64_t padLo = padPaddings[i * 2];
771 const int64_t padHi = padPaddings[i * 2 + 1];
772 const int64_t sliceStart = sliceStarts[i];
773 const int64_t sliceSize = sliceSizes[i];
774 const int64_t sliceEnd = sliceStart + sliceSize;
777 if (inputTy.isDynamicDim(i)) {
778 newPadPaddings[i * 2] = padLo;
779 newPadPaddings[i * 2 + 1] = padHi;
780 newSliceStarts[i] = sliceStart;
785 const int64_t dimSize = inputTy.getShape()[i];
786 const int64_t dimTotal = padLo + dimSize + padHi;
789 if (sliceStart < 0 || sliceEnd > dimTotal)
793 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
794 newSliceStarts[i] = newSliceStart;
795 updated |= newSliceStart != sliceStart;
798 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
800 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
801 newPadPaddings[i * 2] = newPadLo;
802 newPadPaddings[i * 2 + 1] = newPadHi;
803 updated |= (newPadLo != padLo) || (newPadHi != padHi);
807 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
813 sliceOp,
"terminate condition; nothing to rewrite");
819 RankedTensorType::get(newPadShape, inputTy.getElementType());
820 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
821 padOp.getInput1(), newPaddingsOp,
822 padOp.getPadConst());
828 newPadOp.getResult(), newStartOp,
843 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
845 ElementsAttr sizeElems;
848 sliceOp,
"size of slice must be a static ranked shape");
852 llvm::to_vector(sizeElems.getValues<
int64_t>());
854 bool replaceSliceSize{
false};
858 for (
const auto &[
index, size] : llvm::enumerate(sliceSizes)) {
859 if (size == -1 && !resultType.isDynamicDim(
index)) {
860 sliceSizes[
index] = resultType.getDimSize(
index);
861 replaceSliceSize =
true;
865 if (!replaceSliceSize) {
867 sliceOp,
"no dimension of size of slice is dynamic that resolves "
868 "to static output shape");
873 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
874 sliceOp.getInput1(), sliceOp.getStart(), size_op);
876 rewriter.
replaceOp(sliceOp, newSliceOp.getResult());
881void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
882 MLIRContext *context) {
883 results.
add<ConcatSliceOptimization, PadSliceOptimization,
884 SliceDynamicSizeCanonicalization>(context);
891template <
typename IntFolder,
typename FloatFolder>
894 RankedTensorType returnTy) {
896 auto lETy = llvm::cast<ShapedType>(
lhs.getType()).getElementType();
897 auto rETy = llvm::cast<ShapedType>(
rhs.getType()).getElementType();
901 if (llvm::isa<IntegerType>(lETy)) {
902 APInt l =
lhs.getSplatValue<APInt>();
903 APInt r =
rhs.getSplatValue<APInt>();
904 auto result = IntFolder()(l, r);
908 if (llvm::isa<FloatType>(lETy)) {
909 APFloat l =
lhs.getSplatValue<APFloat>();
910 APFloat r =
rhs.getSplatValue<APFloat>();
911 auto result = FloatFolder()(l, r);
920 if (llvm::isa<FloatType>(elemType))
922 if (llvm::isa<IntegerType>(elemType))
928 if (llvm::isa<FloatType>(elemType))
931 if (llvm::isa<IntegerType>(elemType)) {
932 const int64_t shifted = 1LL << shift;
939OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
940 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
941 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
942 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
943 if (!lhsTy || !rhsTy || !resultTy)
947 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
948 !rhsTy.getElementType().isIntOrIndexOrFloat())
951 auto resultETy = resultTy.getElementType();
953 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
955 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
957 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
959 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
962 if (!lhsAttr || !rhsAttr)
969OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
970 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
971 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
972 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
973 !outputTy.hasStaticShape())
977 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.
isInteger()) {
978 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
979 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
986OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
987 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
988 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
989 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
990 if (!lhsTy || !rhsTy || !resultTy)
996 auto resultETy = resultTy.getElementType();
998 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1000 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1001 if (lhsAttr && lhsAttr.isSplat()) {
1002 if (llvm::isa<IntegerType>(resultETy) &&
1003 lhsAttr.getSplatValue<APInt>().isZero())
1007 if (rhsAttr && rhsAttr.isSplat()) {
1008 if (llvm::isa<IntegerType>(resultETy) &&
1009 rhsAttr.getSplatValue<APInt>().isOne())
1013 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1014 llvm::isa<IntegerType>(resultETy)) {
1015 APInt l = lhsAttr.getSplatValue<APInt>();
1016 APInt r = rhsAttr.getSplatValue<APInt>();
1018 APInt
result = l.sdiv(r);
1029std::optional<APInt> mulInt(APInt
lhs, APInt
rhs, int32_t shift,
1030 unsigned bitwidth) {
1034 auto round = APInt(64, 1) << (shift - 1);
1036 result.ashrInPlace(shift);
1038 if (!(
result.getSExtValue() >= INT32_MIN &&
1039 result.getSExtValue() <= INT32_MAX)) {
1041 return std::nullopt;
1045 return result.trunc(bitwidth);
1048DenseElementsAttr mulBinaryFolder(DenseElementsAttr
lhs, DenseElementsAttr
rhs,
1049 RankedTensorType ty, int32_t shift) {
1051 if (llvm::isa<IntegerType>(ty.getElementType())) {
1052 APInt l =
lhs.getSplatValue<APInt>();
1053 APInt r =
rhs.getSplatValue<APInt>();
1059 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1060 const std::optional<APInt>
result = mulInt(l, r, shift, bitwidth);
1066 if (llvm::isa<FloatType>(ty.getElementType())) {
1067 APFloat l =
lhs.getSplatValue<APFloat>();
1068 APFloat r =
rhs.getSplatValue<APFloat>();
1078OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1079 auto lhs = getInput1();
1080 auto rhs = getInput2();
1081 auto lhsTy = llvm::dyn_cast<RankedTensorType>(
lhs.getType());
1082 auto rhsTy = llvm::dyn_cast<RankedTensorType>(
rhs.getType());
1083 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1084 if (!lhsTy || !rhsTy || !resultTy)
1087 auto resultETy = resultTy.getElementType();
1089 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1091 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1096 if (resultETy.isInteger(32)) {
1097 ElementsAttr shift_elem;
1098 if (getShift().getImpl()) {
1102 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1106 if (rhsTy == resultTy) {
1107 if (
isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
1109 return lhsAttr.resizeSplat(resultTy);
1113 if (lhsTy == resultTy) {
1114 if (
isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
1115 return rhsAttr.resizeSplat(resultTy);
1120 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1123OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1124 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1125 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1126 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1127 if (!lhsTy || !rhsTy || !resultTy)
1131 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1132 !rhsTy.getElementType().isIntOrIndexOrFloat())
1135 auto resultETy = resultTy.getElementType();
1137 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1139 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1141 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1144 if (!lhsAttr || !rhsAttr)
1152template <
typename Cmp>
1153struct ComparisonFold {
1154 ComparisonFold() =
default;
1155 APInt operator()(
const APInt &l,
const APInt &r) {
1156 return APInt(1, Cmp()(l, r));
1159 APInt operator()(
const APFloat &l,
const APFloat &r) {
1160 return APInt(1, Cmp()(l, r));
1164struct APIntFoldGreater {
1165 APIntFoldGreater() =
default;
1166 APInt operator()(
const APInt &l,
const APInt &r) {
1167 return APInt(1, l.sgt(r));
1171struct APIntFoldGreaterEqual {
1172 APIntFoldGreaterEqual() =
default;
1173 APInt operator()(
const APInt &l,
const APInt &r) {
1174 return APInt(1, l.sge(r));
1179OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1180 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1182 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1184 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1186 if (!lhsAttr || !rhsAttr)
1190 lhsAttr, rhsAttr, resultTy);
1193OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1194 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1196 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1198 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1200 if (!lhsAttr || !rhsAttr)
1204 ComparisonFold<std::greater_equal<APFloat>>>(
1205 lhsAttr, rhsAttr, resultTy);
1208OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1209 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1211 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1213 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1214 Value
lhs = getInput1();
1215 Value
rhs = getInput2();
1216 auto lhsTy = llvm::cast<ShapedType>(
lhs.getType());
1220 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1221 resultTy.hasStaticShape() &&
lhs ==
rhs) {
1225 if (!lhsAttr || !rhsAttr)
1229 ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
1233OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1237 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1241 auto inTy = llvm::cast<ShapedType>(getInput().
getType());
1242 auto outTy = llvm::cast<ShapedType>(
getType());
1243 auto inETy = inTy.getElementType();
1244 auto outETy = outTy.getElementType();
1246 if (operand.isSplat()) {
1247 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1249 auto splatVal = operand.getSplatValue<APFloat>();
1250 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1251 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1256 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1257 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1258 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1259 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1260 llvm::RoundingMode::NearestTiesToEven);
1264 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1265 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1266 auto intVal = APSInt(
1267 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1268 auto floatVal = operand.getSplatValue<APFloat>();
1270 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1275 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1276 const auto inIntType = llvm::cast<IntegerType>(inETy);
1277 auto unsignIn = inIntType.isUnsignedInteger();
1279 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1280 auto intVal = operand.getSplatValue<APInt>();
1281 auto bitwidth = outETy.getIntOrFloatBitWidth();
1284 if (outETy.isInteger(1)) {
1285 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1287 intVal = intVal.trunc(bitwidth);
1288 }
else if (unsignIn || inIntType.isInteger(1)) {
1289 intVal = intVal.zext(bitwidth);
1291 intVal = intVal.sext(bitwidth);
1301OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1303OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1305#define REDUCE_FOLDER(OP) \
1306 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1307 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1308 if (!inputTy.hasRank()) \
1310 if (inputTy != getType()) \
1312 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1313 return getInput(); \
1326 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1327 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1329 if (!inputTy || !outputTy)
1335 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1339 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1340 getInput1().getDefiningOp())) {
1341 getInput1Mutable().assign(reshapeOp.getInput1());
1346 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1351 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1353 if (!outputTy.hasStaticShape())
1357 if (operand.isSplat())
1362 if (!getInput1().hasOneUse())
1369 return operand.reshape(
1370 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1376OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1378 if (adaptor.getPadding() && getInput1().
getType() ==
getType()) {
1379 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1380 if (densePad && densePad.isSplat() &&
1381 densePad.getSplatValue<APInt>().isZero()) {
1391OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1393 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1395 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1397 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1398 if (!scaleAttr || !offsetAttr || !borderAttr) {
1405 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1410 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1415 if (offset[0] != 0 || offset[1] != 0) {
1420 if (border[0] != 0 || border[1] != 0) {
1424 auto input = getInput();
1425 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1426 auto resultTy = llvm::cast<RankedTensorType>(
getType());
1427 if (inputTy != resultTy)
1433OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1434 auto operand = getInput1();
1435 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1436 auto axis = getAxis();
1438 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1443 if (operandTy.hasRank() &&
1444 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1450OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1451 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1452 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1454 if (!inputTy || !outputTy)
1457 if (inputTy == outputTy && inputTy.hasStaticShape())
1460 if (!adaptor.getInput1())
1464 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1465 !outputTy.getElementType().isIntOrIndexOrFloat())
1468 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1469 if (operand.isSplat() && outputTy.hasStaticShape()) {
1473 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1474 outputTy.getNumElements() == 1) {
1475 DenseElementsAttr startElems;
1479 llvm::SmallVector<uint64_t>
indices =
1480 llvm::to_vector(startElems.
getValues<uint64_t>());
1481 auto value = operand.getValues<Attribute>()[
indices];
1488OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1489 if (getOnTrue() == getOnFalse())
1493 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1497 if (!predicate.isSplat())
1499 return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1503OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1505 if (
auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1506 adaptor.getMultiples())) {
1507 if (multiples.isSplat() &&
1508 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1510 if (
auto int_array_attr =
1511 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1512 if (llvm::all_of(int_array_attr.getValues<APInt>(),
1513 [](APInt v) { return v.getSExtValue() == 1; }))
1521OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1522 auto resultTy = llvm::cast<ShapedType>(
getType());
1526 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1527 if (input.isSplat() && resultTy.hasStaticShape() &&
1528 input.getType().getElementType() == resultTy.getElementType())
1529 return input.reshape(resultTy);
1533 const llvm::ArrayRef<int32_t> perms = getPerms();
1535 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1541OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1544 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1550 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1551 failed(maybeIZp) || *maybeIZp != 0) {
1555 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1556 failed(maybeOZp) || *maybeOZp != 0) {
1560 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1561 failed(maybeIZp) || *maybeIZp != 0) {
1565 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1566 failed(maybeOZp) || *maybeOZp != 0) {
1571 return definingOp.getInput1();
1574OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1575 auto input = getInput1();
1577 if (
auto op = input.getDefiningOp<tosa::AbsOp>()) {
1584OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1589 SmallVector<Value, 8> concatOperands;
1590 concatOperands.reserve(2 * getNumOperands());
1593 bool foundFoldableConcat =
false;
1594 for (Value operand : getOperands()) {
1595 concatOperands.emplace_back(operand);
1597 auto producer = operand.getDefiningOp<ConcatOp>();
1602 if (getAxis() != producer.getAxis())
1606 foundFoldableConcat =
true;
1607 concatOperands.pop_back();
1608 llvm::append_range(concatOperands, producer->getOperands());
1611 if (!foundFoldableConcat)
1614 getOperation()->setOperands(concatOperands);
1618OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1619 auto input = adaptor.getInput1();
1621 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1623 if (!inputAttr || !inputAttr.isSplat())
1626 auto shapeType = llvm::cast<ShapedType>(
getType());
1627 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1628 auto floatVal = inputAttr.getSplatValue<APFloat>();
1630 ReciprocalOp::calcOneElement(floatVal));
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
#define REDUCE_FOLDER(OP)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
DynamicAPInt round(const Fraction &f)
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...