25#include "llvm/ADT/APFloat.h"
26#include "llvm/ADT/APInt.h"
46 (padConstAttr.
size() != 1)) {
51 if (
auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
52 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
53 return padConstVal == 0.0f;
57 if (
auto padConstIntAttr =
58 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
67 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
68 return zpVal == padConstVal;
76template <
typename OpTy>
77struct PoolPadFoldAdaptor;
80struct PoolPadFoldAdaptor<
tosa::MaxPool2dOp> {
81 using OpTy = tosa::MaxPool2dOp;
82 static bool checkKernelCompliance(OpTy op,
const ArrayRef<int64_t> newPad) {
83 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
84 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
85 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
89 static bool checkPadConstCompliance(OpTy, Value padConst) {
91 DenseElementsAttr padConstAttr;
93 padConstAttr.
size() != 1) {
98 if (
auto padConstFpAttr =
99 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
100 const APFloat padConstVal = *padConstFpAttr.begin();
101 const APFloat lowestVal =
102 APFloat::getLargest(padConstVal.getSemantics(),
true);
103 return padConstVal == lowestVal;
105 if (
auto padConstIntAttr =
106 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
107 const APInt padConstVal = *padConstIntAttr.begin();
108 const unsigned int bitWidth = padConstVal.getBitWidth();
109 const APInt lowestVal =
110 padConstIntAttr.getElementType().isUnsignedInteger()
111 ? APInt::getZero(bitWidth)
112 : APInt::getSignedMinValue(bitWidth);
113 return padConstVal == lowestVal;
119 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
120 Value padInput, ArrayRef<int64_t> newPad) {
122 op, op.getType(), padInput, op.getKernel(), op.getStride(),
127template <
typename OpTy>
128struct ConvPadFoldAdaptor {
129 static bool checkKernelCompliance(OpTy,
const ArrayRef<int64_t>) {
132 static bool checkPadConstCompliance(OpTy op, Value padConst) {
135 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
136 Value padInput, ArrayRef<int64_t> newPad) {
138 op, op.getResult().
getType(), padInput, op.getWeight(), op.getBias(),
139 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
140 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
148template <
typename OpTy,
typename AdaptorTy>
150 using OpRewritePattern<OpTy>::OpRewritePattern;
152 LogicalResult matchAndRewrite(OpTy tensorOp,
153 PatternRewriter &rewriter)
const override {
155 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
158 "Producer must be a tosa::PadOp.");
161 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
162 if (tensorOpPad.size() != 4)
164 tensorOp,
"Tensor operation padding shall have 4 elements.");
167 DenseIntElementsAttr padOpPadding;
171 "The `padding` input specified on the tosa::PadOp must be constant.");
175 if (padOpPadding.size() != 8)
177 "Pad padding should have 8 elements.");
178 int64_t padNBefore = (*(padOpPadding.
begin() + 0)).getLimitedValue();
179 int64_t padNAfter = (*(padOpPadding.
begin() + 1)).getLimitedValue();
180 int64_t padHBefore = (*(padOpPadding.
begin() + 2)).getLimitedValue();
181 int64_t padHAfter = (*(padOpPadding.
begin() + 3)).getLimitedValue();
182 int64_t padWBefore = (*(padOpPadding.
begin() + 4)).getLimitedValue();
183 int64_t padWAfter = (*(padOpPadding.
begin() + 5)).getLimitedValue();
184 int64_t padCBefore = (*(padOpPadding.
begin() + 6)).getLimitedValue();
185 int64_t padCAfter = (*(padOpPadding.
begin() + 7)).getLimitedValue();
187 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
189 tensorOp,
"Folding padding in N or C dimensions is not supported.");
193 SmallVector<int64_t> foldedPad(tensorOpPad.size());
194 foldedPad[0] = padHBefore + tensorOpPad[0];
195 foldedPad[1] = padHAfter + tensorOpPad[1];
196 foldedPad[2] = padWBefore + tensorOpPad[2];
197 foldedPad[3] = padWAfter + tensorOpPad[3];
200 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
202 tensorOp,
"Padding size not aligned with kernel restrictions.");
206 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
209 "Padding constant is not aligned with operator zero-point.");
213 if (llvm::any_of(foldedPad, [](int64_t padVal) {
return padVal > 8192; })) {
215 tensorOp,
"Padding size more than the 8K level limit.");
219 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
230 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
236 results.
add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
237 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
246 Value input = op.getInput();
247 Value output = op.getOutput();
248 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
249 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
251 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
257 if (outputShape[1] != 1 || outputShape[2] != 1) {
262 if (inputShape[1] != 1 || inputShape[2] != 1) {
274 FoldPadToTensorOp<tosa::MaxPool2dOp,
275 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
288 if (op.getInput1().size() != 1)
290 if (op.getInput1().front().getType() != op.getType()) {
293 op.getInput1().front())
298 rewriter.
replaceOp(op, op.getInput1().front());
308LogicalResult SelectOp::canonicalize(SelectOp op,
PatternRewriter &rewriter) {
309 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
313 op.getOperation()->setOperands(
314 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
326 auto innerTranspose =
327 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
330 "input must be transpose operation");
334 innerTranspose.getPerms();
336 if (transposePerms.size() != innerTransposePerms.size())
339 "transpose and inner transpose perms sizes must be equal");
340 if (transposePerms.empty())
342 transposeOp,
"transpose perms sizes must be positive");
346 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
347 perms[i] = innerTransposePerms[transposePerms[i]];
350 transposeOp, transposeOp.getResult().
getType(),
363 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
365 op,
"Src is from transpose, can compose transposes");
369 if (isa_and_nonnull<tosa::TransposeOp>(subop))
371 op,
"Dest is used by transpose, can compose transposes");
374 auto input = op.getInput1();
375 auto inputTy = llvm::cast<ShapedType>(input.getType());
376 if (!inputTy.hasRank())
380 for (
int i = 0; i < inputTy.getRank(); ++i)
381 if (inputTy.isDynamicDim(i))
390 nonZeroPerms.reserve(permValues.size());
391 for (
auto idx : permValues) {
392 auto sz = inputTy.getDimSize(idx);
394 nonZeroPerms.push_back(idx);
397 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
398 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
400 "Transpose changes memory layout.");
403 newShape.reserve(inputTy.getRank());
404 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
405 newShape.push_back(inputTy.getDimSize(permValues[i]));
408 op, op.getType(), op.getInput1(),
416 results.
add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
424 Value input = op.getInput();
425 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
426 auto inputElementType = inputType.getElementType();
428 if (isa<FloatType>(inputElementType)) {
430 const auto minClamp =
431 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
432 const auto maxClamp =
433 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
434 const bool isMin = minClamp.isNegInfinity();
435 const bool isMax = maxClamp.isInfinity();
437 if (isMin && isMax) {
445 const bool isBoolean = inputElementType.isInteger(1);
446 if (inputElementType.isUnsignedInteger() || isBoolean) {
447 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
450 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
454 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
455 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
456 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
458 if (minClamp <= intMin && maxClamp >= intMax) {
465 if (llvm::isa<IntegerType>(inputElementType)) {
467 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
469 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
471 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
472 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
473 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
475 if (minClamp <= intMin && maxClamp >= intMax) {
507 template <
typename T>
521 Value input = op.getInput();
529 const auto opNanMode = op.getNanMode();
530 const auto clampNanMode = clampOp.getNanMode();
531 if (opNanMode == NanPropagationMode::IGNORE &&
532 clampNanMode == NanPropagationMode::PROPAGATE)
535 auto maxValAttr = op.getMaxValAttr();
536 auto minValAttr = op.getMinValAttr();
537 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
538 auto clampOpMinValAttr = clampOp.getMinValAttr();
540 auto inputEType = llvm::cast<ShapedType>(input.
getType()).getElementType();
542 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
547 if (mlir::isa<FloatType>(inputEType)) {
548 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
549 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
550 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
551 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
554 const auto opMinFloat = floatMinValAttr.getValue();
555 const auto opMaxFloat = floatMaxValAttr.getValue();
556 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
557 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
561 if (!opRangeFloatRange.
intersects(clampRangeFloatRange))
565 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
566 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
567 newMinValAttr = rewriter.
getFloatAttr(inputEType, newMinVal);
568 newMaxValAttr = rewriter.
getFloatAttr(inputEType, newMaxVal);
570 assert(mlir::isa<IntegerType>(inputEType));
571 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
572 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
573 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
574 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
576 if (inputEType.isUnsignedInteger()) {
578 const auto opMinInt = intMinValAttr.getUInt();
579 const auto opMaxInt = intMaxValAttr.getUInt();
580 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
581 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
585 if (!opRangeIntRange.
intersects(clampRangeIntRange))
589 auto newMinVal = std::max(opMinInt, clampOpMinInt);
590 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
595 const auto opMinInt = intMinValAttr.getInt();
596 const auto opMaxInt = intMaxValAttr.getInt();
597 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
598 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
602 if (!opRangeIntRange.
intersects(clampRangeIntRange))
606 auto newMinVal = std::max(opMinInt, clampOpMinInt);
607 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
613 auto newMode = (opNanMode != clampNanMode)
614 ? tosa::NanPropagationMode::IGNORE
618 NanPropagationModeAttr::get(rewriter.
getContext(), newMode);
621 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
627void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
628 MLIRContext *context) {
629 results.
add<ClampIsNoOp>(context);
630 results.
add<ClampClampOptimization>(context);
638 Value sliceInput = sliceOp.getInput1();
642 sliceOp,
"slice input must be concat operation");
645 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
646 if (!concatType || !concatType.hasStaticShape())
648 sliceOp,
"slice input must be a static ranked tensor");
649 int32_t axis = concatOp.getAxis();
656 sliceOp,
"start of slice must be a static ranked shape");
660 sliceOp,
"size of slice must be a static ranked shape");
670 std::optional<Value> replaceWithSlice;
671 for (
auto input : inputs) {
672 auto inputType = dyn_cast<RankedTensorType>(input.getType());
673 if (!inputType || !inputType.hasStaticShape())
675 sliceOp,
"concat input must be a static ranked tensor");
677 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
678 inputType.getDimSize(axis)) {
684 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
685 input, start_op, size_op)
689 sliceStarts[axis] -= inputType.getDimSize(axis);
692 if (!replaceWithSlice)
694 sliceOp,
"corresponding concat input not found for slice");
696 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
706 Value sliceInput = sliceOp.getInput1();
712 "slice input must be a pad operation");
715 if (!padOp->hasOneUse())
717 "pad shall have a single consumer");
720 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
721 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
722 if (!inputTy || !padTy || !inputTy.hasRank())
724 "slice input must be a ranked tensor");
731 "`padding` input specified on the tosa::PadOp must be constant.");
734 llvm::to_vector(paddingElems.getValues<
int64_t>());
740 sliceOp,
"start of slice must be a static ranked shape");
747 sliceOp,
"size of slice must be a static ranked shape");
752 const int64_t rank = inputTy.getRank();
753 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](
int64_t i) {
754 const bool isDimDynamic = inputTy.isDynamicDim(i);
755 const bool isDimSliced =
756 (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
758 return isDimDynamic && isDimSliced;
761 sliceOp,
"axis that are sliced shall be statically known.");
768 bool updated =
false;
770 for (
int64_t i = 0; i < rank; ++i) {
771 const int64_t padLo = padPaddings[i * 2];
772 const int64_t padHi = padPaddings[i * 2 + 1];
773 const int64_t sliceStart = sliceStarts[i];
774 const int64_t sliceSize = sliceSizes[i];
775 const int64_t sliceEnd = sliceStart + sliceSize;
778 if (inputTy.isDynamicDim(i)) {
779 newPadPaddings[i * 2] = padLo;
780 newPadPaddings[i * 2 + 1] = padHi;
781 newSliceStarts[i] = sliceStart;
786 const int64_t dimSize = inputTy.getShape()[i];
787 const int64_t dimTotal = padLo + dimSize + padHi;
790 if (sliceStart < 0 || sliceEnd > dimTotal)
794 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
795 newSliceStarts[i] = newSliceStart;
796 updated |= newSliceStart != sliceStart;
799 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
801 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
802 newPadPaddings[i * 2] = newPadLo;
803 newPadPaddings[i * 2 + 1] = newPadHi;
804 updated |= (newPadLo != padLo) || (newPadHi != padHi);
808 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
814 sliceOp,
"terminate condition; nothing to rewrite");
820 RankedTensorType::get(newPadShape, inputTy.getElementType());
821 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
822 padOp.getInput1(), newPaddingsOp,
823 padOp.getPadConst());
829 newPadOp.getResult(), newStartOp,
844 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
846 ElementsAttr sizeElems;
849 sliceOp,
"size of slice must be a static ranked shape");
853 llvm::to_vector(sizeElems.getValues<
int64_t>());
855 bool replaceSliceSize{
false};
859 for (
const auto &[
index, size] : llvm::enumerate(sliceSizes)) {
860 if (size == -1 && !resultType.isDynamicDim(
index)) {
861 sliceSizes[
index] = resultType.getDimSize(
index);
862 replaceSliceSize =
true;
866 if (!replaceSliceSize) {
868 sliceOp,
"no dimension of size of slice is dynamic that resolves "
869 "to static output shape");
874 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
875 sliceOp.getInput1(), sliceOp.getStart(), size_op);
877 rewriter.
replaceOp(sliceOp, newSliceOp.getResult());
882void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
883 MLIRContext *context) {
884 results.
add<ConcatSliceOptimization, PadSliceOptimization,
885 SliceDynamicSizeCanonicalization>(context);
893 const Value castInput = castOp.getInput();
897 "input must be cast operation");
899 const Value innerCastInput = innerCastOp.getInput();
901 const auto innerInputType =
902 llvm::cast<ShapedType>(innerCastInput.
getType());
903 const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
904 const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
908 if (llvm::any_of(types, [](
const ShapedType type) {
909 return !type.getElementType().isInteger();
912 "only integer types are supported");
915 const unsigned innerInputBitWidth = innerInputType.getElementTypeBitWidth();
916 if (innerInputBitWidth > innerOutputType.getElementTypeBitWidth())
918 "inner cast operation is narrowing");
921 if (innerInputBitWidth > outerOutputType.getElementTypeBitWidth())
923 "outer cast operation is narrowing");
932void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
933 MLIRContext *context) {
934 results.
add<NonNarrowingCastsOptimization>(context);
941template <
typename Folder>
942static DenseElementsAttr
944 bool foldDenseValues =
false) {
949 const auto rETy = llvm::cast<ShapedType>(
rhs.getType()).getElementType();
953 if (
lhs.isSplat() &&
rhs.isSplat()) {
954 if (isa<FloatType>(lETy)) {
955 const APFloat l =
lhs.getSplatValue<APFloat>();
956 const APFloat r =
rhs.getSplatValue<APFloat>();
957 const auto maybeResult = Folder::fold(l, r);
958 if (failed(maybeResult))
963 if (
const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
964 const APInt l =
lhs.getSplatValue<APInt>();
965 const APInt r =
rhs.getSplatValue<APInt>();
966 const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
967 if (failed(maybeResult))
973 if (foldDenseValues) {
974 assert(lETy.isIntOrIndex() &&
975 "Only integer types are currently supported.");
978 llvm::zip(
lhs.getValues<APInt>(),
rhs.getValues<APInt>())) {
979 const auto maybeResult = Folder::fold(l, r,
false);
980 if (failed(maybeResult))
982 resultValues.push_back(maybeResult.value());
990 static FailureOr<APInt>
fold(
const APInt &
lhs,
const APInt &
rhs,
991 const bool isUnsigned) {
994 isUnsigned ?
lhs.uadd_ov(
rhs, overflow) :
lhs.sadd_ov(
rhs, overflow);
1000 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1007 const bool isUnsigned) {
1010 isUnsigned ?
lhs.usub_ov(
rhs, overflow) :
lhs.ssub_ov(
rhs, overflow);
1016 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1023 const bool isUnsigned) {
1024 return isUnsigned ? APInt(1,
lhs.ugt(
rhs)) : APInt(1,
lhs.sgt(
rhs));
1027 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1028 return APInt(1,
lhs >
rhs);
1034 const bool isUnsigned) {
1035 return isUnsigned ? APInt(1,
lhs.uge(
rhs)) : APInt(1,
lhs.sge(
rhs));
1038 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1039 return APInt(1,
lhs >=
rhs);
1045 const bool isUnsigned) {
1046 return APInt(1,
lhs ==
rhs);
1049 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1050 return APInt(1,
lhs ==
rhs);
1055 if (llvm::isa<FloatType>(elemType))
1057 if (llvm::isa<IntegerType>(elemType))
1063 if (llvm::isa<FloatType>(elemType))
1064 return val && val.
isSplat() &&
1066 if (llvm::isa<IntegerType>(elemType)) {
1067 const int64_t shifted = 1LL << shift;
1068 return val && val.
isSplat() &&
1074OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1075 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1076 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1077 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1078 if (!lhsTy || !rhsTy || !resultTy)
1082 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1083 !rhsTy.getElementType().isIntOrIndexOrFloat())
1086 auto resultETy = resultTy.getElementType();
1088 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1090 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1092 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1094 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
1097 if (!lhsAttr || !rhsAttr)
1103OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1104 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
1105 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1106 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1107 !outputTy.hasStaticShape())
1111 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.
isInteger()) {
1112 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1113 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1120OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1121 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1122 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1123 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1124 if (!lhsTy || !rhsTy || !resultTy)
1130 auto resultETy = resultTy.getElementType();
1132 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1134 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1135 if (lhsAttr && lhsAttr.isSplat()) {
1136 if (llvm::isa<IntegerType>(resultETy) &&
1137 lhsAttr.getSplatValue<APInt>().isZero())
1141 if (rhsAttr && rhsAttr.isSplat()) {
1142 if (llvm::isa<IntegerType>(resultETy) &&
1143 rhsAttr.getSplatValue<APInt>().isOne())
1147 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1148 llvm::isa<IntegerType>(resultETy)) {
1149 APInt l = lhsAttr.getSplatValue<APInt>();
1150 APInt r = rhsAttr.getSplatValue<APInt>();
1152 APInt
result = l.sdiv(r);
1163std::optional<APInt> mulInt(APInt
lhs, APInt
rhs, int32_t shift,
1164 unsigned bitwidth) {
1168 auto round = APInt(64, 1) << (shift - 1);
1170 result.ashrInPlace(shift);
1172 if (!(
result.getSExtValue() >= INT32_MIN &&
1173 result.getSExtValue() <= INT32_MAX)) {
1175 return std::nullopt;
1179 return result.trunc(bitwidth);
1182DenseElementsAttr mulBinaryFolder(DenseElementsAttr
lhs, DenseElementsAttr
rhs,
1183 RankedTensorType ty, int32_t shift) {
1185 if (llvm::isa<IntegerType>(ty.getElementType())) {
1186 APInt l =
lhs.getSplatValue<APInt>();
1187 APInt r =
rhs.getSplatValue<APInt>();
1193 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1194 const std::optional<APInt>
result = mulInt(l, r, shift, bitwidth);
1200 if (llvm::isa<FloatType>(ty.getElementType())) {
1201 APFloat l =
lhs.getSplatValue<APFloat>();
1202 APFloat r =
rhs.getSplatValue<APFloat>();
1212OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1213 auto lhs = getInput1();
1214 auto rhs = getInput2();
1215 auto lhsTy = llvm::dyn_cast<RankedTensorType>(
lhs.getType());
1216 auto rhsTy = llvm::dyn_cast<RankedTensorType>(
rhs.getType());
1217 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1218 if (!lhsTy || !rhsTy || !resultTy)
1221 auto resultETy = resultTy.getElementType();
1223 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1225 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1230 if (resultETy.isInteger(32)) {
1231 ElementsAttr shift_elem;
1232 if (getShift().getImpl()) {
1236 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1240 if (rhsTy == resultTy) {
1241 if (
isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
1243 return lhsAttr.resizeSplat(resultTy);
1247 if (lhsTy == resultTy) {
1248 if (
isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
1249 return rhsAttr.resizeSplat(resultTy);
1254 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1257OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1258 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1259 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1260 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1261 if (!lhsTy || !rhsTy || !resultTy)
1265 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1266 !rhsTy.getElementType().isIntOrIndexOrFloat())
1269 auto resultETy = resultTy.getElementType();
1271 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1273 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1275 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1278 if (!lhsAttr || !rhsAttr)
1284OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1285 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1287 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1289 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1291 if (!lhsAttr || !rhsAttr)
1297OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1298 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1300 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1302 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1304 if (!lhsAttr || !rhsAttr)
1310OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1311 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1313 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1315 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1316 Value
lhs = getInput1();
1317 Value
rhs = getInput2();
1318 auto lhsTy = llvm::cast<ShapedType>(
lhs.getType());
1322 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1323 resultTy.hasStaticShape() &&
lhs ==
rhs) {
1327 if (!lhsAttr || !rhsAttr)
1333OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1337 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1341 auto inTy = llvm::cast<ShapedType>(getInput().
getType());
1342 auto outTy = llvm::cast<ShapedType>(
getType());
1343 auto inETy = inTy.getElementType();
1344 auto outETy = outTy.getElementType();
1346 if (operand.isSplat()) {
1347 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1349 auto splatVal = operand.getSplatValue<APFloat>();
1350 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1351 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1356 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1357 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1358 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1359 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1360 llvm::RoundingMode::NearestTiesToEven);
1364 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1365 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1366 auto intVal = APSInt(
1367 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1368 auto floatVal = operand.getSplatValue<APFloat>();
1370 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1375 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1376 const auto inIntType = llvm::cast<IntegerType>(inETy);
1377 auto unsignIn = inIntType.isUnsignedInteger();
1379 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1380 auto intVal = operand.getSplatValue<APInt>();
1381 auto bitwidth = outETy.getIntOrFloatBitWidth();
1384 if (outETy.isInteger(1)) {
1385 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1387 intVal = intVal.trunc(bitwidth);
1388 }
else if (unsignIn || inIntType.isInteger(1)) {
1389 intVal = intVal.zext(bitwidth);
1391 intVal = intVal.sext(bitwidth);
1401OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1403OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1405#define REDUCE_FOLDER(OP) \
1406 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1407 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1408 if (!inputTy.hasRank()) \
1410 if (inputTy != getType()) \
1412 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1413 return getInput(); \
1426 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1427 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1429 if (!inputTy || !outputTy)
1435 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1439 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1440 getInput1().getDefiningOp())) {
1441 getInput1Mutable().assign(reshapeOp.getInput1());
1446 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1451 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1453 if (!outputTy.hasStaticShape())
1457 if (operand.isSplat())
1462 if (!getInput1().hasOneUse())
1469 return operand.reshape(
1470 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1476OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1478 if (adaptor.getPadding() && getInput1().
getType() ==
getType()) {
1479 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1480 if (densePad && densePad.isSplat() &&
1481 densePad.getSplatValue<APInt>().isZero()) {
1491OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1493 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1495 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1497 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1498 if (!scaleAttr || !offsetAttr || !borderAttr) {
1505 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1510 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1515 if (offset[0] != 0 || offset[1] != 0) {
1520 if (border[0] != 0 || border[1] != 0) {
1524 auto input = getInput();
1525 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1526 auto resultTy = llvm::cast<RankedTensorType>(
getType());
1527 if (inputTy != resultTy)
1533OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1534 auto operand = getInput1();
1535 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1536 auto axis = getAxis();
1538 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1543 if (operandTy.hasRank() &&
1544 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1550OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1551 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1552 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1554 if (!inputTy || !outputTy)
1557 if (inputTy == outputTy && inputTy.hasStaticShape())
1560 if (!adaptor.getInput1())
1564 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1565 !outputTy.getElementType().isIntOrIndexOrFloat())
1568 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1569 if (operand.isSplat() && outputTy.hasStaticShape()) {
1573 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1574 outputTy.getNumElements() == 1) {
1575 DenseElementsAttr startElems;
1579 llvm::SmallVector<uint64_t>
indices =
1580 llvm::to_vector(startElems.
getValues<uint64_t>());
1581 auto value = operand.getValues<Attribute>()[
indices];
1590 const auto isDynamic = [](
Type ty) {
1591 const auto shapedTy = llvm::dyn_cast<ShapedType>(ty);
1592 return !shapedTy || !shapedTy.hasStaticShape();
1595 return llvm::any_of(operandTypes, isDynamic) ||
1599OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1606 if (getOnTrue() == getOnFalse())
1610 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1614 if (!predicate.isSplat())
1616 return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1620OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1622 if (
auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1623 adaptor.getMultiples())) {
1624 if (multiples.isSplat() &&
1625 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1627 if (
auto int_array_attr =
1628 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1629 if (llvm::all_of(int_array_attr.getValues<APInt>(),
1630 [](APInt v) { return v.getSExtValue() == 1; }))
1638OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1639 auto resultTy = llvm::cast<ShapedType>(
getType());
1643 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1644 if (input.isSplat() && resultTy.hasStaticShape() &&
1645 input.getType().getElementType() == resultTy.getElementType())
1646 return input.reshape(resultTy);
1650 const llvm::ArrayRef<int32_t> perms = getPerms();
1652 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1658OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1661 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1667 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1668 failed(maybeIZp) || *maybeIZp != 0) {
1672 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1673 failed(maybeOZp) || *maybeOZp != 0) {
1677 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1678 failed(maybeIZp) || *maybeIZp != 0) {
1682 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1683 failed(maybeOZp) || *maybeOZp != 0) {
1688 return definingOp.getInput1();
1691OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1692 auto input = getInput1();
1694 if (
auto op = input.getDefiningOp<tosa::AbsOp>()) {
1701OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1706 SmallVector<Value, 8> concatOperands;
1707 concatOperands.reserve(2 * getNumOperands());
1710 bool foundFoldableConcat =
false;
1711 for (Value operand : getOperands()) {
1712 concatOperands.emplace_back(operand);
1714 auto producer = operand.getDefiningOp<ConcatOp>();
1719 if (getAxis() != producer.getAxis())
1723 foundFoldableConcat =
true;
1724 concatOperands.pop_back();
1725 llvm::append_range(concatOperands, producer->getOperands());
1728 if (!foundFoldableConcat)
1731 getOperation()->setOperands(concatOperands);
1735OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1736 auto input = adaptor.getInput1();
1738 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1740 if (!inputAttr || !inputAttr.isSplat())
1743 auto shapeType = llvm::cast<ShapedType>(
getType());
1744 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1745 auto floatVal = inputAttr.getSplatValue<APFloat>();
1747 ReciprocalOp::calcOneElement(floatVal));
1753OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
1754 auto input1ConstShape =
1755 dyn_cast<tosa::ConstShapeOp>(getInput1().getDefiningOp());
1756 auto input2ConstShape =
1757 dyn_cast<tosa::ConstShapeOp>(getInput2().getDefiningOp());
1758 if (!input1ConstShape || !input2ConstShape)
1761 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
1762 const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
1765 input1Attr, input2Attr, input1Attr.getType(),
true);
1768OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
1769 const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().
getType());
1770 if (!inputTy || !inputTy.hasRank())
1772 const int32_t axis = getAxis();
1773 const int64_t dimSize = inputTy.getDimSize(axis);
1774 if (ShapedType::isDynamic(dimSize))
1778 const auto resultAttrTy =
1779 RankedTensorType::get(1, builder.getIndexType());
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
#define REDUCE_FOLDER(OP)
static bool mayRequireBroadcast(ValueTypeRange< mlir::OperandRange > operandTypes)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, bool foldDenseValues=false)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class implements iteration on the types of a given range of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
DynamicAPInt round(const Fraction &f)
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::CastOp castOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...