28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
51 (padConstAttr.
size() != 1)) {
56 if (
auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
57 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
58 return padConstVal == 0.0f;
62 if (
auto padConstIntAttr =
63 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
71 int64_t zpVal = (*zpAttr.
begin()).getSExtValue();
72 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
73 return zpVal == padConstVal;
81 template <
typename OpTy>
82 struct PoolPadFoldAdaptor;
85 struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
86 using OpTy = tosa::AvgPool2dOp;
89 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
90 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
94 static bool checkPadConstCompliance(OpTy op,
Value padConst) {
100 op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
107 struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
108 using OpTy = tosa::MaxPool2dOp;
111 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
112 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
116 static bool checkPadConstCompliance(OpTy,
Value padConst) {
120 padConstAttr.
size() != 1) {
125 if (
auto padConstFpAttr =
126 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
127 const APFloat padConstVal = *padConstFpAttr.begin();
128 const APFloat lowestVal =
129 APFloat::getLargest(padConstVal.getSemantics(),
true);
130 return padConstVal == lowestVal;
131 }
else if (
auto padConstIntAttr =
132 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
133 const APInt padConstVal = *padConstIntAttr.begin();
134 const unsigned int bitWidth = padConstVal.getBitWidth();
135 const APInt lowestVal =
136 padConstIntAttr.getElementType().isUnsignedInteger()
138 : APInt::getSignedMinValue(bitWidth);
139 return padConstVal == lowestVal;
148 op, op.getType(), padInput, op.getKernel(), op.getStride(),
153 template <
typename OpTy>
154 struct ConvPadFoldAdaptor {
158 static bool checkPadConstCompliance(OpTy op,
Value padConst) {
164 op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
165 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
166 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
174 template <
typename OpTy,
typename AdaptorTy>
178 LogicalResult matchAndRewrite(OpTy tensorOp,
181 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
184 "Producer must be a tosa::PadOp.");
187 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
188 if (tensorOpPad.size() != 4)
190 tensorOp,
"Tensor operation padding shall have 4 elements.");
197 "The `padding` input specified on the tosa::PadOp must be constant.");
201 if (padOpPadding.size() != 8)
203 "Pad padding should have 8 elements.");
204 int64_t padNBefore = (*(padOpPadding.
begin() + 0)).getLimitedValue();
205 int64_t padNAfter = (*(padOpPadding.
begin() + 1)).getLimitedValue();
206 int64_t padHBefore = (*(padOpPadding.
begin() + 2)).getLimitedValue();
207 int64_t padHAfter = (*(padOpPadding.
begin() + 3)).getLimitedValue();
208 int64_t padWBefore = (*(padOpPadding.
begin() + 4)).getLimitedValue();
209 int64_t padWAfter = (*(padOpPadding.
begin() + 5)).getLimitedValue();
210 int64_t padCBefore = (*(padOpPadding.
begin() + 6)).getLimitedValue();
211 int64_t padCAfter = (*(padOpPadding.
begin() + 7)).getLimitedValue();
213 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
215 tensorOp,
"Folding padding in N or C dimensions is not supported.");
220 foldedPad[0] = padHBefore + tensorOpPad[0];
221 foldedPad[1] = padHAfter + tensorOpPad[1];
222 foldedPad[2] = padWBefore + tensorOpPad[2];
223 foldedPad[3] = padWAfter + tensorOpPad[3];
226 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
228 tensorOp,
"Padding size not aligned with kernel restrictions.");
232 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
235 "Padding constant is not aligned with operator zero-point.");
239 if (llvm::any_of(foldedPad, [](int64_t padVal) {
return padVal > 8192; })) {
241 tensorOp,
"Padding size more than the 8K level limit.");
245 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
255 results.
add<FoldPadToTensorOp<tosa::AvgPool2dOp,
256 PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
263 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
269 results.
add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
270 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
279 Value input = op.getInput();
280 Value output = op.getOutput();
281 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
282 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
284 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
290 if (outputShape[1] != 1 || outputShape[2] != 1) {
295 if (inputShape[1] != 1 || inputShape[2] != 1) {
307 FoldPadToTensorOp<tosa::MaxPool2dOp,
308 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
321 if (op.getInput1().size() != 1)
323 if (op.getInput1().front().getType() != op.getType()) {
326 op.getInput1().front())
331 rewriter.
replaceOp(op, op.getInput1().front());
341 LogicalResult SelectOp::canonicalize(SelectOp op,
PatternRewriter &rewriter) {
342 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
346 op.getOperation()->setOperands(
347 {notOp.getInput1(), op.getInput3(), op.getInput2()});
359 auto innerTranspose =
360 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
363 "input must be transpose operation");
367 innerTranspose.getPerms();
369 if (transposePerms.size() != innerTransposePerms.size())
372 "transpose and inner transpose perms sizes must be equal");
373 if (transposePerms.empty())
375 transposeOp,
"transpose perms sizes must be positive");
379 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
380 perms[i] = innerTransposePerms[transposePerms[i]];
383 transposeOp, transposeOp.getResult().getType(),
396 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
398 op,
"Src is from transpose, can compose transposes");
400 Value result = op.getResult();
402 if (dyn_cast_or_null<tosa::TransposeOp>(subop))
404 op,
"Dest is used by transpose, can compose transposes");
407 auto input = op.getInput1();
408 auto inputTy = llvm::cast<ShapedType>(input.getType());
409 if (!inputTy.hasRank())
412 int64_t numDynDims = 0;
413 for (
int i = 0; i < inputTy.getRank(); ++i)
414 if (inputTy.isDynamicDim(i))
423 nonZeroPerms.reserve(permValues.size());
424 for (
auto idx : permValues) {
425 auto sz = inputTy.getDimSize(idx);
427 nonZeroPerms.push_back(idx);
430 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
431 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
433 "Transpose changes memory layout.");
436 newShape.reserve(inputTy.getRank());
437 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
438 newShape.push_back(inputTy.getDimSize(permValues[i]));
441 op, op.getType(), op.getInput1(),
457 Value input = op.getInput();
458 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
459 auto inputElementType = inputType.getElementType();
461 if (!inputType.hasStaticShape()) {
465 if (isa<FloatType>(inputElementType)) {
468 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
470 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
471 bool isMin = minClamp.isNegInfinity();
472 bool isMax = maxClamp.isInfinity();
474 if (isMin && isMax) {
481 if (inputElementType.isUnsignedInteger()) {
483 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt();
485 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt();
488 APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
491 APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
494 if (minClamp <= intMin && maxClamp >= intMax) {
501 if (llvm::isa<IntegerType>(inputElementType)) {
503 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
505 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
508 APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
511 APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
514 if (minClamp <= intMin && maxClamp >= intMax) {
546 template <
typename T>
548 ClampRange(
const T &start,
const T &end) : start(start), end(end) {}
554 return start < otherRange.
end && otherRange.
start < end;
560 Value input = op.getInput();
563 auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.
getDefiningOp());
568 const auto opNanMode = op.getNanMode();
569 const auto clampNanMode = clampOp.getNanMode();
570 if (opNanMode ==
"IGNORE" && clampNanMode ==
"PROPAGATE")
573 auto maxValAttr = op.getMaxValAttr();
574 auto minValAttr = op.getMinValAttr();
575 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
576 auto clampOpMinValAttr = clampOp.getMinValAttr();
578 auto inputEType = llvm::cast<ShapedType>(input.
getType()).getElementType();
580 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
581 inputEType = quantType.getStorageType();
585 if (mlir::isa<FloatType>(inputEType)) {
586 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
587 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
588 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
589 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
592 const auto opMinFloat = floatMinValAttr.getValue();
593 const auto opMaxFloat = floatMaxValAttr.getValue();
594 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
595 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
599 if (!opRangeFloatRange.
intersects(clampRangeFloatRange))
603 auto newMinVal =
std::max(opMinFloat, clampOpMinFloat);
604 auto newMaxVal =
std::min(opMaxFloat, clampOpMaxFloat);
605 newMinValAttr = rewriter.
getFloatAttr(inputEType, newMinVal);
606 newMaxValAttr = rewriter.
getFloatAttr(inputEType, newMaxVal);
608 assert(mlir::isa<IntegerType>(inputEType));
609 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
610 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
611 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
612 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
614 if (inputEType.isUnsignedInteger()) {
616 const auto opMinInt = intMinValAttr.getUInt();
617 const auto opMaxInt = intMaxValAttr.getUInt();
618 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
619 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
623 if (!opRangeIntRange.
intersects(clampRangeIntRange))
627 auto newMinVal =
std::max(opMinInt, clampOpMinInt);
628 auto newMaxVal =
std::min(opMaxInt, clampOpMaxInt);
633 const auto opMinInt = intMinValAttr.getInt();
634 const auto opMaxInt = intMaxValAttr.getInt();
635 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
636 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
640 if (!opRangeIntRange.
intersects(clampRangeIntRange))
644 auto newMinVal =
std::max(opMinInt, clampOpMinInt);
645 auto newMaxVal =
std::min(opMaxInt, clampOpMaxInt);
652 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
653 rewriter.
getStringAttr((opNanMode != clampNanMode) ?
"IGNORE"
670 Value sliceInput = sliceOp.getInput1();
674 sliceOp,
"slice input must be concat operation");
677 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
678 if (!concatType || !concatType.hasStaticShape())
680 sliceOp,
"slice input must be a static ranked tensor");
681 int32_t axis = concatOp.getAxis();
688 sliceOp,
"start of slice must be a static ranked shape");
692 sliceOp,
"size of slice must be a static ranked shape");
695 llvm::to_vector(startElems.
getValues<int64_t>());
697 llvm::to_vector(sizeElems.
getValues<int64_t>());
702 std::optional<Value> replaceWithSlice;
703 for (
auto input : inputs) {
704 auto inputType = dyn_cast<RankedTensorType>(input.getType());
705 if (!inputType || !inputType.hasStaticShape())
707 sliceOp,
"concat input must be a static ranked tensor");
709 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
710 inputType.getDimSize(axis)) {
717 .
create<tosa::SliceOp>(sliceOp.getLoc(), sliceOp.getType(),
718 input, start_op, size_op)
722 sliceStarts[axis] -= inputType.getDimSize(axis);
725 if (!replaceWithSlice)
727 sliceOp,
"corresponding concat input not found for slice");
729 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
742 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
744 ElementsAttr sizeElems;
747 sliceOp,
"size of slice must be a static ranked shape");
751 llvm::to_vector(sizeElems.getValues<int64_t>());
753 bool replaceSliceSize{
false};
758 if (size == -1 && !resultType.isDynamicDim(index)) {
759 sliceSizes[index] = resultType.getDimSize(index);
760 replaceSliceSize =
true;
764 if (!replaceSliceSize) {
766 sliceOp,
"no dimension of size of slice is dynamic that resolves "
767 "to static output shape");
771 auto newSliceOp = rewriter.
create<tosa::SliceOp>(
772 sliceOp.getLoc(), sliceOp.getType(), sliceOp.getInput1(),
773 sliceOp.getStart(), size_op);
775 rewriter.
replaceOp(sliceOp, newSliceOp.getResult());
790 template <
typename IntFolder,
typename FloatFolder>
792 RankedTensorType returnTy) {
795 auto rETy = llvm::cast<ShapedType>(rhs.
getType()).getElementType();
799 if (llvm::isa<IntegerType>(lETy)) {
802 auto result = IntFolder()(l, r);
806 if (llvm::isa<FloatType>(lETy)) {
809 auto result = FloatFolder()(l, r);
818 if (llvm::isa<FloatType>(elemType))
820 if (llvm::isa<IntegerType>(elemType))
826 if (llvm::isa<FloatType>(elemType))
829 if (llvm::isa<IntegerType>(elemType)) {
830 const int64_t shifted = 1LL << shift;
838 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
839 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
840 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
841 if (!lhsTy || !rhsTy || !resultTy)
845 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
846 !rhsTy.getElementType().isIntOrIndexOrFloat())
849 auto resultETy = resultTy.getElementType();
851 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
853 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
855 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
857 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
860 if (!lhsAttr || !rhsAttr)
863 return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
868 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
869 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
870 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
871 !outputTy.hasStaticShape())
874 if (inputTy.getDimSize(getAxis()) == 1)
881 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
882 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
883 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
884 if (!lhsTy || !rhsTy || !resultTy)
890 auto resultETy = resultTy.getElementType();
892 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
894 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
895 if (lhsAttr && lhsAttr.isSplat()) {
896 if (llvm::isa<IntegerType>(resultETy) &&
897 lhsAttr.getSplatValue<APInt>().isZero())
901 if (rhsAttr && rhsAttr.isSplat()) {
902 if (llvm::isa<IntegerType>(resultETy) &&
903 rhsAttr.getSplatValue<APInt>().isOne())
907 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
908 llvm::isa<IntegerType>(resultETy)) {
909 APInt l = lhsAttr.getSplatValue<APInt>();
910 APInt r = rhsAttr.getSplatValue<APInt>();
912 APInt result = l.sdiv(r);
922 RankedTensorType ty, int32_t shift) {
924 if (llvm::isa<IntegerType>(ty.getElementType())) {
932 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
933 l = l.sext(bitwidth * 2);
934 r = r.sext(bitwidth * 2);
936 result.lshrInPlace(shift);
937 result = result.trunc(bitwidth);
941 if (llvm::isa<FloatType>(ty.getElementType())) {
944 APFloat result = l * r;
954 auto lhs = getInput1();
955 auto rhs = getInput2();
956 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.
getType());
957 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.
getType());
958 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
959 if (!lhsTy || !rhsTy || !resultTy)
962 auto resultETy = resultTy.getElementType();
964 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
966 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
971 if (resultETy.isInteger(32)) {
972 ElementsAttr shift_elem;
973 if (getShift().getImpl()) {
977 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
981 if (rhsTy == resultTy) {
983 return lhsAttr.resizeSplat(resultTy);
987 if (lhsTy == resultTy) {
994 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
998 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
999 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1000 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1001 if (!lhsTy || !rhsTy || !resultTy)
1005 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1006 !rhsTy.getElementType().isIntOrIndexOrFloat())
1009 auto resultETy = resultTy.getElementType();
1011 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1013 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1015 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1018 if (!lhsAttr || !rhsAttr)
1021 return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
1026 template <
typename Cmp>
1027 struct ComparisonFold {
1028 ComparisonFold() =
default;
1029 APInt operator()(
const APInt &l,
const APInt &r) {
1030 return APInt(1, Cmp()(l, r));
1033 APInt operator()(
const APFloat &l,
const APFloat &r) {
1034 return APInt(1, Cmp()(l, r));
1038 struct APIntFoldGreater {
1039 APIntFoldGreater() =
default;
1040 APInt operator()(
const APInt &l,
const APInt &r) {
1041 return APInt(1, l.sgt(r));
1045 struct APIntFoldGreaterEqual {
1046 APIntFoldGreaterEqual() =
default;
1047 APInt operator()(
const APInt &l,
const APInt &r) {
1048 return APInt(1, l.sge(r));
1054 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1056 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1058 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1060 if (!lhsAttr || !rhsAttr)
1063 return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
1064 lhsAttr, rhsAttr, resultTy);
1067 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1068 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1070 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1072 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1074 if (!lhsAttr || !rhsAttr)
1078 ComparisonFold<std::greater_equal<APFloat>>>(
1079 lhsAttr, rhsAttr, resultTy);
1083 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1085 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1087 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1088 Value lhs = getInput1();
1089 Value rhs = getInput2();
1090 auto lhsTy = llvm::cast<ShapedType>(lhs.
getType());
1094 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1095 resultTy.hasStaticShape() && lhs == rhs) {
1099 if (!lhsAttr || !rhsAttr)
1102 return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
1103 ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
1111 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1115 auto inTy = llvm::cast<ShapedType>(getInput().
getType());
1116 auto outTy = llvm::cast<ShapedType>(
getType());
1117 auto inETy = inTy.getElementType();
1118 auto outETy = outTy.getElementType();
1120 if (operand.isSplat()) {
1121 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1123 auto splatVal = operand.getSplatValue<APFloat>();
1124 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1125 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1130 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1131 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1132 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1133 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1134 llvm::RoundingMode::NearestTiesToEven);
1138 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1139 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1140 auto intVal = APSInt(
1141 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1142 auto floatVal = operand.getSplatValue<APFloat>();
1144 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1149 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1150 auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1152 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1153 auto intVal = operand.getSplatValue<APInt>();
1154 auto bitwidth = outETy.getIntOrFloatBitWidth();
1157 intVal = intVal.trunc(bitwidth);
1158 }
else if (unsignIn) {
1159 intVal = intVal.zext(bitwidth);
1161 intVal = intVal.sext(bitwidth);
1171 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1173 OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1175 #define REDUCE_FOLDER(OP) \
1176 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1177 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1178 if (!inputTy.hasRank()) \
1180 if (inputTy != getType()) \
1182 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1183 return getInput(); \
1193 #undef REDUCE_FOLDER
1196 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1197 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1199 if (!inputTy || !outputTy)
1205 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1209 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1210 getInput1().getDefiningOp())) {
1211 getInput1Mutable().assign(reshapeOp.getInput1());
1216 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1221 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1223 if (!outputTy.hasStaticShape())
1227 if (operand.isSplat())
1232 if (!getInput1().hasOneUse())
1239 return operand.reshape(
1240 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1248 if (adaptor.getPadding() && getInput1().
getType() ==
getType()) {
1249 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1250 if (densePad && densePad.isSplat() &&
1251 densePad.getSplatValue<APInt>().isZero()) {
1263 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1265 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1267 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1268 if (!scaleAttr || !offsetAttr || !borderAttr) {
1275 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1280 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1285 if (offset[0] != 0 || offset[1] != 0) {
1290 if (border[0] != 0 || border[1] != 0) {
1294 auto input = getInput();
1295 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1296 auto resultTy = llvm::cast<RankedTensorType>(
getType());
1297 if (inputTy != resultTy)
1304 auto operand = getInput1();
1305 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1306 auto axis = getAxis();
1308 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1313 if (operandTy.hasRank() &&
1314 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1321 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1322 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1324 if (!inputTy || !outputTy)
1327 if (inputTy == outputTy && inputTy.hasStaticShape())
1330 if (!adaptor.getInput1())
1334 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1335 !outputTy.getElementType().isIntOrIndexOrFloat())
1338 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1339 if (operand.isSplat() && outputTy.hasStaticShape()) {
1343 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1344 outputTy.getNumElements() == 1) {
1350 llvm::to_vector(startElems.
getValues<uint64_t>());
1351 auto value = operand.getValues<
Attribute>()[indices];
1358 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1359 if (getInput2() == getInput3())
1363 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1367 if (!predicate.isSplat())
1369 return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
1375 if (
auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1376 adaptor.getMultiples())) {
1377 if (multiples.isSplat() &&
1378 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1380 if (
auto int_array_attr =
1381 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1382 if (llvm::all_of(int_array_attr.getValues<APInt>(),
1383 [](APInt v) { return v.getSExtValue() == 1; }))
1392 auto resultTy = llvm::cast<ShapedType>(
getType());
1396 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1397 if (input.isSplat() && resultTy.hasStaticShape() &&
1398 input.getType().getElementType() == resultTy.getElementType())
1399 return input.reshape(resultTy);
1405 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1412 auto input = getInput1();
1414 if (
auto op = input.getDefiningOp<tosa::ExpOp>()) {
1415 return op.getInput1();
1422 auto input = getInput1();
1424 if (
auto op = input.getDefiningOp<tosa::LogOp>()) {
1425 return op.getInput1();
1431 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1434 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1440 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1441 failed(maybeIZp) || *maybeIZp != 0) {
1445 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1446 failed(maybeOZp) || *maybeOZp != 0) {
1450 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1451 failed(maybeIZp) || *maybeIZp != 0) {
1455 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1456 failed(maybeOZp) || *maybeOZp != 0) {
1461 return definingOp.getInput1();
1465 auto input = getInput1();
1467 if (
auto op = input.getDefiningOp<tosa::AbsOp>()) {
1480 concatOperands.reserve(2 * getNumOperands());
1483 bool foundFoldableConcat =
false;
1484 for (
Value operand : getOperands()) {
1485 concatOperands.emplace_back(operand);
1487 auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1492 if (getAxis() != producer.getAxis())
1496 foundFoldableConcat =
true;
1497 concatOperands.pop_back();
1498 llvm::append_range(concatOperands, producer->getOperands());
1501 if (!foundFoldableConcat)
1504 getOperation()->setOperands(concatOperands);
1508 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1509 auto input = adaptor.getInput1();
1511 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1513 if (!inputAttr || !inputAttr.isSplat())
1516 auto shapeType = llvm::cast<ShapedType>(
getType());
1517 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1518 auto floatVal = inputAttr.getSplatValue<APFloat>();
1520 ReciprocalOp::calcOneElement(floatVal));
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
#define REDUCE_FOLDER(OP)
bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
StringAttr getStringAttr(const Twine &bytes)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
auto getValues() const
Return the held element values as a range of the given type.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...