27#include "llvm/ADT/APFloat.h"
28#include "llvm/ADT/APInt.h"
54 (padConstAttr.
size() != 1)) {
59 if (
auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
60 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
61 return padConstVal == 0.0f;
65 if (
auto padConstIntAttr =
66 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
75 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
76 return zpVal == padConstVal;
84template <
typename OpTy>
85struct PoolPadFoldAdaptor;
88struct PoolPadFoldAdaptor<
tosa::MaxPool2dOp> {
89 using OpTy = tosa::MaxPool2dOp;
90 static bool checkKernelCompliance(OpTy op,
const ArrayRef<int64_t> newPad) {
91 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
92 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
93 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
97 static bool checkPadConstCompliance(OpTy, Value padConst) {
99 DenseElementsAttr padConstAttr;
101 padConstAttr.
size() != 1) {
106 if (
auto padConstFpAttr =
107 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
108 const APFloat padConstVal = *padConstFpAttr.begin();
109 const APFloat lowestVal =
110 APFloat::getLargest(padConstVal.getSemantics(),
true);
111 return padConstVal == lowestVal;
113 if (
auto padConstIntAttr =
114 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
115 const APInt padConstVal = *padConstIntAttr.begin();
116 const unsigned int bitWidth = padConstVal.getBitWidth();
117 const APInt lowestVal =
118 padConstIntAttr.getElementType().isUnsignedInteger()
119 ? APInt::getZero(bitWidth)
120 : APInt::getSignedMinValue(bitWidth);
121 return padConstVal == lowestVal;
127 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
128 Value padInput, ArrayRef<int64_t> newPad) {
130 op, op.getType(), padInput, op.getKernel(), op.getStride(),
135template <
typename OpTy>
136struct ConvPadFoldAdaptor {
137 static bool checkKernelCompliance(OpTy,
const ArrayRef<int64_t>) {
140 static bool checkPadConstCompliance(OpTy op, Value padConst) {
143 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
144 Value padInput, ArrayRef<int64_t> newPad) {
146 op, op.getResult().
getType(), padInput, op.getWeight(), op.getBias(),
147 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
148 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
156template <
typename OpTy,
typename AdaptorTy>
158 using OpRewritePattern<OpTy>::OpRewritePattern;
160 LogicalResult matchAndRewrite(OpTy tensorOp,
161 PatternRewriter &rewriter)
const override {
163 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
166 "Producer must be a tosa::PadOp.");
169 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
170 if (tensorOpPad.size() != 4)
172 tensorOp,
"Tensor operation padding shall have 4 elements.");
175 DenseIntElementsAttr padOpPadding;
179 "The `padding` input specified on the tosa::PadOp must be constant.");
183 if (padOpPadding.size() != 8)
185 "Pad padding should have 8 elements.");
186 int64_t padNBefore = (*(padOpPadding.
begin() + 0)).getLimitedValue();
187 int64_t padNAfter = (*(padOpPadding.
begin() + 1)).getLimitedValue();
188 int64_t padHBefore = (*(padOpPadding.
begin() + 2)).getLimitedValue();
189 int64_t padHAfter = (*(padOpPadding.
begin() + 3)).getLimitedValue();
190 int64_t padWBefore = (*(padOpPadding.
begin() + 4)).getLimitedValue();
191 int64_t padWAfter = (*(padOpPadding.
begin() + 5)).getLimitedValue();
192 int64_t padCBefore = (*(padOpPadding.
begin() + 6)).getLimitedValue();
193 int64_t padCAfter = (*(padOpPadding.
begin() + 7)).getLimitedValue();
195 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
197 tensorOp,
"Folding padding in N or C dimensions is not supported.");
201 SmallVector<int64_t> foldedPad(tensorOpPad.size());
202 foldedPad[0] = padHBefore + tensorOpPad[0];
203 foldedPad[1] = padHAfter + tensorOpPad[1];
204 foldedPad[2] = padWBefore + tensorOpPad[2];
205 foldedPad[3] = padWAfter + tensorOpPad[3];
208 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
210 tensorOp,
"Padding size not aligned with kernel restrictions.");
214 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
217 "Padding constant is not aligned with operator zero-point.");
221 if (llvm::any_of(foldedPad, [](int64_t padVal) {
return padVal > 8192; })) {
223 tensorOp,
"Padding size more than the 8K level limit.");
227 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
238 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
244 results.
add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
245 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
262 op,
"expected constant kernel, stride, and pad operands");
265 rewriter, op.getLoc(), op.
getType(), op.getInput(), op.getInputZp(),
274void AvgPool2dAdaptiveOp::getCanonicalizationPatterns(
284 Value input = op.getInput();
285 Value output = op.getOutput();
286 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
287 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
289 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
295 if (outputShape[1] != 1 || outputShape[2] != 1) {
300 if (inputShape[1] != 1 || inputShape[2] != 1) {
312 FoldPadToTensorOp<tosa::MaxPool2dOp,
313 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
330 op,
"expected constant kernel, stride, and pad operands");
333 rewriter, op.getLoc(), op.
getType(), op.getInput(),
342void MaxPool2dAdaptiveOp::getCanonicalizationPatterns(
356 if (op.getInput1().size() != 1)
358 if (op.getInput1().front().getType() != op.getType()) {
361 op.getInput1().front())
366 rewriter.
replaceOp(op, op.getInput1().front());
381 concatOperands.reserve(2 * op.getNumOperands());
383 int32_t maxNumOperands = 0;
389 bool foundRewritableConcat =
false;
390 for (
Value operand : op.getOperands()) {
391 concatOperands.emplace_back(operand);
393 auto producer = operand.getDefiningOp<tosa::ConcatOp>();
398 if (op.getAxis() != producer.getAxis())
402 foundRewritableConcat =
true;
403 concatOperands.pop_back();
404 llvm::append_range(concatOperands, producer->getOperands());
407 if (!foundRewritableConcat)
409 "No rewritable concat operand found.");
411 if (maxNumOperands > 0 &&
412 concatOperands.size() >
static_cast<size_t>(maxNumOperands))
414 op,
"Rewriting would exceed the maximum number of operands for the "
415 "target environment level.");
418 op, op.getType(), concatOperands, op.getAxisAttr());
428LogicalResult SelectOp::canonicalize(SelectOp op,
PatternRewriter &rewriter) {
429 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
433 op.getOperation()->setOperands(
434 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
446 auto innerTranspose =
447 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
450 "input must be transpose operation");
454 innerTranspose.getPerms();
456 if (transposePerms.size() != innerTransposePerms.size())
459 "transpose and inner transpose perms sizes must be equal");
460 if (transposePerms.empty())
462 transposeOp,
"transpose perms sizes must be positive");
466 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
467 perms[i] = innerTransposePerms[transposePerms[i]];
470 transposeOp, transposeOp.getResult().
getType(),
483 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
485 op,
"Src is from transpose, can compose transposes");
489 if (isa_and_nonnull<tosa::TransposeOp>(subop))
491 op,
"Dest is used by transpose, can compose transposes");
494 auto input = op.getInput1();
495 auto inputTy = llvm::cast<ShapedType>(input.
getType());
496 if (!inputTy.hasRank())
500 for (
int i = 0; i < inputTy.getRank(); ++i)
501 if (inputTy.isDynamicDim(i))
510 nonZeroPerms.reserve(permValues.size());
511 for (
auto idx : permValues) {
512 auto sz = inputTy.getDimSize(idx);
514 nonZeroPerms.push_back(idx);
517 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
518 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
520 "Transpose changes memory layout.");
523 newShape.reserve(inputTy.getRank());
524 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
525 newShape.push_back(inputTy.getDimSize(permValues[i]));
528 op, op.getType(), op.getInput1(),
536 results.
add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
544 Value input = op.getInput();
545 auto inputType = llvm::cast<ShapedType>(op.getInput().getType());
546 auto inputElementType = inputType.getElementType();
548 if (isa<FloatType>(inputElementType)) {
550 const auto minClamp =
551 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
552 const auto maxClamp =
553 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
554 const bool isMin = minClamp.isNegInfinity();
555 const bool isMax = maxClamp.isInfinity();
557 if (isMin && isMax) {
565 const bool isBoolean = inputElementType.isInteger(1);
566 if (inputElementType.isUnsignedInteger() || isBoolean) {
567 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
570 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
574 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
575 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
576 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
578 if (minClamp <= intMin && maxClamp >= intMax) {
585 if (llvm::isa<IntegerType>(inputElementType)) {
587 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
589 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
591 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
592 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
593 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
595 if (minClamp <= intMin && maxClamp >= intMax) {
627 template <
typename T>
641 Value input = op.getInput();
649 const auto opNanMode = op.getNanMode();
650 const auto clampNanMode = clampOp.getNanMode();
651 if (opNanMode == NanPropagationMode::IGNORE &&
652 clampNanMode == NanPropagationMode::PROPAGATE)
655 auto maxValAttr = op.getMaxValAttr();
656 auto minValAttr = op.getMinValAttr();
657 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
658 auto clampOpMinValAttr = clampOp.getMinValAttr();
660 auto inputEType = llvm::cast<ShapedType>(input.
getType()).getElementType();
662 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
667 if (mlir::isa<FloatType>(inputEType)) {
668 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
669 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
670 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
671 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
674 const auto opMinFloat = floatMinValAttr.getValue();
675 const auto opMaxFloat = floatMaxValAttr.getValue();
676 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
677 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
681 if (!opRangeFloatRange.
intersects(clampRangeFloatRange))
685 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
686 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
687 newMinValAttr = rewriter.
getFloatAttr(inputEType, newMinVal);
688 newMaxValAttr = rewriter.
getFloatAttr(inputEType, newMaxVal);
690 assert(mlir::isa<IntegerType>(inputEType));
691 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
692 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
693 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
694 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
696 if (inputEType.isUnsignedInteger()) {
698 const auto opMinInt = intMinValAttr.getUInt();
699 const auto opMaxInt = intMaxValAttr.getUInt();
700 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
701 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
705 if (!opRangeIntRange.
intersects(clampRangeIntRange))
709 auto newMinVal = std::max(opMinInt, clampOpMinInt);
710 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
715 const auto opMinInt = intMinValAttr.getInt();
716 const auto opMaxInt = intMaxValAttr.getInt();
717 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
718 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
722 if (!opRangeIntRange.
intersects(clampRangeIntRange))
726 auto newMinVal = std::max(opMinInt, clampOpMinInt);
727 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
733 auto newMode = (opNanMode != clampNanMode)
734 ? tosa::NanPropagationMode::IGNORE
738 NanPropagationModeAttr::get(rewriter.
getContext(), newMode);
741 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
747void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
748 MLIRContext *context) {
749 results.
add<ClampIsNoOp>(context);
750 results.
add<ClampClampOptimization>(context);
758 Value sliceInput = sliceOp.getInput1();
762 sliceOp,
"slice input must be concat operation");
765 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
766 if (!concatType || !concatType.hasStaticShape())
768 sliceOp,
"slice input must be a static ranked tensor");
769 int32_t axis = concatOp.getAxis();
776 sliceOp,
"start of slice must be a static ranked shape");
780 sliceOp,
"size of slice must be a static ranked shape");
790 std::optional<Value> replaceWithSlice;
791 for (
auto input : inputs) {
792 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
793 if (!inputType || !inputType.hasStaticShape())
795 sliceOp,
"concat input must be a static ranked tensor");
797 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
798 inputType.getDimSize(axis)) {
804 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
805 input, start_op, size_op)
809 sliceStarts[axis] -= inputType.getDimSize(axis);
812 if (!replaceWithSlice)
814 sliceOp,
"corresponding concat input not found for slice");
816 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
826 Value sliceInput = sliceOp.getInput1();
832 "slice input must be a pad operation");
835 if (!padOp->hasOneUse())
837 "pad shall have a single consumer");
840 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
841 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
842 if (!inputTy || !padTy || !inputTy.hasRank())
844 "slice input must be a ranked tensor");
851 "`padding` input specified on the tosa::PadOp must be constant.");
854 llvm::to_vector(paddingElems.getValues<
int64_t>());
860 sliceOp,
"start of slice must be a static ranked shape");
867 sliceOp,
"size of slice must be a static ranked shape");
872 const int64_t rank = inputTy.getRank();
873 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](
int64_t i) {
874 const bool isDimDynamic = inputTy.isDynamicDim(i);
875 const bool isDimSliced =
878 return isDimDynamic && isDimSliced;
881 sliceOp,
"axis that are sliced shall be statically known.");
888 bool updated =
false;
890 for (
int64_t i = 0; i < rank; ++i) {
891 const int64_t padLo = padPaddings[i * 2];
892 const int64_t padHi = padPaddings[i * 2 + 1];
893 const int64_t sliceStart = sliceStarts[i];
894 const int64_t sliceSize = sliceSizes[i];
895 const int64_t sliceEnd = sliceStart + sliceSize;
898 if (inputTy.isDynamicDim(i)) {
899 newPadPaddings[i * 2] = padLo;
900 newPadPaddings[i * 2 + 1] = padHi;
901 newSliceStarts[i] = sliceStart;
906 const int64_t dimSize = inputTy.getShape()[i];
907 const int64_t dimTotal = padLo + dimSize + padHi;
910 if (sliceStart < 0 || sliceEnd > dimTotal)
914 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
915 newSliceStarts[i] = newSliceStart;
916 updated |= newSliceStart != sliceStart;
919 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
921 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
922 newPadPaddings[i * 2] = newPadLo;
923 newPadPaddings[i * 2 + 1] = newPadHi;
924 updated |= (newPadLo != padLo) || (newPadHi != padHi);
928 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
934 sliceOp,
"terminate condition; nothing to rewrite");
940 RankedTensorType::get(newPadShape, inputTy.getElementType());
941 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
942 padOp.getInput1(), newPaddingsOp,
943 padOp.getPadConst());
949 newPadOp.getResult(), newStartOp,
964 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
965 if (!resultType.hasRank())
968 ElementsAttr sizeElems;
971 sliceOp,
"size of slice must be a static ranked shape");
975 llvm::to_vector(sizeElems.getValues<
int64_t>());
977 bool replaceSliceSize{
false};
981 for (
const auto &[
index, size] : llvm::enumerate(sliceSizes)) {
983 sliceSizes[
index] = resultType.getDimSize(
index);
984 replaceSliceSize =
true;
988 if (!replaceSliceSize) {
990 sliceOp,
"no dimension of size of slice is dynamic that resolves "
991 "to static output shape");
996 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
997 sliceOp.getInput1(), sliceOp.getStart(), size_op);
999 rewriter.
replaceOp(sliceOp, newSliceOp.getResult());
1004void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1005 MLIRContext *context) {
1006 results.
add<ConcatSliceOptimization, PadSliceOptimization,
1007 SliceDynamicSizeCanonicalization>(context);
1015 const Value castInput = castOp.getInput();
1019 "input must be cast operation");
1021 const Value innerCastInput = innerCastOp.getInput();
1023 const ShapedType innerInputType =
1024 llvm::cast<ShapedType>(innerCastInput.
getType());
1025 const ShapedType innerOutputType =
1026 llvm::cast<ShapedType>(innerCastOp.getType());
1027 const ShapedType outerOutputType = llvm::cast<ShapedType>(castOp.getType());
1029 const Type innerInputElemType = innerInputType.getElementType();
1030 const Type innerOutputElemType = innerOutputType.getElementType();
1031 const Type outerOutputElemType = outerOutputType.getElementType();
1034 outerOutputElemType};
1036 if (llvm::any_of(types, [](
const Type type) {
1040 llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
1041 Float16Type, Float32Type>(type));
1044 castOp,
"only integer and f32, f16, bf16, f8E4M3FN, f8E5M2 types are "
1047 if (llvm::isa<Float8E5M2Type>(innerInputElemType) &&
1048 llvm::isa<Float8E4M3FNType>(outerOutputElemType)) {
1050 castOp,
"avoid introducing f8E5M2 -> f8E4M3FN casts which are not "
1054 if (llvm::isa<Float8E4M3FNType>(innerInputElemType) &&
1055 llvm::isa<Float8E5M2Type>(outerOutputElemType)) {
1057 castOp,
"avoid introducing f8E4M3FN -> f8E5M2 casts which are not "
1061 if (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(innerInputElemType) &&
1064 castOp,
"avoid introducing fp8 -> integer casts which are not "
1069 llvm::isa<Float8E5M2Type, Float8E4M3FNType>(outerOutputElemType)) {
1071 castOp,
"avoid introducing integer -> fp8 casts which are not "
1075 if (llvm::isa<Float16Type>(innerInputElemType) &&
1076 llvm::isa<BFloat16Type>(outerOutputElemType)) {
1078 castOp,
"avoid introducing fp16 -> bf16 casts which are not "
1082 if (llvm::isa<BFloat16Type>(innerInputElemType) &&
1083 llvm::isa<Float16Type>(outerOutputElemType)) {
1085 castOp,
"avoid introducing bf16 -> fp16 casts which are not "
1089 const auto isIntegerOneOfWidth = [](
Type type,
size_t bitwidth1,
1094 if (isIntegerOneOfWidth(innerInputElemType, 8, 16) &&
1097 castOp,
"avoid introducing i8/i16 -> i64 casts which are not "
1101 if (isIntegerOneOfWidth(innerInputElemType, 1, 64) &&
1104 castOp,
"avoid introducing bool/i64 to float casts which are not "
1105 "supported in all versions of TOSA");
1109 isIntegerOneOfWidth(outerOutputElemType, 1, 64)) {
1111 castOp,
"avoid introducing float to bool/i64 casts which are not "
1112 "supported in all versions of TOSA");
1118 "inner cast operation is narrowing");
1127 return semantics.nonFiniteBehavior !=
1128 llvm::fltNonfiniteBehavior::FiniteOnly;
1132 return semantics.nonFiniteBehavior == llvm::fltNonfiniteBehavior::IEEE754;
1136 const ShapedType outType)
const {
1138 if (inType.getElementType().isInteger() &&
1139 outType.getElementType().isInteger()) {
1141 const auto inTypeSignedness =
1142 cast<IntegerType>(inType.getElementType()).getSignedness();
1143 const auto outTypeSignedness =
1144 cast<IntegerType>(outType.getElementType()).getSignedness();
1146 return (inTypeSignedness != outTypeSignedness ||
1147 inType.getElementTypeBitWidth() >
1148 outType.getElementTypeBitWidth());
1151 if (inType.getElementType().isFloat() &&
1152 outType.getElementType().isFloat()) {
1154 FloatType inElemTy = cast<FloatType>(inType.getElementType());
1155 FloatType outElemTy = cast<FloatType>(outType.getElementType());
1156 llvm::fltSemantics inTypeSemantics = inElemTy.getFloatSemantics();
1157 llvm::fltSemantics outTypeSemantics = outElemTy.getFloatSemantics();
1163 [[maybe_unused]]
const auto isSupported = [](
Type elemType) {
1164 return llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
1165 Float16Type, Float32Type>(elemType);
1168 assert(isSupported(inElemTy) &&
1169 "unsupported input element type in isNarrowingCast");
1170 assert(isSupported(outElemTy) &&
1171 "unsupported output element type in isNarrowingCast");
1174 inTypeSemantics.maxExponent > outTypeSemantics.maxExponent ||
1175 inTypeSemantics.minExponent < outTypeSemantics.minExponent ||
1176 inTypeSemantics.precision > outTypeSemantics.precision ||
1187void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1188 MLIRContext *context) {
1189 results.
add<NonNarrowingCastsOptimization>(context);
1198 const Value castToBlockScaledInput = castToBlockScaledOp.getInputData();
1199 auto castFromBlockScaledOp =
1200 castToBlockScaledInput.
getDefiningOp<tosa::CastFromBlockScaledOp>();
1201 if (!castFromBlockScaledOp)
1203 castToBlockScaledOp,
1204 "input must be cast_from_block_scaled operation");
1206 const Value innerData = castFromBlockScaledOp.getInputData();
1207 const Value innerScale = castFromBlockScaledOp.getInputScale();
1208 const auto innerDataTy = llvm::cast<ShapedType>(innerData.
getType());
1209 const auto innerScaleTy = llvm::cast<ShapedType>(innerScale.
getType());
1211 const Value outerData = castToBlockScaledOp.getOutputData();
1212 const Value outerScale = castToBlockScaledOp.getOutputScale();
1213 const auto outerDataTy = llvm::cast<ShapedType>(outerData.
getType());
1214 const auto outerScaleTy = llvm::cast<ShapedType>(outerScale.
getType());
1216 if (innerDataTy != outerDataTy || innerScaleTy != outerScaleTy) {
1218 castToBlockScaledOp,
1219 "inputs types to cast_from_block_scaled operation must match output "
1220 "types to cast_to_block_scaled");
1223 if (castFromBlockScaledOp.getBlockSize() !=
1224 castToBlockScaledOp.getBlockSize()) {
1226 castToBlockScaledOp,
"block sizes for cast_from_block_scaled and "
1227 "cast_to_block_scaled must match");
1230 rewriter.
replaceOp(castToBlockScaledOp, {innerData, innerScale});
1236void CastToBlockScaledOp::getCanonicalizationPatterns(
1237 RewritePatternSet &results, MLIRContext *context) {
1238 results.
add<CancellingBlockScaledCastsOptimization>(context);
1245template <
typename Folder>
1246static DenseElementsAttr
1248 bool foldDenseValues =
false) {
1252 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1256 const auto rETy = llvm::cast<ShapedType>(
rhs.getType()).getElementType();
1260 if (
lhs.isSplat() &&
rhs.isSplat()) {
1261 if (isa<FloatType>(lETy)) {
1262 const APFloat l =
lhs.getSplatValue<APFloat>();
1263 const APFloat r =
rhs.getSplatValue<APFloat>();
1264 const auto maybeResult = Folder::fold(l, r);
1265 if (failed(maybeResult))
1270 if (
const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
1271 const APInt l =
lhs.getSplatValue<APInt>();
1272 const APInt r =
rhs.getSplatValue<APInt>();
1273 const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
1274 if (failed(maybeResult))
1280 if (foldDenseValues) {
1281 assert(lETy.isIntOrIndex() &&
1282 "Only integer types are currently supported.");
1285 llvm::zip(
lhs.getValues<APInt>(),
rhs.getValues<APInt>())) {
1286 const auto maybeResult = Folder::fold(l, r,
false);
1287 if (failed(maybeResult))
1289 resultValues.push_back(maybeResult.value());
1297template <
typename Folder>
1299 bool foldDenseValues =
false) {
1303 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1309 if (
const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
1311 const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
1312 if (failed(maybeResult))
1318 if (foldDenseValues) {
1322 for (
auto const &v : val.
getValues<APInt>()) {
1323 const auto maybeResult = Folder::fold(v,
false);
1324 if (failed(maybeResult))
1326 resultValues.push_back(maybeResult.value());
1341 assert(dense.isSplat());
1342 APInt a = dense.getSplatValue<APInt>();
1343 return a.getSExtValue();
1348 const bool isUnsigned) {
1351 isUnsigned ?
lhs.uadd_ov(
rhs, overflow) :
lhs.sadd_ov(
rhs, overflow);
1357 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1364 const bool isUnsigned) {
1367 isUnsigned ?
lhs.usub_ov(
rhs, overflow) :
lhs.ssub_ov(
rhs, overflow);
1373 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1380 const bool isUnsigned) {
1382 const unsigned originalWidth =
lhs.getBitWidth();
1385 if (
lhs.getBitWidth() !=
rhs.getBitWidth()) {
1390 if (
lhs == 0 ||
rhs == 0)
1391 return APInt::getZero(originalWidth);
1393 bool overflow =
false;
1395 isUnsigned ?
lhs.umul_ov(
rhs, overflow) :
lhs.smul_ov(
rhs, overflow);
1400 return result.trunc(originalWidth);
1403 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1409 return a.isNegative() !=
b.isNegative();
1416 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1424 APInt::udivrem(
lhs,
rhs, q, r);
1425 if (!r.isZero() && Ceil) {
1432 bool overflow{
false};
1433 APInt
const q =
lhs.sdiv_ov(
rhs, overflow);
1436 APInt
const r =
lhs.srem(
rhs);
1446 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1454 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1456 if (
lhs.isNegative() || (!
rhs.isStrictlyPositive()))
1466 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1468 auto const r = t.mod(
rhs);
1469 if (llvm::APFloatBase::opStatus::opOK == r) {
1479 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1481 return lhs.getSExtValue() >=
rhs.getSExtValue() ?
lhs :
rhs;
1484 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1492 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1494 return lhs.getSExtValue() <=
rhs.getSExtValue() ?
lhs :
rhs;
1497 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1503 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1504 auto const numBits = value.getBitWidth();
1506 auto const zextv = value.getZExtValue();
1507 if (zextv >= numBits)
1509 return APInt::getOneBitSet(numBits, zextv);
1511 auto const sextv = value.getSExtValue();
1512 if (sextv < 0 || sextv >= numBits || (value.isNegative()))
1514 return APInt::getOneBitSet(numBits, sextv);
1524 assert(!isUnsigned &&
1525 "unsigned values are not supported for shape div folders");
1526 if (
lhs.isNegative() || !
rhs.isStrictlyPositive())
1531 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1537 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1538 if (!value.isStrictlyPositive())
1540 return APInt(value.getBitWidth(), value.ceilLogBase2());
1545 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1546 if (!value.isStrictlyPositive())
1548 return APInt(value.getBitWidth(), value.logBase2());
1554 const bool isUnsigned) {
1555 return isUnsigned ? APInt(1,
lhs.ugt(
rhs)) : APInt(1,
lhs.sgt(
rhs));
1558 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1559 return APInt(1,
lhs >
rhs);
1565 const bool isUnsigned) {
1566 return isUnsigned ? APInt(1,
lhs.uge(
rhs)) : APInt(1,
lhs.sge(
rhs));
1569 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1570 return APInt(1,
lhs >=
rhs);
1576 const bool isUnsigned) {
1577 return APInt(1,
lhs ==
rhs);
1580 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1581 return APInt(1,
lhs ==
rhs);
1586 if (llvm::isa<FloatType>(elemType))
1588 if (llvm::isa<IntegerType>(elemType))
1594 if (llvm::isa<FloatType>(elemType))
1595 return val && val.
isSplat() &&
1597 if (llvm::isa<IntegerType>(elemType)) {
1598 const int64_t shifted = 1LL << shift;
1599 return val && val.
isSplat() &&
1605OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1606 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1607 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1608 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1609 if (!lhsTy || !rhsTy || !resultTy)
1613 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1614 !rhsTy.getElementType().isIntOrIndexOrFloat())
1617 auto resultETy = resultTy.getElementType();
1619 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1621 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1624 lhsTy.getShape(), rhsTy.getShape());
1625 if (isBroadcastable && lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1627 if (isBroadcastable && rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
1630 if (!lhsAttr || !rhsAttr)
1636OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1637 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
1638 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1639 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1640 !outputTy.hasStaticShape())
1644 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.
isInteger()) {
1645 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1646 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1653OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1654 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1655 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1656 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1657 if (!lhsTy || !rhsTy || !resultTy)
1659 if (lhsTy.getElementType() != rhsTy.getElementType())
1664 auto resultETy = resultTy.getElementType();
1666 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1668 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1669 if (lhsAttr && lhsAttr.isSplat() && rhsAttr && rhsAttr.isSplat()) {
1670 if (llvm::isa<IntegerType>(resultETy) && resultTy.hasStaticShape() &&
1671 lhsAttr.getSplatValue<APInt>().isZero() &&
1672 !rhsAttr.getSplatValue<APInt>().isZero()) {
1673 return lhsAttr.resizeSplat(resultTy);
1677 if (rhsAttr && rhsAttr.isSplat()) {
1679 lhsTy.getShape(), rhsTy.getShape());
1680 if (isBroadcastable && lhsTy == resultTy &&
1681 llvm::isa<IntegerType>(resultETy) &&
1682 rhsAttr.getSplatValue<APInt>().isOne())
1686 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1687 llvm::isa<IntegerType>(resultETy)) {
1688 APInt l = lhsAttr.getSplatValue<APInt>();
1689 APInt r = rhsAttr.getSplatValue<APInt>();
1691 auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
1693 DivFoldAdaptor<
false>::fold(l, r, intTy.isUnsigned());
1706std::optional<APInt> mulInt(APInt
lhs, APInt
rhs, int32_t shift,
1707 unsigned bitwidth) {
1708 bool overflow =
false;
1709 APInt
result =
lhs.sext(64).smul_ov(
rhs.sext(64), overflow);
1712 return std::nullopt;
1715 auto round = APInt(64, 1) << (shift - 1);
1717 result.ashrInPlace(shift);
1720 if (!(
result.getSExtValue() >= INT32_MIN &&
1721 result.getSExtValue() <= INT32_MAX)) {
1723 return std::nullopt;
1727 return result.trunc(bitwidth);
1730DenseElementsAttr mulBinaryFolder(DenseElementsAttr
lhs, DenseElementsAttr
rhs,
1731 RankedTensorType ty, int32_t shift) {
1733 if (llvm::isa<IntegerType>(ty.getElementType())) {
1734 APInt l =
lhs.getSplatValue<APInt>();
1735 APInt r =
rhs.getSplatValue<APInt>();
1741 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1742 const std::optional<APInt>
result = mulInt(l, r, shift, bitwidth);
1748 if (llvm::isa<FloatType>(ty.getElementType())) {
1749 APFloat l =
lhs.getSplatValue<APFloat>();
1750 APFloat r =
rhs.getSplatValue<APFloat>();
1760OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1761 auto lhs = getInput1();
1762 auto rhs = getInput2();
1763 auto lhsTy = llvm::dyn_cast<RankedTensorType>(
lhs.getType());
1764 auto rhsTy = llvm::dyn_cast<RankedTensorType>(
rhs.getType());
1765 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1766 if (!lhsTy || !rhsTy || !resultTy)
1769 auto resultETy = resultTy.getElementType();
1771 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1773 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1778 if (resultETy.isInteger(32)) {
1779 ElementsAttr shift_elem;
1780 if (getShift().getImpl()) {
1784 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1788 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr) &&
1789 resultTy.hasStaticShape())
1791 return lhsAttr.resizeSplat(resultTy);
1792 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr) &&
1793 resultTy.hasStaticShape())
1794 return rhsAttr.resizeSplat(resultTy);
1797 lhsTy.getShape(), rhsTy.getShape());
1798 if (isBroadcastable && rhsTy == resultTy &&
1801 if (isBroadcastable && lhsTy == resultTy &&
1805 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1808OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1809 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1810 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1811 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1812 if (!lhsTy || !rhsTy || !resultTy)
1816 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1817 !rhsTy.getElementType().isIntOrIndexOrFloat())
1820 auto resultETy = resultTy.getElementType();
1822 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1824 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1827 lhsTy.getShape(), rhsTy.getShape());
1828 if (isBroadcastable && lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1831 if (!lhsAttr || !rhsAttr)
1837OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1838 auto resultTy = llvm::cast<ShapedType>(
getType());
1840 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1842 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1844 if (!lhsAttr || !rhsAttr)
1850OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1851 auto resultTy = llvm::cast<ShapedType>(
getType());
1853 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1855 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1857 if (!lhsAttr || !rhsAttr)
1863OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1864 auto resultTy = llvm::cast<ShapedType>(
getType());
1866 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1868 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1869 Value
lhs = getInput1();
1870 Value
rhs = getInput2();
1871 auto lhsTy = llvm::cast<ShapedType>(
lhs.getType());
1875 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy.hasRank() &&
1876 resultTy.hasStaticShape() &&
lhs ==
rhs) {
1880 if (!lhsAttr || !rhsAttr)
1886OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1890 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1894 auto inTy = llvm::cast<ShapedType>(getInput().
getType());
1895 auto outTy = llvm::cast<ShapedType>(
getType());
1896 if (!outTy.hasRank() || !outTy.hasStaticShape())
1898 auto inETy = inTy.getElementType();
1899 auto outETy = outTy.getElementType();
1901 if (operand.isSplat()) {
1902 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1904 auto splatVal = operand.getSplatValue<APFloat>();
1905 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1906 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1911 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1912 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1913 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1914 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1915 llvm::RoundingMode::NearestTiesToEven);
1919 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1920 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1921 auto intVal = APSInt(
1922 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1923 auto floatVal = operand.getSplatValue<APFloat>();
1925 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1930 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1931 const auto inIntType = llvm::cast<IntegerType>(inETy);
1932 auto unsignIn = inIntType.isUnsignedInteger();
1934 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1935 auto intVal = operand.getSplatValue<APInt>();
1936 auto bitwidth = outETy.getIntOrFloatBitWidth();
1939 if (outETy.isInteger(1)) {
1940 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1942 intVal = intVal.trunc(bitwidth);
1943 }
else if (unsignIn || inIntType.isInteger(1)) {
1944 intVal = intVal.zext(bitwidth);
1946 intVal = intVal.sext(bitwidth);
1956OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1958OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1960#define REDUCE_FOLDER(OP) \
1961 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1962 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1963 if (!inputTy.hasRank()) \
1965 if (inputTy != getType()) \
1967 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1968 return getInput(); \
1981 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1982 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1984 if (!inputTy || !outputTy)
1990 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1994 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1995 getInput1().getDefiningOp())) {
1996 getInput1Mutable().assign(reshapeOp.getInput1());
2001 if (!inputTy.getElementType().isIntOrIndexOrFloat())
2006 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
2008 if (!outputTy.hasStaticShape())
2012 if (operand.isSplat())
2017 if (!getInput1().hasOneUse())
2024 return operand.reshape(
2025 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
2031OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
2033 if (adaptor.getPadding() && getInput1().
getType() ==
getType()) {
2034 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
2035 if (densePad && densePad.isSplat() &&
2036 densePad.getSplatValue<APInt>().isZero()) {
2046OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
2048 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
2050 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
2052 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
2053 if (!scaleAttr || !offsetAttr || !borderAttr) {
2060 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
2065 if (scale[0] != scale[1] || scale[2] != scale[3]) {
2070 if (offset[0] != 0 || offset[1] != 0) {
2075 if (border[0] != 0 || border[1] != 0) {
2079 return foldToInputIfTypeMatches(
getType(), getInput());
2082OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
2083 auto operand = getInput1();
2084 auto operandTy = llvm::cast<ShapedType>(operand.getType());
2085 auto axis = getAxis();
2087 const bool isSplatInput =
2088 llvm::isa_and_nonnull<SplatElementsAttr>(adaptor.getInput1());
2089 if (!operandTy.hasRank() ||
2090 (!isSplatInput && operandTy.getDimSize(axis) != 1))
2092 return foldToInputIfTypeMatches(
getType(), operand);
2095OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
2096 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
2097 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
2099 if (!inputTy || !outputTy)
2102 if (inputTy == outputTy && inputTy.hasStaticShape())
2107 DenseElementsAttr startElems;
2113 llvm::all_of(startElems.
getValues<APInt>(),
2114 [](
const APInt &val) { return val.isZero(); });
2119 DenseElementsAttr sizeElems;
2123 auto inputShape = inputTy.getShape();
2124 auto sizeValues = sizeElems.
getValues<APInt>();
2126 bool sizeMatchesInput =
true;
2127 for (
const auto &[i, sizeVal] : llvm::enumerate(sizeValues)) {
2128 int64_t size = sizeVal.getSExtValue();
2130 if (inputTy.isDynamicDim(i)) {
2134 sizeMatchesInput =
false;
2141 sizeMatchesInput =
false;
2147 if (sizeMatchesInput)
2152 if (!adaptor.getInput1())
2156 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
2157 !outputTy.getElementType().isIntOrIndexOrFloat())
2160 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
2161 if (operand.isSplat() && outputTy.hasStaticShape()) {
2165 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
2166 outputTy.getNumElements() == 1) {
2167 llvm::SmallVector<uint64_t>
indices =
2168 llvm::to_vector(startElems.
getValues<uint64_t>());
2169 if (
auto values = operand.tryGetValues<Attribute>())
2176OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
2177 const Value pred = getPred();
2178 const Value onTrue = getOnTrue();
2179 const Value onFalse = getOnFalse();
2181 const auto predTy = llvm::dyn_cast<RankedTensorType>(pred.
getType());
2182 const auto onTrueTy = llvm::dyn_cast<RankedTensorType>(onTrue.
getType());
2183 const auto onFalseTy = llvm::dyn_cast<RankedTensorType>(onFalse.
getType());
2184 if (!predTy || !onTrueTy || !onFalseTy)
2187 const Type resultTy =
getType();
2189 const ArrayRef<int64_t> predShape = predTy.getShape();
2190 const ArrayRef<int64_t> onTrueShape = onTrueTy.getShape();
2192 if (onTrue == onFalse && onTrueTy == resultTy &&
2197 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
2200 if (!predicate.isSplat())
2203 const bool predicateValue = predicate.getSplatValue<APInt>().getBoolValue();
2205 SmallVector<SmallVector<int64_t>, 3> shapes;
2206 shapes.emplace_back(predShape);
2207 shapes.emplace_back(onTrueShape);
2208 shapes.emplace_back(onFalseTy.getShape());
2209 const bool isBroadcastable =
2212 if (predicateValue ==
true && onTrueTy == resultTy && isBroadcastable)
2214 if (predicateValue ==
false && onFalseTy == resultTy && isBroadcastable)
2219OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
2221 if (
auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
2222 adaptor.getMultiples())) {
2223 if (multiples.isSplat() &&
2224 multiples.getSplatValue<APInt>().getSExtValue() == 1)
2226 if (
auto int_array_attr =
2227 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
2228 if (llvm::all_of(int_array_attr.getValues<APInt>(),
2229 [](APInt v) { return v.getSExtValue() == 1; }))
2237OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
2238 auto resultTy = llvm::cast<ShapedType>(
getType());
2242 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
2243 if (input.isSplat() && resultTy.hasRank() && resultTy.hasStaticShape() &&
2244 input.
getType().getElementType() == resultTy.getElementType())
2245 return input.reshape(resultTy);
2249 const llvm::ArrayRef<int32_t> perms = getPerms();
2251 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
2254 return foldToInputIfTypeMatches(
getType(), getInput1());
2257OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
2260 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
2266 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2267 failed(maybeIZp) || *maybeIZp != 0) {
2271 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2272 failed(maybeOZp) || *maybeOZp != 0) {
2276 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
2277 failed(maybeIZp) || *maybeIZp != 0) {
2281 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
2282 failed(maybeOZp) || *maybeOZp != 0) {
2287 return foldToInputIfTypeMatches(
getType(), definingOp.getInput1());
2290OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
2291 auto input = getInput1();
2294 return foldToInputIfTypeMatches(
getType(), input);
2299OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
2300 auto input = adaptor.getInput1();
2302 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
2304 if (!inputAttr || !inputAttr.isSplat())
2307 auto shapeType = llvm::cast<ShapedType>(
getType());
2308 if (!shapeType.hasRank() || !shapeType.hasStaticShape())
2310 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
2311 auto floatVal = inputAttr.getSplatValue<APFloat>();
2313 ReciprocalOp::calcOneElement(floatVal));
2319template <
typename Op,
typename OpFoldAdaptor>
2321 auto input1ConstShape =
2322 dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
2323 if (!input1ConstShape)
2326 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2332template <
typename Op,
typename OpFoldAdaptor>
2334 auto input1ConstShape =
2335 dyn_cast<tosa::ConstShapeOp>(op->getInput1().getDefiningOp());
2336 auto input2ConstShape =
2337 dyn_cast<tosa::ConstShapeOp>(op->getInput2().getDefiningOp());
2338 if (!input1ConstShape || !input2ConstShape)
2341 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2342 const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
2345 input1Attr.getType(),
2349OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
2350 const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().
getType());
2351 if (!inputTy || !inputTy.hasRank())
2353 const int32_t axis = getAxis();
2354 const int64_t dimSize = inputTy.getDimSize(axis);
2355 if (ShapedType::isDynamic(dimSize))
2359 const auto resultAttrTy =
2360 RankedTensorType::get(1, builder.getIndexType());
2365 auto const inputs = op->getInput();
2371 concatDims.reserve( 64);
2372 for (
auto const &v : inputs) {
2373 auto vConstShape = dyn_cast<tosa::ConstShapeOp>(v.getDefiningOp());
2377 const auto vAttr = cast<DenseElementsAttr>(vConstShape.getValues());
2380 auto const vAttrVals = vAttr.getValues<APInt>();
2381 for (
auto const &v : vAttrVals) {
2382 concatDims.push_back(v);
2386 auto *ctx = op->getContext();
2387 assert(ctx !=
nullptr &&
"ctx is nullptr");
2388 auto const rankedTy = RankedTensorType::get(
2389 {
static_cast<int64_t>(concatDims.size())}, IndexType::get(ctx));
2395 auto const input1 = op->getInput();
2396 auto const input2 = op->getStart();
2397 auto const input3 = op->getSize();
2399 auto input1ConstShape = dyn_cast<tosa::ConstShapeOp>(input1.getDefiningOp());
2401 if (!input1ConstShape)
2404 auto const input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2408 auto const input1Vals = input1Attr.getValues<APInt>();
2409 auto const totalInput1 = input1Vals.size();
2414 if (failed(start) || failed(size))
2417 auto const startV =
static_cast<int32_t
>(start.value());
2418 auto const sizeV =
static_cast<int32_t
>(size.value());
2420 if ((sizeV <= 0) || (startV < 0) ||
2421 (
static_cast<size_t>(startV + sizeV) > totalInput1))
2425 sliceOfInput.reserve(totalInput1);
2427 for (
auto i = startV; i < (startV + sizeV); i++) {
2428 sliceOfInput.push_back(input1Vals[i]);
2431 auto *ctx = op->getContext();
2432 assert(ctx !=
nullptr &&
"ctx is nullptr");
2434 auto const rankedTy = RankedTensorType::get(
2435 {
static_cast<int64_t>(sliceOfInput.size())}, IndexType::get(ctx));
2440OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
2444OpFoldResult tosa::SubShapeOp::fold(FoldAdaptor adaptor) {
2448OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
2452OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
2453 return binaryFold<DivCeilShapeOp, ShapeDivFoldAdaptor<
true>>(
this);
2456OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
2457 return binaryFold<DivFloorShapeOp, ShapeDivFoldAdaptor<
false>>(
this);
2460OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
2464OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
2468OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
2472OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
2476OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
2480OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
2484OpFoldResult tosa::ConcatShapeOp::fold(FoldAdaptor adaptor) {
2488OpFoldResult tosa::SliceShapeOp::fold(FoldAdaptor adaptor) {
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
#define REDUCE_FOLDER(OP)
OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, bool foldDenseValues=false)
static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy, bool foldDenseValues=false)
OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op)
static FailureOr< int64_t > getSingleI64From1ElementTensor(Value v)
OpFoldResult binaryFold(Op *op)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
OpFoldResult unaryShapeFold(Op *op)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool signsDiffer(const APInt &a, const APInt &b)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
DynamicAPInt round(const Fraction &f)
TosaLevel getTosaLevelFromEnum(const Level level)
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
TargetEnvAttr lookupTargetEnv(Operation *op)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::AvgPool2dAdaptiveOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp, PatternRewriter &rewriter) const override
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::MaxPool2dAdaptiveOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
bool isNarrowingCast(const ShapedType inType, const ShapedType outType) const
LogicalResult matchAndRewrite(tosa::CastOp castOp, PatternRewriter &rewriter) const override
bool supportsInf(const llvm::fltSemantics &semantics) const
bool supportsNaN(const llvm::fltSemantics &semantics) const
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
int32_t MAX_TENSOR_LIST_SIZE