27#include "llvm/ADT/APFloat.h"
28#include "llvm/ADT/APInt.h"
54 (padConstAttr.
size() != 1)) {
59 if (
auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
60 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
61 return padConstVal == 0.0f;
65 if (
auto padConstIntAttr =
66 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
75 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
76 return zpVal == padConstVal;
84template <
typename OpTy>
85struct PoolPadFoldAdaptor;
88struct PoolPadFoldAdaptor<
tosa::MaxPool2dOp> {
89 using OpTy = tosa::MaxPool2dOp;
90 static bool checkKernelCompliance(OpTy op,
const ArrayRef<int64_t> newPad) {
91 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
92 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
93 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
97 static bool checkPadConstCompliance(OpTy, Value padConst) {
99 DenseElementsAttr padConstAttr;
101 padConstAttr.
size() != 1) {
106 if (
auto padConstFpAttr =
107 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
108 const APFloat padConstVal = *padConstFpAttr.begin();
109 const APFloat lowestVal =
110 APFloat::getLargest(padConstVal.getSemantics(),
true);
111 return padConstVal == lowestVal;
113 if (
auto padConstIntAttr =
114 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
115 const APInt padConstVal = *padConstIntAttr.begin();
116 const unsigned int bitWidth = padConstVal.getBitWidth();
117 const APInt lowestVal =
118 padConstIntAttr.getElementType().isUnsignedInteger()
119 ? APInt::getZero(bitWidth)
120 : APInt::getSignedMinValue(bitWidth);
121 return padConstVal == lowestVal;
127 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
128 Value padInput, ArrayRef<int64_t> newPad) {
130 op, op.getType(), padInput, op.getKernel(), op.getStride(),
135template <
typename OpTy>
136struct ConvPadFoldAdaptor {
137 static bool checkKernelCompliance(OpTy,
const ArrayRef<int64_t>) {
140 static bool checkPadConstCompliance(OpTy op, Value padConst) {
143 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
144 Value padInput, ArrayRef<int64_t> newPad) {
146 op, op.getResult().
getType(), padInput, op.getWeight(), op.getBias(),
147 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
148 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
156template <
typename OpTy,
typename AdaptorTy>
158 using OpRewritePattern<OpTy>::OpRewritePattern;
160 LogicalResult matchAndRewrite(OpTy tensorOp,
161 PatternRewriter &rewriter)
const override {
163 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
166 "Producer must be a tosa::PadOp.");
169 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
170 if (tensorOpPad.size() != 4)
172 tensorOp,
"Tensor operation padding shall have 4 elements.");
175 DenseIntElementsAttr padOpPadding;
179 "The `padding` input specified on the tosa::PadOp must be constant.");
183 if (padOpPadding.size() != 8)
185 "Pad padding should have 8 elements.");
186 int64_t padNBefore = (*(padOpPadding.
begin() + 0)).getLimitedValue();
187 int64_t padNAfter = (*(padOpPadding.
begin() + 1)).getLimitedValue();
188 int64_t padHBefore = (*(padOpPadding.
begin() + 2)).getLimitedValue();
189 int64_t padHAfter = (*(padOpPadding.
begin() + 3)).getLimitedValue();
190 int64_t padWBefore = (*(padOpPadding.
begin() + 4)).getLimitedValue();
191 int64_t padWAfter = (*(padOpPadding.
begin() + 5)).getLimitedValue();
192 int64_t padCBefore = (*(padOpPadding.
begin() + 6)).getLimitedValue();
193 int64_t padCAfter = (*(padOpPadding.
begin() + 7)).getLimitedValue();
195 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
197 tensorOp,
"Folding padding in N or C dimensions is not supported.");
201 SmallVector<int64_t> foldedPad(tensorOpPad.size());
202 foldedPad[0] = padHBefore + tensorOpPad[0];
203 foldedPad[1] = padHAfter + tensorOpPad[1];
204 foldedPad[2] = padWBefore + tensorOpPad[2];
205 foldedPad[3] = padWAfter + tensorOpPad[3];
208 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
210 tensorOp,
"Padding size not aligned with kernel restrictions.");
214 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
217 "Padding constant is not aligned with operator zero-point.");
221 if (llvm::any_of(foldedPad, [](int64_t padVal) {
return padVal > 8192; })) {
223 tensorOp,
"Padding size more than the 8K level limit.");
227 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
238 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
244 results.
add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
245 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
262 op,
"expected constant kernel, stride, and pad operands");
265 rewriter, op.getLoc(), op.
getType(), op.getInput(), op.getInputZp(),
274void AvgPool2dAdaptiveOp::getCanonicalizationPatterns(
284 Value input = op.getInput();
285 Value output = op.getOutput();
286 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
287 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
289 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
295 if (outputShape[1] != 1 || outputShape[2] != 1) {
300 if (inputShape[1] != 1 || inputShape[2] != 1) {
312 FoldPadToTensorOp<tosa::MaxPool2dOp,
313 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
330 op,
"expected constant kernel, stride, and pad operands");
333 rewriter, op.getLoc(), op.
getType(), op.getInput(),
342void MaxPool2dAdaptiveOp::getCanonicalizationPatterns(
356 if (op.getInput1().size() != 1)
358 if (op.getInput1().front().getType() != op.getType()) {
361 op.getInput1().front())
366 rewriter.
replaceOp(op, op.getInput1().front());
381 concatOperands.reserve(2 * op.getNumOperands());
383 int32_t maxNumOperands = 0;
389 bool foundRewritableConcat =
false;
390 for (
Value operand : op.getOperands()) {
391 concatOperands.emplace_back(operand);
393 auto producer = operand.getDefiningOp<tosa::ConcatOp>();
398 if (op.getAxis() != producer.getAxis())
402 foundRewritableConcat =
true;
403 concatOperands.pop_back();
404 llvm::append_range(concatOperands, producer->getOperands());
407 if (!foundRewritableConcat)
409 "No rewritable concat operand found.");
411 if (maxNumOperands > 0 &&
412 concatOperands.size() >
static_cast<size_t>(maxNumOperands))
414 op,
"Rewriting would exceed the maximum number of operands for the "
415 "target environment level.");
418 op, op.getType(), concatOperands, op.getAxisAttr());
428LogicalResult SelectOp::canonicalize(SelectOp op,
PatternRewriter &rewriter) {
429 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
433 op.getOperation()->setOperands(
434 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
446 auto innerTranspose =
447 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
450 "input must be transpose operation");
454 innerTranspose.getPerms();
456 if (transposePerms.size() != innerTransposePerms.size())
459 "transpose and inner transpose perms sizes must be equal");
460 if (transposePerms.empty())
462 transposeOp,
"transpose perms sizes must be positive");
466 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
467 perms[i] = innerTransposePerms[transposePerms[i]];
470 transposeOp, transposeOp.getResult().
getType(),
483 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
485 op,
"Src is from transpose, can compose transposes");
489 if (isa_and_nonnull<tosa::TransposeOp>(subop))
491 op,
"Dest is used by transpose, can compose transposes");
494 auto input = op.getInput1();
495 auto inputTy = llvm::cast<ShapedType>(input.
getType());
496 if (!inputTy.hasRank())
500 for (
int i = 0; i < inputTy.getRank(); ++i)
501 if (inputTy.isDynamicDim(i))
510 nonZeroPerms.reserve(permValues.size());
511 for (
auto idx : permValues) {
512 auto sz = inputTy.getDimSize(idx);
514 nonZeroPerms.push_back(idx);
517 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
518 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
520 "Transpose changes memory layout.");
523 newShape.reserve(inputTy.getRank());
524 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
525 newShape.push_back(inputTy.getDimSize(permValues[i]));
528 op, op.getType(), op.getInput1(),
536 results.
add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
544 Value input = op.getInput();
545 auto inputType = llvm::cast<ShapedType>(op.getInput().getType());
546 auto inputElementType = inputType.getElementType();
548 if (isa<FloatType>(inputElementType)) {
550 const auto minClamp =
551 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
552 const auto maxClamp =
553 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
554 const bool isMin = minClamp.isNegInfinity();
555 const bool isMax = maxClamp.isInfinity();
557 if (isMin && isMax) {
565 const bool isBoolean = inputElementType.isInteger(1);
566 if (inputElementType.isUnsignedInteger() || isBoolean) {
567 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
570 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
574 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
575 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
576 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
578 if (minClamp <= intMin && maxClamp >= intMax) {
585 if (llvm::isa<IntegerType>(inputElementType)) {
587 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
589 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
591 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
592 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
593 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
595 if (minClamp <= intMin && maxClamp >= intMax) {
627 template <
typename T>
641 Value input = op.getInput();
649 const auto opNanMode = op.getNanMode();
650 const auto clampNanMode = clampOp.getNanMode();
651 if (opNanMode == NanPropagationMode::IGNORE &&
652 clampNanMode == NanPropagationMode::PROPAGATE)
655 auto maxValAttr = op.getMaxValAttr();
656 auto minValAttr = op.getMinValAttr();
657 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
658 auto clampOpMinValAttr = clampOp.getMinValAttr();
660 auto inputEType = llvm::cast<ShapedType>(input.
getType()).getElementType();
662 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
667 if (mlir::isa<FloatType>(inputEType)) {
668 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
669 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
670 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
671 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
674 const auto opMinFloat = floatMinValAttr.getValue();
675 const auto opMaxFloat = floatMaxValAttr.getValue();
676 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
677 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
681 if (!opRangeFloatRange.
intersects(clampRangeFloatRange))
685 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
686 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
687 newMinValAttr = rewriter.
getFloatAttr(inputEType, newMinVal);
688 newMaxValAttr = rewriter.
getFloatAttr(inputEType, newMaxVal);
690 assert(mlir::isa<IntegerType>(inputEType));
691 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
692 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
693 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
694 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
696 if (inputEType.isUnsignedInteger()) {
698 const auto opMinInt = intMinValAttr.getUInt();
699 const auto opMaxInt = intMaxValAttr.getUInt();
700 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
701 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
705 if (!opRangeIntRange.
intersects(clampRangeIntRange))
709 auto newMinVal = std::max(opMinInt, clampOpMinInt);
710 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
715 const auto opMinInt = intMinValAttr.getInt();
716 const auto opMaxInt = intMaxValAttr.getInt();
717 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
718 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
722 if (!opRangeIntRange.
intersects(clampRangeIntRange))
726 auto newMinVal = std::max(opMinInt, clampOpMinInt);
727 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
733 auto newMode = (opNanMode != clampNanMode)
734 ? tosa::NanPropagationMode::IGNORE
738 NanPropagationModeAttr::get(rewriter.
getContext(), newMode);
741 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
747void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
748 MLIRContext *context) {
749 results.
add<ClampIsNoOp>(context);
750 results.
add<ClampClampOptimization>(context);
758 Value sliceInput = sliceOp.getInput1();
762 sliceOp,
"slice input must be concat operation");
765 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
766 if (!concatType || !concatType.hasStaticShape())
768 sliceOp,
"slice input must be a static ranked tensor");
769 int32_t axis = concatOp.getAxis();
776 sliceOp,
"start of slice must be a static ranked shape");
780 sliceOp,
"size of slice must be a static ranked shape");
790 std::optional<Value> replaceWithSlice;
791 for (
auto input : inputs) {
792 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
793 if (!inputType || !inputType.hasStaticShape())
795 sliceOp,
"concat input must be a static ranked tensor");
797 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
798 inputType.getDimSize(axis)) {
804 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
805 input, start_op, size_op)
809 sliceStarts[axis] -= inputType.getDimSize(axis);
812 if (!replaceWithSlice)
814 sliceOp,
"corresponding concat input not found for slice");
816 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
826 Value sliceInput = sliceOp.getInput1();
832 "slice input must be a pad operation");
835 if (!padOp->hasOneUse())
837 "pad shall have a single consumer");
840 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
841 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
842 if (!inputTy || !padTy || !inputTy.hasRank())
844 "slice input must be a ranked tensor");
851 "`padding` input specified on the tosa::PadOp must be constant.");
854 llvm::to_vector(paddingElems.getValues<
int64_t>());
860 sliceOp,
"start of slice must be a static ranked shape");
867 sliceOp,
"size of slice must be a static ranked shape");
872 const int64_t rank = inputTy.getRank();
873 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](
int64_t i) {
874 const bool isDimDynamic = inputTy.isDynamicDim(i);
875 const bool isDimSliced =
878 return isDimDynamic && isDimSliced;
881 sliceOp,
"axis that are sliced shall be statically known.");
888 bool updated =
false;
890 for (
int64_t i = 0; i < rank; ++i) {
891 const int64_t padLo = padPaddings[i * 2];
892 const int64_t padHi = padPaddings[i * 2 + 1];
893 const int64_t sliceStart = sliceStarts[i];
894 const int64_t sliceSize = sliceSizes[i];
895 const int64_t sliceEnd = sliceStart + sliceSize;
898 if (inputTy.isDynamicDim(i)) {
899 newPadPaddings[i * 2] = padLo;
900 newPadPaddings[i * 2 + 1] = padHi;
901 newSliceStarts[i] = sliceStart;
906 const int64_t dimSize = inputTy.getShape()[i];
907 const int64_t dimTotal = padLo + dimSize + padHi;
910 if (sliceStart < 0 || sliceEnd > dimTotal)
914 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
915 newSliceStarts[i] = newSliceStart;
916 updated |= newSliceStart != sliceStart;
919 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
921 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
922 newPadPaddings[i * 2] = newPadLo;
923 newPadPaddings[i * 2 + 1] = newPadHi;
924 updated |= (newPadLo != padLo) || (newPadHi != padHi);
928 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
934 sliceOp,
"terminate condition; nothing to rewrite");
940 RankedTensorType::get(newPadShape, inputTy.getElementType());
941 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
942 padOp.getInput1(), newPaddingsOp,
943 padOp.getPadConst());
949 newPadOp.getResult(), newStartOp,
964 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
965 if (!resultType.hasRank())
968 ElementsAttr sizeElems;
971 sliceOp,
"size of slice must be a static ranked shape");
975 llvm::to_vector(sizeElems.getValues<
int64_t>());
977 bool replaceSliceSize{
false};
981 for (
const auto &[
index, size] : llvm::enumerate(sliceSizes)) {
983 sliceSizes[
index] = resultType.getDimSize(
index);
984 replaceSliceSize =
true;
988 if (!replaceSliceSize) {
990 sliceOp,
"no dimension of size of slice is dynamic that resolves "
991 "to static output shape");
996 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
997 sliceOp.getInput1(), sliceOp.getStart(), size_op);
999 rewriter.
replaceOp(sliceOp, newSliceOp.getResult());
1004void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1005 MLIRContext *context) {
1006 results.
add<ConcatSliceOptimization, PadSliceOptimization,
1007 SliceDynamicSizeCanonicalization>(context);
1015 const Value castInput = castOp.getInput();
1019 "input must be cast operation");
1021 const Value innerCastInput = innerCastOp.getInput();
1023 const ShapedType innerInputType =
1024 llvm::cast<ShapedType>(innerCastInput.
getType());
1025 const ShapedType innerOutputType =
1026 llvm::cast<ShapedType>(innerCastOp.getType());
1027 const ShapedType outerOutputType = llvm::cast<ShapedType>(castOp.getType());
1029 const Type innerInputElemType = innerInputType.getElementType();
1030 const Type innerOutputElemType = innerOutputType.getElementType();
1031 const Type outerOutputElemType = outerOutputType.getElementType();
1034 outerOutputElemType};
1036 if (llvm::any_of(types, [](
const Type type) {
1040 llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
1041 Float16Type, Float32Type>(type));
1044 castOp,
"only integer and f32, f16, bf16, f8E4M3FN, f8E5M2 types are "
1047 if (llvm::isa<Float8E5M2Type>(innerInputElemType) &&
1048 llvm::isa<Float8E4M3FNType>(outerOutputElemType)) {
1050 castOp,
"avoid introducing f8E5M2 -> f8E4M3FN casts which are not "
1054 if (llvm::isa<Float8E4M3FNType>(innerInputElemType) &&
1055 llvm::isa<Float8E5M2Type>(outerOutputElemType)) {
1057 castOp,
"avoid introducing f8E4M3FN -> f8E5M2 casts which are not "
1061 if (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(innerInputElemType) &&
1064 castOp,
"avoid introducing fp8 -> integer casts which are not "
1069 llvm::isa<Float8E5M2Type, Float8E4M3FNType>(outerOutputElemType)) {
1071 castOp,
"avoid introducing integer -> fp8 casts which are not "
1075 if (llvm::isa<Float16Type>(innerInputElemType) &&
1076 llvm::isa<BFloat16Type>(outerOutputElemType)) {
1078 castOp,
"avoid introducing fp16 -> bf16 casts which are not "
1082 if (llvm::isa<BFloat16Type>(innerInputElemType) &&
1083 llvm::isa<Float16Type>(outerOutputElemType)) {
1085 castOp,
"avoid introducing bf16 -> fp16 casts which are not "
1089 const auto isIntegerOneOfWidth = [](
Type type,
size_t bitwidth1,
1094 if (isIntegerOneOfWidth(innerInputElemType, 8, 16) &&
1097 castOp,
"avoid introducing i8/i16 -> i64 casts which are not "
1101 if (isIntegerOneOfWidth(innerInputElemType, 1, 64) &&
1104 castOp,
"avoid introducing bool/i64 to float casts which are not "
1105 "supported in all versions of TOSA");
1109 isIntegerOneOfWidth(outerOutputElemType, 1, 64)) {
1111 castOp,
"avoid introducing float to bool/i64 casts which are not "
1112 "supported in all versions of TOSA");
1118 "inner cast operation is narrowing");
1127 return semantics.nonFiniteBehavior !=
1128 llvm::fltNonfiniteBehavior::FiniteOnly;
1132 return semantics.nonFiniteBehavior == llvm::fltNonfiniteBehavior::IEEE754;
1136 const ShapedType outType)
const {
1138 if (inType.getElementType().isInteger() &&
1139 outType.getElementType().isInteger()) {
1141 const auto inTypeSignedness =
1142 cast<IntegerType>(inType.getElementType()).getSignedness();
1143 const auto outTypeSignedness =
1144 cast<IntegerType>(outType.getElementType()).getSignedness();
1146 return (inTypeSignedness != outTypeSignedness ||
1147 inType.getElementTypeBitWidth() >
1148 outType.getElementTypeBitWidth());
1151 if (inType.getElementType().isFloat() &&
1152 outType.getElementType().isFloat()) {
1154 FloatType inElemTy = cast<FloatType>(inType.getElementType());
1155 FloatType outElemTy = cast<FloatType>(outType.getElementType());
1156 llvm::fltSemantics inTypeSemantics = inElemTy.getFloatSemantics();
1157 llvm::fltSemantics outTypeSemantics = outElemTy.getFloatSemantics();
1163 [[maybe_unused]]
const auto isSupported = [](
Type elemType) {
1164 return llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
1165 Float16Type, Float32Type>(elemType);
1168 assert(isSupported(inElemTy) &&
1169 "unsupported input element type in isNarrowingCast");
1170 assert(isSupported(outElemTy) &&
1171 "unsupported output element type in isNarrowingCast");
1174 inTypeSemantics.maxExponent > outTypeSemantics.maxExponent ||
1175 inTypeSemantics.minExponent < outTypeSemantics.minExponent ||
1176 inTypeSemantics.precision > outTypeSemantics.precision ||
1187void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1188 MLIRContext *context) {
1189 results.
add<NonNarrowingCastsOptimization>(context);
1198 const Value castToBlockScaledInput = castToBlockScaledOp.getInputData();
1199 auto castFromBlockScaledOp =
1200 castToBlockScaledInput.
getDefiningOp<tosa::CastFromBlockScaledOp>();
1201 if (!castFromBlockScaledOp)
1203 castToBlockScaledOp,
1204 "input must be cast_from_block_scaled operation");
1206 const Value innerData = castFromBlockScaledOp.getInputData();
1207 const Value innerScale = castFromBlockScaledOp.getInputScale();
1208 const auto innerDataTy = llvm::cast<ShapedType>(innerData.
getType());
1209 const auto innerScaleTy = llvm::cast<ShapedType>(innerScale.
getType());
1211 const Value outerData = castToBlockScaledOp.getOutputData();
1212 const Value outerScale = castToBlockScaledOp.getOutputScale();
1213 const auto outerDataTy = llvm::cast<ShapedType>(outerData.
getType());
1214 const auto outerScaleTy = llvm::cast<ShapedType>(outerScale.
getType());
1216 if (innerDataTy != outerDataTy || innerScaleTy != outerScaleTy) {
1218 castToBlockScaledOp,
1219 "inputs types to cast_from_block_scaled operation must match output "
1220 "types to cast_to_block_scaled");
1223 if (castFromBlockScaledOp.getBlockSize() !=
1224 castToBlockScaledOp.getBlockSize()) {
1226 castToBlockScaledOp,
"block sizes for cast_from_block_scaled and "
1227 "cast_to_block_scaled must match");
1230 rewriter.
replaceOp(castToBlockScaledOp, {innerData, innerScale});
1236void CastToBlockScaledOp::getCanonicalizationPatterns(
1237 RewritePatternSet &results, MLIRContext *context) {
1238 results.
add<CancellingBlockScaledCastsOptimization>(context);
1246 const FailureOr<int32_t> rowCount =
1248 if (failed(rowCount) || rowCount.value() != 1)
1252 op, op.getOutput().
getType(), op.getValues(), op.getIndices());
1257void RowGatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
1258 MLIRContext *context) {
1259 results.
add<RowGatherToGather>(context);
1266template <
typename Folder>
1267static DenseElementsAttr
1269 bool foldDenseValues =
false) {
1273 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1277 const auto rETy = llvm::cast<ShapedType>(
rhs.getType()).getElementType();
1281 if (
lhs.isSplat() &&
rhs.isSplat()) {
1282 if (isa<FloatType>(lETy)) {
1283 const APFloat l =
lhs.getSplatValue<APFloat>();
1284 const APFloat r =
rhs.getSplatValue<APFloat>();
1285 const auto maybeResult = Folder::fold(l, r);
1286 if (failed(maybeResult))
1291 if (
const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
1292 const APInt l =
lhs.getSplatValue<APInt>();
1293 const APInt r =
rhs.getSplatValue<APInt>();
1294 const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
1295 if (failed(maybeResult))
1301 if (foldDenseValues) {
1302 assert(lETy.isIntOrIndex() &&
1303 "Only integer types are currently supported.");
1306 llvm::zip(
lhs.getValues<APInt>(),
rhs.getValues<APInt>())) {
1307 const auto maybeResult = Folder::fold(l, r,
false);
1308 if (failed(maybeResult))
1310 resultValues.push_back(maybeResult.value());
1318template <
typename Folder>
1320 bool foldDenseValues =
false) {
1324 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1330 if (
const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
1332 const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
1333 if (failed(maybeResult))
1339 if (foldDenseValues) {
1343 for (
auto const &v : val.
getValues<APInt>()) {
1344 const auto maybeResult = Folder::fold(v,
false);
1345 if (failed(maybeResult))
1347 resultValues.push_back(maybeResult.value());
1362 assert(dense.isSplat());
1363 APInt a = dense.getSplatValue<APInt>();
1364 return a.getSExtValue();
1369 const bool isUnsigned) {
1372 isUnsigned ?
lhs.uadd_ov(
rhs, overflow) :
lhs.sadd_ov(
rhs, overflow);
1378 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1385 const bool isUnsigned) {
1388 isUnsigned ?
lhs.usub_ov(
rhs, overflow) :
lhs.ssub_ov(
rhs, overflow);
1394 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1401 const bool isUnsigned) {
1403 const unsigned originalWidth =
lhs.getBitWidth();
1406 if (
lhs.getBitWidth() !=
rhs.getBitWidth()) {
1411 if (
lhs == 0 ||
rhs == 0)
1412 return APInt::getZero(originalWidth);
1414 bool overflow =
false;
1416 isUnsigned ?
lhs.umul_ov(
rhs, overflow) :
lhs.smul_ov(
rhs, overflow);
1421 return result.trunc(originalWidth);
1424 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1430 return a.isNegative() !=
b.isNegative();
1437 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1445 APInt::udivrem(
lhs,
rhs, q, r);
1446 if (!r.isZero() && Ceil) {
1453 bool overflow{
false};
1454 APInt
const q =
lhs.sdiv_ov(
rhs, overflow);
1457 APInt
const r =
lhs.srem(
rhs);
1467 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1475 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1477 if (
lhs.isNegative() || (!
rhs.isStrictlyPositive()))
1487 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1489 auto const r = t.mod(
rhs);
1490 if (llvm::APFloatBase::opStatus::opOK == r) {
1500 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1502 return lhs.getSExtValue() >=
rhs.getSExtValue() ?
lhs :
rhs;
1505 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1513 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1515 return lhs.getSExtValue() <=
rhs.getSExtValue() ?
lhs :
rhs;
1518 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1524 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1525 auto const numBits = value.getBitWidth();
1527 auto const zextv = value.getZExtValue();
1528 if (zextv >= numBits)
1530 return APInt::getOneBitSet(numBits, zextv);
1532 auto const sextv = value.getSExtValue();
1533 if (sextv < 0 || sextv >= numBits || (value.isNegative()))
1535 return APInt::getOneBitSet(numBits, sextv);
1545 assert(!isUnsigned &&
1546 "unsigned values are not supported for shape div folders");
1547 if (
lhs.isNegative() || !
rhs.isStrictlyPositive())
1552 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1558 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1559 if (!value.isStrictlyPositive())
1561 return APInt(value.getBitWidth(), value.ceilLogBase2());
1566 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1567 if (!value.isStrictlyPositive())
1569 return APInt(value.getBitWidth(), value.logBase2());
1575 const bool isUnsigned) {
1576 return isUnsigned ? APInt(1,
lhs.ugt(
rhs)) : APInt(1,
lhs.sgt(
rhs));
1579 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1580 return APInt(1,
lhs >
rhs);
1586 const bool isUnsigned) {
1587 return isUnsigned ? APInt(1,
lhs.uge(
rhs)) : APInt(1,
lhs.sge(
rhs));
1590 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1591 return APInt(1,
lhs >=
rhs);
1597 const bool isUnsigned) {
1598 return APInt(1,
lhs ==
rhs);
1601 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1602 return APInt(1,
lhs ==
rhs);
1607 if (llvm::isa<FloatType>(elemType))
1609 if (llvm::isa<IntegerType>(elemType))
1615 if (llvm::isa<FloatType>(elemType))
1616 return val && val.
isSplat() &&
1618 if (llvm::isa<IntegerType>(elemType)) {
1619 const int64_t shifted = 1LL << shift;
1620 return val && val.
isSplat() &&
1626OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1627 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1628 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1629 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1630 if (!lhsTy || !rhsTy || !resultTy)
1634 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1635 !rhsTy.getElementType().isIntOrIndexOrFloat())
1638 auto resultETy = resultTy.getElementType();
1640 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1642 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1645 lhsTy.getShape(), rhsTy.getShape());
1646 if (isBroadcastable && lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1648 if (isBroadcastable && rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
1651 if (!lhsAttr || !rhsAttr)
1657OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1658 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
1659 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1660 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1661 !outputTy.hasStaticShape())
1665 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.
isInteger()) {
1666 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1667 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1674OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1675 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1676 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1677 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1678 if (!lhsTy || !rhsTy || !resultTy)
1680 if (lhsTy.getElementType() != rhsTy.getElementType())
1685 auto resultETy = resultTy.getElementType();
1687 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1689 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1690 if (lhsAttr && lhsAttr.isSplat() && rhsAttr && rhsAttr.isSplat()) {
1691 if (llvm::isa<IntegerType>(resultETy) && resultTy.hasStaticShape() &&
1692 lhsAttr.getSplatValue<APInt>().isZero() &&
1693 !rhsAttr.getSplatValue<APInt>().isZero()) {
1694 return lhsAttr.resizeSplat(resultTy);
1698 if (rhsAttr && rhsAttr.isSplat()) {
1700 lhsTy.getShape(), rhsTy.getShape());
1701 if (isBroadcastable && lhsTy == resultTy &&
1702 llvm::isa<IntegerType>(resultETy) &&
1703 rhsAttr.getSplatValue<APInt>().isOne())
1707 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1708 llvm::isa<IntegerType>(resultETy)) {
1709 APInt l = lhsAttr.getSplatValue<APInt>();
1710 APInt r = rhsAttr.getSplatValue<APInt>();
1712 auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
1714 DivFoldAdaptor<
false>::fold(l, r, intTy.isUnsigned());
1727std::optional<APInt> mulInt(APInt
lhs, APInt
rhs, int32_t shift,
1728 unsigned bitwidth) {
1729 bool overflow =
false;
1730 APInt
result =
lhs.sext(64).smul_ov(
rhs.sext(64), overflow);
1733 return std::nullopt;
1736 auto round = APInt(64, 1) << (shift - 1);
1738 result.ashrInPlace(shift);
1741 if (!(
result.getSExtValue() >= INT32_MIN &&
1742 result.getSExtValue() <= INT32_MAX)) {
1744 return std::nullopt;
1748 return result.trunc(bitwidth);
1751DenseElementsAttr mulBinaryFolder(DenseElementsAttr
lhs, DenseElementsAttr
rhs,
1752 RankedTensorType ty, int32_t shift) {
1754 if (llvm::isa<IntegerType>(ty.getElementType())) {
1755 APInt l =
lhs.getSplatValue<APInt>();
1756 APInt r =
rhs.getSplatValue<APInt>();
1762 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1763 const std::optional<APInt>
result = mulInt(l, r, shift, bitwidth);
1769 if (llvm::isa<FloatType>(ty.getElementType())) {
1770 APFloat l =
lhs.getSplatValue<APFloat>();
1771 APFloat r =
rhs.getSplatValue<APFloat>();
1781OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1782 auto lhs = getInput1();
1783 auto rhs = getInput2();
1784 auto lhsTy = llvm::dyn_cast<RankedTensorType>(
lhs.getType());
1785 auto rhsTy = llvm::dyn_cast<RankedTensorType>(
rhs.getType());
1786 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1787 if (!lhsTy || !rhsTy || !resultTy)
1790 auto resultETy = resultTy.getElementType();
1792 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1794 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1799 if (resultETy.isInteger(32)) {
1800 ElementsAttr shift_elem;
1801 if (getShift().getImpl()) {
1805 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1809 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr) &&
1810 resultTy.hasStaticShape())
1812 return lhsAttr.resizeSplat(resultTy);
1813 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr) &&
1814 resultTy.hasStaticShape())
1815 return rhsAttr.resizeSplat(resultTy);
1818 lhsTy.getShape(), rhsTy.getShape());
1819 if (isBroadcastable && rhsTy == resultTy &&
1822 if (isBroadcastable && lhsTy == resultTy &&
1826 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1829OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1830 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1831 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1832 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1833 if (!lhsTy || !rhsTy || !resultTy)
1837 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1838 !rhsTy.getElementType().isIntOrIndexOrFloat())
1841 auto resultETy = resultTy.getElementType();
1843 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1845 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1848 lhsTy.getShape(), rhsTy.getShape());
1849 if (isBroadcastable && lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1852 if (!lhsAttr || !rhsAttr)
1858OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1859 auto resultTy = llvm::cast<ShapedType>(
getType());
1861 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1863 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1865 if (!lhsAttr || !rhsAttr)
1871OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1872 auto resultTy = llvm::cast<ShapedType>(
getType());
1874 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1876 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1878 if (!lhsAttr || !rhsAttr)
1884OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1885 auto resultTy = llvm::cast<ShapedType>(
getType());
1887 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1889 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1890 Value
lhs = getInput1();
1891 Value
rhs = getInput2();
1892 auto lhsTy = llvm::cast<ShapedType>(
lhs.getType());
1896 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy.hasRank() &&
1897 resultTy.hasStaticShape() &&
lhs ==
rhs) {
1901 if (!lhsAttr || !rhsAttr)
1907OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1911 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1915 auto inTy = llvm::cast<ShapedType>(getInput().
getType());
1916 auto outTy = llvm::cast<ShapedType>(
getType());
1917 if (!outTy.hasRank() || !outTy.hasStaticShape())
1919 auto inETy = inTy.getElementType();
1920 auto outETy = outTy.getElementType();
1922 if (operand.isSplat()) {
1923 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1925 auto splatVal = operand.getSplatValue<APFloat>();
1926 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1927 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1932 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1933 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1934 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1935 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1936 llvm::RoundingMode::NearestTiesToEven);
1940 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1941 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1942 auto intVal = APSInt(
1943 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1944 auto floatVal = operand.getSplatValue<APFloat>();
1946 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1951 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1952 const auto inIntType = llvm::cast<IntegerType>(inETy);
1953 auto unsignIn = inIntType.isUnsignedInteger();
1955 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1956 auto intVal = operand.getSplatValue<APInt>();
1957 auto bitwidth = outETy.getIntOrFloatBitWidth();
1960 if (outETy.isInteger(1)) {
1961 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1963 intVal = intVal.trunc(bitwidth);
1964 }
else if (unsignIn || inIntType.isInteger(1)) {
1965 intVal = intVal.zext(bitwidth);
1967 intVal = intVal.sext(bitwidth);
1977OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1979OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1981#define REDUCE_FOLDER(OP) \
1982 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1983 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1984 if (!inputTy.hasRank()) \
1986 if (inputTy != getType()) \
1988 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1989 return getInput(); \
2002 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
2003 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
2005 if (!inputTy || !outputTy)
2011 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
2015 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
2016 getInput1().getDefiningOp())) {
2017 getInput1Mutable().assign(reshapeOp.getInput1());
2022 if (!inputTy.getElementType().isIntOrIndexOrFloat())
2027 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
2029 if (!outputTy.hasStaticShape())
2033 if (operand.isSplat())
2038 if (!getInput1().hasOneUse())
2045 return operand.reshape(
2046 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
2052OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
2054 if (adaptor.getPadding() && getInput1().
getType() ==
getType()) {
2055 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
2056 if (densePad && densePad.isSplat() &&
2057 densePad.getSplatValue<APInt>().isZero()) {
2067OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
2069 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
2071 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
2073 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
2074 if (!scaleAttr || !offsetAttr || !borderAttr) {
2081 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
2086 if (scale[0] != scale[1] || scale[2] != scale[3]) {
2091 if (offset[0] != 0 || offset[1] != 0) {
2096 if (border[0] != 0 || border[1] != 0) {
2100 return foldToInputIfTypeMatches(
getType(), getInput());
2103OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
2104 auto operand = getInput1();
2105 auto operandTy = llvm::cast<ShapedType>(operand.getType());
2106 auto axis = getAxis();
2108 const bool isSplatInput =
2109 llvm::isa_and_nonnull<SplatElementsAttr>(adaptor.getInput1());
2110 if (!operandTy.hasRank() ||
2111 (!isSplatInput && operandTy.getDimSize(axis) != 1))
2113 return foldToInputIfTypeMatches(
getType(), operand);
2116OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
2117 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
2118 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
2120 if (!inputTy || !outputTy)
2123 if (inputTy == outputTy && inputTy.hasStaticShape())
2128 DenseElementsAttr startElems;
2134 llvm::all_of(startElems.
getValues<APInt>(),
2135 [](
const APInt &val) { return val.isZero(); });
2140 DenseElementsAttr sizeElems;
2144 auto inputShape = inputTy.getShape();
2145 auto sizeValues = sizeElems.
getValues<APInt>();
2147 bool sizeMatchesInput =
true;
2148 for (
const auto &[i, sizeVal] : llvm::enumerate(sizeValues)) {
2149 int64_t size = sizeVal.getSExtValue();
2151 if (inputTy.isDynamicDim(i)) {
2155 sizeMatchesInput =
false;
2162 sizeMatchesInput =
false;
2168 if (sizeMatchesInput)
2173 if (!adaptor.getInput1())
2177 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
2178 !outputTy.getElementType().isIntOrIndexOrFloat())
2181 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
2182 if (operand.isSplat() && outputTy.hasStaticShape()) {
2186 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
2187 outputTy.getNumElements() == 1) {
2188 llvm::SmallVector<uint64_t>
indices =
2189 llvm::to_vector(startElems.
getValues<uint64_t>());
2190 if (
auto values = operand.tryGetValues<Attribute>())
2197OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
2198 const Value pred = getPred();
2199 const Value onTrue = getOnTrue();
2200 const Value onFalse = getOnFalse();
2202 const auto predTy = llvm::dyn_cast<RankedTensorType>(pred.
getType());
2203 const auto onTrueTy = llvm::dyn_cast<RankedTensorType>(onTrue.
getType());
2204 const auto onFalseTy = llvm::dyn_cast<RankedTensorType>(onFalse.
getType());
2205 if (!predTy || !onTrueTy || !onFalseTy)
2208 const Type resultTy =
getType();
2210 const ArrayRef<int64_t> predShape = predTy.getShape();
2211 const ArrayRef<int64_t> onTrueShape = onTrueTy.getShape();
2213 if (onTrue == onFalse && onTrueTy == resultTy &&
2218 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
2221 if (!predicate.isSplat())
2224 const bool predicateValue = predicate.getSplatValue<APInt>().getBoolValue();
2226 SmallVector<SmallVector<int64_t>, 3> shapes;
2227 shapes.emplace_back(predShape);
2228 shapes.emplace_back(onTrueShape);
2229 shapes.emplace_back(onFalseTy.getShape());
2230 const bool isBroadcastable =
2233 if (predicateValue ==
true && onTrueTy == resultTy && isBroadcastable)
2235 if (predicateValue ==
false && onFalseTy == resultTy && isBroadcastable)
2240OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
2242 if (
auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
2243 adaptor.getMultiples())) {
2244 if (multiples.isSplat() &&
2245 multiples.getSplatValue<APInt>().getSExtValue() == 1)
2247 if (
auto int_array_attr =
2248 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
2249 if (llvm::all_of(int_array_attr.getValues<APInt>(),
2250 [](APInt v) { return v.getSExtValue() == 1; }))
2258OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
2259 auto resultTy = llvm::cast<ShapedType>(
getType());
2263 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
2264 if (input.isSplat() && resultTy.hasRank() && resultTy.hasStaticShape() &&
2265 input.
getType().getElementType() == resultTy.getElementType())
2266 return input.reshape(resultTy);
2270 const llvm::ArrayRef<int32_t> perms = getPerms();
2272 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
2275 return foldToInputIfTypeMatches(
getType(), getInput1());
2278OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
2281 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
2287 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2288 failed(maybeIZp) || *maybeIZp != 0) {
2292 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2293 failed(maybeOZp) || *maybeOZp != 0) {
2297 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
2298 failed(maybeIZp) || *maybeIZp != 0) {
2302 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
2303 failed(maybeOZp) || *maybeOZp != 0) {
2308 return foldToInputIfTypeMatches(
getType(), definingOp.getInput1());
2311OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
2312 auto input = getInput1();
2315 return foldToInputIfTypeMatches(
getType(), input);
2320OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
2321 auto input = adaptor.getInput1();
2323 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
2325 if (!inputAttr || !inputAttr.isSplat())
2328 auto shapeType = llvm::cast<ShapedType>(
getType());
2329 if (!shapeType.hasRank() || !shapeType.hasStaticShape())
2331 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
2332 auto floatVal = inputAttr.getSplatValue<APFloat>();
2334 ReciprocalOp::calcOneElement(floatVal));
2340template <
typename Op,
typename OpFoldAdaptor>
2342 auto input1ConstShape =
2343 dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
2344 if (!input1ConstShape)
2347 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2353template <
typename Op,
typename OpFoldAdaptor>
2355 auto input1ConstShape =
2356 dyn_cast<tosa::ConstShapeOp>(op->getInput1().getDefiningOp());
2357 auto input2ConstShape =
2358 dyn_cast<tosa::ConstShapeOp>(op->getInput2().getDefiningOp());
2359 if (!input1ConstShape || !input2ConstShape)
2362 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2363 const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
2366 input1Attr.getType(),
2370OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
2371 const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().
getType());
2372 if (!inputTy || !inputTy.hasRank())
2374 const int32_t axis = getAxis();
2375 const int64_t dimSize = inputTy.getDimSize(axis);
2376 if (ShapedType::isDynamic(dimSize))
2380 const auto resultAttrTy =
2381 RankedTensorType::get(1, builder.getIndexType());
2386 auto const inputs = op->getInput();
2392 concatDims.reserve( 64);
2393 for (
auto const &v : inputs) {
2394 auto vConstShape = dyn_cast<tosa::ConstShapeOp>(v.getDefiningOp());
2398 const auto vAttr = cast<DenseElementsAttr>(vConstShape.getValues());
2401 auto const vAttrVals = vAttr.getValues<APInt>();
2402 for (
auto const &v : vAttrVals) {
2403 concatDims.push_back(v);
2407 auto *ctx = op->getContext();
2408 assert(ctx !=
nullptr &&
"ctx is nullptr");
2409 auto const rankedTy = RankedTensorType::get(
2410 {
static_cast<int64_t>(concatDims.size())}, IndexType::get(ctx));
2416 auto const input1 = op->getInput();
2417 auto const input2 = op->getStart();
2418 auto const input3 = op->getSize();
2420 auto input1ConstShape = dyn_cast<tosa::ConstShapeOp>(input1.getDefiningOp());
2422 if (!input1ConstShape)
2425 auto const input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2429 auto const input1Vals = input1Attr.getValues<APInt>();
2430 auto const totalInput1 = input1Vals.size();
2435 if (failed(start) || failed(size))
2438 auto const startV =
static_cast<int32_t
>(start.value());
2439 auto const sizeV =
static_cast<int32_t
>(size.value());
2441 if ((sizeV <= 0) || (startV < 0) ||
2442 (
static_cast<size_t>(startV + sizeV) > totalInput1))
2446 sliceOfInput.reserve(totalInput1);
2448 for (
auto i = startV; i < (startV + sizeV); i++) {
2449 sliceOfInput.push_back(input1Vals[i]);
2452 auto *ctx = op->getContext();
2453 assert(ctx !=
nullptr &&
"ctx is nullptr");
2455 auto const rankedTy = RankedTensorType::get(
2456 {
static_cast<int64_t>(sliceOfInput.size())}, IndexType::get(ctx));
2461OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
2465OpFoldResult tosa::SubShapeOp::fold(FoldAdaptor adaptor) {
2469OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
2473OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
2474 return binaryFold<DivCeilShapeOp, ShapeDivFoldAdaptor<
true>>(
this);
2477OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
2478 return binaryFold<DivFloorShapeOp, ShapeDivFoldAdaptor<
false>>(
this);
2481OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
2485OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
2489OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
2493OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
2497OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
2501OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
2505OpFoldResult tosa::ConcatShapeOp::fold(FoldAdaptor adaptor) {
2509OpFoldResult tosa::SliceShapeOp::fold(FoldAdaptor adaptor) {
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
#define REDUCE_FOLDER(OP)
OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, bool foldDenseValues=false)
static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy, bool foldDenseValues=false)
OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op)
static FailureOr< int64_t > getSingleI64From1ElementTensor(Value v)
OpFoldResult binaryFold(Op *op)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
OpFoldResult unaryShapeFold(Op *op)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool signsDiffer(const APInt &a, const APInt &b)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
DynamicAPInt round(const Fraction &f)
TosaLevel getTosaLevelFromEnum(const Level level)
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
TargetEnvAttr lookupTargetEnv(Operation *op)
FailureOr< T > getConstantScalarIntValue(Value val)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::AvgPool2dAdaptiveOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp, PatternRewriter &rewriter) const override
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::MaxPool2dAdaptiveOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
bool isNarrowingCast(const ShapedType inType, const ShapedType outType) const
LogicalResult matchAndRewrite(tosa::CastOp castOp, PatternRewriter &rewriter) const override
bool supportsInf(const llvm::fltSemantics &semantics) const
bool supportsNaN(const llvm::fltSemantics &semantics) const
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::RowGatherOp op, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
int32_t MAX_TENSOR_LIST_SIZE