25#include "llvm/ADT/APFloat.h"
26#include "llvm/ADT/APInt.h"
46 (padConstAttr.
size() != 1)) {
51 if (
auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
52 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
53 return padConstVal == 0.0f;
57 if (
auto padConstIntAttr =
58 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
67 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
68 return zpVal == padConstVal;
76template <
typename OpTy>
77struct PoolPadFoldAdaptor;
80struct PoolPadFoldAdaptor<
tosa::MaxPool2dOp> {
81 using OpTy = tosa::MaxPool2dOp;
82 static bool checkKernelCompliance(OpTy op,
const ArrayRef<int64_t> newPad) {
83 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
84 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
85 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
89 static bool checkPadConstCompliance(OpTy, Value padConst) {
91 DenseElementsAttr padConstAttr;
93 padConstAttr.
size() != 1) {
98 if (
auto padConstFpAttr =
99 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
100 const APFloat padConstVal = *padConstFpAttr.begin();
101 const APFloat lowestVal =
102 APFloat::getLargest(padConstVal.getSemantics(),
true);
103 return padConstVal == lowestVal;
105 if (
auto padConstIntAttr =
106 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
107 const APInt padConstVal = *padConstIntAttr.begin();
108 const unsigned int bitWidth = padConstVal.getBitWidth();
109 const APInt lowestVal =
110 padConstIntAttr.getElementType().isUnsignedInteger()
111 ? APInt::getZero(bitWidth)
112 : APInt::getSignedMinValue(bitWidth);
113 return padConstVal == lowestVal;
119 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
120 Value padInput, ArrayRef<int64_t> newPad) {
122 op, op.getType(), padInput, op.getKernel(), op.getStride(),
127template <
typename OpTy>
128struct ConvPadFoldAdaptor {
129 static bool checkKernelCompliance(OpTy,
const ArrayRef<int64_t>) {
132 static bool checkPadConstCompliance(OpTy op, Value padConst) {
135 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
136 Value padInput, ArrayRef<int64_t> newPad) {
138 op, op.getResult().
getType(), padInput, op.getWeight(), op.getBias(),
139 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
140 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
148template <
typename OpTy,
typename AdaptorTy>
150 using OpRewritePattern<OpTy>::OpRewritePattern;
152 LogicalResult matchAndRewrite(OpTy tensorOp,
153 PatternRewriter &rewriter)
const override {
155 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
158 "Producer must be a tosa::PadOp.");
161 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
162 if (tensorOpPad.size() != 4)
164 tensorOp,
"Tensor operation padding shall have 4 elements.");
167 DenseIntElementsAttr padOpPadding;
171 "The `padding` input specified on the tosa::PadOp must be constant.");
175 if (padOpPadding.size() != 8)
177 "Pad padding should have 8 elements.");
178 int64_t padNBefore = (*(padOpPadding.
begin() + 0)).getLimitedValue();
179 int64_t padNAfter = (*(padOpPadding.
begin() + 1)).getLimitedValue();
180 int64_t padHBefore = (*(padOpPadding.
begin() + 2)).getLimitedValue();
181 int64_t padHAfter = (*(padOpPadding.
begin() + 3)).getLimitedValue();
182 int64_t padWBefore = (*(padOpPadding.
begin() + 4)).getLimitedValue();
183 int64_t padWAfter = (*(padOpPadding.
begin() + 5)).getLimitedValue();
184 int64_t padCBefore = (*(padOpPadding.
begin() + 6)).getLimitedValue();
185 int64_t padCAfter = (*(padOpPadding.
begin() + 7)).getLimitedValue();
187 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
189 tensorOp,
"Folding padding in N or C dimensions is not supported.");
193 SmallVector<int64_t> foldedPad(tensorOpPad.size());
194 foldedPad[0] = padHBefore + tensorOpPad[0];
195 foldedPad[1] = padHAfter + tensorOpPad[1];
196 foldedPad[2] = padWBefore + tensorOpPad[2];
197 foldedPad[3] = padWAfter + tensorOpPad[3];
200 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
202 tensorOp,
"Padding size not aligned with kernel restrictions.");
206 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
209 "Padding constant is not aligned with operator zero-point.");
213 if (llvm::any_of(foldedPad, [](int64_t padVal) {
return padVal > 8192; })) {
215 tensorOp,
"Padding size more than the 8K level limit.");
219 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
230 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
236 results.
add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
237 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
246 Value input = op.getInput();
247 Value output = op.getOutput();
248 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
249 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
251 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
257 if (outputShape[1] != 1 || outputShape[2] != 1) {
262 if (inputShape[1] != 1 || inputShape[2] != 1) {
274 FoldPadToTensorOp<tosa::MaxPool2dOp,
275 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
288 if (op.getInput1().size() != 1)
290 if (op.getInput1().front().getType() != op.getType()) {
293 op.getInput1().front())
298 rewriter.
replaceOp(op, op.getInput1().front());
308LogicalResult SelectOp::canonicalize(SelectOp op,
PatternRewriter &rewriter) {
309 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
313 op.getOperation()->setOperands(
314 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
326 auto innerTranspose =
327 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
330 "input must be transpose operation");
334 innerTranspose.getPerms();
336 if (transposePerms.size() != innerTransposePerms.size())
339 "transpose and inner transpose perms sizes must be equal");
340 if (transposePerms.empty())
342 transposeOp,
"transpose perms sizes must be positive");
346 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
347 perms[i] = innerTransposePerms[transposePerms[i]];
350 transposeOp, transposeOp.getResult().
getType(),
363 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
365 op,
"Src is from transpose, can compose transposes");
369 if (isa_and_nonnull<tosa::TransposeOp>(subop))
371 op,
"Dest is used by transpose, can compose transposes");
374 auto input = op.getInput1();
375 auto inputTy = llvm::cast<ShapedType>(input.getType());
376 if (!inputTy.hasRank())
380 for (
int i = 0; i < inputTy.getRank(); ++i)
381 if (inputTy.isDynamicDim(i))
390 nonZeroPerms.reserve(permValues.size());
391 for (
auto idx : permValues) {
392 auto sz = inputTy.getDimSize(idx);
394 nonZeroPerms.push_back(idx);
397 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
398 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
400 "Transpose changes memory layout.");
403 newShape.reserve(inputTy.getRank());
404 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
405 newShape.push_back(inputTy.getDimSize(permValues[i]));
408 op, op.getType(), op.getInput1(),
416 results.
add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
424 Value input = op.getInput();
425 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
426 auto inputElementType = inputType.getElementType();
428 if (isa<FloatType>(inputElementType)) {
430 const auto minClamp =
431 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
432 const auto maxClamp =
433 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
434 const bool isMin = minClamp.isNegInfinity();
435 const bool isMax = maxClamp.isInfinity();
437 if (isMin && isMax) {
445 const bool isBoolean = inputElementType.isInteger(1);
446 if (inputElementType.isUnsignedInteger() || isBoolean) {
447 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
450 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
454 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
455 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
456 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
458 if (minClamp <= intMin && maxClamp >= intMax) {
465 if (llvm::isa<IntegerType>(inputElementType)) {
467 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
469 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
471 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
472 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
473 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
475 if (minClamp <= intMin && maxClamp >= intMax) {
507 template <
typename T>
521 Value input = op.getInput();
529 const auto opNanMode = op.getNanMode();
530 const auto clampNanMode = clampOp.getNanMode();
531 if (opNanMode == NanPropagationMode::IGNORE &&
532 clampNanMode == NanPropagationMode::PROPAGATE)
535 auto maxValAttr = op.getMaxValAttr();
536 auto minValAttr = op.getMinValAttr();
537 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
538 auto clampOpMinValAttr = clampOp.getMinValAttr();
540 auto inputEType = llvm::cast<ShapedType>(input.
getType()).getElementType();
542 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
547 if (mlir::isa<FloatType>(inputEType)) {
548 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
549 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
550 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
551 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
554 const auto opMinFloat = floatMinValAttr.getValue();
555 const auto opMaxFloat = floatMaxValAttr.getValue();
556 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
557 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
561 if (!opRangeFloatRange.
intersects(clampRangeFloatRange))
565 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
566 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
567 newMinValAttr = rewriter.
getFloatAttr(inputEType, newMinVal);
568 newMaxValAttr = rewriter.
getFloatAttr(inputEType, newMaxVal);
570 assert(mlir::isa<IntegerType>(inputEType));
571 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
572 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
573 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
574 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
576 if (inputEType.isUnsignedInteger()) {
578 const auto opMinInt = intMinValAttr.getUInt();
579 const auto opMaxInt = intMaxValAttr.getUInt();
580 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
581 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
585 if (!opRangeIntRange.
intersects(clampRangeIntRange))
589 auto newMinVal = std::max(opMinInt, clampOpMinInt);
590 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
595 const auto opMinInt = intMinValAttr.getInt();
596 const auto opMaxInt = intMaxValAttr.getInt();
597 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
598 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
602 if (!opRangeIntRange.
intersects(clampRangeIntRange))
606 auto newMinVal = std::max(opMinInt, clampOpMinInt);
607 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
613 auto newMode = (opNanMode != clampNanMode)
614 ? tosa::NanPropagationMode::IGNORE
618 NanPropagationModeAttr::get(rewriter.
getContext(), newMode);
621 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
627void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
628 MLIRContext *context) {
629 results.
add<ClampIsNoOp>(context);
630 results.
add<ClampClampOptimization>(context);
638 Value sliceInput = sliceOp.getInput1();
642 sliceOp,
"slice input must be concat operation");
645 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
646 if (!concatType || !concatType.hasStaticShape())
648 sliceOp,
"slice input must be a static ranked tensor");
649 int32_t axis = concatOp.getAxis();
656 sliceOp,
"start of slice must be a static ranked shape");
660 sliceOp,
"size of slice must be a static ranked shape");
670 std::optional<Value> replaceWithSlice;
671 for (
auto input : inputs) {
672 auto inputType = dyn_cast<RankedTensorType>(input.getType());
673 if (!inputType || !inputType.hasStaticShape())
675 sliceOp,
"concat input must be a static ranked tensor");
677 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
678 inputType.getDimSize(axis)) {
684 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
685 input, start_op, size_op)
689 sliceStarts[axis] -= inputType.getDimSize(axis);
692 if (!replaceWithSlice)
694 sliceOp,
"corresponding concat input not found for slice");
696 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
706 Value sliceInput = sliceOp.getInput1();
712 "slice input must be a pad operation");
715 if (!padOp->hasOneUse())
717 "pad shall have a single consumer");
720 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
721 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
722 if (!inputTy || !padTy || !inputTy.hasRank())
724 "slice input must be a ranked tensor");
731 "`padding` input specified on the tosa::PadOp must be constant.");
734 llvm::to_vector(paddingElems.getValues<
int64_t>());
740 sliceOp,
"start of slice must be a static ranked shape");
747 sliceOp,
"size of slice must be a static ranked shape");
752 const int64_t rank = inputTy.getRank();
753 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](
int64_t i) {
754 const bool isDimDynamic = inputTy.isDynamicDim(i);
755 const bool isDimSliced =
756 (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
758 return isDimDynamic && isDimSliced;
761 sliceOp,
"axis that are sliced shall be statically known.");
768 bool updated =
false;
770 for (
int64_t i = 0; i < rank; ++i) {
771 const int64_t padLo = padPaddings[i * 2];
772 const int64_t padHi = padPaddings[i * 2 + 1];
773 const int64_t sliceStart = sliceStarts[i];
774 const int64_t sliceSize = sliceSizes[i];
775 const int64_t sliceEnd = sliceStart + sliceSize;
778 if (inputTy.isDynamicDim(i)) {
779 newPadPaddings[i * 2] = padLo;
780 newPadPaddings[i * 2 + 1] = padHi;
781 newSliceStarts[i] = sliceStart;
786 const int64_t dimSize = inputTy.getShape()[i];
787 const int64_t dimTotal = padLo + dimSize + padHi;
790 if (sliceStart < 0 || sliceEnd > dimTotal)
794 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
795 newSliceStarts[i] = newSliceStart;
796 updated |= newSliceStart != sliceStart;
799 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
801 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
802 newPadPaddings[i * 2] = newPadLo;
803 newPadPaddings[i * 2 + 1] = newPadHi;
804 updated |= (newPadLo != padLo) || (newPadHi != padHi);
808 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
814 sliceOp,
"terminate condition; nothing to rewrite");
820 RankedTensorType::get(newPadShape, inputTy.getElementType());
821 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
822 padOp.getInput1(), newPaddingsOp,
823 padOp.getPadConst());
829 newPadOp.getResult(), newStartOp,
844 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
846 ElementsAttr sizeElems;
849 sliceOp,
"size of slice must be a static ranked shape");
853 llvm::to_vector(sizeElems.getValues<
int64_t>());
855 bool replaceSliceSize{
false};
859 for (
const auto &[
index, size] : llvm::enumerate(sliceSizes)) {
860 if (size == -1 && !resultType.isDynamicDim(
index)) {
861 sliceSizes[
index] = resultType.getDimSize(
index);
862 replaceSliceSize =
true;
866 if (!replaceSliceSize) {
868 sliceOp,
"no dimension of size of slice is dynamic that resolves "
869 "to static output shape");
874 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
875 sliceOp.getInput1(), sliceOp.getStart(), size_op);
877 rewriter.
replaceOp(sliceOp, newSliceOp.getResult());
882void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
883 MLIRContext *context) {
884 results.
add<ConcatSliceOptimization, PadSliceOptimization,
885 SliceDynamicSizeCanonicalization>(context);
893 const Value castInput = castOp.getInput();
897 "input must be cast operation");
899 const Value innerCastInput = innerCastOp.getInput();
901 const auto innerInputType =
902 llvm::cast<ShapedType>(innerCastInput.
getType());
903 const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
904 const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
908 if (llvm::any_of(types, [](
const ShapedType type) {
909 return !type.getElementType().isInteger();
912 "only integer types are supported");
915 const unsigned innerInputBitWidth = innerInputType.getElementTypeBitWidth();
916 if (innerInputBitWidth > innerOutputType.getElementTypeBitWidth())
918 "inner cast operation is narrowing");
921 if (innerInputBitWidth > outerOutputType.getElementTypeBitWidth())
923 "outer cast operation is narrowing");
932void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
933 MLIRContext *context) {
934 results.
add<NonNarrowingCastsOptimization>(context);
941template <
typename Folder>
942static DenseElementsAttr
944 bool foldDenseValues =
false) {
949 const auto rETy = llvm::cast<ShapedType>(
rhs.getType()).getElementType();
953 if (
lhs.isSplat() &&
rhs.isSplat()) {
954 if (isa<FloatType>(lETy)) {
955 const APFloat l =
lhs.getSplatValue<APFloat>();
956 const APFloat r =
rhs.getSplatValue<APFloat>();
957 const auto maybeResult = Folder::fold(l, r);
958 if (failed(maybeResult))
963 if (
const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
964 const APInt l =
lhs.getSplatValue<APInt>();
965 const APInt r =
rhs.getSplatValue<APInt>();
966 const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
967 if (failed(maybeResult))
973 if (foldDenseValues) {
974 assert(lETy.isIntOrIndex() &&
975 "Only integer types are currently supported.");
978 llvm::zip(
lhs.getValues<APInt>(),
rhs.getValues<APInt>())) {
979 const auto maybeResult = Folder::fold(l, r,
false);
980 if (failed(maybeResult))
982 resultValues.push_back(maybeResult.value());
990template <
typename Folder>
992 bool foldDenseValues =
false) {
999 if (
const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
1001 const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
1002 if (failed(maybeResult))
1008 if (foldDenseValues) {
1012 for (
auto const &v : val.
getValues<APInt>()) {
1013 const auto maybeResult = Folder::fold(v,
false);
1014 if (failed(maybeResult))
1016 resultValues.push_back(maybeResult.value());
1028 const bool isUnsigned) {
1031 isUnsigned ?
lhs.uadd_ov(
rhs, overflow) :
lhs.sadd_ov(
rhs, overflow);
1037 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1044 const bool isUnsigned) {
1047 isUnsigned ?
lhs.usub_ov(
rhs, overflow) :
lhs.ssub_ov(
rhs, overflow);
1053 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1060 const bool isUnsigned) {
1062 const unsigned originalWidth =
lhs.getBitWidth();
1065 if (
lhs.getBitWidth() !=
rhs.getBitWidth()) {
1070 if (
lhs == 0 ||
rhs == 0)
1071 return APInt::getZero(originalWidth);
1073 bool overflow =
false;
1075 isUnsigned ?
lhs.umul_ov(
rhs, overflow) :
lhs.smul_ov(
rhs, overflow);
1080 return result.trunc(originalWidth);
1083 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1089 return a.isNegative() !=
b.isNegative();
1096 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1104 APInt::udivrem(
lhs,
rhs, q, r);
1105 if (!r.isZero() && Ceil) {
1112 bool overflow{
false};
1113 APInt
const q =
lhs.sdiv_ov(
rhs, overflow);
1116 APInt
const r =
lhs.srem(
rhs);
1126 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1134 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1136 if (
lhs.isNegative() || (!
rhs.isStrictlyPositive()))
1146 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1148 auto const r = t.mod(
rhs);
1149 if (llvm::APFloatBase::opStatus::opOK == r) {
1159 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1161 return lhs.getSExtValue() >=
rhs.getSExtValue() ?
lhs :
rhs;
1164 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1172 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1174 return lhs.getSExtValue() <=
rhs.getSExtValue() ?
lhs :
rhs;
1177 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1183 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1184 auto const numBits = value.getBitWidth();
1186 auto const zextv = value.getZExtValue();
1187 if (zextv >= numBits)
1189 return APInt::getOneBitSet(numBits, zextv);
1191 auto const sextv = value.getSExtValue();
1192 if (sextv < 0 || sextv >= numBits || (value.isNegative()))
1194 return APInt::getOneBitSet(numBits, sextv);
1199 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1200 if (!value.isStrictlyPositive())
1202 return APInt(value.getBitWidth(), value.ceilLogBase2());
1207 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1208 if (!value.isStrictlyPositive())
1210 return APInt(value.getBitWidth(), value.logBase2());
1216 const bool isUnsigned) {
1217 return isUnsigned ? APInt(1,
lhs.ugt(
rhs)) : APInt(1,
lhs.sgt(
rhs));
1220 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1221 return APInt(1,
lhs >
rhs);
1227 const bool isUnsigned) {
1228 return isUnsigned ? APInt(1,
lhs.uge(
rhs)) : APInt(1,
lhs.sge(
rhs));
1231 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1232 return APInt(1,
lhs >=
rhs);
1238 const bool isUnsigned) {
1239 return APInt(1,
lhs ==
rhs);
1242 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1243 return APInt(1,
lhs ==
rhs);
1248 if (llvm::isa<FloatType>(elemType))
1250 if (llvm::isa<IntegerType>(elemType))
1256 if (llvm::isa<FloatType>(elemType))
1257 return val && val.
isSplat() &&
1259 if (llvm::isa<IntegerType>(elemType)) {
1260 const int64_t shifted = 1LL << shift;
1261 return val && val.
isSplat() &&
1267OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1268 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1269 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1270 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1271 if (!lhsTy || !rhsTy || !resultTy)
1275 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1276 !rhsTy.getElementType().isIntOrIndexOrFloat())
1279 auto resultETy = resultTy.getElementType();
1281 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1283 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1285 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1287 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
1290 if (!lhsAttr || !rhsAttr)
1296OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1297 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
1298 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1299 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1300 !outputTy.hasStaticShape())
1304 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.
isInteger()) {
1305 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1306 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1313OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1314 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1315 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1316 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1317 if (!lhsTy || !rhsTy || !resultTy)
1324 auto resultETy = resultTy.getElementType();
1326 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1328 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1329 if (lhsAttr && lhsAttr.isSplat()) {
1330 if (llvm::isa<IntegerType>(resultETy) &&
1331 lhsAttr.getSplatValue<APInt>().isZero())
1335 if (rhsAttr && rhsAttr.isSplat()) {
1336 if (llvm::isa<IntegerType>(resultETy) &&
1337 rhsAttr.getSplatValue<APInt>().isOne())
1341 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1342 llvm::isa<IntegerType>(resultETy)) {
1343 APInt l = lhsAttr.getSplatValue<APInt>();
1344 APInt r = rhsAttr.getSplatValue<APInt>();
1346 auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
1348 DivFoldAdaptor<
false>::fold(l, r, intTy.isUnsigned());
1361std::optional<APInt> mulInt(APInt
lhs, APInt
rhs, int32_t shift,
1362 unsigned bitwidth) {
1363 bool overflow =
false;
1364 APInt
result =
lhs.sext(64).smul_ov(
rhs.sext(64), overflow);
1367 return std::nullopt;
1370 auto round = APInt(64, 1) << (shift - 1);
1372 result.ashrInPlace(shift);
1375 if (!(
result.getSExtValue() >= INT32_MIN &&
1376 result.getSExtValue() <= INT32_MAX)) {
1378 return std::nullopt;
1382 return result.trunc(bitwidth);
1385DenseElementsAttr mulBinaryFolder(DenseElementsAttr
lhs, DenseElementsAttr
rhs,
1386 RankedTensorType ty, int32_t shift) {
1388 if (llvm::isa<IntegerType>(ty.getElementType())) {
1389 APInt l =
lhs.getSplatValue<APInt>();
1390 APInt r =
rhs.getSplatValue<APInt>();
1396 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1397 const std::optional<APInt>
result = mulInt(l, r, shift, bitwidth);
1403 if (llvm::isa<FloatType>(ty.getElementType())) {
1404 APFloat l =
lhs.getSplatValue<APFloat>();
1405 APFloat r =
rhs.getSplatValue<APFloat>();
1415OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1416 auto lhs = getInput1();
1417 auto rhs = getInput2();
1418 auto lhsTy = llvm::dyn_cast<RankedTensorType>(
lhs.getType());
1419 auto rhsTy = llvm::dyn_cast<RankedTensorType>(
rhs.getType());
1420 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1421 if (!lhsTy || !rhsTy || !resultTy)
1424 auto resultETy = resultTy.getElementType();
1426 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1428 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1433 if (resultETy.isInteger(32)) {
1434 ElementsAttr shift_elem;
1435 if (getShift().getImpl()) {
1439 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1443 if (rhsTy == resultTy) {
1444 if (
isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
1446 return lhsAttr.resizeSplat(resultTy);
1450 if (lhsTy == resultTy) {
1451 if (
isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
1452 return rhsAttr.resizeSplat(resultTy);
1457 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1460OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1461 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1462 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1463 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1464 if (!lhsTy || !rhsTy || !resultTy)
1468 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1469 !rhsTy.getElementType().isIntOrIndexOrFloat())
1472 auto resultETy = resultTy.getElementType();
1474 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1476 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1478 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1481 if (!lhsAttr || !rhsAttr)
1487OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1488 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1490 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1492 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1494 if (!lhsAttr || !rhsAttr)
1500OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1501 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1503 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1505 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1507 if (!lhsAttr || !rhsAttr)
1513OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1514 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1516 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1518 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1519 Value
lhs = getInput1();
1520 Value
rhs = getInput2();
1521 auto lhsTy = llvm::cast<ShapedType>(
lhs.getType());
1525 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1526 resultTy.hasStaticShape() &&
lhs ==
rhs) {
1530 if (!lhsAttr || !rhsAttr)
1536OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1540 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1544 auto inTy = llvm::cast<ShapedType>(getInput().
getType());
1545 auto outTy = llvm::cast<ShapedType>(
getType());
1546 auto inETy = inTy.getElementType();
1547 auto outETy = outTy.getElementType();
1549 if (operand.isSplat()) {
1550 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1552 auto splatVal = operand.getSplatValue<APFloat>();
1553 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1554 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1559 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1560 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1561 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1562 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1563 llvm::RoundingMode::NearestTiesToEven);
1567 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1568 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1569 auto intVal = APSInt(
1570 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1571 auto floatVal = operand.getSplatValue<APFloat>();
1573 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1578 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1579 const auto inIntType = llvm::cast<IntegerType>(inETy);
1580 auto unsignIn = inIntType.isUnsignedInteger();
1582 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1583 auto intVal = operand.getSplatValue<APInt>();
1584 auto bitwidth = outETy.getIntOrFloatBitWidth();
1587 if (outETy.isInteger(1)) {
1588 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1590 intVal = intVal.trunc(bitwidth);
1591 }
else if (unsignIn || inIntType.isInteger(1)) {
1592 intVal = intVal.zext(bitwidth);
1594 intVal = intVal.sext(bitwidth);
1604OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1606OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1608#define REDUCE_FOLDER(OP) \
1609 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1610 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1611 if (!inputTy.hasRank()) \
1613 if (inputTy != getType()) \
1615 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1616 return getInput(); \
1629 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1630 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1632 if (!inputTy || !outputTy)
1638 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1642 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1643 getInput1().getDefiningOp())) {
1644 getInput1Mutable().assign(reshapeOp.getInput1());
1649 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1654 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1656 if (!outputTy.hasStaticShape())
1660 if (operand.isSplat())
1665 if (!getInput1().hasOneUse())
1672 return operand.reshape(
1673 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1679OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1681 if (adaptor.getPadding() && getInput1().
getType() ==
getType()) {
1682 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1683 if (densePad && densePad.isSplat() &&
1684 densePad.getSplatValue<APInt>().isZero()) {
1694OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1696 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1698 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1700 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1701 if (!scaleAttr || !offsetAttr || !borderAttr) {
1708 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1713 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1718 if (offset[0] != 0 || offset[1] != 0) {
1723 if (border[0] != 0 || border[1] != 0) {
1727 auto input = getInput();
1728 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1729 auto resultTy = llvm::cast<RankedTensorType>(
getType());
1730 if (inputTy != resultTy)
1736OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1737 auto operand = getInput1();
1738 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1739 auto axis = getAxis();
1741 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1746 if (operandTy.hasRank() &&
1747 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1753OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1754 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1755 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1757 if (!inputTy || !outputTy)
1760 if (inputTy == outputTy && inputTy.hasStaticShape())
1763 if (!adaptor.getInput1())
1767 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1768 !outputTy.getElementType().isIntOrIndexOrFloat())
1771 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1772 if (operand.isSplat() && outputTy.hasStaticShape()) {
1776 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1777 outputTy.getNumElements() == 1) {
1778 DenseElementsAttr startElems;
1782 llvm::SmallVector<uint64_t>
indices =
1783 llvm::to_vector(startElems.
getValues<uint64_t>());
1784 auto value = operand.getValues<Attribute>()[
indices];
1793 const auto isDynamic = [](
Type ty) {
1794 const auto shapedTy = llvm::dyn_cast<ShapedType>(ty);
1795 return !shapedTy || !shapedTy.hasStaticShape();
1798 return llvm::any_of(operandTypes, isDynamic) ||
1802OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1809 if (getOnTrue() == getOnFalse())
1813 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1817 if (!predicate.isSplat())
1819 return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1823OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
1825 if (
auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1826 adaptor.getMultiples())) {
1827 if (multiples.isSplat() &&
1828 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1830 if (
auto int_array_attr =
1831 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1832 if (llvm::all_of(int_array_attr.getValues<APInt>(),
1833 [](APInt v) { return v.getSExtValue() == 1; }))
1841OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
1842 auto resultTy = llvm::cast<ShapedType>(
getType());
1846 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1847 if (input.isSplat() && resultTy.hasStaticShape() &&
1848 input.getType().getElementType() == resultTy.getElementType())
1849 return input.reshape(resultTy);
1853 const llvm::ArrayRef<int32_t> perms = getPerms();
1855 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1861OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1864 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1870 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1871 failed(maybeIZp) || *maybeIZp != 0) {
1875 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1876 failed(maybeOZp) || *maybeOZp != 0) {
1880 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1881 failed(maybeIZp) || *maybeIZp != 0) {
1885 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1886 failed(maybeOZp) || *maybeOZp != 0) {
1891 return definingOp.getInput1();
1894OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
1895 auto input = getInput1();
1897 if (
auto op = input.getDefiningOp<tosa::AbsOp>()) {
1904OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1909 SmallVector<Value, 8> concatOperands;
1910 concatOperands.reserve(2 * getNumOperands());
1913 bool foundFoldableConcat =
false;
1914 for (Value operand : getOperands()) {
1915 concatOperands.emplace_back(operand);
1917 auto producer = operand.getDefiningOp<ConcatOp>();
1922 if (getAxis() != producer.getAxis())
1926 foundFoldableConcat =
true;
1927 concatOperands.pop_back();
1928 llvm::append_range(concatOperands, producer->getOperands());
1931 if (!foundFoldableConcat)
1934 getOperation()->setOperands(concatOperands);
1938OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1939 auto input = adaptor.getInput1();
1941 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1943 if (!inputAttr || !inputAttr.isSplat())
1946 auto shapeType = llvm::cast<ShapedType>(
getType());
1947 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1948 auto floatVal = inputAttr.getSplatValue<APFloat>();
1950 ReciprocalOp::calcOneElement(floatVal));
1956template <
typename Op,
typename OpFoldAdaptor>
1958 auto input1ConstShape =
1959 dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
1960 if (!input1ConstShape)
1963 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
1969template <
typename Op,
typename OpFoldAdaptor>
1971 auto input1ConstShape =
1972 dyn_cast<tosa::ConstShapeOp>(op->getInput1().getDefiningOp());
1973 auto input2ConstShape =
1974 dyn_cast<tosa::ConstShapeOp>(op->getInput2().getDefiningOp());
1975 if (!input1ConstShape || !input2ConstShape)
1978 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
1979 const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
1982 input1Attr.getType(),
1986OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
1987 const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().
getType());
1988 if (!inputTy || !inputTy.hasRank())
1990 const int32_t axis = getAxis();
1991 const int64_t dimSize = inputTy.getDimSize(axis);
1992 if (ShapedType::isDynamic(dimSize))
1996 const auto resultAttrTy =
1997 RankedTensorType::get(1, builder.getIndexType());
2001OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
2005OpFoldResult tosa::SubShapeOp::fold(FoldAdaptor adaptor) {
2009OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
2013OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
2014 return binaryFold<DivCeilShapeOp, DivFoldAdaptor<
true>>(
this);
2017OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
2018 return binaryFold<DivFloorShapeOp, DivFoldAdaptor<
false>>(
this);
2021OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
2025OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
2029OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
2033OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
2037OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
2041OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
#define REDUCE_FOLDER(OP)
static bool mayRequireBroadcast(ValueTypeRange< mlir::OperandRange > operandTypes)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, bool foldDenseValues=false)
static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy, bool foldDenseValues=false)
OpFoldResult binaryFold(Op *op)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
OpFoldResult unaryShapeFold(Op *op)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool signsDiffer(const APInt &a, const APInt &b)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class implements iteration on the types of a given range of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
DynamicAPInt round(const Fraction &f)
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
LogicalResult matchAndRewrite(tosa::CastOp castOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...