24 #include "llvm/ADT/APFloat.h"
25 #include "llvm/ADT/APInt.h"
45 (padConstAttr.
size() != 1)) {
50 if (
auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
51 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
52 return padConstVal == 0.0f;
56 if (
auto padConstIntAttr =
57 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
65 int64_t zpVal = (*zpAttr.
begin()).getSExtValue();
66 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
67 return zpVal == padConstVal;
75 template <
typename OpTy>
76 struct PoolPadFoldAdaptor;
79 struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
80 using OpTy = tosa::AvgPool2dOp;
83 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
84 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
88 static bool checkPadConstCompliance(OpTy op,
Value padConst) {
94 op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
101 struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
102 using OpTy = tosa::MaxPool2dOp;
105 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
106 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
110 static bool checkPadConstCompliance(OpTy,
Value padConst) {
114 padConstAttr.
size() != 1) {
119 if (
auto padConstFpAttr =
120 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
121 const APFloat padConstVal = *padConstFpAttr.begin();
122 const APFloat lowestVal =
123 APFloat::getLargest(padConstVal.getSemantics(),
true);
124 return padConstVal == lowestVal;
126 if (
auto padConstIntAttr =
127 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
128 const APInt padConstVal = *padConstIntAttr.begin();
129 const unsigned int bitWidth = padConstVal.getBitWidth();
130 const APInt lowestVal =
131 padConstIntAttr.getElementType().isUnsignedInteger()
133 : APInt::getSignedMinValue(bitWidth);
134 return padConstVal == lowestVal;
143 op, op.getType(), padInput, op.getKernel(), op.getStride(),
148 template <
typename OpTy>
149 struct ConvPadFoldAdaptor {
153 static bool checkPadConstCompliance(OpTy op,
Value padConst) {
159 op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
160 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
161 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
169 template <
typename OpTy,
typename AdaptorTy>
173 LogicalResult matchAndRewrite(OpTy tensorOp,
176 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
179 "Producer must be a tosa::PadOp.");
182 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
183 if (tensorOpPad.size() != 4)
185 tensorOp,
"Tensor operation padding shall have 4 elements.");
192 "The `padding` input specified on the tosa::PadOp must be constant.");
196 if (padOpPadding.size() != 8)
198 "Pad padding should have 8 elements.");
199 int64_t padNBefore = (*(padOpPadding.
begin() + 0)).getLimitedValue();
200 int64_t padNAfter = (*(padOpPadding.
begin() + 1)).getLimitedValue();
201 int64_t padHBefore = (*(padOpPadding.
begin() + 2)).getLimitedValue();
202 int64_t padHAfter = (*(padOpPadding.
begin() + 3)).getLimitedValue();
203 int64_t padWBefore = (*(padOpPadding.
begin() + 4)).getLimitedValue();
204 int64_t padWAfter = (*(padOpPadding.
begin() + 5)).getLimitedValue();
205 int64_t padCBefore = (*(padOpPadding.
begin() + 6)).getLimitedValue();
206 int64_t padCAfter = (*(padOpPadding.
begin() + 7)).getLimitedValue();
208 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
210 tensorOp,
"Folding padding in N or C dimensions is not supported.");
215 foldedPad[0] = padHBefore + tensorOpPad[0];
216 foldedPad[1] = padHAfter + tensorOpPad[1];
217 foldedPad[2] = padWBefore + tensorOpPad[2];
218 foldedPad[3] = padWAfter + tensorOpPad[3];
221 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
223 tensorOp,
"Padding size not aligned with kernel restrictions.");
227 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
230 "Padding constant is not aligned with operator zero-point.");
234 if (llvm::any_of(foldedPad, [](int64_t padVal) {
return padVal > 8192; })) {
236 tensorOp,
"Padding size more than the 8K level limit.");
240 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
250 results.
add<FoldPadToTensorOp<tosa::AvgPool2dOp,
251 PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
258 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
264 results.
add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
265 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
274 Value input = op.getInput();
275 Value output = op.getOutput();
276 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
277 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
279 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
285 if (outputShape[1] != 1 || outputShape[2] != 1) {
290 if (inputShape[1] != 1 || inputShape[2] != 1) {
302 FoldPadToTensorOp<tosa::MaxPool2dOp,
303 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
316 if (op.getInput1().size() != 1)
318 if (op.getInput1().front().getType() != op.getType()) {
321 op.getInput1().front())
326 rewriter.
replaceOp(op, op.getInput1().front());
336 LogicalResult SelectOp::canonicalize(SelectOp op,
PatternRewriter &rewriter) {
337 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
341 op.getOperation()->setOperands(
342 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
354 auto innerTranspose =
355 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
358 "input must be transpose operation");
362 innerTranspose.getPerms();
364 if (transposePerms.size() != innerTransposePerms.size())
367 "transpose and inner transpose perms sizes must be equal");
368 if (transposePerms.empty())
370 transposeOp,
"transpose perms sizes must be positive");
374 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
375 perms[i] = innerTransposePerms[transposePerms[i]];
378 transposeOp, transposeOp.getResult().getType(),
391 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
393 op,
"Src is from transpose, can compose transposes");
395 Value result = op.getResult();
397 if (isa_and_nonnull<tosa::TransposeOp>(subop))
399 op,
"Dest is used by transpose, can compose transposes");
402 auto input = op.getInput1();
403 auto inputTy = llvm::cast<ShapedType>(input.getType());
404 if (!inputTy.hasRank())
407 int64_t numDynDims = 0;
408 for (
int i = 0; i < inputTy.getRank(); ++i)
409 if (inputTy.isDynamicDim(i))
418 nonZeroPerms.reserve(permValues.size());
419 for (
auto idx : permValues) {
420 auto sz = inputTy.getDimSize(idx);
422 nonZeroPerms.push_back(idx);
425 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
426 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
428 "Transpose changes memory layout.");
431 newShape.reserve(inputTy.getRank());
432 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
433 newShape.push_back(inputTy.getDimSize(permValues[i]));
436 op, op.getType(), op.getInput1(),
452 Value input = op.getInput();
453 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
454 auto inputElementType = inputType.getElementType();
456 if (isa<FloatType>(inputElementType)) {
458 const auto minClamp =
459 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
460 const auto maxClamp =
461 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
462 const bool isMin = minClamp.isNegInfinity();
463 const bool isMax = maxClamp.isInfinity();
465 if (isMin && isMax) {
473 const bool isBoolean = inputElementType.isInteger(1);
474 if (inputElementType.isUnsignedInteger() || isBoolean) {
475 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
478 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
482 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
483 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
484 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
486 if (minClamp <= intMin && maxClamp >= intMax) {
493 if (llvm::isa<IntegerType>(inputElementType)) {
494 const int64_t minClamp =
495 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
496 const int64_t maxClamp =
497 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
499 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
500 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
501 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
503 if (minClamp <= intMin && maxClamp >= intMax) {
535 template <
typename T>
537 ClampRange(
const T &start,
const T &end) : start(start), end(end) {}
543 return start < otherRange.
end && otherRange.
start < end;
549 Value input = op.getInput();
557 const auto opNanMode = op.getNanMode();
558 const auto clampNanMode = clampOp.getNanMode();
559 if (opNanMode == NanPropagationMode::IGNORE &&
560 clampNanMode == NanPropagationMode::PROPAGATE)
563 auto maxValAttr = op.getMaxValAttr();
564 auto minValAttr = op.getMinValAttr();
565 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
566 auto clampOpMinValAttr = clampOp.getMinValAttr();
568 auto inputEType = llvm::cast<ShapedType>(input.
getType()).getElementType();
570 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
571 inputEType = quantType.getStorageType();
575 if (mlir::isa<FloatType>(inputEType)) {
576 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
577 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
578 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
579 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
582 const auto opMinFloat = floatMinValAttr.getValue();
583 const auto opMaxFloat = floatMaxValAttr.getValue();
584 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
585 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
589 if (!opRangeFloatRange.
intersects(clampRangeFloatRange))
593 auto newMinVal =
std::max(opMinFloat, clampOpMinFloat);
594 auto newMaxVal =
std::min(opMaxFloat, clampOpMaxFloat);
595 newMinValAttr = rewriter.
getFloatAttr(inputEType, newMinVal);
596 newMaxValAttr = rewriter.
getFloatAttr(inputEType, newMaxVal);
598 assert(mlir::isa<IntegerType>(inputEType));
599 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
600 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
601 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
602 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
604 if (inputEType.isUnsignedInteger()) {
606 const auto opMinInt = intMinValAttr.getUInt();
607 const auto opMaxInt = intMaxValAttr.getUInt();
608 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
609 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
613 if (!opRangeIntRange.
intersects(clampRangeIntRange))
617 auto newMinVal =
std::max(opMinInt, clampOpMinInt);
618 auto newMaxVal =
std::min(opMaxInt, clampOpMaxInt);
623 const auto opMinInt = intMinValAttr.getInt();
624 const auto opMaxInt = intMaxValAttr.getInt();
625 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
626 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
630 if (!opRangeIntRange.
intersects(clampRangeIntRange))
634 auto newMinVal =
std::max(opMinInt, clampOpMinInt);
635 auto newMaxVal =
std::min(opMaxInt, clampOpMaxInt);
641 auto newMode = (opNanMode != clampNanMode)
642 ? tosa::NanPropagationMode::IGNORE
649 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
666 Value sliceInput = sliceOp.getInput1();
670 sliceOp,
"slice input must be concat operation");
673 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
674 if (!concatType || !concatType.hasStaticShape())
676 sliceOp,
"slice input must be a static ranked tensor");
677 int32_t axis = concatOp.getAxis();
684 sliceOp,
"start of slice must be a static ranked shape");
688 sliceOp,
"size of slice must be a static ranked shape");
691 llvm::to_vector(startElems.
getValues<int64_t>());
693 llvm::to_vector(sizeElems.
getValues<int64_t>());
698 std::optional<Value> replaceWithSlice;
699 for (
auto input : inputs) {
700 auto inputType = dyn_cast<RankedTensorType>(input.getType());
701 if (!inputType || !inputType.hasStaticShape())
703 sliceOp,
"concat input must be a static ranked tensor");
705 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
706 inputType.getDimSize(axis)) {
712 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
713 input, start_op, size_op)
717 sliceStarts[axis] -= inputType.getDimSize(axis);
720 if (!replaceWithSlice)
722 sliceOp,
"corresponding concat input not found for slice");
724 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
734 Value sliceInput = sliceOp.getInput1();
740 "slice input must be a pad operation");
743 if (!padOp->hasOneUse())
745 "pad shall have a single consumer");
748 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
749 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
750 if (!inputTy || !padTy || !inputTy.hasRank())
752 "slice input must be a ranked tensor");
759 "`padding` input specified on the tosa::PadOp must be constant.");
762 llvm::to_vector(paddingElems.getValues<int64_t>());
768 sliceOp,
"start of slice must be a static ranked shape");
770 llvm::to_vector(startElems.
getValues<int64_t>());
775 sliceOp,
"size of slice must be a static ranked shape");
777 llvm::to_vector(sizeElems.
getValues<int64_t>());
780 const int64_t rank = inputTy.getRank();
781 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
782 const bool isDimDynamic = inputTy.isDynamicDim(i);
783 const bool isDimSliced =
784 (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
786 return isDimDynamic && isDimSliced;
789 sliceOp,
"axis that are sliced shall be statically known.");
796 bool updated =
false;
798 for (int64_t i = 0; i < rank; ++i) {
799 const int64_t padLo = padPaddings[i * 2];
800 const int64_t padHi = padPaddings[i * 2 + 1];
801 const int64_t sliceStart = sliceStarts[i];
802 const int64_t sliceSize = sliceSizes[i];
803 const int64_t sliceEnd = sliceStart + sliceSize;
806 if (inputTy.isDynamicDim(i)) {
807 newPadPaddings[i * 2] = padLo;
808 newPadPaddings[i * 2 + 1] = padHi;
809 newSliceStarts[i] = sliceStart;
814 const int64_t dimSize = inputTy.getShape()[i];
815 const int64_t dimTotal = padLo + dimSize + padHi;
818 if (sliceStart < 0 || sliceEnd > dimTotal)
822 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
823 newSliceStarts[i] = newSliceStart;
824 updated |= newSliceStart != sliceStart;
827 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
828 const int64_t newPadHi =
829 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
830 newPadPaddings[i * 2] = newPadLo;
831 newPadPaddings[i * 2 + 1] = newPadHi;
832 updated |= (newPadLo != padLo) || (newPadHi != padHi);
836 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
842 sliceOp,
"terminate condition; nothing to rewrite");
849 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
850 padOp.getInput1(), newPaddingsOp,
851 padOp.getPadConst());
857 newPadOp.getResult(), newStartOp,
872 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
874 ElementsAttr sizeElems;
877 sliceOp,
"size of slice must be a static ranked shape");
881 llvm::to_vector(sizeElems.getValues<int64_t>());
883 bool replaceSliceSize{
false};
888 if (size == -1 && !resultType.isDynamicDim(index)) {
889 sliceSizes[index] = resultType.getDimSize(index);
890 replaceSliceSize =
true;
894 if (!replaceSliceSize) {
896 sliceOp,
"no dimension of size of slice is dynamic that resolves "
897 "to static output shape");
902 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
903 sliceOp.getInput1(), sliceOp.getStart(), size_op);
905 rewriter.
replaceOp(sliceOp, newSliceOp.getResult());
920 template <
typename IntFolder,
typename FloatFolder>
922 RankedTensorType returnTy) {
925 auto rETy = llvm::cast<ShapedType>(rhs.
getType()).getElementType();
929 if (llvm::isa<IntegerType>(lETy)) {
932 auto result = IntFolder()(l, r);
936 if (llvm::isa<FloatType>(lETy)) {
939 auto result = FloatFolder()(l, r);
948 if (llvm::isa<FloatType>(elemType))
950 if (llvm::isa<IntegerType>(elemType))
956 if (llvm::isa<FloatType>(elemType))
959 if (llvm::isa<IntegerType>(elemType)) {
960 const int64_t shifted = 1LL << shift;
968 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
969 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
970 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
971 if (!lhsTy || !rhsTy || !resultTy)
975 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
976 !rhsTy.getElementType().isIntOrIndexOrFloat())
979 auto resultETy = resultTy.getElementType();
981 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
983 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
985 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
987 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
990 if (!lhsAttr || !rhsAttr)
993 return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
998 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
999 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1000 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1001 !outputTy.hasStaticShape())
1004 if (inputTy.getDimSize(getAxis()) == 1)
1011 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1012 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1013 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1014 if (!lhsTy || !rhsTy || !resultTy)
1020 auto resultETy = resultTy.getElementType();
1022 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1024 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1025 if (lhsAttr && lhsAttr.isSplat()) {
1026 if (llvm::isa<IntegerType>(resultETy) &&
1027 lhsAttr.getSplatValue<APInt>().isZero())
1031 if (rhsAttr && rhsAttr.isSplat()) {
1032 if (llvm::isa<IntegerType>(resultETy) &&
1033 rhsAttr.getSplatValue<APInt>().isOne())
1037 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1038 llvm::isa<IntegerType>(resultETy)) {
1039 APInt l = lhsAttr.getSplatValue<APInt>();
1040 APInt r = rhsAttr.getSplatValue<APInt>();
1042 APInt result = l.sdiv(r);
1053 std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
1054 unsigned bitwidth) {
1055 APInt result = lhs.sext(64) * rhs.sext(64);
1058 auto round = APInt(64, 1) << (shift - 1);
1060 result.ashrInPlace(shift);
1062 if (!(result.getSExtValue() >= INT32_MIN &&
1063 result.getSExtValue() <= INT32_MAX)) {
1065 return std::nullopt;
1069 return result.trunc(bitwidth);
1073 RankedTensorType ty, int32_t shift) {
1075 if (llvm::isa<IntegerType>(ty.getElementType())) {
1083 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1084 const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
1090 if (llvm::isa<FloatType>(ty.getElementType())) {
1093 APFloat result = l * r;
1103 auto lhs = getInput1();
1104 auto rhs = getInput2();
1105 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.
getType());
1106 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.
getType());
1107 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1108 if (!lhsTy || !rhsTy || !resultTy)
1111 auto resultETy = resultTy.getElementType();
1113 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1115 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1120 if (resultETy.isInteger(32)) {
1121 ElementsAttr shift_elem;
1122 if (getShift().getImpl()) {
1126 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1130 if (rhsTy == resultTy) {
1131 if (
isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
1133 return lhsAttr.resizeSplat(resultTy);
1137 if (lhsTy == resultTy) {
1138 if (
isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
1144 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1148 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1149 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1150 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1151 if (!lhsTy || !rhsTy || !resultTy)
1155 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1156 !rhsTy.getElementType().isIntOrIndexOrFloat())
1159 auto resultETy = resultTy.getElementType();
1161 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1163 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1165 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1168 if (!lhsAttr || !rhsAttr)
1171 return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
1176 template <
typename Cmp>
1177 struct ComparisonFold {
1178 ComparisonFold() =
default;
1179 APInt operator()(
const APInt &l,
const APInt &r) {
1180 return APInt(1, Cmp()(l, r));
1183 APInt operator()(
const APFloat &l,
const APFloat &r) {
1184 return APInt(1, Cmp()(l, r));
1188 struct APIntFoldGreater {
1189 APIntFoldGreater() =
default;
1190 APInt operator()(
const APInt &l,
const APInt &r) {
1191 return APInt(1, l.sgt(r));
1195 struct APIntFoldGreaterEqual {
1196 APIntFoldGreaterEqual() =
default;
1197 APInt operator()(
const APInt &l,
const APInt &r) {
1198 return APInt(1, l.sge(r));
1204 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1206 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1208 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1210 if (!lhsAttr || !rhsAttr)
1213 return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
1214 lhsAttr, rhsAttr, resultTy);
1217 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1218 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1220 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1222 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1224 if (!lhsAttr || !rhsAttr)
1228 ComparisonFold<std::greater_equal<APFloat>>>(
1229 lhsAttr, rhsAttr, resultTy);
1233 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1235 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1237 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1238 Value lhs = getInput1();
1239 Value rhs = getInput2();
1240 auto lhsTy = llvm::cast<ShapedType>(lhs.
getType());
1244 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1245 resultTy.hasStaticShape() && lhs == rhs) {
1249 if (!lhsAttr || !rhsAttr)
1252 return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
1253 ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
1261 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1265 auto inTy = llvm::cast<ShapedType>(getInput().
getType());
1266 auto outTy = llvm::cast<ShapedType>(
getType());
1267 auto inETy = inTy.getElementType();
1268 auto outETy = outTy.getElementType();
1270 if (operand.isSplat()) {
1271 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1273 auto splatVal = operand.getSplatValue<APFloat>();
1274 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1275 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1280 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1281 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1282 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1283 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1284 llvm::RoundingMode::NearestTiesToEven);
1288 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1289 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1290 auto intVal = APSInt(
1291 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1292 auto floatVal = operand.getSplatValue<APFloat>();
1294 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1299 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1300 const auto inIntType = llvm::cast<IntegerType>(inETy);
1301 auto unsignIn = inIntType.isUnsignedInteger();
1303 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1304 auto intVal = operand.getSplatValue<APInt>();
1305 auto bitwidth = outETy.getIntOrFloatBitWidth();
1308 if (outETy.isInteger(1)) {
1309 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1311 intVal = intVal.trunc(bitwidth);
1312 }
else if (unsignIn || inIntType.isInteger(1)) {
1313 intVal = intVal.zext(bitwidth);
1315 intVal = intVal.sext(bitwidth);
1325 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1327 OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1329 #define REDUCE_FOLDER(OP) \
1330 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1331 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1332 if (!inputTy.hasRank()) \
1334 if (inputTy != getType()) \
1336 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1337 return getInput(); \
1347 #undef REDUCE_FOLDER
1350 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1351 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1353 if (!inputTy || !outputTy)
1359 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1363 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1364 getInput1().getDefiningOp())) {
1365 getInput1Mutable().assign(reshapeOp.getInput1());
1370 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1375 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1377 if (!outputTy.hasStaticShape())
1381 if (operand.isSplat())
1386 if (!getInput1().hasOneUse())
1393 return operand.reshape(
1394 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1402 if (adaptor.getPadding() && getInput1().
getType() ==
getType()) {
1403 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1404 if (densePad && densePad.isSplat() &&
1405 densePad.getSplatValue<APInt>().isZero()) {
1417 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1419 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1421 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1422 if (!scaleAttr || !offsetAttr || !borderAttr) {
1429 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1434 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1439 if (offset[0] != 0 || offset[1] != 0) {
1444 if (border[0] != 0 || border[1] != 0) {
1448 auto input = getInput();
1449 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1450 auto resultTy = llvm::cast<RankedTensorType>(
getType());
1451 if (inputTy != resultTy)
1458 auto operand = getInput1();
1459 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1460 auto axis = getAxis();
1462 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1467 if (operandTy.hasRank() &&
1468 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1475 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1476 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1478 if (!inputTy || !outputTy)
1481 if (inputTy == outputTy && inputTy.hasStaticShape())
1484 if (!adaptor.getInput1())
1488 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1489 !outputTy.getElementType().isIntOrIndexOrFloat())
1492 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1493 if (operand.isSplat() && outputTy.hasStaticShape()) {
1497 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1498 outputTy.getNumElements() == 1) {
1504 llvm::to_vector(startElems.
getValues<uint64_t>());
1505 auto value = operand.getValues<
Attribute>()[indices];
1512 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1513 if (getOnTrue() == getOnFalse())
1517 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1521 if (!predicate.isSplat())
1523 return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1529 if (
auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1530 adaptor.getMultiples())) {
1531 if (multiples.isSplat() &&
1532 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1534 if (
auto int_array_attr =
1535 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1536 if (llvm::all_of(int_array_attr.getValues<APInt>(),
1537 [](APInt v) { return v.getSExtValue() == 1; }))
1546 auto resultTy = llvm::cast<ShapedType>(
getType());
1550 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1551 if (input.isSplat() && resultTy.hasStaticShape() &&
1552 input.getType().getElementType() == resultTy.getElementType())
1553 return input.reshape(resultTy);
1559 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1566 auto input = getInput1();
1568 if (
auto op = input.getDefiningOp<tosa::ExpOp>()) {
1569 return op.getInput1();
1576 auto input = getInput1();
1578 if (
auto op = input.getDefiningOp<tosa::LogOp>()) {
1579 return op.getInput1();
1585 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1588 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1594 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1595 failed(maybeIZp) || *maybeIZp != 0) {
1599 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1600 failed(maybeOZp) || *maybeOZp != 0) {
1604 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1605 failed(maybeIZp) || *maybeIZp != 0) {
1609 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1610 failed(maybeOZp) || *maybeOZp != 0) {
1615 return definingOp.getInput1();
1619 auto input = getInput1();
1621 if (
auto op = input.getDefiningOp<tosa::AbsOp>()) {
1634 concatOperands.reserve(2 * getNumOperands());
1637 bool foundFoldableConcat =
false;
1638 for (
Value operand : getOperands()) {
1639 concatOperands.emplace_back(operand);
1641 auto producer = operand.getDefiningOp<ConcatOp>();
1646 if (getAxis() != producer.getAxis())
1650 foundFoldableConcat =
true;
1651 concatOperands.pop_back();
1652 llvm::append_range(concatOperands, producer->getOperands());
1655 if (!foundFoldableConcat)
1658 getOperation()->setOperands(concatOperands);
1662 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1663 auto input = adaptor.getInput1();
1665 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1667 if (!inputAttr || !inputAttr.isSplat())
1670 auto shapeType = llvm::cast<ShapedType>(
getType());
1671 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1672 auto floatVal = inputAttr.getSplatValue<APFloat>();
1674 ReciprocalOp::calcOneElement(floatVal));
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
#define REDUCE_FOLDER(OP)
bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
auto getValues() const
Return the held element values as a range of the given type.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt round(const Fraction &f)
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...