24 #include "llvm/ADT/APFloat.h" 
   25 #include "llvm/ADT/APInt.h" 
   45       (padConstAttr.
size() != 1)) {
 
   50   if (
auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
 
   51     float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
 
   52     return padConstVal == 0.0f;
 
   56   if (
auto padConstIntAttr =
 
   57           mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
 
   65     int64_t zpVal = (*zpAttr.
begin()).getSExtValue();
 
   66     int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
 
   67     return zpVal == padConstVal;
 
   75 template <
typename OpTy>
 
   76 struct PoolPadFoldAdaptor;
 
   79 struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
 
   80   using OpTy = tosa::MaxPool2dOp;
 
   83     if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
 
   84         newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
 
   88   static bool checkPadConstCompliance(OpTy, 
Value padConst) {
 
   92         padConstAttr.
size() != 1) {
 
   97     if (
auto padConstFpAttr =
 
   98             mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
 
   99       const APFloat padConstVal = *padConstFpAttr.begin();
 
  100       const APFloat lowestVal =
 
  101           APFloat::getLargest(padConstVal.getSemantics(), 
true);
 
  102       return padConstVal == lowestVal;
 
  104     if (
auto padConstIntAttr =
 
  105             mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
 
  106       const APInt padConstVal = *padConstIntAttr.begin();
 
  107       const unsigned int bitWidth = padConstVal.getBitWidth();
 
  108       const APInt lowestVal =
 
  109           padConstIntAttr.getElementType().isUnsignedInteger()
 
  111               : APInt::getSignedMinValue(bitWidth);
 
  112       return padConstVal == lowestVal;
 
  121         op, op.getType(), padInput, op.getKernel(), op.getStride(),
 
  126 template <
typename OpTy>
 
  127 struct ConvPadFoldAdaptor {
 
  131   static bool checkPadConstCompliance(OpTy op, 
Value padConst) {
 
  137         op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(),
 
  138         op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
 
  139         op.getDilationAttr(), op.getAccType(), op.getLocalBound());
 
  147 template <
typename OpTy, 
typename AdaptorTy>
 
  151   LogicalResult matchAndRewrite(OpTy tensorOp,
 
  154     auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
 
  157                                          "Producer must be a tosa::PadOp.");
 
  160     const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
 
  161     if (tensorOpPad.size() != 4) 
 
  163           tensorOp, 
"Tensor operation padding shall have 4 elements.");
 
  170           "The `padding` input specified on the tosa::PadOp must be constant.");
 
  174     if (padOpPadding.size() != 8)
 
  176                                          "Pad padding should have 8 elements.");
 
  177     int64_t padNBefore = (*(padOpPadding.
begin() + 0)).getLimitedValue();
 
  178     int64_t padNAfter = (*(padOpPadding.
begin() + 1)).getLimitedValue();
 
  179     int64_t padHBefore = (*(padOpPadding.
begin() + 2)).getLimitedValue();
 
  180     int64_t padHAfter = (*(padOpPadding.
begin() + 3)).getLimitedValue();
 
  181     int64_t padWBefore = (*(padOpPadding.
begin() + 4)).getLimitedValue();
 
  182     int64_t padWAfter = (*(padOpPadding.
begin() + 5)).getLimitedValue();
 
  183     int64_t padCBefore = (*(padOpPadding.
begin() + 6)).getLimitedValue();
 
  184     int64_t padCAfter = (*(padOpPadding.
begin() + 7)).getLimitedValue();
 
  186     if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
 
  188           tensorOp, 
"Folding padding in N or C dimensions is not supported.");
 
  193     foldedPad[0] = padHBefore + tensorOpPad[0];
 
  194     foldedPad[1] = padHAfter + tensorOpPad[1];
 
  195     foldedPad[2] = padWBefore + tensorOpPad[2];
 
  196     foldedPad[3] = padWAfter + tensorOpPad[3];
 
  199     if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
 
  201           tensorOp, 
"Padding size not aligned with kernel restrictions.");
 
  205     if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
 
  208           "Padding constant is not aligned with operator zero-point.");
 
  212     if (llvm::any_of(foldedPad, [](int64_t padVal) { 
return padVal > 8192; })) {
 
  214           tensorOp, 
"Padding size more than the 8K level limit.");
 
  218     AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
 
  229       FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
 
  235   results.
add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
 
  236                                 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
 
  245     Value input = op.getInput();
 
  246     Value output = op.getOutput();
 
  247     ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
 
  248     ShapedType outputType = llvm::cast<ShapedType>(output.getType());
 
  250     if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
 
  256     if (outputShape[1] != 1 || outputShape[2] != 1) {
 
  261     if (inputShape[1] != 1 || inputShape[2] != 1) {
 
  273               FoldPadToTensorOp<tosa::MaxPool2dOp,
 
  274                                 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
 
  287     if (op.getInput1().size() != 1)
 
  289     if (op.getInput1().front().getType() != op.getType()) {
 
  292                                               op.getInput1().front())
 
  297     rewriter.
replaceOp(op, op.getInput1().front());
 
  307 LogicalResult SelectOp::canonicalize(SelectOp op, 
PatternRewriter &rewriter) {
 
  308   auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
 
  312     op.getOperation()->setOperands(
 
  313         {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
 
  325     auto innerTranspose =
 
  326         transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
 
  329                                          "input must be transpose operation");
 
  333         innerTranspose.getPerms();
 
  335     if (transposePerms.size() != innerTransposePerms.size())
 
  338           "transpose and inner transpose perms sizes must be equal");
 
  339     if (transposePerms.empty())
 
  341           transposeOp, 
"transpose perms sizes must be positive");
 
  345     for (
int i = 0, s = transposePerms.size(); i < s; ++i)
 
  346       perms[i] = innerTransposePerms[transposePerms[i]];
 
  349         transposeOp, transposeOp.getResult().getType(),
 
  362     if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
 
  364           op, 
"Src is from transpose, can compose transposes");
 
  366     Value result = op.getResult();
 
  368       if (isa_and_nonnull<tosa::TransposeOp>(subop))
 
  370             op, 
"Dest is used by transpose, can compose transposes");
 
  373     auto input = op.getInput1();
 
  374     auto inputTy = llvm::cast<ShapedType>(input.getType());
 
  375     if (!inputTy.hasRank())
 
  378     int64_t numDynDims = 0;
 
  379     for (
int i = 0; i < inputTy.getRank(); ++i)
 
  380       if (inputTy.isDynamicDim(i))
 
  389     nonZeroPerms.reserve(permValues.size());
 
  390     for (
auto idx : permValues) {
 
  391       auto sz = inputTy.getDimSize(idx);
 
  393         nonZeroPerms.push_back(idx);
 
  396     for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
 
  397       if (nonZeroPerms[i - 1] > nonZeroPerms[i])
 
  399                                            "Transpose changes memory layout.");
 
  402     newShape.reserve(inputTy.getRank());
 
  403     for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
 
  404       newShape.push_back(inputTy.getDimSize(permValues[i]));
 
  407         op, op.getType(), op.getInput1(),
 
  423     Value input = op.getInput();
 
  424     auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
 
  425     auto inputElementType = inputType.getElementType();
 
  427     if (isa<FloatType>(inputElementType)) {
 
  429       const auto minClamp =
 
  430           llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
 
  431       const auto maxClamp =
 
  432           llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
 
  433       const bool isMin = minClamp.isNegInfinity();
 
  434       const bool isMax = maxClamp.isInfinity();
 
  436       if (isMin && isMax) {
 
  444     const bool isBoolean = inputElementType.isInteger(1);
 
  445     if (inputElementType.isUnsignedInteger() || isBoolean) {
 
  446       const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
 
  449       const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
 
  453       const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
 
  454       const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
 
  455       const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
 
  457       if (minClamp <= intMin && maxClamp >= intMax) {
 
  464     if (llvm::isa<IntegerType>(inputElementType)) {
 
  465       const int64_t minClamp =
 
  466           llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
 
  467       const int64_t maxClamp =
 
  468           llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
 
  470       const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
 
  471       const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
 
  472       const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
 
  474       if (minClamp <= intMin && maxClamp >= intMax) {
 
  506   template <
typename T>
 
  508     ClampRange(
const T &start, 
const T &end) : start(start), end(end) {}
 
  514       return start < otherRange.
end && otherRange.
start < end;
 
  520     Value input = op.getInput();
 
  528     const auto opNanMode = op.getNanMode();
 
  529     const auto clampNanMode = clampOp.getNanMode();
 
  530     if (opNanMode == NanPropagationMode::IGNORE &&
 
  531         clampNanMode == NanPropagationMode::PROPAGATE)
 
  534     auto maxValAttr = op.getMaxValAttr();
 
  535     auto minValAttr = op.getMinValAttr();
 
  536     auto clampOpMaxValAttr = clampOp.getMaxValAttr();
 
  537     auto clampOpMinValAttr = clampOp.getMinValAttr();
 
  539     auto inputEType = llvm::cast<ShapedType>(input.
getType()).getElementType();
 
  541             llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
 
  542       inputEType = quantType.getStorageType();
 
  546     if (mlir::isa<FloatType>(inputEType)) {
 
  547       auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
 
  548       auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
 
  549       auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
 
  550       auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
 
  553       const auto opMinFloat = floatMinValAttr.getValue();
 
  554       const auto opMaxFloat = floatMaxValAttr.getValue();
 
  555       const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
 
  556       const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
 
  560       if (!opRangeFloatRange.
intersects(clampRangeFloatRange))
 
  564       auto newMinVal = 
std::max(opMinFloat, clampOpMinFloat);
 
  565       auto newMaxVal = 
std::min(opMaxFloat, clampOpMaxFloat);
 
  566       newMinValAttr = rewriter.
getFloatAttr(inputEType, newMinVal);
 
  567       newMaxValAttr = rewriter.
getFloatAttr(inputEType, newMaxVal);
 
  569       assert(mlir::isa<IntegerType>(inputEType));
 
  570       auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
 
  571       auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
 
  572       auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
 
  573       auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
 
  575       if (inputEType.isUnsignedInteger()) {
 
  577         const auto opMinInt = intMinValAttr.getUInt();
 
  578         const auto opMaxInt = intMaxValAttr.getUInt();
 
  579         const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
 
  580         const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
 
  584         if (!opRangeIntRange.
intersects(clampRangeIntRange))
 
  588         auto newMinVal = 
std::max(opMinInt, clampOpMinInt);
 
  589         auto newMaxVal = 
std::min(opMaxInt, clampOpMaxInt);
 
  594         const auto opMinInt = intMinValAttr.getInt();
 
  595         const auto opMaxInt = intMaxValAttr.getInt();
 
  596         const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
 
  597         const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
 
  601         if (!opRangeIntRange.
intersects(clampRangeIntRange))
 
  605         auto newMinVal = 
std::max(opMinInt, clampOpMinInt);
 
  606         auto newMaxVal = 
std::min(opMaxInt, clampOpMaxInt);
 
  612     auto newMode = (opNanMode != clampNanMode)
 
  613                        ? tosa::NanPropagationMode::IGNORE
 
  620         op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
 
  637     Value sliceInput = sliceOp.getInput1();
 
  641           sliceOp, 
"slice input must be concat operation");
 
  644     auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
 
  645     if (!concatType || !concatType.hasStaticShape())
 
  647           sliceOp, 
"slice input must be a static ranked tensor");
 
  648     int32_t axis = concatOp.getAxis();
 
  655           sliceOp, 
"start of slice must be a static ranked shape");
 
  659           sliceOp, 
"size of slice must be a static ranked shape");
 
  662         llvm::to_vector(startElems.
getValues<int64_t>());
 
  664         llvm::to_vector(sizeElems.
getValues<int64_t>());
 
  669     std::optional<Value> replaceWithSlice;
 
  670     for (
auto input : inputs) {
 
  671       auto inputType = dyn_cast<RankedTensorType>(input.getType());
 
  672       if (!inputType || !inputType.hasStaticShape())
 
  674             sliceOp, 
"concat input must be a static ranked tensor");
 
  676       if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
 
  677                                         inputType.getDimSize(axis)) {
 
  683             tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
 
  684                                   input, start_op, size_op)
 
  688       sliceStarts[axis] -= inputType.getDimSize(axis);
 
  691     if (!replaceWithSlice)
 
  693           sliceOp, 
"corresponding concat input not found for slice");
 
  695     rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
 
  705     Value sliceInput = sliceOp.getInput1();
 
  711                                          "slice input must be a pad operation");
 
  714     if (!padOp->hasOneUse())
 
  716                                          "pad shall have a single consumer");
 
  719     auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
 
  720     auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
 
  721     if (!inputTy || !padTy || !inputTy.hasRank())
 
  723                                          "slice input must be a ranked tensor");
 
  730           "`padding` input specified on the tosa::PadOp must be constant.");
 
  733         llvm::to_vector(paddingElems.getValues<int64_t>());
 
  739           sliceOp, 
