26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
53 (padConstAttr.
size() != 1)) {
58 if (
auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
59 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
60 return padConstVal == 0.0f;
64 if (
auto padConstIntAttr =
65 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
74 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
75 return zpVal == padConstVal;
83template <
typename OpTy>
84struct PoolPadFoldAdaptor;
87struct PoolPadFoldAdaptor<
tosa::MaxPool2dOp> {
88 using OpTy = tosa::MaxPool2dOp;
89 static bool checkKernelCompliance(OpTy op,
const ArrayRef<int64_t> newPad) {
90 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
91 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
92 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
96 static bool checkPadConstCompliance(OpTy, Value padConst) {
98 DenseElementsAttr padConstAttr;
100 padConstAttr.
size() != 1) {
105 if (
auto padConstFpAttr =
106 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
107 const APFloat padConstVal = *padConstFpAttr.begin();
108 const APFloat lowestVal =
109 APFloat::getLargest(padConstVal.getSemantics(),
true);
110 return padConstVal == lowestVal;
112 if (
auto padConstIntAttr =
113 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
114 const APInt padConstVal = *padConstIntAttr.begin();
115 const unsigned int bitWidth = padConstVal.getBitWidth();
116 const APInt lowestVal =
117 padConstIntAttr.getElementType().isUnsignedInteger()
118 ? APInt::getZero(bitWidth)
119 : APInt::getSignedMinValue(bitWidth);
120 return padConstVal == lowestVal;
126 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
127 Value padInput, ArrayRef<int64_t> newPad) {
129 op, op.getType(), padInput, op.getKernel(), op.getStride(),
134template <
typename OpTy>
135struct ConvPadFoldAdaptor {
136 static bool checkKernelCompliance(OpTy,
const ArrayRef<int64_t>) {
139 static bool checkPadConstCompliance(OpTy op, Value padConst) {
142 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
143 Value padInput, ArrayRef<int64_t> newPad) {
145 op, op.getResult().
getType(), padInput, op.getWeight(), op.getBias(),
146 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
147 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
155template <
typename OpTy,
typename AdaptorTy>
157 using OpRewritePattern<OpTy>::OpRewritePattern;
159 LogicalResult matchAndRewrite(OpTy tensorOp,
160 PatternRewriter &rewriter)
const override {
162 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
165 "Producer must be a tosa::PadOp.");
168 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
169 if (tensorOpPad.size() != 4)
171 tensorOp,
"Tensor operation padding shall have 4 elements.");
174 DenseIntElementsAttr padOpPadding;
178 "The `padding` input specified on the tosa::PadOp must be constant.");
182 if (padOpPadding.size() != 8)
184 "Pad padding should have 8 elements.");
185 int64_t padNBefore = (*(padOpPadding.
begin() + 0)).getLimitedValue();
186 int64_t padNAfter = (*(padOpPadding.
begin() + 1)).getLimitedValue();
187 int64_t padHBefore = (*(padOpPadding.
begin() + 2)).getLimitedValue();
188 int64_t padHAfter = (*(padOpPadding.
begin() + 3)).getLimitedValue();
189 int64_t padWBefore = (*(padOpPadding.
begin() + 4)).getLimitedValue();
190 int64_t padWAfter = (*(padOpPadding.
begin() + 5)).getLimitedValue();
191 int64_t padCBefore = (*(padOpPadding.
begin() + 6)).getLimitedValue();
192 int64_t padCAfter = (*(padOpPadding.
begin() + 7)).getLimitedValue();
194 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
196 tensorOp,
"Folding padding in N or C dimensions is not supported.");
200 SmallVector<int64_t> foldedPad(tensorOpPad.size());
201 foldedPad[0] = padHBefore + tensorOpPad[0];
202 foldedPad[1] = padHAfter + tensorOpPad[1];
203 foldedPad[2] = padWBefore + tensorOpPad[2];
204 foldedPad[3] = padWAfter + tensorOpPad[3];
207 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
209 tensorOp,
"Padding size not aligned with kernel restrictions.");
213 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
216 "Padding constant is not aligned with operator zero-point.");
220 if (llvm::any_of(foldedPad, [](int64_t padVal) {
return padVal > 8192; })) {
222 tensorOp,
"Padding size more than the 8K level limit.");
226 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
237 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
243 results.
add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
244 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
261 op,
"expected constant kernel, stride, and pad operands");
264 rewriter, op.getLoc(), op.
getType(), op.getInput(), op.getInputZp(),
273void AvgPool2dAdaptiveOp::getCanonicalizationPatterns(
283 Value input = op.getInput();
284 Value output = op.getOutput();
285 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
286 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
288 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
294 if (outputShape[1] != 1 || outputShape[2] != 1) {
299 if (inputShape[1] != 1 || inputShape[2] != 1) {
311 FoldPadToTensorOp<tosa::MaxPool2dOp,
312 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
329 op,
"expected constant kernel, stride, and pad operands");
332 rewriter, op.getLoc(), op.
getType(), op.getInput(),
341void MaxPool2dAdaptiveOp::getCanonicalizationPatterns(
355 if (op.getInput1().size() != 1)
357 if (op.getInput1().front().getType() != op.getType()) {
360 op.getInput1().front())
365 rewriter.
replaceOp(op, op.getInput1().front());
375LogicalResult SelectOp::canonicalize(SelectOp op,
PatternRewriter &rewriter) {
376 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
380 op.getOperation()->setOperands(
381 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
393 auto innerTranspose =
394 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
397 "input must be transpose operation");
401 innerTranspose.getPerms();
403 if (transposePerms.size() != innerTransposePerms.size())
406 "transpose and inner transpose perms sizes must be equal");
407 if (transposePerms.empty())
409 transposeOp,
"transpose perms sizes must be positive");
413 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
414 perms[i] = innerTransposePerms[transposePerms[i]];
417 transposeOp, transposeOp.getResult().
getType(),
430 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
432 op,
"Src is from transpose, can compose transposes");
436 if (isa_and_nonnull<tosa::TransposeOp>(subop))
438 op,
"Dest is used by transpose, can compose transposes");
441 auto input = op.getInput1();
442 auto inputTy = llvm::cast<ShapedType>(input.
getType());
443 if (!inputTy.hasRank())
447 for (
int i = 0; i < inputTy.getRank(); ++i)
448 if (inputTy.isDynamicDim(i))
457 nonZeroPerms.reserve(permValues.size());
458 for (
auto idx : permValues) {
459 auto sz = inputTy.getDimSize(idx);
461 nonZeroPerms.push_back(idx);
464 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
465 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
467 "Transpose changes memory layout.");
470 newShape.reserve(inputTy.getRank());
471 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
472 newShape.push_back(inputTy.getDimSize(permValues[i]));
475 op, op.getType(), op.getInput1(),
483 results.
add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
491 Value input = op.getInput();
492 auto inputType = llvm::cast<ShapedType>(op.getInput().getType());
493 auto inputElementType = inputType.getElementType();
495 if (isa<FloatType>(inputElementType)) {
497 const auto minClamp =
498 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
499 const auto maxClamp =
500 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
501 const bool isMin = minClamp.isNegInfinity();
502 const bool isMax = maxClamp.isInfinity();
504 if (isMin && isMax) {
512 const bool isBoolean = inputElementType.isInteger(1);
513 if (inputElementType.isUnsignedInteger() || isBoolean) {
514 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
517 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
521 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
522 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
523 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
525 if (minClamp <= intMin && maxClamp >= intMax) {
532 if (llvm::isa<IntegerType>(inputElementType)) {
534 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
536 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
538 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
539 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
540 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
542 if (minClamp <= intMin && maxClamp >= intMax) {
574 template <
typename T>
588 Value input = op.getInput();
596 const auto opNanMode = op.getNanMode();
597 const auto clampNanMode = clampOp.getNanMode();
598 if (opNanMode == NanPropagationMode::IGNORE &&
599 clampNanMode == NanPropagationMode::PROPAGATE)
602 auto maxValAttr = op.getMaxValAttr();
603 auto minValAttr = op.getMinValAttr();
604 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
605 auto clampOpMinValAttr = clampOp.getMinValAttr();
607 auto inputEType = llvm::cast<ShapedType>(input.
getType()).getElementType();
609 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
614 if (mlir::isa<FloatType>(inputEType)) {
615 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
616 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
617 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
618 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
621 const auto opMinFloat = floatMinValAttr.getValue();
622 const auto opMaxFloat = floatMaxValAttr.getValue();
623 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
624 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
628 if (!opRangeFloatRange.
intersects(clampRangeFloatRange))
632 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
633 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
634 newMinValAttr = rewriter.
getFloatAttr(inputEType, newMinVal);
635 newMaxValAttr = rewriter.
getFloatAttr(inputEType, newMaxVal);
637 assert(mlir::isa<IntegerType>(inputEType));
638 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
639 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
640 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
641 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
643 if (inputEType.isUnsignedInteger()) {
645 const auto opMinInt = intMinValAttr.getUInt();
646 const auto opMaxInt = intMaxValAttr.getUInt();
647 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
648 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
652 if (!opRangeIntRange.
intersects(clampRangeIntRange))
656 auto newMinVal = std::max(opMinInt, clampOpMinInt);
657 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
662 const auto opMinInt = intMinValAttr.getInt();
663 const auto opMaxInt = intMaxValAttr.getInt();
664 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
665 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
669 if (!opRangeIntRange.
intersects(clampRangeIntRange))
673 auto newMinVal = std::max(opMinInt, clampOpMinInt);
674 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
680 auto newMode = (opNanMode != clampNanMode)
681 ? tosa::NanPropagationMode::IGNORE
685 NanPropagationModeAttr::get(rewriter.
getContext(), newMode);
688 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
694void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
695 MLIRContext *context) {
696 results.
add<ClampIsNoOp>(context);
697 results.
add<ClampClampOptimization>(context);
705 Value sliceInput = sliceOp.getInput1();
709 sliceOp,
"slice input must be concat operation");
712 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
713 if (!concatType || !concatType.hasStaticShape())
715 sliceOp,
"slice input must be a static ranked tensor");
716 int32_t axis = concatOp.getAxis();
723 sliceOp,
"start of slice must be a static ranked shape");
727 sliceOp,
"size of slice must be a static ranked shape");
737 std::optional<Value> replaceWithSlice;
738 for (
auto input : inputs) {
739 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
740 if (!inputType || !inputType.hasStaticShape())
742 sliceOp,
"concat input must be a static ranked tensor");
744 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
745 inputType.getDimSize(axis)) {
751 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
752 input, start_op, size_op)
756 sliceStarts[axis] -= inputType.getDimSize(axis);
759 if (!replaceWithSlice)
761 sliceOp,
"corresponding concat input not found for slice");
763 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
773 Value sliceInput = sliceOp.getInput1();
779 "slice input must be a pad operation");
782 if (!padOp->hasOneUse())
784 "pad shall have a single consumer");
787 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
788 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
789 if (!inputTy || !padTy || !inputTy.hasRank())
791 "slice input must be a ranked tensor");
798 "`padding` input specified on the tosa::PadOp must be constant.");
801 llvm::to_vector(paddingElems.getValues<
int64_t>());
807 sliceOp,
"start of slice must be a static ranked shape");
814 sliceOp,
"size of slice must be a static ranked shape");
819 const int64_t rank = inputTy.getRank();
820 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](
int64_t i) {
821 const bool isDimDynamic = inputTy.isDynamicDim(i);
822 const bool isDimSliced =
825 return isDimDynamic && isDimSliced;
828 sliceOp,
"axis that are sliced shall be statically known.");
835 bool updated =
false;
837 for (
int64_t i = 0; i < rank; ++i) {
838 const int64_t padLo = padPaddings[i * 2];
839 const int64_t padHi = padPaddings[i * 2 + 1];
840 const int64_t sliceStart = sliceStarts[i];
841 const int64_t sliceSize = sliceSizes[i];
842 const int64_t sliceEnd = sliceStart + sliceSize;
845 if (inputTy.isDynamicDim(i)) {
846 newPadPaddings[i * 2] = padLo;
847 newPadPaddings[i * 2 + 1] = padHi;
848 newSliceStarts[i] = sliceStart;
853 const int64_t dimSize = inputTy.getShape()[i];
854 const int64_t dimTotal = padLo + dimSize + padHi;
857 if (sliceStart < 0 || sliceEnd > dimTotal)
861 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
862 newSliceStarts[i] = newSliceStart;
863 updated |= newSliceStart != sliceStart;
866 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
868 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
869 newPadPaddings[i * 2] = newPadLo;
870 newPadPaddings[i * 2 + 1] = newPadHi;
871 updated |= (newPadLo != padLo) || (newPadHi != padHi);
875 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
881 sliceOp,
"terminate condition; nothing to rewrite");
887 RankedTensorType::get(newPadShape, inputTy.getElementType());
888 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
889 padOp.getInput1(), newPaddingsOp,
890 padOp.getPadConst());
896 newPadOp.getResult(), newStartOp,
911 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
912 if (!resultType.hasRank())
915 ElementsAttr sizeElems;
918 sliceOp,
"size of slice must be a static ranked shape");
922 llvm::to_vector(sizeElems.getValues<
int64_t>());
924 bool replaceSliceSize{
false};
928 for (
const auto &[
index, size] : llvm::enumerate(sliceSizes)) {
930 sliceSizes[
index] = resultType.getDimSize(
index);
931 replaceSliceSize =
true;
935 if (!replaceSliceSize) {
937 sliceOp,
"no dimension of size of slice is dynamic that resolves "
938 "to static output shape");
943 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
944 sliceOp.getInput1(), sliceOp.getStart(), size_op);
946 rewriter.
replaceOp(sliceOp, newSliceOp.getResult());
951void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
952 MLIRContext *context) {
953 results.
add<ConcatSliceOptimization, PadSliceOptimization,
954 SliceDynamicSizeCanonicalization>(context);
962 const Value castInput = castOp.getInput();
966 "input must be cast operation");
968 const Value innerCastInput = innerCastOp.getInput();
970 const ShapedType innerInputType =
971 llvm::cast<ShapedType>(innerCastInput.
getType());
972 const ShapedType innerOutputType =
973 llvm::cast<ShapedType>(innerCastOp.getType());
974 const ShapedType outerOutputType = llvm::cast<ShapedType>(castOp.getType());
976 const Type innerInputElemType = innerInputType.getElementType();
977 const Type innerOutputElemType = innerOutputType.getElementType();
978 const Type outerOutputElemType = outerOutputType.getElementType();
981 outerOutputElemType};
983 if (llvm::any_of(types, [](
const Type type) {
987 llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
988 Float16Type, Float32Type>(type));
991 castOp,
"only integer and f32, f16, bf16, f8E4M3FN, f8E5M2 types are "
994 if (llvm::isa<Float8E5M2Type>(innerInputElemType) &&
995 llvm::isa<Float8E4M3FNType>(outerOutputElemType)) {
997 castOp,
"avoid introducing f8E5M2 -> f8E4M3FN casts which are not "
1001 if (llvm::isa<Float8E4M3FNType>(innerInputElemType) &&
1002 llvm::isa<Float8E5M2Type>(outerOutputElemType)) {
1004 castOp,
"avoid introducing f8E4M3FN -> f8E5M2 casts which are not "
1008 if (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(innerInputElemType) &&
1011 castOp,
"avoid introducing fp8 -> integer casts which are not "
1016 llvm::isa<Float8E5M2Type, Float8E4M3FNType>(outerOutputElemType)) {
1018 castOp,
"avoid introducing integer -> fp8 casts which are not "
1022 if (llvm::isa<Float16Type>(innerInputElemType) &&
1023 llvm::isa<BFloat16Type>(outerOutputElemType)) {
1025 castOp,
"avoid introducing fp16 -> bf16 casts which are not "
1029 if (llvm::isa<BFloat16Type>(innerInputElemType) &&
1030 llvm::isa<Float16Type>(outerOutputElemType)) {
1032 castOp,
"avoid introducing bf16 -> fp16 casts which are not "
1036 const auto isIntegerOneOfWidth = [](
Type type,
size_t bitwidth1,
1041 if (isIntegerOneOfWidth(innerInputElemType, 8, 16) &&
1044 castOp,
"avoid introducing i8/i16 -> i64 casts which are not "
1048 if (isIntegerOneOfWidth(innerInputElemType, 1, 64) &&
1051 castOp,
"avoid introducing bool/i64 to float casts which are not "
1052 "supported in all versions of TOSA");
1056 isIntegerOneOfWidth(outerOutputElemType, 1, 64)) {
1058 castOp,
"avoid introducing float to bool/i64 casts which are not "
1059 "supported in all versions of TOSA");
1065 "inner cast operation is narrowing");
1074 return semantics.nonFiniteBehavior !=
1075 llvm::fltNonfiniteBehavior::FiniteOnly;
1079 return semantics.nonFiniteBehavior == llvm::fltNonfiniteBehavior::IEEE754;
1083 const ShapedType outType)
const {
1085 if (inType.getElementType().isInteger() &&
1086 outType.getElementType().isInteger()) {
1088 const auto inTypeSignedness =
1089 cast<IntegerType>(inType.getElementType()).getSignedness();
1090 const auto outTypeSignedness =
1091 cast<IntegerType>(outType.getElementType()).getSignedness();
1093 return (inTypeSignedness != outTypeSignedness ||
1094 inType.getElementTypeBitWidth() >
1095 outType.getElementTypeBitWidth());
1098 if (inType.getElementType().isFloat() &&
1099 outType.getElementType().isFloat()) {
1101 FloatType inElemTy = cast<FloatType>(inType.getElementType());
1102 FloatType outElemTy = cast<FloatType>(outType.getElementType());
1103 llvm::fltSemantics inTypeSemantics = inElemTy.getFloatSemantics();
1104 llvm::fltSemantics outTypeSemantics = outElemTy.getFloatSemantics();
1110 [[maybe_unused]]
const auto isSupported = [](
Type elemType) {
1111 return llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
1112 Float16Type, Float32Type>(elemType);
1115 assert(isSupported(inElemTy) &&
1116 "unsupported input element type in isNarrowingCast");
1117 assert(isSupported(outElemTy) &&
1118 "unsupported output element type in isNarrowingCast");
1121 inTypeSemantics.maxExponent > outTypeSemantics.maxExponent ||
1122 inTypeSemantics.minExponent < outTypeSemantics.minExponent ||
1123 inTypeSemantics.precision > outTypeSemantics.precision ||
1134void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1135 MLIRContext *context) {
1136 results.
add<NonNarrowingCastsOptimization>(context);
1145 const Value castToBlockScaledInput = castToBlockScaledOp.getInputData();
1146 auto castFromBlockScaledOp =
1147 castToBlockScaledInput.
getDefiningOp<tosa::CastFromBlockScaledOp>();
1148 if (!castFromBlockScaledOp)
1150 castToBlockScaledOp,
1151 "input must be cast_from_block_scaled operation");
1153 const Value innerData = castFromBlockScaledOp.getInputData();
1154 const Value innerScale = castFromBlockScaledOp.getInputScale();
1155 const auto innerDataTy = llvm::cast<ShapedType>(innerData.
getType());
1156 const auto innerScaleTy = llvm::cast<ShapedType>(innerScale.
getType());
1158 const Value outerData = castToBlockScaledOp.getOutputData();
1159 const Value outerScale = castToBlockScaledOp.getOutputScale();
1160 const auto outerDataTy = llvm::cast<ShapedType>(outerData.
getType());
1161 const auto outerScaleTy = llvm::cast<ShapedType>(outerScale.
getType());
1163 if (innerDataTy != outerDataTy || innerScaleTy != outerScaleTy) {
1165 castToBlockScaledOp,
1166 "inputs types to cast_from_block_scaled operation must match output "
1167 "types to cast_to_block_scaled");
1170 if (castFromBlockScaledOp.getBlockSize() !=
1171 castToBlockScaledOp.getBlockSize()) {
1173 castToBlockScaledOp,
"block sizes for cast_from_block_scaled and "
1174 "cast_to_block_scaled must match");
1177 rewriter.
replaceOp(castToBlockScaledOp, {innerData, innerScale});
1183void CastToBlockScaledOp::getCanonicalizationPatterns(
1184 RewritePatternSet &results, MLIRContext *context) {
1185 results.
add<CancellingBlockScaledCastsOptimization>(context);
1192template <
typename Folder>
1193static DenseElementsAttr
1195 bool foldDenseValues =
false) {
1199 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1203 const auto rETy = llvm::cast<ShapedType>(
rhs.getType()).getElementType();
1207 if (
lhs.isSplat() &&
rhs.isSplat()) {
1208 if (isa<FloatType>(lETy)) {
1209 const APFloat l =
lhs.getSplatValue<APFloat>();
1210 const APFloat r =
rhs.getSplatValue<APFloat>();
1211 const auto maybeResult = Folder::fold(l, r);
1212 if (failed(maybeResult))
1217 if (
const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
1218 const APInt l =
lhs.getSplatValue<APInt>();
1219 const APInt r =
rhs.getSplatValue<APInt>();
1220 const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
1221 if (failed(maybeResult))
1227 if (foldDenseValues) {
1228 assert(lETy.isIntOrIndex() &&
1229 "Only integer types are currently supported.");
1232 llvm::zip(
lhs.getValues<APInt>(),
rhs.getValues<APInt>())) {
1233 const auto maybeResult = Folder::fold(l, r,
false);
1234 if (failed(maybeResult))
1236 resultValues.push_back(maybeResult.value());
1244template <
typename Folder>
1246 bool foldDenseValues =
false) {
1250 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1256 if (
const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
1258 const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
1259 if (failed(maybeResult))
1265 if (foldDenseValues) {
1269 for (
auto const &v : val.
getValues<APInt>()) {
1270 const auto maybeResult = Folder::fold(v,
false);
1271 if (failed(maybeResult))
1273 resultValues.push_back(maybeResult.value());
1288 assert(dense.isSplat());
1289 APInt a = dense.getSplatValue<APInt>();
1290 return a.getSExtValue();
1295 const bool isUnsigned) {
1298 isUnsigned ?
lhs.uadd_ov(
rhs, overflow) :
lhs.sadd_ov(
rhs, overflow);
1304 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1311 const bool isUnsigned) {
1314 isUnsigned ?
lhs.usub_ov(
rhs, overflow) :
lhs.ssub_ov(
rhs, overflow);
1320 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1327 const bool isUnsigned) {
1329 const unsigned originalWidth =
lhs.getBitWidth();
1332 if (
lhs.getBitWidth() !=
rhs.getBitWidth()) {
1337 if (
lhs == 0 ||
rhs == 0)
1338 return APInt::getZero(originalWidth);
1340 bool overflow =
false;
1342 isUnsigned ?
lhs.umul_ov(
rhs, overflow) :
lhs.smul_ov(
rhs, overflow);
1347 return result.trunc(originalWidth);
1350 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1356 return a.isNegative() !=
b.isNegative();
1363 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1371 APInt::udivrem(
lhs,
rhs, q, r);
1372 if (!r.isZero() && Ceil) {
1379 bool overflow{
false};
1380 APInt
const q =
lhs.sdiv_ov(
rhs, overflow);
1383 APInt
const r =
lhs.srem(
rhs);
1393 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1401 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1403 if (
lhs.isNegative() || (!
rhs.isStrictlyPositive()))
1413 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1415 auto const r = t.mod(
rhs);
1416 if (llvm::APFloatBase::opStatus::opOK == r) {
1426 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1428 return lhs.getSExtValue() >=
rhs.getSExtValue() ?
lhs :
rhs;
1431 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1439 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1441 return lhs.getSExtValue() <=
rhs.getSExtValue() ?
lhs :
rhs;
1444 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1450 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1451 auto const numBits = value.getBitWidth();
1453 auto const zextv = value.getZExtValue();
1454 if (zextv >= numBits)
1456 return APInt::getOneBitSet(numBits, zextv);
1458 auto const sextv = value.getSExtValue();
1459 if (sextv < 0 || sextv >= numBits || (value.isNegative()))
1461 return APInt::getOneBitSet(numBits, sextv);
1466 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1467 if (!value.isStrictlyPositive())
1469 return APInt(value.getBitWidth(), value.ceilLogBase2());
1474 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1475 if (!value.isStrictlyPositive())
1477 return APInt(value.getBitWidth(), value.logBase2());
1483 const bool isUnsigned) {
1484 return isUnsigned ? APInt(1,
lhs.ugt(
rhs)) : APInt(1,
lhs.sgt(
rhs));
1487 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1488 return APInt(1,
lhs >
rhs);
1494 const bool isUnsigned) {
1495 return isUnsigned ? APInt(1,
lhs.uge(
rhs)) : APInt(1,
lhs.sge(
rhs));
1498 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1499 return APInt(1,
lhs >=
rhs);
1505 const bool isUnsigned) {
1506 return APInt(1,
lhs ==
rhs);
1509 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1510 return APInt(1,
lhs ==
rhs);
1515 if (llvm::isa<FloatType>(elemType))
1517 if (llvm::isa<IntegerType>(elemType))
1523 if (llvm::isa<FloatType>(elemType))
1524 return val && val.
isSplat() &&
1526 if (llvm::isa<IntegerType>(elemType)) {
1527 const int64_t shifted = 1LL << shift;
1528 return val && val.
isSplat() &&
1534OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1535 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1536 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1537 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1538 if (!lhsTy || !rhsTy || !resultTy)
1542 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1543 !rhsTy.getElementType().isIntOrIndexOrFloat())
1546 auto resultETy = resultTy.getElementType();
1548 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1550 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1553 lhsTy.getShape(), rhsTy.getShape());
1554 if (isBroadcastable && lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1556 if (isBroadcastable && rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
1559 if (!lhsAttr || !rhsAttr)
1565OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1566 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
1567 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1568 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1569 !outputTy.hasStaticShape())
1573 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.
isInteger()) {
1574 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1575 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1582OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1583 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1584 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1585 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1586 if (!lhsTy || !rhsTy || !resultTy)
1588 if (lhsTy.getElementType() != rhsTy.getElementType())
1593 auto resultETy = resultTy.getElementType();
1595 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1597 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1598 if (lhsAttr && lhsAttr.isSplat()) {
1599 if (llvm::isa<IntegerType>(resultETy) && resultTy.hasStaticShape() &&
1600 lhsAttr.getSplatValue<APInt>().isZero())
1601 return lhsAttr.resizeSplat(resultTy);
1604 if (rhsAttr && rhsAttr.isSplat()) {
1606 lhsTy.getShape(), rhsTy.getShape());
1607 if (isBroadcastable && lhsTy == resultTy &&
1608 llvm::isa<IntegerType>(resultETy) &&
1609 rhsAttr.getSplatValue<APInt>().isOne())
1613 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1614 llvm::isa<IntegerType>(resultETy)) {
1615 APInt l = lhsAttr.getSplatValue<APInt>();
1616 APInt r = rhsAttr.getSplatValue<APInt>();
1618 auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
1620 DivFoldAdaptor<
false>::fold(l, r, intTy.isUnsigned());
1633std::optional<APInt> mulInt(APInt
lhs, APInt
rhs, int32_t shift,
1634 unsigned bitwidth) {
1635 bool overflow =
false;
1636 APInt
result =
lhs.sext(64).smul_ov(
rhs.sext(64), overflow);
1639 return std::nullopt;
1642 auto round = APInt(64, 1) << (shift - 1);
1644 result.ashrInPlace(shift);
1647 if (!(
result.getSExtValue() >= INT32_MIN &&
1648 result.getSExtValue() <= INT32_MAX)) {
1650 return std::nullopt;
1654 return result.trunc(bitwidth);
1657DenseElementsAttr mulBinaryFolder(DenseElementsAttr
lhs, DenseElementsAttr
rhs,
1658 RankedTensorType ty, int32_t shift) {
1660 if (llvm::isa<IntegerType>(ty.getElementType())) {
1661 APInt l =
lhs.getSplatValue<APInt>();
1662 APInt r =
rhs.getSplatValue<APInt>();
1668 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1669 const std::optional<APInt>
result = mulInt(l, r, shift, bitwidth);
1675 if (llvm::isa<FloatType>(ty.getElementType())) {
1676 APFloat l =
lhs.getSplatValue<APFloat>();
1677 APFloat r =
rhs.getSplatValue<APFloat>();
1687OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1688 auto lhs = getInput1();
1689 auto rhs = getInput2();
1690 auto lhsTy = llvm::dyn_cast<RankedTensorType>(
lhs.getType());
1691 auto rhsTy = llvm::dyn_cast<RankedTensorType>(
rhs.getType());
1692 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1693 if (!lhsTy || !rhsTy || !resultTy)
1696 auto resultETy = resultTy.getElementType();
1698 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1700 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1705 if (resultETy.isInteger(32)) {
1706 ElementsAttr shift_elem;
1707 if (getShift().getImpl()) {
1711 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1715 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr) &&
1716 resultTy.hasStaticShape())
1718 return lhsAttr.resizeSplat(resultTy);
1719 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr) &&
1720 resultTy.hasStaticShape())
1721 return rhsAttr.resizeSplat(resultTy);
1724 lhsTy.getShape(), rhsTy.getShape());
1725 if (isBroadcastable && rhsTy == resultTy &&
1728 if (isBroadcastable && lhsTy == resultTy &&
1732 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1735OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1736 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1737 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1738 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1739 if (!lhsTy || !rhsTy || !resultTy)
1743 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1744 !rhsTy.getElementType().isIntOrIndexOrFloat())
1747 auto resultETy = resultTy.getElementType();
1749 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1751 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1754 lhsTy.getShape(), rhsTy.getShape());
1755 if (isBroadcastable && lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1758 if (!lhsAttr || !rhsAttr)
1764OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1765 auto resultTy = llvm::cast<ShapedType>(
getType());
1767 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1769 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1771 if (!lhsAttr || !rhsAttr)
1777OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1778 auto resultTy = llvm::cast<ShapedType>(
getType());
1780 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1782 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1784 if (!lhsAttr || !rhsAttr)
1790OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1791 auto resultTy = llvm::cast<ShapedType>(
getType());
1793 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1795 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1796 Value
lhs = getInput1();
1797 Value
rhs = getInput2();
1798 auto lhsTy = llvm::cast<ShapedType>(
lhs.getType());
1802 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy.hasRank() &&
1803 resultTy.hasStaticShape() &&
lhs ==
rhs) {
1807 if (!lhsAttr || !rhsAttr)
1813OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1817 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1821 auto inTy = llvm::cast<ShapedType>(getInput().
getType());
1822 auto outTy = llvm::cast<ShapedType>(
getType());
1823 if (!outTy.hasRank() || !outTy.hasStaticShape())
1825 auto inETy = inTy.getElementType();
1826 auto outETy = outTy.getElementType();
1828 if (operand.isSplat()) {
1829 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1831 auto splatVal = operand.getSplatValue<APFloat>();
1832 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1833 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1838 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1839 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1840 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1841 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1842 llvm::RoundingMode::NearestTiesToEven);
1846 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1847 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1848 auto intVal = APSInt(
1849 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1850 auto floatVal = operand.getSplatValue<APFloat>();
1852 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1857 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1858 const auto inIntType = llvm::cast<IntegerType>(inETy);
1859 auto unsignIn = inIntType.isUnsignedInteger();
1861 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1862 auto intVal = operand.getSplatValue<APInt>();
1863 auto bitwidth = outETy.getIntOrFloatBitWidth();
1866 if (outETy.isInteger(1)) {
1867 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1869 intVal = intVal.trunc(bitwidth);
1870 }
else if (unsignIn || inIntType.isInteger(1)) {
1871 intVal = intVal.zext(bitwidth);
1873 intVal = intVal.sext(bitwidth);
1883OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1885OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1887#define REDUCE_FOLDER(OP) \
1888 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1889 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1890 if (!inputTy.hasRank()) \
1892 if (inputTy != getType()) \
1894 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1895 return getInput(); \
1908 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1909 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1911 if (!inputTy || !outputTy)
1917 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1921 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1922 getInput1().getDefiningOp())) {
1923 getInput1Mutable().assign(reshapeOp.getInput1());
1928 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1933 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1935 if (!outputTy.hasStaticShape())
1939 if (operand.isSplat())
1944 if (!getInput1().hasOneUse())
1951 return operand.reshape(
1952 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
1958OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
1960 if (adaptor.getPadding() && getInput1().
getType() ==
getType()) {
1961 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
1962 if (densePad && densePad.isSplat() &&
1963 densePad.getSplatValue<APInt>().isZero()) {
1973OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1975 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1977 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1979 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1980 if (!scaleAttr || !offsetAttr || !borderAttr) {
1987 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1992 if (scale[0] != scale[1] || scale[2] != scale[3]) {
1997 if (offset[0] != 0 || offset[1] != 0) {
2002 if (border[0] != 0 || border[1] != 0) {
2006 return foldToInputIfTypeMatches(
getType(), getInput());
2009OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
2010 auto operand = getInput1();
2011 auto operandTy = llvm::cast<ShapedType>(operand.getType());
2012 auto axis = getAxis();
2014 const bool isSplatInput =
2015 llvm::isa_and_nonnull<SplatElementsAttr>(adaptor.getInput1());
2016 if (!operandTy.hasRank() ||
2017 (!isSplatInput && operandTy.getDimSize(axis) != 1))
2019 return foldToInputIfTypeMatches(
getType(), operand);
2022OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
2023 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
2024 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
2026 if (!inputTy || !outputTy)
2029 if (inputTy == outputTy && inputTy.hasStaticShape())
2034 DenseElementsAttr startElems;
2040 llvm::all_of(startElems.
getValues<APInt>(),
2041 [](
const APInt &val) { return val.isZero(); });
2046 DenseElementsAttr sizeElems;
2050 auto inputShape = inputTy.getShape();
2051 auto sizeValues = sizeElems.
getValues<APInt>();
2053 bool sizeMatchesInput =
true;
2054 for (
const auto &[i, sizeVal] : llvm::enumerate(sizeValues)) {
2055 int64_t size = sizeVal.getSExtValue();
2057 if (inputTy.isDynamicDim(i)) {
2061 sizeMatchesInput =
false;
2068 sizeMatchesInput =
false;
2074 if (sizeMatchesInput)
2079 if (!adaptor.getInput1())
2083 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
2084 !outputTy.getElementType().isIntOrIndexOrFloat())
2087 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
2088 if (operand.isSplat() && outputTy.hasStaticShape()) {
2092 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
2093 outputTy.getNumElements() == 1) {
2094 llvm::SmallVector<uint64_t>
indices =
2095 llvm::to_vector(startElems.
getValues<uint64_t>());
2096 if (
auto values = operand.tryGetValues<Attribute>())
2103OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
2104 const Value pred = getPred();
2105 const Value onTrue = getOnTrue();
2106 const Value onFalse = getOnFalse();
2108 const auto predTy = llvm::dyn_cast<RankedTensorType>(pred.
getType());
2109 const auto onTrueTy = llvm::dyn_cast<RankedTensorType>(onTrue.
getType());
2110 const auto onFalseTy = llvm::dyn_cast<RankedTensorType>(onFalse.
getType());
2111 if (!predTy || !onTrueTy || !onFalseTy)
2114 const Type resultTy =
getType();
2116 const ArrayRef<int64_t> predShape = predTy.getShape();
2117 const ArrayRef<int64_t> onTrueShape = onTrueTy.getShape();
2119 if (onTrue == onFalse && onTrueTy == resultTy &&
2124 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
2127 if (!predicate.isSplat())
2130 const bool predicateValue = predicate.getSplatValue<APInt>().getBoolValue();
2132 SmallVector<SmallVector<int64_t>, 3> shapes;
2133 shapes.emplace_back(predShape);
2134 shapes.emplace_back(onTrueShape);
2135 shapes.emplace_back(onFalseTy.getShape());
2136 const bool isBroadcastable =
2139 if (predicateValue ==
true && onTrueTy == resultTy && isBroadcastable)
2141 if (predicateValue ==
false && onFalseTy == resultTy && isBroadcastable)
2146OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
2148 if (
auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
2149 adaptor.getMultiples())) {
2150 if (multiples.isSplat() &&
2151 multiples.getSplatValue<APInt>().getSExtValue() == 1)
2153 if (
auto int_array_attr =
2154 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
2155 if (llvm::all_of(int_array_attr.getValues<APInt>(),
2156 [](APInt v) { return v.getSExtValue() == 1; }))
2164OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
2165 auto resultTy = llvm::cast<ShapedType>(
getType());
2169 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
2170 if (input.isSplat() && resultTy.hasRank() && resultTy.hasStaticShape() &&
2171 input.
getType().getElementType() == resultTy.getElementType())
2172 return input.reshape(resultTy);
2176 const llvm::ArrayRef<int32_t> perms = getPerms();
2178 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
2181 return foldToInputIfTypeMatches(
getType(), getInput1());
2184OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
2187 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
2193 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2194 failed(maybeIZp) || *maybeIZp != 0) {
2198 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2199 failed(maybeOZp) || *maybeOZp != 0) {
2203 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
2204 failed(maybeIZp) || *maybeIZp != 0) {
2208 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
2209 failed(maybeOZp) || *maybeOZp != 0) {
2214 return foldToInputIfTypeMatches(
getType(), definingOp.getInput1());
2217OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
2218 auto input = getInput1();
2221 return foldToInputIfTypeMatches(
getType(), input);
2226OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
2231 SmallVector<Value, 8> concatOperands;
2232 concatOperands.reserve(2 * getNumOperands());
2235 bool foundFoldableConcat =
false;
2236 for (Value operand : getOperands()) {
2237 concatOperands.emplace_back(operand);
2239 auto producer = operand.getDefiningOp<ConcatOp>();
2244 if (getAxis() != producer.getAxis())
2248 foundFoldableConcat =
true;
2249 concatOperands.pop_back();
2250 llvm::append_range(concatOperands, producer->getOperands());
2253 if (!foundFoldableConcat)
2256 getOperation()->setOperands(concatOperands);
2260OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
2261 auto input = adaptor.getInput1();
2263 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
2265 if (!inputAttr || !inputAttr.isSplat())
2268 auto shapeType = llvm::cast<ShapedType>(
getType());
2269 if (!shapeType.hasRank() || !shapeType.hasStaticShape())
2271 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
2272 auto floatVal = inputAttr.getSplatValue<APFloat>();
2274 ReciprocalOp::calcOneElement(floatVal));
2280template <
typename Op,
typename OpFoldAdaptor>
2282 auto input1ConstShape =
2283 dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
2284 if (!input1ConstShape)
2287 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2293template <
typename Op,
typename OpFoldAdaptor>
2295 auto input1ConstShape =
2296 dyn_cast<tosa::ConstShapeOp>(op->getInput1().getDefiningOp());
2297 auto input2ConstShape =
2298 dyn_cast<tosa::ConstShapeOp>(op->getInput2().getDefiningOp());
2299 if (!input1ConstShape || !input2ConstShape)
2302 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2303 const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
2306 input1Attr.getType(),
2310OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
2311 const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().
getType());
2312 if (!inputTy || !inputTy.hasRank())
2314 const int32_t axis = getAxis();
2315 const int64_t dimSize = inputTy.getDimSize(axis);
2316 if (ShapedType::isDynamic(dimSize))
2320 const auto resultAttrTy =
2321 RankedTensorType::get(1, builder.getIndexType());
2326 auto const inputs = op->getInput();
2332 concatDims.reserve( 64);
2333 for (
auto const &v : inputs) {
2334 auto vConstShape = dyn_cast<tosa::ConstShapeOp>(v.getDefiningOp());
2338 const auto vAttr = cast<DenseElementsAttr>(vConstShape.getValues());
2341 auto const vAttrVals = vAttr.getValues<APInt>();
2342 for (
auto const &v : vAttrVals) {
2343 concatDims.push_back(v);
2347 auto *ctx = op->getContext();
2348 assert(ctx !=
nullptr &&
"ctx is nullptr");
2349 auto const rankedTy = RankedTensorType::get(
2350 {
static_cast<int64_t>(concatDims.size())}, IndexType::get(ctx));
2356 auto const input1 = op->getInput();
2357 auto const input2 = op->getStart();
2358 auto const input3 = op->getSize();
2360 auto input1ConstShape = dyn_cast<tosa::ConstShapeOp>(input1.getDefiningOp());
2362 if (!input1ConstShape)
2365 auto const input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2369 auto const input1Vals = input1Attr.getValues<APInt>();
2370 auto const totalInput1 = input1Vals.size();
2375 if (failed(start) || failed(size))
2378 auto const startV =
static_cast<int32_t
>(start.value());
2379 auto const sizeV =
static_cast<int32_t
>(size.value());
2381 if ((sizeV <= 0) || (startV < 0) ||
2382 (
static_cast<size_t>(startV + sizeV) > totalInput1))
2386 sliceOfInput.reserve(totalInput1);
2388 for (
auto i = startV; i < (startV + sizeV); i++) {
2389 sliceOfInput.push_back(input1Vals[i]);
2392 auto *ctx = op->getContext();
2393 assert(ctx !=
nullptr &&
"ctx is nullptr");
2395 auto const rankedTy = RankedTensorType::get(
2396 {
static_cast<int64_t>(sliceOfInput.size())}, IndexType::get(ctx));
2401OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
2405OpFoldResult tosa::SubShapeOp::fold(FoldAdaptor adaptor) {
2409OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
2413OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
2414 return binaryFold<DivCeilShapeOp, DivFoldAdaptor<
true>>(
this);
2417OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
2418 return binaryFold<DivFloorShapeOp, DivFoldAdaptor<
false>>(
this);
2421OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
2425OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
2429OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
2433OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
2437OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
2441OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
2445OpFoldResult tosa::ConcatShapeOp::fold(FoldAdaptor adaptor) {
2449OpFoldResult tosa::SliceShapeOp::fold(FoldAdaptor adaptor) {
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
#define REDUCE_FOLDER(OP)
OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, bool foldDenseValues=false)
static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy, bool foldDenseValues=false)
OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op)
static FailureOr< int64_t > getSingleI64From1ElementTensor(Value v)
OpFoldResult binaryFold(Op *op)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
OpFoldResult unaryShapeFold(Op *op)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool signsDiffer(const APInt &a, const APInt &b)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
DynamicAPInt round(const Fraction &f)
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::AvgPool2dAdaptiveOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp, PatternRewriter &rewriter) const override
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::MaxPool2dAdaptiveOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
bool isNarrowingCast(const ShapedType inType, const ShapedType outType) const
LogicalResult matchAndRewrite(tosa::CastOp castOp, PatternRewriter &rewriter) const override
bool supportsInf(const llvm::fltSemantics &semantics) const
bool supportsNaN(const llvm::fltSemantics &semantics) const
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...