28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
51 (padConstAttr.
size() != 1)) {
56 if (
auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
57 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
58 return padConstVal == 0.0f;
62 if (
auto padConstIntAttr =
63 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
71 int64_t zpVal = (*zpAttr.
begin()).getSExtValue();
72 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
73 return zpVal == padConstVal;
81 template <
typename OpTy>
82 struct PoolPadFoldAdaptor;
85 struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
86 using OpTy = tosa::AvgPool2dOp;
89 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
90 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
94 static bool checkPadConstCompliance(OpTy op,
Value padConst) {
100 op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
107 struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
108 using OpTy = tosa::MaxPool2dOp;
111 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
112 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
116 static bool checkPadConstCompliance(OpTy,
Value padConst) {
120 padConstAttr.
size() != 1) {
125 if (
auto padConstFpAttr =
126 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
127 const APFloat padConstVal = *padConstFpAttr.begin();
128 const APFloat lowestVal =
129 APFloat::getLargest(padConstVal.getSemantics(),
true);
130 return padConstVal == lowestVal;
131 }
else if (
auto padConstIntAttr =
132 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
133 const APInt padConstVal = *padConstIntAttr.begin();
134 const unsigned int bitWidth = padConstVal.getBitWidth();
135 const APInt lowestVal =
136 padConstIntAttr.getElementType().isUnsignedInteger()
138 : APInt::getSignedMinValue(bitWidth);
139 return padConstVal == lowestVal;
148 op, op.getType(), padInput, op.getKernel(), op.getStride(),
153 template <
typename OpTy>
154 struct ConvPadFoldAdaptor {
158 static bool checkPadConstCompliance(OpTy op,
Value padConst) {
164 op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
165 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
166 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
174 template <
typename OpTy,
typename AdaptorTy>
178 LogicalResult matchAndRewrite(OpTy tensorOp,
181 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
184 "Producer must be a tosa::PadOp.");
187 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
188 if (tensorOpPad.size() != 4)
190 tensorOp,
"Tensor operation padding shall have 4 elements.");
197 "The `padding` input specified on the tosa::PadOp must be constant.");
201 if (padOpPadding.size() != 8)
203 "Pad padding should have 8 elements.");
204 int64_t padNBefore = (*(padOpPadding.
begin() + 0)).getLimitedValue();
205 int64_t padNAfter = (*(padOpPadding.
begin() + 1)).getLimitedValue();
206 int64_t padHBefore = (*(padOpPadding.
begin() + 2)).getLimitedValue();
207 int64_t padHAfter = (*(padOpPadding.
begin() + 3)).getLimitedValue();
208 int64_t padWBefore = (*(padOpPadding.
begin() + 4)).getLimitedValue();
209 int64_t padWAfter = (*(padOpPadding.
begin() + 5)).getLimitedValue();
210 int64_t padCBefore = (*(padOpPadding.
begin() + 6)).getLimitedValue();
211 int64_t padCAfter = (*(padOpPadding.
begin() + 7)).getLimitedValue();
213 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
215 tensorOp,
"Folding padding in N or C dimensions is not supported.");
220 foldedPad[0] = padHBefore + tensorOpPad[0];
221 foldedPad[1] = padHAfter + tensorOpPad[1];
222 foldedPad[2] = padWBefore + tensorOpPad[2];
223 foldedPad[3] = padWAfter + tensorOpPad[3];
226 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
228 tensorOp,
"Padding size not aligned with kernel restrictions.");
232 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
235 "Padding constant is not aligned with operator zero-point.");
239 if (llvm::any_of(foldedPad, [](int64_t padVal) {
return padVal > 8192; })) {
241 tensorOp,
"Padding size more than the 8K level limit.");
245 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
255 results.
add<FoldPadToTensorOp<tosa::AvgPool2dOp,
256 PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
263 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
269 results.
add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
270 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
279 Value input = op.getInput();
280 Value output = op.getOutput();
281 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
282 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
284 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
290 if (outputShape[1] != 1 || outputShape[2] != 1) {
295 if (inputShape[1] != 1 || inputShape[2] != 1) {
307 FoldPadToTensorOp<tosa::MaxPool2dOp,
308 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
321 if (op.getInput1().size() != 1)
323 if (op.getInput1().front().getType() != op.getType()) {
326 op.getInput1().front())
331 rewriter.
replaceOp(op, op.getInput1().front());
341 LogicalResult SelectOp::canonicalize(SelectOp op,
PatternRewriter &rewriter) {
342 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
346 op.getOperation()->setOperands(
347 {notOp.getInput1(), op.getInput3(), op.getInput2()});
359 auto innerTranspose =
360 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
363 "input must be transpose operation");
367 innerTranspose.getPerms();
369 if (transposePerms.size() != innerTransposePerms.size())
372 "transpose and inner transpose perms sizes must be equal");
373 if (transposePerms.empty())
375 transposeOp,
"transpose perms sizes must be positive");
379 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
380 perms[i] = innerTransposePerms[transposePerms[i]];
383 transposeOp, transposeOp.getResult().getType(),
396 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
398 op,
"Src is from transpose, can compose transposes");
400 Value result = op.getResult();
402 if (isa_and_nonnull<tosa::TransposeOp>(subop))
404 op,
"Dest is used by transpose, can compose transposes");
407 auto input = op.getInput1();
408 auto inputTy = llvm::cast<ShapedType>(input.getType());
409 if (!inputTy.hasRank())
412 int64_t numDynDims = 0;
413 for (
int i = 0; i < inputTy.getRank(); ++i)
414 if (inputTy.isDynamicDim(i))
423 nonZeroPerms.reserve(permValues.size());
424 for (
auto idx : permValues) {
425 auto sz = inputTy.getDimSize(idx);
427 nonZeroPerms.push_back(idx);
430 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
431 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
433 "Transpose changes memory layout.");
436 newShape.reserve(inputTy.getRank());
437 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
438 newShape.push_back(inputTy.getDimSize(permValues[i]));
441 op, op.getType(), op.getInput1(),
457 Value input = op.getInput();
458 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
459 auto inputElementType = inputType.getElementType();
461 if (!inputType.hasStaticShape()) {
465 if (isa<FloatType>(inputElementType)) {
468 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
470 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
471 bool isMin = minClamp.isNegInfinity();
472 bool isMax = maxClamp.isInfinity();
474 if (isMin && isMax) {
481 if (inputElementType.isUnsignedInteger()) {
483 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt();
485 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt();
488 APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
491 APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
494 if (minClamp <= intMin && maxClamp >= intMax) {
501 if (llvm::isa<IntegerType>(inputElementType)) {
503 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
505 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
508 APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
511 APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
514 if (minClamp <= intMin && maxClamp >= intMax) {
546 template <
typename T>
548 ClampRange(
const T &start,
const T &end) : start(start), end(end) {}
554 return start < otherRange.
end && otherRange.
start < end;
560 Value input = op.getInput();
563 auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.
getDefiningOp());
568 const auto opNanMode = op.getNanMode();
569 const auto clampNanMode = clampOp.getNanMode();
570 if (opNanMode ==
"IGNORE" && clampNanMode ==
"PROPAGATE")
573 auto maxValAttr = op.getMaxValAttr();
574 auto minValAttr = op.getMinValAttr();
575 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
576 auto clampOpMinValAttr = clampOp.getMinValAttr();
578 auto inputEType = llvm::cast<ShapedType>(input.
getType()).getElementType();
580 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
581 inputEType = quantType.getStorageType();
585 if (mlir::isa<FloatType>(inputEType)) {
586 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
587 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
588 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
589 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
592 const auto opMinFloat = floatMinValAttr.getValue();
593 const auto opMaxFloat = floatMaxValAttr.getValue();
594 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
595 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
599 if (!opRangeFloatRange.
intersects(clampRangeFloatRange))
603 auto newMinVal =
std::max(opMinFloat, clampOpMinFloat);
604 auto newMaxVal =
std::min(opMaxFloat, clampOpMaxFloat);
605 newMinValAttr = rewriter.
getFloatAttr(inputEType, newMinVal);
606 newMaxValAttr = rewriter.
getFloatAttr(inputEType, newMaxVal);
608 assert(mlir::isa<IntegerType>(inputEType));
609 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
610 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
611 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
612 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
614 if (inputEType.isUnsignedInteger()) {
616 const auto opMinInt = intMinValAttr.getUInt();
617 const auto opMaxInt = intMaxValAttr.getUInt();
618 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
619 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
623 if (!opRangeIntRange.
intersects(clampRangeIntRange))
627 auto newMinVal =
std::max(opMinInt, clampOpMinInt);
628 auto newMaxVal =
std::min(opMaxInt, clampOpMaxInt);
633 const auto opMinInt = intMinValAttr.getInt();
634 const auto opMaxInt = intMaxValAttr.getInt();
635 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
636 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
640 if (!opRangeIntRange.
intersects(clampRangeIntRange))
644 auto newMinVal =
std::max(opMinInt, clampOpMinInt);
645 auto newMaxVal =
std::min(opMaxInt, clampOpMaxInt);
652 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
653 rewriter.
getStringAttr((opNanMode != clampNanMode) ?
"IGNORE"
670 Value sliceInput = sliceOp.getInput1();
674 sliceOp,
"slice input must be concat operation");
677 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
678 if (!concatType || !concatType.hasStaticShape())
680 sliceOp,
"slice input must be a static ranked tensor");
681 int32_t axis = concatOp.getAxis();
688 sliceOp,
"start of slice must be a static ranked shape");
692 sliceOp,
"size of slice must be a static ranked shape");
695 llvm::to_vector(startElems.
getValues<int64_t>());
697 llvm::to_vector(sizeElems.
getValues<int64_t>());
702 std::optional<Value> replaceWithSlice;
703 for (
auto input : inputs) {
704 auto inputType = dyn_cast<RankedTensorType>(input.getType());
705 if (!inputType || !inputType.hasStaticShape())
707 sliceOp,
"concat input must be a static ranked tensor");
709 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
710 inputType.getDimSize(axis)) {
717 .
create<tosa::SliceOp>(sliceOp.getLoc(), sliceOp.getType(),
718 input, start_op, size_op)
722 sliceStarts[axis] -= inputType.getDimSize(axis);
725 if (!replaceWithSlice)
727 sliceOp,
"corresponding concat input not found for slice");
729 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
739 Value sliceInput = sliceOp.getInput1();
745 "slice input must be a pad operation");
748 if (!padOp->hasOneUse())
750 "pad shall have a single consumer");
753 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
754 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
755 if (!inputTy || !padTy || !inputTy.hasRank())
757 "slice input must be a ranked tensor");
764 "`padding` input specified on the tosa::PadOp must be constant.");
767 llvm::to_vector(paddingElems.getValues<int64_t>());
773 sliceOp,
"start of slice must be a static ranked shape");
775 llvm::to_vector(startElems.
getValues<int64_t>());
780 sliceOp,
"size of slice must be a static ranked shape");
782 llvm::to_vector(sizeElems.
getValues<int64_t>());
785 const int64_t rank = inputTy.getRank();
786 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
787 const bool isDimDynamic = inputTy.isDynamicDim(i);
788 const bool isDimSliced =
789 (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
791 return isDimDynamic && isDimSliced;
794 sliceOp,
"axis that are sliced shall be statically known.");
801 bool updated =
false;
803 for (int64_t i = 0; i < rank; ++i) {
804 const int64_t padLo = padPaddings[i * 2];
805 const int64_t padHi = padPaddings[i * 2 + 1];
806 const int64_t sliceStart = sliceStarts[i];
807 const int64_t sliceSize = sliceSizes[i];
808 const int64_t sliceEnd = sliceStart + sliceSize;
811 if (inputTy.isDynamicDim(i)) {
812 newPadPaddings[i * 2] = padLo;
813 newPadPaddings[i * 2 + 1] = padHi;
814 newSliceStarts[i] = sliceStart;
819 const int64_t dimSize = inputTy.getShape()[i];
820 const int64_t dimTotal = padLo + dimSize + padHi;
823 if (sliceStart < 0 || sliceEnd > dimTotal)
827 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
828 newSliceStarts[i] = newSliceStart;
829 updated |= newSliceStart != sliceStart;
832 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
833 const int64_t newPadHi =
834 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
835 newPadPaddings[i * 2] = newPadLo;
836 newPadPaddings[i * 2 + 1] = newPadHi;
837 updated |= (newPadLo != padLo) || (newPadHi != padHi);
841 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
847 sliceOp,
"terminate condition; nothing to rewrite");
854 auto newPadOp = rewriter.
create<tosa::PadOp>(
855 padOp.getLoc(), newPadTy, padOp.getInput1(), newPaddingsOp,
856 padOp.getPadConst());
862 newPadOp.getResult(), newStartOp,
877 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
879 ElementsAttr sizeElems;
882 sliceOp,
"size of slice must be a static ranked shape");
886 llvm::to_vector(sizeElems.getValues<int64_t>());
888 bool replaceSliceSize{
false};
893 if (size == -1 && !resultType.isDynamicDim(index)) {
894 sliceSizes[index] = resultType.getDimSize(index);
895 replaceSliceSize =
true;
899 if (!replaceSliceSize) {
901 sliceOp,
"no dimension of size of slice is dynamic that resolves "
902 "to static output shape");
906 auto newSliceOp = rewriter.
create<tosa::SliceOp>(
907 sliceOp.getLoc(), sliceOp.getType(), sliceOp.getInput1(),
908 sliceOp.getStart(), size_op);
910 rewriter.
replaceOp(sliceOp, newSliceOp.getResult());
925 template <
typename IntFolder,
typename FloatFolder>
927 RankedTensorType returnTy) {
930 auto rETy = llvm::cast<ShapedType>(rhs.
getType()).getElementType();
934 if (llvm::isa<IntegerType>(lETy)) {
937 auto result = IntFolder()(l, r);
941 if (llvm::isa<FloatType>(lETy)) {
944 auto result = FloatFolder()(l, r);
953 if (llvm::isa<FloatType>(elemType))
955 if (llvm::isa<IntegerType>(elemType))
961 if (llvm::isa<FloatType>(elemType))
964 if (llvm::isa<IntegerType>(elemType)) {
965 const int64_t shifted = 1LL << shift;
973 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
974 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
975 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
976 if (!lhsTy || !rhsTy || !resultTy)
980 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
981 !rhsTy.getElementType().isIntOrIndexOrFloat())
984 auto resultETy = resultTy.getElementType();
986 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
988 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
990 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
992 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
995 if (!lhsAttr || !rhsAttr)
998 return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
1003 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
1004 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1005 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1006 !outputTy.hasStaticShape())
1009 if (inputTy.getDimSize(getAxis()) == 1)
1016 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1017 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1018 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1019 if (!lhsTy || !rhsTy || !resultTy)
1025 auto resultETy = resultTy.getElementType();
1027 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1029 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1030 if (lhsAttr && lhsAttr.isSplat()) {
1031 if (llvm::isa<IntegerType>(resultETy) &&
1032 lhsAttr.getSplatValue<APInt>().isZero())
1036 if (rhsAttr && rhsAttr.isSplat()) {
1037 if (llvm::isa<IntegerType>(resultETy) &&
1038 rhsAttr.getSplatValue<APInt>().isOne())
1042 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1043 llvm::isa<IntegerType>(resultETy)) {
1044 APInt l = lhsAttr.getSplatValue<APInt>();
1045 APInt r = rhsAttr.getSplatValue<APInt>();
1047 APInt result = l.sdiv(r);
1058 std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
1059 unsigned bitwidth) {
1060 APInt result = lhs.sext(64) * rhs.sext(64);
1063 auto round = APInt(64, 1) << (shift - 1);
1065 result.ashrInPlace(shift);
1067 if (!(result.getSExtValue() >= INT32_MIN &&
1068 result.getSExtValue() <= INT32_MAX)) {
1070 return std::nullopt;
1074 return result.trunc(bitwidth);
1078 RankedTensorType ty, int32_t shift) {
1080 if (llvm::isa<IntegerType>(ty.getElementType())) {
1088 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1089 const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
1095 if (llvm::isa<FloatType>(ty.getElementType())) {
1098 APFloat result = l * r;
1108 auto lhs = getInput1();
1109 auto rhs = getInput2();
1110 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.
getType());
1111 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.
getType());
1112 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1113 if (!lhsTy || !rhsTy || !resultTy)
1116 auto resultETy = resultTy.getElementType();
1118 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1120 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1125 if (resultETy.isInteger(32)) {
1126 ElementsAttr shift_elem;
1127 if (getShift().getImpl()) {
1131 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1135 if (rhsTy == resultTy) {
1137 return lhsAttr.resizeSplat(resultTy);
1141 if (lhsTy == resultTy) {
1148 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1152 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1153 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1154 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1155 if (!lhsTy || !rhsTy || !resultTy)
1159 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1160 !rhsTy.getElementType().isIntOrIndexOrFloat())
1163 auto resultETy = resultTy.getElementType();
1165 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1167 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1169 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1172 if (!lhsAttr || !rhsAttr)
1175 return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
1180 template <
typename Cmp>
1181 struct ComparisonFold {
1182 ComparisonFold() =
default;
1183 APInt operator()(
const APInt &l,
const APInt &r) {
1184 return APInt(1, Cmp()(l, r));
1187 APInt operator()(
const APFloat &l,
const APFloat &r) {
1188 return APInt(1, Cmp()(l, r));
1192 struct APIntFoldGreater {
1193 APIntFoldGreater() =
default;
1194 APInt operator()(
const APInt &l,
const APInt &r) {
1195 return APInt(1, l.sgt(r));
1199 struct APIntFoldGreaterEqual {
1200 APIntFoldGreaterEqual() =
default;
1201 APInt operator()(
const APInt &l,
const APInt &r) {
1202 return APInt(1, l.sge(r));
1208 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1210 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1212 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1214 if (!lhsAttr || !rhsAttr)
1217 return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
1218 lhsAttr, rhsAttr, resultTy);
1221 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1222 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1224 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1226 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1228 if (!lhsAttr || !rhsAttr)
1232 ComparisonFold<std::greater_equal<APFloat>>>(
1233 lhsAttr, rhsAttr, resultTy);
1237 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1239 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1241 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1242 Value lhs = getInput1();
1243 Value rhs = getInput2();
1244 auto lhsTy = llvm::cast<ShapedType>(lhs.
getType());
1248 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
1249 resultTy.hasStaticShape() && lhs == rhs) {
1253 if (!lhsAttr || !rhsAttr)
1256 return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
1257 ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
1265 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1269 auto inTy = llvm::cast<ShapedType>(getInput().
getType());
1270 auto outTy = llvm::cast<ShapedType>(
getType());
1271 auto inETy = inTy.getElementType();
1272 auto outETy = outTy.getElementType();
1274 if (operand.isSplat()) {
1275 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1277 auto splatVal = operand.getSplatValue<APFloat>();
1278 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1279 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1284 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1285 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1286 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1287 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1288 llvm::RoundingMode::NearestTiesToEven);
1292 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1293 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1294 auto intVal = APSInt(
1295 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1296 auto floatVal = operand.getSplatValue<APFloat>();
1298 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1303 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1304 auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1306 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1307 auto intVal = operand.getSplatValue<APInt>();
1308 auto bitwidth = outETy.getIntOrFloatBitWidth();
1311 intVal = intVal.trunc(bitwidth);
1312 }
else if (unsignIn) {
1313 intVal = intVal.zext(bitwidth);
1315 intVal = intVal.sext(bitwidth);
1325 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1327 OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1329 #define REDUCE_FOLDER(OP) \
1330 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1331 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1332 if (!inputTy.hasRank()) \
1334 if (inputTy != getType()) \
1336 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1337 return getInput(); \
1347 #undef REDUCE_FOLDER
1350 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1351 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1353 if (!inputTy || !outputTy)
1359 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1363 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1364 getInput1().getDefiningOp())) {
1365 getInput1Mutable().assign(reshapeOp.getInput1());
1370 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1375 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1377 if (!outputTy.hasStaticShape())
1381 if (operand.isSplat())
1386 if (!getInput1().hasOneUse())
1393 return operand.reshape(
1394 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1402 if (adaptor.getPadding() && getInput1().
getType() ==
getType()) {
1403 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1404 if (densePad && densePad.isSplat() &&
1405 densePad.getSplatValue<APInt>().isZero()) {
1417 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1419 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1421 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1422 if (!scaleAttr || !offsetAttr || !borderAttr) {
1429 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1434 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1439 if (offset[0] != 0 || offset[1] != 0) {
1444 if (border[0] != 0 || border[1] != 0) {
1448 auto input = getInput();
1449 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
1450 auto resultTy = llvm::cast<RankedTensorType>(
getType());
1451 if (inputTy != resultTy)
1458 auto operand = getInput1();
1459 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1460 auto axis = getAxis();
1462 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1467 if (operandTy.hasRank() &&
1468 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1475 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1476 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1478 if (!inputTy || !outputTy)
1481 if (inputTy == outputTy && inputTy.hasStaticShape())
1484 if (!adaptor.getInput1())
1488 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1489 !outputTy.getElementType().isIntOrIndexOrFloat())
1492 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1493 if (operand.isSplat() && outputTy.hasStaticShape()) {
1497 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1498 outputTy.getNumElements() == 1) {
1504 llvm::to_vector(startElems.
getValues<uint64_t>());
1505 auto value = operand.getValues<
Attribute>()[indices];
1512 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1513 if (getInput2() == getInput3())
1517 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
1521 if (!predicate.isSplat())
1523 return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
1529 if (
auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1530 adaptor.getMultiples())) {
1531 if (multiples.isSplat() &&
1532 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1534 if (
auto int_array_attr =
1535 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1536 if (llvm::all_of(int_array_attr.getValues<APInt>(),
1537 [](APInt v) { return v.getSExtValue() == 1; }))
1546 auto resultTy = llvm::cast<ShapedType>(
getType());
1550 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1551 if (input.isSplat() && resultTy.hasStaticShape() &&
1552 input.getType().getElementType() == resultTy.getElementType())
1553 return input.reshape(resultTy);
1559 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1566 auto input = getInput1();
1568 if (
auto op = input.getDefiningOp<tosa::ExpOp>()) {
1569 return op.getInput1();
1576 auto input = getInput1();
1578 if (
auto op = input.getDefiningOp<tosa::LogOp>()) {
1579 return op.getInput1();
1585 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1588 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
1594 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
1595 failed(maybeIZp) || *maybeIZp != 0) {
1599 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1600 failed(maybeOZp) || *maybeOZp != 0) {
1604 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
1605 failed(maybeIZp) || *maybeIZp != 0) {
1609 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
1610 failed(maybeOZp) || *maybeOZp != 0) {
1615 return definingOp.getInput1();
1619 auto input = getInput1();
1621 if (
auto op = input.getDefiningOp<tosa::AbsOp>()) {
1634 concatOperands.reserve(2 * getNumOperands());
1637 bool foundFoldableConcat =
false;
1638 for (
Value operand : getOperands()) {
1639 concatOperands.emplace_back(operand);
1641 auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1646 if (getAxis() != producer.getAxis())
1650 foundFoldableConcat =
true;
1651 concatOperands.pop_back();
1652 llvm::append_range(concatOperands, producer->getOperands());
1655 if (!foundFoldableConcat)
1658 getOperation()->setOperands(concatOperands);
1662 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1663 auto input = adaptor.getInput1();
1665 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1667 if (!inputAttr || !inputAttr.isSplat())
1670 auto shapeType = llvm::cast<ShapedType>(
getType());
1671 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1672 auto floatVal = inputAttr.getSplatValue<APFloat>();
1674 ReciprocalOp::calcOneElement(floatVal));
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
#define REDUCE_FOLDER(OP)
bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
StringAttr getStringAttr(const Twine &bytes)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
auto getValues() const
Return the held element values as a range of the given type.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt round(const Fraction &f)
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...