"start of slice must be a static ranked shape");
 
  741         llvm::to_vector(startElems.
getValues<int64_t>());
 
  746           sliceOp, 
"size of slice must be a static ranked shape");
 
  748         llvm::to_vector(sizeElems.
getValues<int64_t>());
 
  751     const int64_t rank = inputTy.getRank();
 
  752     if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
 
  753           const bool isDimDynamic = inputTy.isDynamicDim(i);
 
  754           const bool isDimSliced =
 
  755               (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
 
  757           return isDimDynamic && isDimSliced;
 
  760           sliceOp, 
"axis that are sliced shall be statically known.");
 
  767     bool updated = 
false;
 
  769     for (int64_t i = 0; i < rank; ++i) {
 
  770       const int64_t padLo = padPaddings[i * 2];
 
  771       const int64_t padHi = padPaddings[i * 2 + 1];
 
  772       const int64_t sliceStart = sliceStarts[i];
 
  773       const int64_t sliceSize = sliceSizes[i];
 
  774       const int64_t sliceEnd = sliceStart + sliceSize;
 
  777       if (inputTy.isDynamicDim(i)) {
 
  778         newPadPaddings[i * 2] = padLo;
 
  779         newPadPaddings[i * 2 + 1] = padHi;
 
  780         newSliceStarts[i] = sliceStart;
 
  785       const int64_t dimSize = inputTy.getShape()[i];
 
  786       const int64_t dimTotal = padLo + dimSize + padHi;
 
  789       if (sliceStart < 0 || sliceEnd > dimTotal)
 
  793       const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
 
  794       newSliceStarts[i] = newSliceStart;
 
  795       updated |= newSliceStart != sliceStart;
 
  798       const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
 
  799       const int64_t newPadHi =
 
  800           std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
 
  801       newPadPaddings[i * 2] = newPadLo;
 
  802       newPadPaddings[i * 2 + 1] = newPadHi;
 
  803       updated |= (newPadLo != padLo) || (newPadHi != padHi);
 
  807           newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
 
  813           sliceOp, 
"terminate condition; nothing to rewrite");
 
  820     auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
 
  821                                         padOp.getInput1(), newPaddingsOp,
 
  822                                         padOp.getPadConst());
 
  828                                                newPadOp.getResult(), newStartOp,
 
  843     ShapedType resultType = cast<ShapedType>(sliceOp.getType());
 
  845     ElementsAttr sizeElems;
 
  848           sliceOp, 
