26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
53 (padConstAttr.
size() != 1)) {
58 if (
auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
59 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
60 return padConstVal == 0.0f;
64 if (
auto padConstIntAttr =
65 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
74 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
75 return zpVal == padConstVal;
83template <
typename OpTy>
84struct PoolPadFoldAdaptor;
87struct PoolPadFoldAdaptor<
tosa::MaxPool2dOp> {
88 using OpTy = tosa::MaxPool2dOp;
89 static bool checkKernelCompliance(OpTy op,
const ArrayRef<int64_t> newPad) {
90 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
91 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
92 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
96 static bool checkPadConstCompliance(OpTy, Value padConst) {
98 DenseElementsAttr padConstAttr;
100 padConstAttr.
size() != 1) {
105 if (
auto padConstFpAttr =
106 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
107 const APFloat padConstVal = *padConstFpAttr.begin();
108 const APFloat lowestVal =
109 APFloat::getLargest(padConstVal.getSemantics(),
true);
110 return padConstVal == lowestVal;
112 if (
auto padConstIntAttr =
113 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
114 const APInt padConstVal = *padConstIntAttr.begin();
115 const unsigned int bitWidth = padConstVal.getBitWidth();
116 const APInt lowestVal =
117 padConstIntAttr.getElementType().isUnsignedInteger()
118 ? APInt::getZero(bitWidth)
119 : APInt::getSignedMinValue(bitWidth);
120 return padConstVal == lowestVal;
126 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
127 Value padInput, ArrayRef<int64_t> newPad) {
129 op, op.getType(), padInput, op.getKernel(), op.getStride(),
134template <
typename OpTy>
135struct ConvPadFoldAdaptor {
136 static bool checkKernelCompliance(OpTy,
const ArrayRef<int64_t>) {
139 static bool checkPadConstCompliance(OpTy op, Value padConst) {
142 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
143 Value padInput, ArrayRef<int64_t> newPad) {
145 op, op.getResult().
getType(), padInput, op.getWeight(), op.getBias(),
146 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
147 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
155template <
typename OpTy,
typename AdaptorTy>
157 using OpRewritePattern<OpTy>::OpRewritePattern;
159 LogicalResult matchAndRewrite(OpTy tensorOp,
160 PatternRewriter &rewriter)
const override {
162 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
165 "Producer must be a tosa::PadOp.");
168 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
169 if (tensorOpPad.size() != 4)
171 tensorOp,
"Tensor operation padding shall have 4 elements.");
174 DenseIntElementsAttr padOpPadding;
178 "The `padding` input specified on the tosa::PadOp must be constant.");
182 if (padOpPadding.size() != 8)
184 "Pad padding should have 8 elements.");
185 int64_t padNBefore = (*(padOpPadding.
begin() + 0)).getLimitedValue();
186 int64_t padNAfter = (*(padOpPadding.
begin() + 1)).getLimitedValue();
187 int64_t padHBefore = (*(padOpPadding.
begin() + 2)).getLimitedValue();
188 int64_t padHAfter = (*(padOpPadding.
begin() + 3)).getLimitedValue();
189 int64_t padWBefore = (*(padOpPadding.
begin() + 4)).getLimitedValue();
190 int64_t padWAfter = (*(padOpPadding.
begin() + 5)).getLimitedValue();
191 int64_t padCBefore = (*(padOpPadding.
begin() + 6)).getLimitedValue();
192 int64_t padCAfter = (*(padOpPadding.
begin() + 7)).getLimitedValue();
194 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
196 tensorOp,
"Folding padding in N or C dimensions is not supported.");
200 SmallVector<int64_t> foldedPad(tensorOpPad.size());
201 foldedPad[0] = padHBefore + tensorOpPad[0];
202 foldedPad[1] = padHAfter + tensorOpPad[1];
203 foldedPad[2] = padWBefore + tensorOpPad[2];
204 foldedPad[3] = padWAfter + tensorOpPad[3];
207 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
209 tensorOp,
"Padding size not aligned with kernel restrictions.");
213 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
216 "Padding constant is not aligned with operator zero-point.");
220 if (llvm::any_of(foldedPad, [](int64_t padVal) {
return padVal > 8192; })) {
222 tensorOp,
"Padding size more than the 8K level limit.");
226 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
237 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
243 results.
add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
244 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
253 Value input = op.getInput();
254 Value output = op.getOutput();
255 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
256 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
258 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
264 if (outputShape[1] != 1 || outputShape[2] != 1) {
269 if (inputShape[1] != 1 || inputShape[2] != 1) {
281 FoldPadToTensorOp<tosa::MaxPool2dOp,
282 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
295 if (op.getInput1().size() != 1)
297 if (op.getInput1().front().getType() != op.getType()) {
300 op.getInput1().front())
305 rewriter.
replaceOp(op, op.getInput1().front());
315LogicalResult SelectOp::canonicalize(SelectOp op,
PatternRewriter &rewriter) {
316 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
320 op.getOperation()->setOperands(
321 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
333 auto innerTranspose =
334 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
337 "input must be transpose operation");
341 innerTranspose.getPerms();
343 if (transposePerms.size() != innerTransposePerms.size())
346 "transpose and inner transpose perms sizes must be equal");
347 if (transposePerms.empty())
349 transposeOp,
"transpose perms sizes must be positive");
353 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
354 perms[i] = innerTransposePerms[transposePerms[i]];
357 transposeOp, transposeOp.getResult().
getType(),
370 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
372 op,
"Src is from transpose, can compose transposes");
376 if (isa_and_nonnull<tosa::TransposeOp>(subop))
378 op,
"Dest is used by transpose, can compose transposes");
381 auto input = op.getInput1();
382 auto inputTy = llvm::cast<ShapedType>(input.
getType());
383 if (!inputTy.hasRank())
387 for (
int i = 0; i < inputTy.getRank(); ++i)
388 if (inputTy.isDynamicDim(i))
397 nonZeroPerms.reserve(permValues.size());
398 for (
auto idx : permValues) {
399 auto sz = inputTy.getDimSize(idx);
401 nonZeroPerms.push_back(idx);
404 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
405 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
407 "Transpose changes memory layout.");
410 newShape.reserve(inputTy.getRank());
411 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
412 newShape.push_back(inputTy.getDimSize(permValues[i]));
415 op, op.getType(), op.getInput1(),
423 results.
add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
431 Value input = op.getInput();
432 auto inputType = llvm::cast<ShapedType>(op.getInput().getType());
433 auto inputElementType = inputType.getElementType();
435 if (isa<FloatType>(inputElementType)) {
437 const auto minClamp =
438 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
439 const auto maxClamp =
440 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
441 const bool isMin = minClamp.isNegInfinity();
442 const bool isMax = maxClamp.isInfinity();
444 if (isMin && isMax) {
452 const bool isBoolean = inputElementType.isInteger(1);
453 if (inputElementType.isUnsignedInteger() || isBoolean) {
454 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
457 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
461 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
462 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
463 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
465 if (minClamp <= intMin && maxClamp >= intMax) {
472 if (llvm::isa<IntegerType>(inputElementType)) {
474 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
476 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
478 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
479 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
480 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
482 if (minClamp <= intMin && maxClamp >= intMax) {
514 template <
typename T>
528 Value input = op.getInput();
536 const auto opNanMode = op.getNanMode();
537 const auto clampNanMode = clampOp.getNanMode();
538 if (opNanMode == NanPropagationMode::IGNORE &&
539 clampNanMode == NanPropagationMode::PROPAGATE)
542 auto maxValAttr = op.getMaxValAttr();
543 auto minValAttr = op.getMinValAttr();
544 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
545 auto clampOpMinValAttr = clampOp.getMinValAttr();
547 auto inputEType = llvm::cast<ShapedType>(input.
getType()).getElementType();
549 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
554 if (mlir::isa<FloatType>(inputEType)) {
555 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
556 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
557 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
558 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
561 const auto opMinFloat = floatMinValAttr.getValue();
562 const auto opMaxFloat = floatMaxValAttr.getValue();
563 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
564 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
568 if (!opRangeFloatRange.
intersects(clampRangeFloatRange))
572 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
573 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
574 newMinValAttr = rewriter.
getFloatAttr(inputEType, newMinVal);
575 newMaxValAttr = rewriter.
getFloatAttr(inputEType, newMaxVal);
577 assert(mlir::isa<IntegerType>(inputEType));
578 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
579 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
580 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
581 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
583 if (inputEType.isUnsignedInteger()) {
585 const auto opMinInt = intMinValAttr.getUInt();
586 const auto opMaxInt = intMaxValAttr.getUInt();
587 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
588 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
592 if (!opRangeIntRange.
intersects(clampRangeIntRange))
596 auto newMinVal = std::max(opMinInt, clampOpMinInt);
597 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
602 const auto opMinInt = intMinValAttr.getInt();
603 const auto opMaxInt = intMaxValAttr.getInt();
604 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
605 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
609 if (!opRangeIntRange.
intersects(clampRangeIntRange))
613 auto newMinVal = std::max(opMinInt, clampOpMinInt);
614 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
620 auto newMode = (opNanMode != clampNanMode)
621 ? tosa::NanPropagationMode::IGNORE
625 NanPropagationModeAttr::get(rewriter.
getContext(), newMode);
628 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
634void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
635 MLIRContext *context) {
636 results.
add<ClampIsNoOp>(context);
637 results.
add<ClampClampOptimization>(context);
645 Value sliceInput = sliceOp.getInput1();
649 sliceOp,
"slice input must be concat operation");
652 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
653 if (!concatType || !concatType.hasStaticShape())
655 sliceOp,
"slice input must be a static ranked tensor");
656 int32_t axis = concatOp.getAxis();
663 sliceOp,
"start of slice must be a static ranked shape");
667 sliceOp,
"size of slice must be a static ranked shape");
677 std::optional<Value> replaceWithSlice;
678 for (
auto input : inputs) {
679 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
680 if (!inputType || !inputType.hasStaticShape())
682 sliceOp,
"concat input must be a static ranked tensor");
684 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
685 inputType.getDimSize(axis)) {
691 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
692 input, start_op, size_op)
696 sliceStarts[axis] -= inputType.getDimSize(axis);
699 if (!replaceWithSlice)
701 sliceOp,
"corresponding concat input not found for slice");
703 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
713 Value sliceInput = sliceOp.getInput1();
719 "slice input must be a pad operation");
722 if (!padOp->hasOneUse())
724 "pad shall have a single consumer");
727 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
728 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
729 if (!inputTy || !padTy || !inputTy.hasRank())
731 "slice input must be a ranked tensor");
738 "`padding` input specified on the tosa::PadOp must be constant.");
741 llvm::to_vector(paddingElems.getValues<
int64_t>());
747 sliceOp,
"start of slice must be a static ranked shape");
754 sliceOp,
"size of slice must be a static ranked shape");
759 const int64_t rank = inputTy.getRank();
760 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](
int64_t i) {
761 const bool isDimDynamic = inputTy.isDynamicDim(i);
762 const bool isDimSliced =
765 return isDimDynamic && isDimSliced;
768 sliceOp,
"axis that are sliced shall be statically known.");
775 bool updated =
false;
777 for (
int64_t i = 0; i < rank; ++i) {
778 const int64_t padLo = padPaddings[i * 2];
779 const int64_t padHi = padPaddings[i * 2 + 1];
780 const int64_t sliceStart = sliceStarts[i];
781 const int64_t sliceSize = sliceSizes[i];
782 const int64_t sliceEnd = sliceStart + sliceSize;
785 if (inputTy.isDynamicDim(i)) {
786 newPadPaddings[i * 2] = padLo;
787 newPadPaddings[i * 2 + 1] = padHi;
788 newSliceStarts[i] = sliceStart;
793 const int64_t dimSize = inputTy.getShape()[i];
794 const int64_t dimTotal = padLo + dimSize + padHi;
797 if (sliceStart < 0 || sliceEnd > dimTotal)
801 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
802 newSliceStarts[i] = newSliceStart;
803 updated |= newSliceStart != sliceStart;
806 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
808 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
809 newPadPaddings[i * 2] = newPadLo;
810 newPadPaddings[i * 2 + 1] = newPadHi;
811 updated |= (newPadLo != padLo) || (newPadHi != padHi);
815 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
821 sliceOp,
"terminate condition; nothing to rewrite");
827 RankedTensorType::get(newPadShape, inputTy.getElementType());
828 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
829 padOp.getInput1(), newPaddingsOp,
830 padOp.getPadConst());
836 newPadOp.getResult(), newStartOp,
851 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
852 if (!resultType.hasRank())
855 ElementsAttr sizeElems;
858 sliceOp,
"size of slice must be a static ranked shape");
862 llvm::to_vector(sizeElems.getValues<
int64_t>());
864 bool replaceSliceSize{
false};
868 for (
const auto &[
index, size] : llvm::enumerate(sliceSizes)) {
870 sliceSizes[
index] = resultType.getDimSize(
index);
871 replaceSliceSize =
true;
875 if (!replaceSliceSize) {
877 sliceOp,
"no dimension of size of slice is dynamic that resolves "
878 "to static output shape");
883 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
884 sliceOp.getInput1(), sliceOp.getStart(), size_op);
886 rewriter.
replaceOp(sliceOp, newSliceOp.getResult());
891void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
892 MLIRContext *context) {
893 results.
add<ConcatSliceOptimization, PadSliceOptimization,
894 SliceDynamicSizeCanonicalization>(context);
902 const Value castInput = castOp.getInput();
906 "input must be cast operation");
908 const Value innerCastInput = innerCastOp.getInput();
910 const auto innerInputType =
911 llvm::cast<ShapedType>(innerCastInput.
getType());
912 const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
913 const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
918 if (llvm::any_of(types, [](
const ShapedType type) {
919 const auto elemTy = type.getElementType();
922 return !(elemTy.isInteger() ||
923 llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
924 Float16Type, Float32Type>(elemTy));
927 castOp,
"only integer and f32, f16, bf16, f8E4M3FN, f8E5M2 types are "
930 if (llvm::isa<Float8E5M2Type>(innerInputType.getElementType()) &&
931 llvm::isa<Float8E4M3FNType>(outerOutputType.getElementType())) {
933 castOp,
"avoid introducing f8E5M2 -> f8E4M3FN casts which are not "
937 if (llvm::isa<Float8E4M3FNType>(innerInputType.getElementType()) &&
938 llvm::isa<Float8E5M2Type>(outerOutputType.getElementType())) {
940 castOp,
"avoid introducing f8E4M3FN -> f8E5M2 casts which are not "
947 "inner cast operation is narrowing");
956 return semantics.nonFiniteBehavior !=
957 llvm::fltNonfiniteBehavior::FiniteOnly;
961 return semantics.nonFiniteBehavior == llvm::fltNonfiniteBehavior::IEEE754;
965 const ShapedType outType)
const {
967 if (inType.getElementType().isInteger() &&
968 outType.getElementType().isInteger()) {
970 const auto inTypeSignedness =
971 cast<IntegerType>(inType.getElementType()).getSignedness();
972 const auto outTypeSignedness =
973 cast<IntegerType>(outType.getElementType()).getSignedness();
975 return (inTypeSignedness != outTypeSignedness ||
976 inType.getElementTypeBitWidth() >
977 outType.getElementTypeBitWidth());
980 if (inType.getElementType().isFloat() &&
981 outType.getElementType().isFloat()) {
983 FloatType inElemTy = cast<FloatType>(inType.getElementType());
984 FloatType outElemTy = cast<FloatType>(outType.getElementType());
985 llvm::fltSemantics inTypeSemantics = inElemTy.getFloatSemantics();
986 llvm::fltSemantics outTypeSemantics = outElemTy.getFloatSemantics();
992 [[maybe_unused]]
const auto isSupported = [](
Type elemType) {
993 return llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
994 Float16Type, Float32Type>(elemType);
997 assert(isSupported(inElemTy) &&
998 "unsupported input element type in isNarrowingCast");
999 assert(isSupported(outElemTy) &&
1000 "unsupported output element type in isNarrowingCast");
1003 inTypeSemantics.maxExponent > outTypeSemantics.maxExponent ||
1004 inTypeSemantics.minExponent < outTypeSemantics.minExponent ||
1005 inTypeSemantics.precision > outTypeSemantics.precision ||
1016void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1017 MLIRContext *context) {
1018 results.
add<NonNarrowingCastsOptimization>(context);
1027 const Value castToBlockScaledInput = castToBlockScaledOp.getInputData();
1028 auto castFromBlockScaledOp =
1029 castToBlockScaledInput.
getDefiningOp<tosa::CastFromBlockScaledOp>();
1030 if (!castFromBlockScaledOp)
1032 castToBlockScaledOp,
1033 "input must be cast_from_block_scaled operation");
1035 const Value innerData = castFromBlockScaledOp.getInputData();
1036 const Value innerScale = castFromBlockScaledOp.getInputScale();
1037 const auto innerDataTy = llvm::cast<ShapedType>(innerData.
getType());
1038 const auto innerScaleTy = llvm::cast<ShapedType>(innerScale.
getType());
1040 const Value outerData = castToBlockScaledOp.getOutputData();
1041 const Value outerScale = castToBlockScaledOp.getOutputScale();
1042 const auto outerDataTy = llvm::cast<ShapedType>(outerData.
getType());
1043 const auto outerScaleTy = llvm::cast<ShapedType>(outerScale.
getType());
1045 if (innerDataTy != outerDataTy || innerScaleTy != outerScaleTy) {
1047 castToBlockScaledOp,
1048 "inputs types to cast_from_block_scaled operation must match output "
1049 "types to cast_to_block_scaled");
1052 if (castFromBlockScaledOp.getBlockSize() !=
1053 castToBlockScaledOp.getBlockSize()) {
1055 castToBlockScaledOp,
"block sizes for cast_from_block_scaled and "
1056 "cast_to_block_scaled must match");
1059 rewriter.
replaceOp(castToBlockScaledOp, {innerData, innerScale});
1065void CastToBlockScaledOp::getCanonicalizationPatterns(
1066 RewritePatternSet &results, MLIRContext *context) {
1067 results.
add<CancellingBlockScaledCastsOptimization>(context);
1074template <
typename Folder>
1075static DenseElementsAttr
1077 bool foldDenseValues =
false) {
1081 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1085 const auto rETy = llvm::cast<ShapedType>(
rhs.getType()).getElementType();
1089 if (
lhs.isSplat() &&
rhs.isSplat()) {
1090 if (isa<FloatType>(lETy)) {
1091 const APFloat l =
lhs.getSplatValue<APFloat>();
1092 const APFloat r =
rhs.getSplatValue<APFloat>();
1093 const auto maybeResult = Folder::fold(l, r);
1094 if (failed(maybeResult))
1099 if (
const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
1100 const APInt l =
lhs.getSplatValue<APInt>();
1101 const APInt r =
rhs.getSplatValue<APInt>();
1102 const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
1103 if (failed(maybeResult))
1109 if (foldDenseValues) {
1110 assert(lETy.isIntOrIndex() &&
1111 "Only integer types are currently supported.");
1114 llvm::zip(
lhs.getValues<APInt>(),
rhs.getValues<APInt>())) {
1115 const auto maybeResult = Folder::fold(l, r,
false);
1116 if (failed(maybeResult))
1118 resultValues.push_back(maybeResult.value());
1126template <
typename Folder>
1128 bool foldDenseValues =
false) {
1132 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1138 if (
const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
1140 const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
1141 if (failed(maybeResult))
1147 if (foldDenseValues) {
1151 for (
auto const &v : val.
getValues<APInt>()) {
1152 const auto maybeResult = Folder::fold(v,
false);
1153 if (failed(maybeResult))
1155 resultValues.push_back(maybeResult.value());
1170 assert(dense.isSplat());
1171 APInt a = dense.getSplatValue<APInt>();
1172 return a.getSExtValue();
1177 const bool isUnsigned) {
1180 isUnsigned ?
lhs.uadd_ov(
rhs, overflow) :
lhs.sadd_ov(
rhs, overflow);
1186 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1193 const bool isUnsigned) {
1196 isUnsigned ?
lhs.usub_ov(
rhs, overflow) :
lhs.ssub_ov(
rhs, overflow);
1202 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1209 const bool isUnsigned) {
1211 const unsigned originalWidth =
lhs.getBitWidth();
1214 if (
lhs.getBitWidth() !=
rhs.getBitWidth()) {
1219 if (
lhs == 0 ||
rhs == 0)
1220 return APInt::getZero(originalWidth);
1222 bool overflow =
false;
1224 isUnsigned ?
lhs.umul_ov(
rhs, overflow) :
lhs.smul_ov(
rhs, overflow);
1229 return result.trunc(originalWidth);
1232 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1238 return a.isNegative() !=
b.isNegative();
1245 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1253 APInt::udivrem(
lhs,
rhs, q, r);
1254 if (!r.isZero() && Ceil) {
1261 bool overflow{
false};
1262 APInt
const q =
lhs.sdiv_ov(
rhs, overflow);
1265 APInt
const r =
lhs.srem(
rhs);
1275 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1283 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1285 if (
lhs.isNegative() || (!
rhs.isStrictlyPositive()))
1295 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1297 auto const r = t.mod(
rhs);
1298 if (llvm::APFloatBase::opStatus::opOK == r) {
1308 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1310 return lhs.getSExtValue() >=
rhs.getSExtValue() ?
lhs :
rhs;
1313 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1321 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1323 return lhs.getSExtValue() <=
rhs.getSExtValue() ?
lhs :
rhs;
1326 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1332 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1333 auto const numBits = value.getBitWidth();
1335 auto const zextv = value.getZExtValue();
1336 if (zextv >= numBits)
1338 return APInt::getOneBitSet(numBits, zextv);
1340 auto const sextv = value.getSExtValue();
1341 if (sextv < 0 || sextv >= numBits || (value.isNegative()))
1343 return APInt::getOneBitSet(numBits, sextv);
1348 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1349 if (!value.isStrictlyPositive())
1351 return APInt(value.getBitWidth(), value.ceilLogBase2());
1356 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1357 if (!value.isStrictlyPositive())
1359 return APInt(value.getBitWidth(), value.logBase2());
1365 const bool isUnsigned) {
1366 return isUnsigned ? APInt(1,
lhs.ugt(
rhs)) : APInt(1,
lhs.sgt(
rhs));
1369 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1370 return APInt(1,
lhs >
rhs);
1376 const bool isUnsigned) {
1377 return isUnsigned ? APInt(1,
lhs.uge(
rhs)) : APInt(1,
lhs.sge(
rhs));
1380 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1381 return APInt(1,
lhs >=
rhs);
1387 const bool isUnsigned) {
1388 return APInt(1,
lhs ==
rhs);
1391 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1392 return APInt(1,
lhs ==
rhs);
1397 if (llvm::isa<FloatType>(elemType))
1399 if (llvm::isa<IntegerType>(elemType))
1405 if (llvm::isa<FloatType>(elemType))
1406 return val && val.
isSplat() &&
1408 if (llvm::isa<IntegerType>(elemType)) {
1409 const int64_t shifted = 1LL << shift;
1410 return val && val.
isSplat() &&
1416OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1417 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1418 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1419 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1420 if (!lhsTy || !rhsTy || !resultTy)
1424 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1425 !rhsTy.getElementType().isIntOrIndexOrFloat())
1428 auto resultETy = resultTy.getElementType();
1430 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1432 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1435 lhsTy.getShape(), rhsTy.getShape());
1436 if (isBroadcastable && lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1438 if (isBroadcastable && rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
1441 if (!lhsAttr || !rhsAttr)
1447OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1448 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
1449 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1450 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1451 !outputTy.hasStaticShape())
1455 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.
isInteger()) {
1456 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1457 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1464OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1465 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1466 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1467 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1468 if (!lhsTy || !rhsTy || !resultTy)
1470 if (lhsTy.getElementType() != rhsTy.getElementType())
1475 auto resultETy = resultTy.getElementType();
1477 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1479 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1480 if (lhsAttr && lhsAttr.isSplat()) {
1481 if (llvm::isa<IntegerType>(resultETy) && resultTy.hasStaticShape() &&
1482 lhsAttr.getSplatValue<APInt>().isZero())
1483 return lhsAttr.resizeSplat(resultTy);
1486 if (rhsAttr && rhsAttr.isSplat()) {
1488 lhsTy.getShape(), rhsTy.getShape());
1489 if (isBroadcastable && lhsTy == resultTy &&
1490 llvm::isa<IntegerType>(resultETy) &&
1491 rhsAttr.getSplatValue<APInt>().isOne())
1495 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1496 llvm::isa<IntegerType>(resultETy)) {
1497 APInt l = lhsAttr.getSplatValue<APInt>();
1498 APInt r = rhsAttr.getSplatValue<APInt>();
1500 auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
1502 DivFoldAdaptor<
false>::fold(l, r, intTy.isUnsigned());
1515std::optional<APInt> mulInt(APInt
lhs, APInt
rhs, int32_t shift,
1516 unsigned bitwidth) {
1517 bool overflow =
false;
1518 APInt
result =
lhs.sext(64).smul_ov(
rhs.sext(64), overflow);
1521 return std::nullopt;
1524 auto round = APInt(64, 1) << (shift - 1);
1526 result.ashrInPlace(shift);
1529 if (!(
result.getSExtValue() >= INT32_MIN &&
1530 result.getSExtValue() <= INT32_MAX)) {
1532 return std::nullopt;
1536 return result.trunc(bitwidth);
1539DenseElementsAttr mulBinaryFolder(DenseElementsAttr
lhs, DenseElementsAttr
rhs,
1540 RankedTensorType ty, int32_t shift) {
1542 if (llvm::isa<IntegerType>(ty.getElementType())) {
1543 APInt l =
lhs.getSplatValue<APInt>();
1544 APInt r =
rhs.getSplatValue<APInt>();
1550 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1551 const std::optional<APInt>
result = mulInt(l, r, shift, bitwidth);
1557 if (llvm::isa<FloatType>(ty.getElementType())) {
1558 APFloat l =
lhs.getSplatValue<APFloat>();
1559 APFloat r =
rhs.getSplatValue<APFloat>();
1569OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1570 auto lhs = getInput1();
1571 auto rhs = getInput2();
1572 auto lhsTy = llvm::dyn_cast<RankedTensorType>(
lhs.getType());
1573 auto rhsTy = llvm::dyn_cast<RankedTensorType>(
rhs.getType());
1574 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1575 if (!lhsTy || !rhsTy || !resultTy)
1578 auto resultETy = resultTy.getElementType();
1580 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1582 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1587 if (resultETy.isInteger(32)) {
1588 ElementsAttr shift_elem;
1589 if (getShift().getImpl()) {
1593 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1597 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr) &&
1598 resultTy.hasStaticShape())
1600 return lhsAttr.resizeSplat(resultTy);
1601 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr) &&
1602 resultTy.hasStaticShape())
1603 return rhsAttr.resizeSplat(resultTy);
1606 lhsTy.getShape(), rhsTy.getShape());
1607 if (isBroadcastable && rhsTy == resultTy &&
1610 if (isBroadcastable && lhsTy == resultTy &&
1614 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1617OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1618 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1619 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1620 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1621 if (!lhsTy || !rhsTy || !resultTy)
1625 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1626 !rhsTy.getElementType().isIntOrIndexOrFloat())
1629 auto resultETy = resultTy.getElementType();
1631 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1633 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1636 lhsTy.getShape(), rhsTy.getShape());
1637 if (isBroadcastable && lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1640 if (!lhsAttr || !rhsAttr)
1646OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1647 auto resultTy = llvm::cast<ShapedType>(
getType());
1649 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1651 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1653 if (!lhsAttr || !rhsAttr)
1659OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1660 auto resultTy = llvm::cast<ShapedType>(
getType());
1662 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1664 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1666 if (!lhsAttr || !rhsAttr)
1672OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1673 auto resultTy = llvm::cast<ShapedType>(
getType());
1675 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1677 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1678 Value
lhs = getInput1();
1679 Value
rhs = getInput2();
1680 auto lhsTy = llvm::cast<ShapedType>(
lhs.getType());
1684 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy.hasRank() &&
1685 resultTy.hasStaticShape() &&
lhs ==
rhs) {
1689 if (!lhsAttr || !rhsAttr)
1695OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1699 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1703 auto inTy = llvm::cast<ShapedType>(getInput().
getType());
1704 auto outTy = llvm::cast<ShapedType>(
getType());
1705 if (!outTy.hasRank() || !outTy.hasStaticShape())
1707 auto inETy = inTy.getElementType();
1708 auto outETy = outTy.getElementType();
1710 if (operand.isSplat()) {
1711 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1713 auto splatVal = operand.getSplatValue<APFloat>();
1714 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1715 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1720 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1721 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1722 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1723 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1724 llvm::RoundingMode::NearestTiesToEven);
1728 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1729 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1730 auto intVal = APSInt(
1731 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1732 auto floatVal = operand.getSplatValue<APFloat>();
1734 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1739 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1740 const auto inIntType = llvm::cast<IntegerType>(inETy);
1741 auto unsignIn = inIntType.isUnsignedInteger();
1743 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1744 auto intVal = operand.getSplatValue<APInt>();
1745 auto bitwidth = outETy.getIntOrFloatBitWidth();
1748 if (outETy.isInteger(1)) {
1749 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1751 intVal = intVal.trunc(bitwidth);
1752 }
else if (unsignIn || inIntType.isInteger(1)) {
1753 intVal = intVal.zext(bitwidth);
1755 intVal = intVal.sext(bitwidth);
1765OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1767OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1769#define REDUCE_FOLDER(OP) \
1770 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1771 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1772 if (!inputTy.hasRank()) \
1774 if (inputTy != getType()) \
1776 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1777 return getInput(); \
1790 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1791 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1793 if (!inputTy || !outputTy)
1799 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1803 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1804 getInput1().getDefiningOp())) {
1805 getInput1Mutable().assign(reshapeOp.getInput1());
1810 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1815 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1817 if (!outputTy.hasStaticShape())
1821 if (operand.isSplat())
1826 if (!getInput1().hasOneUse())
1833 return operand.reshape(
1834 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1840OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1842 if (adaptor.getPadding() && getInput1().
getType() ==
getType()) {
1843 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1844 if (densePad && densePad.isSplat() &&
1845 densePad.getSplatValue<APInt>().isZero()) {
1855OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1857 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1859 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1861 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1862 if (!scaleAttr || !offsetAttr || !borderAttr) {
1869 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1874 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1879 if (offset[0] != 0 || offset[1] != 0) {
1884 if (border[0] != 0 || border[1] != 0) {
1888 return foldToInputIfTypeMatches(
getType(), getInput());
1891OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
1892 auto operand = getInput1();
1893 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1894 auto axis = getAxis();
1896 const bool isSplatInput =
1897 llvm::isa_and_nonnull<SplatElementsAttr>(adaptor.getInput1());
1898 if (!operandTy.hasRank() ||
1899 (!isSplatInput && operandTy.getDimSize(axis) != 1))
1901 return foldToInputIfTypeMatches(
getType(), operand);
1904OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1905 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1906 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1908 if (!inputTy || !outputTy)
1911 if (inputTy == outputTy && inputTy.hasStaticShape())
1916 DenseElementsAttr startElems;
1922 llvm::all_of(startElems.
getValues<APInt>(),
1923 [](
const APInt &val) { return val.isZero(); });
1928 DenseElementsAttr sizeElems;
1932 auto inputShape = inputTy.getShape();
1933 auto sizeValues = sizeElems.
getValues<APInt>();
1935 bool sizeMatchesInput =
true;
1936 for (
const auto &[i, sizeVal] : llvm::enumerate(sizeValues)) {
1937 int64_t size = sizeVal.getSExtValue();
1939 if (inputTy.isDynamicDim(i)) {
1943 sizeMatchesInput =
false;
1950 sizeMatchesInput =
false;
1956 if (sizeMatchesInput)
1961 if (!adaptor.getInput1())
1965 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1966 !outputTy.getElementType().isIntOrIndexOrFloat())
1969 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1970 if (operand.isSplat() && outputTy.hasStaticShape()) {
1974 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1975 outputTy.getNumElements() == 1) {
1976 llvm::SmallVector<uint64_t>
indices =
1977 llvm::to_vector(startElems.
getValues<uint64_t>());
1978 if (
auto values = operand.tryGetValues<Attribute>())
1985OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1986 const Value pred = getPred();
1987 const Value onTrue = getOnTrue();
1988 const Value onFalse = getOnFalse();
1990 const auto predTy = llvm::dyn_cast<RankedTensorType>(pred.
getType());
1991 const auto onTrueTy = llvm::dyn_cast<RankedTensorType>(onTrue.
getType());
1992 const auto onFalseTy = llvm::dyn_cast<RankedTensorType>(onFalse.
getType());
1993 if (!predTy || !onTrueTy || !onFalseTy)
1996 const Type resultTy =
getType();
1998 const ArrayRef<int64_t> predShape = predTy.getShape();
1999 const ArrayRef<int64_t> onTrueShape = onTrueTy.getShape();
2001 if (onTrue == onFalse && onTrueTy == resultTy &&
2006 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
2009 if (!predicate.isSplat())
2012 const bool predicateValue = predicate.getSplatValue<APInt>().getBoolValue();
2014 SmallVector<SmallVector<int64_t>, 3> shapes;
2015 shapes.emplace_back(predShape);
2016 shapes.emplace_back(onTrueShape);
2017 shapes.emplace_back(onFalseTy.getShape());
2018 const bool isBroadcastable =
2021 if (predicateValue ==
true && onTrueTy == resultTy && isBroadcastable)
2023 if (predicateValue ==
false && onFalseTy == resultTy && isBroadcastable)
2028OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
2030 if (
auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
2031 adaptor.getMultiples())) {
2032 if (multiples.isSplat() &&
2033 multiples.getSplatValue<APInt>().getSExtValue() == 1)
2035 if (
auto int_array_attr =
2036 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
2037 if (llvm::all_of(int_array_attr.getValues<APInt>(),
2038 [](APInt v) { return v.getSExtValue() == 1; }))
2046OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
2047 auto resultTy = llvm::cast<ShapedType>(
getType());
2051 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
2052 if (input.isSplat() && resultTy.hasRank() && resultTy.hasStaticShape() &&
2053 input.
getType().getElementType() == resultTy.getElementType())
2054 return input.reshape(resultTy);
2058 const llvm::ArrayRef<int32_t> perms = getPerms();
2060 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
2063 return foldToInputIfTypeMatches(
getType(), getInput1());
2066OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
2069 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
2075 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2076 failed(maybeIZp) || *maybeIZp != 0) {
2080 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2081 failed(maybeOZp) || *maybeOZp != 0) {
2085 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
2086 failed(maybeIZp) || *maybeIZp != 0) {
2090 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
2091 failed(maybeOZp) || *maybeOZp != 0) {
2096 return foldToInputIfTypeMatches(
getType(), definingOp.getInput1());
2099OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
2100 auto input = getInput1();
2103 return foldToInputIfTypeMatches(
getType(), input);
2108OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
2113 SmallVector<Value, 8> concatOperands;
2114 concatOperands.reserve(2 * getNumOperands());
2117 bool foundFoldableConcat =
false;
2118 for (Value operand : getOperands()) {
2119 concatOperands.emplace_back(operand);
2121 auto producer = operand.getDefiningOp<ConcatOp>();
2126 if (getAxis() != producer.getAxis())
2130 foundFoldableConcat =
true;
2131 concatOperands.pop_back();
2132 llvm::append_range(concatOperands, producer->getOperands());
2135 if (!foundFoldableConcat)
2138 getOperation()->setOperands(concatOperands);
2142OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
2143 auto input = adaptor.getInput1();
2145 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
2147 if (!inputAttr || !inputAttr.isSplat())
2150 auto shapeType = llvm::cast<ShapedType>(
getType());
2151 if (!shapeType.hasRank() || !shapeType.hasStaticShape())
2153 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
2154 auto floatVal = inputAttr.getSplatValue<APFloat>();
2156 ReciprocalOp::calcOneElement(floatVal));
2162template <
typename Op,
typename OpFoldAdaptor>
2164 auto input1ConstShape =
2165 dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
2166 if (!input1ConstShape)
2169 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2175template <
typename Op,
typename OpFoldAdaptor>
2177 auto input1ConstShape =
2178 dyn_cast<tosa::ConstShapeOp>(op->getInput1().getDefiningOp());
2179 auto input2ConstShape =
2180 dyn_cast<tosa::ConstShapeOp>(op->getInput2().getDefiningOp());
2181 if (!input1ConstShape || !input2ConstShape)
2184 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2185 const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
2188 input1Attr.getType(),
2192OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
2193 const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().
getType());
2194 if (!inputTy || !inputTy.hasRank())
2196 const int32_t axis = getAxis();
2197 const int64_t dimSize = inputTy.getDimSize(axis);
2198 if (ShapedType::isDynamic(dimSize))
2202 const auto resultAttrTy =
2203 RankedTensorType::get(1, builder.getIndexType());
2208 auto const inputs = op->getInput();
2214 concatDims.reserve( 64);
2215 for (
auto const &v : inputs) {
2216 auto vConstShape = dyn_cast<tosa::ConstShapeOp>(v.getDefiningOp());
2220 const auto vAttr = cast<DenseElementsAttr>(vConstShape.getValues());
2223 auto const vAttrVals = vAttr.getValues<APInt>();
2224 for (
auto const &v : vAttrVals) {
2225 concatDims.push_back(v);
2229 auto *ctx = op->getContext();
2230 assert(ctx !=
nullptr &&
"ctx is nullptr");
2231 auto const rankedTy = RankedTensorType::get(
2232 {
static_cast<int64_t>(concatDims.size())}, IndexType::get(ctx));
2238 auto const input1 = op->getInput();
2239 auto const input2 = op->getStart();
2240 auto const input3 = op->getSize();
2242 auto input1ConstShape = dyn_cast<tosa::ConstShapeOp>(input1.getDefiningOp());
2244 if (!input1ConstShape)
2247 auto const input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2251 auto const input1Vals = input1Attr.getValues<APInt>();
2252 auto const totalInput1 = input1Vals.size();
2257 if (failed(start) || failed(size))
2260 auto const startV =
static_cast<int32_t
>(start.value());
2261 auto const sizeV =
static_cast<int32_t
>(size.value());
2263 if ((sizeV <= 0) || (startV < 0) ||
2264 (
static_cast<size_t>(startV + sizeV) > totalInput1))
2268 sliceOfInput.reserve(totalInput1);
2270 for (
auto i = startV; i < (startV + sizeV); i++) {
2271 sliceOfInput.push_back(input1Vals[i]);
2274 auto *ctx = op->getContext();
2275 assert(ctx !=
nullptr &&
"ctx is nullptr");
2277 auto const rankedTy = RankedTensorType::get(
2278 {
static_cast<int64_t>(sliceOfInput.size())}, IndexType::get(ctx));
2283OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
2287OpFoldResult tosa::SubShapeOp::fold(FoldAdaptor adaptor) {
2291OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
2295OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
2296 return binaryFold<DivCeilShapeOp, DivFoldAdaptor<
true>>(
this);
2299OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
2300 return binaryFold<DivFloorShapeOp, DivFoldAdaptor<
false>>(
this);
2303OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
2307OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
2311OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
2315OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
2319OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
2323OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
2327OpFoldResult tosa::ConcatShapeOp::fold(FoldAdaptor adaptor) {
2331OpFoldResult tosa::SliceShapeOp::fold(FoldAdaptor adaptor) {
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
#define REDUCE_FOLDER(OP)
OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, bool foldDenseValues=false)
static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy, bool foldDenseValues=false)
OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op)
static FailureOr< int64_t > getSingleI64From1ElementTensor(Value v)
OpFoldResult binaryFold(Op *op)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
OpFoldResult unaryShapeFold(Op *op)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool signsDiffer(const APInt &a, const APInt &b)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
DynamicAPInt round(const Fraction &f)
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp, PatternRewriter &rewriter) const override
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
bool isNarrowingCast(const ShapedType inType, const ShapedType outType) const
LogicalResult matchAndRewrite(tosa::CastOp castOp, PatternRewriter &rewriter) const override
bool supportsInf(const llvm::fltSemantics &semantics) const
bool supportsNaN(const llvm::fltSemantics &semantics) const
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...