"size of slice must be a static ranked shape");
 
  852         llvm::to_vector(sizeElems.getValues<int64_t>());
 
  854     bool replaceSliceSize{
false};
 
  859       if (size == -1 && !resultType.isDynamicDim(index)) {
 
  860         sliceSizes[index] = resultType.getDimSize(index);
 
  861         replaceSliceSize = 
true;
 
  865     if (!replaceSliceSize) {
 
  867           sliceOp, 
"no dimension of size of slice is dynamic that resolves " 
  868                    "to static output shape");
 
  873         tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
 
  874                               sliceOp.getInput1(), sliceOp.getStart(), size_op);
 
  876     rewriter.
replaceOp(sliceOp, newSliceOp.getResult());
 
  891 template <
typename IntFolder, 
typename FloatFolder>
 
  894                                       RankedTensorType returnTy) {
 
  897     auto rETy = llvm::cast<ShapedType>(rhs.
getType()).getElementType();
 
  901     if (llvm::isa<IntegerType>(lETy)) {
 
  904       auto result = IntFolder()(l, r);
 
  908     if (llvm::isa<FloatType>(lETy)) {
 
  911       auto result = FloatFolder()(l, r);
 
  920   if (llvm::isa<FloatType>(elemType))
 
  922   if (llvm::isa<IntegerType>(elemType))
 
  928   if (llvm::isa<FloatType>(elemType))
 
  931   if (llvm::isa<IntegerType>(elemType)) {
 
  932     const int64_t shifted = 1LL << shift;
 
  940   auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
 
  941   auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
 
  942   auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
 
  943   if (!lhsTy || !rhsTy || !resultTy)
 
  947   if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
 
  948       !rhsTy.getElementType().isIntOrIndexOrFloat())
 
  951   auto resultETy = resultTy.getElementType();
 
  953       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
 
  955       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
  957   if (lhsTy == resultTy && 
isSplatZero(resultETy, rhsAttr))
 
  959   if (rhsTy == resultTy && 
isSplatZero(resultETy, lhsAttr))
 
  962   if (!lhsAttr || !rhsAttr)
 
  965   return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
 
  970   auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
 
  971   auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
 
  972   if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
 
  973       !outputTy.hasStaticShape())
 
  977   if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.
isInteger()) {
 
  978     const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
 
  987   auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
 
  988   auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
 
  989   auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
 
  990   if (!lhsTy || !rhsTy || !resultTy)
 
  996   auto resultETy = resultTy.getElementType();
 
  998       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
 
 1000       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
 1001   if (lhsAttr && lhsAttr.isSplat()) {
 
 1002     if (llvm::isa<IntegerType>(resultETy) &&
 
 1003         lhsAttr.getSplatValue<APInt>().isZero())
 
 1007   if (rhsAttr && rhsAttr.isSplat()) {
 
 1008     if (llvm::isa<IntegerType>(resultETy) &&
 
 1009         rhsAttr.getSplatValue<APInt>().isOne())
 
 1013   if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
 
 1014       llvm::isa<IntegerType>(resultETy)) {
 
 1015     APInt l = lhsAttr.getSplatValue<APInt>();
 
 1016     APInt r = rhsAttr.getSplatValue<APInt>();
 
 1018       APInt result = l.sdiv(r);
 
 1029 std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
 
 1030                             unsigned bitwidth) {
 
 1031   APInt result = lhs.sext(64) * rhs.sext(64);
 
 1034     auto round = APInt(64, 1) << (shift - 1);
 
 1036     result.ashrInPlace(shift);
 
 1038     if (!(result.getSExtValue() >= INT32_MIN &&
 
 1039           result.getSExtValue() <= INT32_MAX)) {
 
 1041       return std::nullopt;
 
 1045   return result.trunc(bitwidth);
 
 1049                                   RankedTensorType ty, int32_t shift) {
 
 1051     if (llvm::isa<IntegerType>(ty.getElementType())) {
 
 1059       auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
 
 1060       const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
 
 1066     if (llvm::isa<FloatType>(ty.getElementType())) {
 
 1069       APFloat result = l * r;
 
 1079   auto lhs = getInput1();
 
 1080   auto rhs = getInput2();
 
 1081   auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.
getType());
 
 1082   auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.
getType());
 
 1083   auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
 
 1084   if (!lhsTy || !rhsTy || !resultTy)
 
 1087   auto resultETy = resultTy.getElementType();
 
 1089       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
 
 1091       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
 1096   if (resultETy.isInteger(32)) {
 
 1097     ElementsAttr shift_elem;
 
 1098     if (getShift().getImpl()) {
 
 1102       shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
 
 1106   if (rhsTy == resultTy) {
 
 1107     if (
isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
 
 1109       return lhsAttr.resizeSplat(resultTy);
 
 1113   if (lhsTy == resultTy) {
 
 1114     if (
isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
 
 1120   return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
 
 1124   auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
 
 1125   auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
 
 1126   auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
 
 1127   if (!lhsTy || !rhsTy || !resultTy)
 
 1131   if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
 
 1132       !rhsTy.getElementType().isIntOrIndexOrFloat())
 
 1135   auto resultETy = resultTy.getElementType();
 
 1137       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
 
 1139       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
 1141   if (lhsTy == resultTy && 
isSplatZero(resultETy, rhsAttr))
 
 1144   if (!lhsAttr || !rhsAttr)
 
 1147   return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
 
 1152 template <
typename Cmp>
 
 1153 struct ComparisonFold {
 
 1154   ComparisonFold() = 
default;
 
 1155   APInt operator()(
const APInt &l, 
const APInt &r) {
 
 1156     return APInt(1, Cmp()(l, r));
 
 1159   APInt operator()(
const APFloat &l, 
const APFloat &r) {
 
 1160     return APInt(1, Cmp()(l, r));
 
 1164 struct APIntFoldGreater {
 
 1165   APIntFoldGreater() = 
default;
 
 1166   APInt operator()(
const APInt &l, 
const APInt &r) {
 
 1167     return APInt(1, l.sgt(r));
 
 1171 struct APIntFoldGreaterEqual {
 
 1172   APIntFoldGreaterEqual() = 
default;
 
 1173   APInt operator()(
const APInt &l, 
const APInt &r) {
 
 1174     return APInt(1, l.sge(r));
 
 1180   auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
 
 1182       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
 
 1184       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
 1186   if (!lhsAttr || !rhsAttr)
 
 1189   return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
 
 1190       lhsAttr, rhsAttr, resultTy);
 
 1193 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
 
 1194   auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
 
 1196       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
 
 1198       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
 1200   if (!lhsAttr || !rhsAttr)
 
 1204                       ComparisonFold<std::greater_equal<APFloat>>>(
 
 1205       lhsAttr, rhsAttr, resultTy);
 
 1209   auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
 
 1211       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
 
 1213       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
 1214   Value lhs = getInput1();
 
 1215   Value rhs = getInput2();
 
 1216   auto lhsTy = llvm::cast<ShapedType>(lhs.
getType());
 
 1220   if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
 
 1221       resultTy.hasStaticShape() && lhs == rhs) {
 
 1225   if (!lhsAttr || !rhsAttr)
 
 1228   return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
 
 1229                       ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
 
 1237   auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
 
 1241   auto inTy = llvm::cast<ShapedType>(getInput().
getType());
 
 1242   auto outTy = llvm::cast<ShapedType>(
getType());
 
 1243   auto inETy = inTy.getElementType();
 
 1244   auto outETy = outTy.getElementType();
 
 1246   if (operand.isSplat()) {
 
 1247     if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
 
 1249       auto splatVal = operand.getSplatValue<APFloat>();
 
 1250       auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
 
 1251       splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
 
 1256     if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
 
 1257       auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
 
 1258       APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
 
 1259       splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
 
 1260                                 llvm::RoundingMode::NearestTiesToEven);
 
 1264     if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
 
 1265       auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
 
 1266       auto intVal = APSInt(
 
 1267           llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
 
 1268       auto floatVal = operand.getSplatValue<APFloat>();
 
 1270       floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
 
 1275     if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
 
 1276       const auto inIntType = llvm::cast<IntegerType>(inETy);
 
 1277       auto unsignIn = inIntType.isUnsignedInteger();
 
 1279           inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
 
 1280       auto intVal = operand.getSplatValue<APInt>();
 
 1281       auto bitwidth = outETy.getIntOrFloatBitWidth();
 
 1284       if (outETy.isInteger(1)) {
 
 1285         intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
 
 1287         intVal = intVal.trunc(bitwidth);
 
 1288       } 
else if (unsignIn || inIntType.isInteger(1)) {
 
 1289         intVal = intVal.zext(bitwidth);
 
 1291         intVal = intVal.sext(bitwidth);
 
 1301 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { 
return getValuesAttr(); }
 
 1303 OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { 
return getValuesAttr(); }
 
 1305 #define REDUCE_FOLDER(OP)                                                      \ 
 1306   OpFoldResult OP::fold(FoldAdaptor adaptor) {                                 \ 
 1307     ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType());         \ 
 1308     if (!inputTy.hasRank())                                                    \ 
 1310     if (inputTy != getType())                                                  \ 
 1312     if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1)          \ 
 1313       return getInput();                                                       \ 
 1323 #undef REDUCE_FOLDER 
 1326   auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
 
 1327   auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
 
 1329   if (!inputTy || !outputTy)
 
 1335   if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
 
 1339   if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
 
 1340           getInput1().getDefiningOp())) {
 
 1341     getInput1Mutable().assign(reshapeOp.getInput1());
 
 1346   if (!inputTy.getElementType().isIntOrIndexOrFloat())
 
 1351           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
 
 1353     if (!outputTy.hasStaticShape())
 
 1357     if (operand.isSplat())
 
 1362     if (!getInput1().hasOneUse())
 
 1369     return operand.reshape(
 
 1370         llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
 
 1378   if (adaptor.getPadding() && getInput1().
getType() == 
getType()) {
 
 1379     auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
 
 1380     if (densePad && densePad.isSplat() &&
 
 1381         densePad.getSplatValue<APInt>().isZero()) {
 
 1393       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
 
 1395       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
 
 1397       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
 
 1398   if (!scaleAttr || !offsetAttr || !borderAttr) {
 
 1405   if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
 
 1410   if (scale[0] != scale[1] || scale[2] != scale[3]) {
 
 1415   if (offset[0] != 0 || offset[1] != 0) {
 
 1420   if (border[0] != 0 || border[1] != 0) {
 
 1424   auto input = getInput();
 
 1425   auto inputTy = llvm::cast<RankedTensorType>(input.getType());
 
 1426   auto resultTy = llvm::cast<RankedTensorType>(
getType());
 
 1427   if (inputTy != resultTy)
 
 1434   auto operand = getInput1();
 
 1435   auto operandTy = llvm::cast<ShapedType>(operand.getType());
 
 1436   auto axis = getAxis();
 
 1438       llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
 
 1443   if (operandTy.hasRank() &&
 
 1444       (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
 
 1451   auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
 
 1452   auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
 
 1454   if (!inputTy || !outputTy)
 
 1457   if (inputTy == outputTy && inputTy.hasStaticShape())
 
 1460   if (!adaptor.getInput1())
 
 1464   if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
 
 1465       !outputTy.getElementType().isIntOrIndexOrFloat())
 
 1468   auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
 
 1469   if (operand.isSplat() && outputTy.hasStaticShape()) {
 
 1473   if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
 
 1474       outputTy.getNumElements() == 1) {
 
 1480         llvm::to_vector(startElems.
getValues<uint64_t>());
 
 1481     auto value = operand.getValues<
Attribute>()[indices];
 
 1488 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
 
 1489   if (getOnTrue() == getOnFalse())
 
 1493       llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
 
 1497   if (!predicate.isSplat())
 
 1499   return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
 
 1505     if (
auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
 
 1506             adaptor.getMultiples())) {
 
 1507       if (multiples.isSplat() &&
 
 1508           multiples.getSplatValue<APInt>().getSExtValue() == 1)
 
 1510       if (
auto int_array_attr =
 
 1511               llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
 
 1512         if (llvm::all_of(int_array_attr.getValues<APInt>(),
 
 1513                          [](APInt v) { return v.getSExtValue() == 1; }))
 
 1522   auto resultTy = llvm::cast<ShapedType>(
getType());
 
 1526           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
 
 1527     if (input.isSplat() && resultTy.hasStaticShape() &&
 
 1528         input.getType().getElementType() == resultTy.getElementType())
 
 1529       return input.reshape(resultTy);
 
 1535   if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
 
 1541 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
 
 1544   auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
 
 1550   if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
 
 1551       failed(maybeIZp) || *maybeIZp != 0) {
 
 1555   if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
 
 1556       failed(maybeOZp) || *maybeOZp != 0) {
 
 1560   if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
 
 1561       failed(maybeIZp) || *maybeIZp != 0) {
 
 1565   if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
 
 1566       failed(maybeOZp) || *maybeOZp != 0) {
 
 1571   return definingOp.getInput1();
 
 1575   auto input = getInput1();
 
 1577   if (
auto op = input.getDefiningOp<tosa::AbsOp>()) {
 
 1590   concatOperands.reserve(2 * getNumOperands());
 
 1593   bool foundFoldableConcat = 
false;
 
 1594   for (
Value operand : getOperands()) {
 
 1595     concatOperands.emplace_back(operand);
 
 1597     auto producer = operand.getDefiningOp<ConcatOp>();
 
 1602     if (getAxis() != producer.getAxis())
 
 1606     foundFoldableConcat = 
true;
 
 1607     concatOperands.pop_back();
 
 1608     llvm::append_range(concatOperands, producer->getOperands());
 
 1611   if (!foundFoldableConcat)
 
 1614   getOperation()->setOperands(concatOperands);
 
 1618 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
 
 1619   auto input = adaptor.getInput1();
 
 1621   auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
 
 1623   if (!inputAttr || !inputAttr.isSplat())
 
 1626   auto shapeType = llvm::cast<ShapedType>(
getType());
 
 1627   if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
 
 1628     auto floatVal = inputAttr.getSplatValue<APFloat>();
 
 1630                                   ReciprocalOp::calcOneElement(floatVal));
 
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define REDUCE_FOLDER(OP)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
auto getValues() const
Return the held element values as a range of the given type.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt round(const Fraction &f)
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...