27#include "llvm/ADT/APFloat.h"
28#include "llvm/ADT/APInt.h"
54 (padConstAttr.
size() != 1)) {
59 if (
auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
60 float padConstVal = (*padConstFpAttr.begin()).convertToFloat();
61 return padConstVal == 0.0f;
65 if (
auto padConstIntAttr =
66 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
75 int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue();
76 return zpVal == padConstVal;
84template <
typename OpTy>
85struct PoolPadFoldAdaptor;
88struct PoolPadFoldAdaptor<
tosa::MaxPool2dOp> {
89 using OpTy = tosa::MaxPool2dOp;
90 static bool checkKernelCompliance(OpTy op,
const ArrayRef<int64_t> newPad) {
91 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
92 if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
93 newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
97 static bool checkPadConstCompliance(OpTy, Value padConst) {
99 DenseElementsAttr padConstAttr;
101 padConstAttr.
size() != 1) {
106 if (
auto padConstFpAttr =
107 mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
108 const APFloat padConstVal = *padConstFpAttr.begin();
109 const APFloat lowestVal =
110 APFloat::getLargest(padConstVal.getSemantics(),
true);
111 return padConstVal == lowestVal;
113 if (
auto padConstIntAttr =
114 mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
115 const APInt padConstVal = *padConstIntAttr.begin();
116 const unsigned int bitWidth = padConstVal.getBitWidth();
117 const APInt lowestVal =
118 padConstIntAttr.getElementType().isUnsignedInteger()
119 ? APInt::getZero(bitWidth)
120 : APInt::getSignedMinValue(bitWidth);
121 return padConstVal == lowestVal;
127 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
128 Value padInput, ArrayRef<int64_t> newPad) {
130 op, op.getType(), padInput, op.getKernel(), op.getStride(),
135template <
typename OpTy>
136struct ConvPadFoldAdaptor {
137 static bool checkKernelCompliance(OpTy,
const ArrayRef<int64_t>) {
140 static bool checkPadConstCompliance(OpTy op, Value padConst) {
143 static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
144 Value padInput, ArrayRef<int64_t> newPad) {
146 op, op.getResult().
getType(), padInput, op.getWeight(), op.getBias(),
147 op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(),
148 op.getDilationAttr(), op.getAccType(), op.getLocalBound());
156template <
typename OpTy,
typename AdaptorTy>
158 using OpRewritePattern<OpTy>::OpRewritePattern;
160 LogicalResult matchAndRewrite(OpTy tensorOp,
161 PatternRewriter &rewriter)
const override {
163 auto padOp = tensorOp.getInput().template getDefiningOp<tosa::PadOp>();
166 "Producer must be a tosa::PadOp.");
169 const std::vector<int64_t> &tensorOpPad = tensorOp.getPad().vec();
170 if (tensorOpPad.size() != 4)
172 tensorOp,
"Tensor operation padding shall have 4 elements.");
175 DenseIntElementsAttr padOpPadding;
179 "The `padding` input specified on the tosa::PadOp must be constant.");
183 if (padOpPadding.size() != 8)
185 "Pad padding should have 8 elements.");
186 int64_t padNBefore = (*(padOpPadding.
begin() + 0)).getLimitedValue();
187 int64_t padNAfter = (*(padOpPadding.
begin() + 1)).getLimitedValue();
188 int64_t padHBefore = (*(padOpPadding.
begin() + 2)).getLimitedValue();
189 int64_t padHAfter = (*(padOpPadding.
begin() + 3)).getLimitedValue();
190 int64_t padWBefore = (*(padOpPadding.
begin() + 4)).getLimitedValue();
191 int64_t padWAfter = (*(padOpPadding.
begin() + 5)).getLimitedValue();
192 int64_t padCBefore = (*(padOpPadding.
begin() + 6)).getLimitedValue();
193 int64_t padCAfter = (*(padOpPadding.
begin() + 7)).getLimitedValue();
195 if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0)
197 tensorOp,
"Folding padding in N or C dimensions is not supported.");
201 SmallVector<int64_t> foldedPad(tensorOpPad.size());
202 foldedPad[0] = padHBefore + tensorOpPad[0];
203 foldedPad[1] = padHAfter + tensorOpPad[1];
204 foldedPad[2] = padWBefore + tensorOpPad[2];
205 foldedPad[3] = padWAfter + tensorOpPad[3];
208 if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) {
210 tensorOp,
"Padding size not aligned with kernel restrictions.");
214 if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) {
217 "Padding constant is not aligned with operator zero-point.");
221 if (llvm::any_of(foldedPad, [](int64_t padVal) {
return padVal > 8192; })) {
223 tensorOp,
"Padding size more than the 8K level limit.");
227 AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(),
238 FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
244 results.
add<FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
245 ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
262 op,
"expected constant kernel, stride, and pad operands");
265 rewriter, op.getLoc(), op.
getType(), op.getInput(), op.getInputZp(),
274void AvgPool2dAdaptiveOp::getCanonicalizationPatterns(
284 Value input = op.getInput();
285 Value output = op.getOutput();
286 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
287 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
289 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
295 if (outputShape[1] != 1 || outputShape[2] != 1) {
300 if (inputShape[1] != 1 || inputShape[2] != 1) {
312 FoldPadToTensorOp<tosa::MaxPool2dOp,
313 PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
330 op,
"expected constant kernel, stride, and pad operands");
333 rewriter, op.getLoc(), op.
getType(), op.getInput(),
342void MaxPool2dAdaptiveOp::getCanonicalizationPatterns(
356 if (op.getInput1().size() != 1)
358 if (op.getInput1().front().getType() != op.getType()) {
361 op.getInput1().front())
366 rewriter.
replaceOp(op, op.getInput1().front());
381 concatOperands.reserve(2 * op.getNumOperands());
383 int32_t maxNumOperands = 0;
389 bool foundRewritableConcat =
false;
390 for (
Value operand : op.getOperands()) {
391 concatOperands.emplace_back(operand);
393 auto producer = operand.getDefiningOp<tosa::ConcatOp>();
398 if (op.getAxis() != producer.getAxis())
402 foundRewritableConcat =
true;
403 concatOperands.pop_back();
404 llvm::append_range(concatOperands, producer->getOperands());
407 if (!foundRewritableConcat)
409 "No rewritable concat operand found.");
411 if (maxNumOperands > 0 &&
412 concatOperands.size() >
static_cast<size_t>(maxNumOperands))
414 op,
"Rewriting would exceed the maximum number of operands for the "
415 "target environment level.");
418 op, op.getType(), concatOperands, op.getAxisAttr());
428LogicalResult SelectOp::canonicalize(SelectOp op,
PatternRewriter &rewriter) {
429 auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
433 op.getOperation()->setOperands(
434 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
446 auto innerTranspose =
447 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
450 "input must be transpose operation");
454 innerTranspose.getPerms();
456 if (transposePerms.size() != innerTransposePerms.size())
459 "transpose and inner transpose perms sizes must be equal");
460 if (transposePerms.empty())
462 transposeOp,
"transpose perms sizes must be positive");
466 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
467 perms[i] = innerTransposePerms[transposePerms[i]];
470 transposeOp, transposeOp.getResult().
getType(),
483 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
485 op,
"Src is from transpose, can compose transposes");
489 if (isa_and_nonnull<tosa::TransposeOp>(subop))
491 op,
"Dest is used by transpose, can compose transposes");
494 auto input = op.getInput1();
495 auto inputTy = llvm::cast<ShapedType>(input.
getType());
496 if (!inputTy.hasRank())
500 for (
int i = 0; i < inputTy.getRank(); ++i)
501 if (inputTy.isDynamicDim(i))
510 nonZeroPerms.reserve(permValues.size());
511 for (
auto idx : permValues) {
512 auto sz = inputTy.getDimSize(idx);
514 nonZeroPerms.push_back(idx);
517 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
518 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
520 "Transpose changes memory layout.");
523 newShape.reserve(inputTy.getRank());
524 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
525 newShape.push_back(inputTy.getDimSize(permValues[i]));
528 op, op.getType(), op.getInput1(),
536 results.
add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
544 Value input = op.getInput();
545 auto inputType = llvm::cast<ShapedType>(op.getInput().getType());
546 auto inputElementType = inputType.getElementType();
548 if (isa<FloatType>(inputElementType)) {
550 const auto minClamp =
551 llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
552 const auto maxClamp =
553 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
554 const bool isMin = minClamp.isNegInfinity();
555 const bool isMax = maxClamp.isInfinity();
557 if (isMin && isMax) {
565 const bool isBoolean = inputElementType.isInteger(1);
566 if (inputElementType.isUnsignedInteger() || isBoolean) {
567 const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
570 const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
574 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
575 const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
576 const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
578 if (minClamp <= intMin && maxClamp >= intMax) {
585 if (llvm::isa<IntegerType>(inputElementType)) {
587 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
589 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
591 const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
592 const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
593 const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
595 if (minClamp <= intMin && maxClamp >= intMax) {
627 template <
typename T>
641 Value input = op.getInput();
649 const auto opNanMode = op.getNanMode();
650 const auto clampNanMode = clampOp.getNanMode();
651 if (opNanMode == NanPropagationMode::IGNORE &&
652 clampNanMode == NanPropagationMode::PROPAGATE)
655 auto maxValAttr = op.getMaxValAttr();
656 auto minValAttr = op.getMinValAttr();
657 auto clampOpMaxValAttr = clampOp.getMaxValAttr();
658 auto clampOpMinValAttr = clampOp.getMinValAttr();
660 auto inputEType = llvm::cast<ShapedType>(input.
getType()).getElementType();
662 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
667 if (mlir::isa<FloatType>(inputEType)) {
668 auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
669 auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
670 auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
671 auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
674 const auto opMinFloat = floatMinValAttr.getValue();
675 const auto opMaxFloat = floatMaxValAttr.getValue();
676 const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
677 const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
681 if (!opRangeFloatRange.
intersects(clampRangeFloatRange))
685 auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
686 auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
687 newMinValAttr = rewriter.
getFloatAttr(inputEType, newMinVal);
688 newMaxValAttr = rewriter.
getFloatAttr(inputEType, newMaxVal);
690 assert(mlir::isa<IntegerType>(inputEType));
691 auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
692 auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
693 auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
694 auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
696 if (inputEType.isUnsignedInteger()) {
698 const auto opMinInt = intMinValAttr.getUInt();
699 const auto opMaxInt = intMaxValAttr.getUInt();
700 const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
701 const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
705 if (!opRangeIntRange.
intersects(clampRangeIntRange))
709 auto newMinVal = std::max(opMinInt, clampOpMinInt);
710 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
715 const auto opMinInt = intMinValAttr.getInt();
716 const auto opMaxInt = intMaxValAttr.getInt();
717 const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
718 const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
722 if (!opRangeIntRange.
intersects(clampRangeIntRange))
726 auto newMinVal = std::max(opMinInt, clampOpMinInt);
727 auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
733 auto newMode = (opNanMode != clampNanMode)
734 ? tosa::NanPropagationMode::IGNORE
738 NanPropagationModeAttr::get(rewriter.
getContext(), newMode);
741 op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
747void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
748 MLIRContext *context) {
749 results.
add<ClampIsNoOp>(context);
750 results.
add<ClampClampOptimization>(context);
758 Value sliceInput = sliceOp.getInput1();
762 sliceOp,
"slice input must be concat operation");
765 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
766 if (!concatType || !concatType.hasStaticShape())
768 sliceOp,
"slice input must be a static ranked tensor");
769 int32_t axis = concatOp.getAxis();
776 sliceOp,
"start of slice must be a static ranked shape");
780 sliceOp,
"size of slice must be a static ranked shape");
790 std::optional<Value> replaceWithSlice;
791 for (
auto input : inputs) {
792 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
793 if (!inputType || !inputType.hasStaticShape())
795 sliceOp,
"concat input must be a static ranked tensor");
797 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
798 inputType.getDimSize(axis)) {
804 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
805 input, start_op, size_op)
809 sliceStarts[axis] -= inputType.getDimSize(axis);
812 if (!replaceWithSlice)
814 sliceOp,
"corresponding concat input not found for slice");
816 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
826 Value sliceInput = sliceOp.getInput1();
832 "slice input must be a pad operation");
835 if (!padOp->hasOneUse())
837 "pad shall have a single consumer");
840 auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
841 auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
842 if (!inputTy || !padTy || !inputTy.hasRank())
844 "slice input must be a ranked tensor");
851 "`padding` input specified on the tosa::PadOp must be constant.");
854 llvm::to_vector(paddingElems.getValues<
int64_t>());
860 sliceOp,
"start of slice must be a static ranked shape");
867 sliceOp,
"size of slice must be a static ranked shape");
872 const int64_t rank = inputTy.getRank();
873 if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](
int64_t i) {
874 const bool isDimDynamic = inputTy.isDynamicDim(i);
875 const bool isDimSliced =
878 return isDimDynamic && isDimSliced;
881 sliceOp,
"axis that are sliced shall be statically known.");
888 bool updated =
false;
890 for (
int64_t i = 0; i < rank; ++i) {
891 const int64_t padLo = padPaddings[i * 2];
892 const int64_t padHi = padPaddings[i * 2 + 1];
893 const int64_t sliceStart = sliceStarts[i];
894 const int64_t sliceSize = sliceSizes[i];
895 const int64_t sliceEnd = sliceStart + sliceSize;
898 if (inputTy.isDynamicDim(i)) {
899 newPadPaddings[i * 2] = padLo;
900 newPadPaddings[i * 2 + 1] = padHi;
901 newSliceStarts[i] = sliceStart;
906 const int64_t dimSize = inputTy.getShape()[i];
907 const int64_t dimTotal = padLo + dimSize + padHi;
910 if (sliceStart < 0 || sliceEnd > dimTotal)
914 const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
915 newSliceStarts[i] = newSliceStart;
916 updated |= newSliceStart != sliceStart;
919 const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
921 std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
922 newPadPaddings[i * 2] = newPadLo;
923 newPadPaddings[i * 2 + 1] = newPadHi;
924 updated |= (newPadLo != padLo) || (newPadHi != padHi);
928 newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
934 sliceOp,
"terminate condition; nothing to rewrite");
940 RankedTensorType::get(newPadShape, inputTy.getElementType());
941 auto newPadOp = tosa::PadOp::create(rewriter, padOp.getLoc(), newPadTy,
942 padOp.getInput1(), newPaddingsOp,
943 padOp.getPadConst());
949 newPadOp.getResult(), newStartOp,
964 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
965 if (!resultType.hasRank())
968 ElementsAttr sizeElems;
971 sliceOp,
"size of slice must be a static ranked shape");
975 llvm::to_vector(sizeElems.getValues<
int64_t>());
977 bool replaceSliceSize{
false};
981 for (
const auto &[
index, size] : llvm::enumerate(sliceSizes)) {
983 sliceSizes[
index] = resultType.getDimSize(
index);
984 replaceSliceSize =
true;
988 if (!replaceSliceSize) {
990 sliceOp,
"no dimension of size of slice is dynamic that resolves "
991 "to static output shape");
996 tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.
getType(),
997 sliceOp.getInput1(), sliceOp.getStart(), size_op);
999 rewriter.
replaceOp(sliceOp, newSliceOp.getResult());
1004void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
1005 MLIRContext *context) {
1006 results.
add<ConcatSliceOptimization, PadSliceOptimization,
1007 SliceDynamicSizeCanonicalization>(context);
1015 const Value castInput = castOp.getInput();
1019 "input must be cast operation");
1021 const Value innerCastInput = innerCastOp.getInput();
1023 const ShapedType innerInputType =
1024 llvm::cast<ShapedType>(innerCastInput.
getType());
1025 const ShapedType innerOutputType =
1026 llvm::cast<ShapedType>(innerCastOp.getType());
1027 const ShapedType outerOutputType = llvm::cast<ShapedType>(castOp.getType());
1029 const Type innerInputElemType = innerInputType.getElementType();
1030 const Type innerOutputElemType = innerOutputType.getElementType();
1031 const Type outerOutputElemType = outerOutputType.getElementType();
1034 outerOutputElemType};
1036 if (llvm::any_of(types, [](
const Type type) {
1040 llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
1041 Float16Type, Float32Type>(type));
1044 castOp,
"only integer and f32, f16, bf16, f8E4M3FN, f8E5M2 types are "
1047 if (llvm::isa<Float8E5M2Type>(innerInputElemType) &&
1048 llvm::isa<Float8E4M3FNType>(outerOutputElemType)) {
1050 castOp,
"avoid introducing f8E5M2 -> f8E4M3FN casts which are not "
1054 if (llvm::isa<Float8E4M3FNType>(innerInputElemType) &&
1055 llvm::isa<Float8E5M2Type>(outerOutputElemType)) {
1057 castOp,
"avoid introducing f8E4M3FN -> f8E5M2 casts which are not "
1061 if (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(innerInputElemType) &&
1064 castOp,
"avoid introducing fp8 -> integer casts which are not "
1069 llvm::isa<Float8E5M2Type, Float8E4M3FNType>(outerOutputElemType)) {
1071 castOp,
"avoid introducing integer -> fp8 casts which are not "
1075 if (llvm::isa<Float16Type>(innerInputElemType) &&
1076 llvm::isa<BFloat16Type>(outerOutputElemType)) {
1078 castOp,
"avoid introducing fp16 -> bf16 casts which are not "
1082 if (llvm::isa<BFloat16Type>(innerInputElemType) &&
1083 llvm::isa<Float16Type>(outerOutputElemType)) {
1085 castOp,
"avoid introducing bf16 -> fp16 casts which are not "
1089 const auto isIntegerOneOfWidth = [](
Type type,
size_t bitwidth1,
1094 if (isIntegerOneOfWidth(innerInputElemType, 8, 16) &&
1097 castOp,
"avoid introducing i8/i16 -> i64 casts which are not "
1101 if (isIntegerOneOfWidth(innerInputElemType, 1, 64) &&
1104 castOp,
"avoid introducing bool/i64 to float casts which are not "
1105 "supported in all versions of TOSA");
1109 isIntegerOneOfWidth(outerOutputElemType, 1, 64)) {
1111 castOp,
"avoid introducing float to bool/i64 casts which are not "
1112 "supported in all versions of TOSA");
1118 "inner cast operation is narrowing");
1127 return semantics.nonFiniteBehavior !=
1128 llvm::fltNonfiniteBehavior::FiniteOnly;
1132 return semantics.nonFiniteBehavior == llvm::fltNonfiniteBehavior::IEEE754;
1136 const ShapedType outType)
const {
1138 if (inType.getElementType().isInteger() &&
1139 outType.getElementType().isInteger()) {
1141 const auto inTypeSignedness =
1142 cast<IntegerType>(inType.getElementType()).getSignedness();
1143 const auto outTypeSignedness =
1144 cast<IntegerType>(outType.getElementType()).getSignedness();
1146 return (inTypeSignedness != outTypeSignedness ||
1147 inType.getElementTypeBitWidth() >
1148 outType.getElementTypeBitWidth());
1151 if (inType.getElementType().isFloat() &&
1152 outType.getElementType().isFloat()) {
1154 FloatType inElemTy = cast<FloatType>(inType.getElementType());
1155 FloatType outElemTy = cast<FloatType>(outType.getElementType());
1156 llvm::fltSemantics inTypeSemantics = inElemTy.getFloatSemantics();
1157 llvm::fltSemantics outTypeSemantics = outElemTy.getFloatSemantics();
1163 [[maybe_unused]]
const auto isSupported = [](
Type elemType) {
1164 return llvm::isa<Float8E4M3FNType, Float8E5M2Type, BFloat16Type,
1165 Float16Type, Float32Type>(elemType);
1168 assert(isSupported(inElemTy) &&
1169 "unsupported input element type in isNarrowingCast");
1170 assert(isSupported(outElemTy) &&
1171 "unsupported output element type in isNarrowingCast");
1174 inTypeSemantics.maxExponent > outTypeSemantics.maxExponent ||
1175 inTypeSemantics.minExponent < outTypeSemantics.minExponent ||
1176 inTypeSemantics.precision > outTypeSemantics.precision ||
1187void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1188 MLIRContext *context) {
1189 results.
add<NonNarrowingCastsOptimization>(context);
1198 const Value castToBlockScaledInput = castToBlockScaledOp.getInputData();
1199 auto castFromBlockScaledOp =
1200 castToBlockScaledInput.
getDefiningOp<tosa::CastFromBlockScaledOp>();
1201 if (!castFromBlockScaledOp)
1203 castToBlockScaledOp,
1204 "input must be cast_from_block_scaled operation");
1206 const Value innerData = castFromBlockScaledOp.getInputData();
1207 const Value innerScale = castFromBlockScaledOp.getInputScale();
1208 const auto innerDataTy = llvm::cast<ShapedType>(innerData.
getType());
1209 const auto innerScaleTy = llvm::cast<ShapedType>(innerScale.
getType());
1211 const Value outerData = castToBlockScaledOp.getOutputData();
1212 const Value outerScale = castToBlockScaledOp.getOutputScale();
1213 const auto outerDataTy = llvm::cast<ShapedType>(outerData.
getType());
1214 const auto outerScaleTy = llvm::cast<ShapedType>(outerScale.
getType());
1216 if (innerDataTy != outerDataTy || innerScaleTy != outerScaleTy) {
1218 castToBlockScaledOp,
1219 "inputs types to cast_from_block_scaled operation must match output "
1220 "types to cast_to_block_scaled");
1223 if (castFromBlockScaledOp.getBlockSize() !=
1224 castToBlockScaledOp.getBlockSize()) {
1226 castToBlockScaledOp,
"block sizes for cast_from_block_scaled and "
1227 "cast_to_block_scaled must match");
1230 rewriter.
replaceOp(castToBlockScaledOp, {innerData, innerScale});
1236void CastToBlockScaledOp::getCanonicalizationPatterns(
1237 RewritePatternSet &results, MLIRContext *context) {
1238 results.
add<CancellingBlockScaledCastsOptimization>(context);
1245template <
typename Folder>
1246static DenseElementsAttr
1248 bool foldDenseValues =
false) {
1252 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1256 const auto rETy = llvm::cast<ShapedType>(
rhs.getType()).getElementType();
1260 if (
lhs.isSplat() &&
rhs.isSplat()) {
1261 if (isa<FloatType>(lETy)) {
1262 const APFloat l =
lhs.getSplatValue<APFloat>();
1263 const APFloat r =
rhs.getSplatValue<APFloat>();
1264 const auto maybeResult = Folder::fold(l, r);
1265 if (failed(maybeResult))
1270 if (
const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
1271 const APInt l =
lhs.getSplatValue<APInt>();
1272 const APInt r =
rhs.getSplatValue<APInt>();
1273 const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
1274 if (failed(maybeResult))
1280 if (foldDenseValues) {
1281 assert(lETy.isIntOrIndex() &&
1282 "Only integer types are currently supported.");
1285 llvm::zip(
lhs.getValues<APInt>(),
rhs.getValues<APInt>())) {
1286 const auto maybeResult = Folder::fold(l, r,
false);
1287 if (failed(maybeResult))
1289 resultValues.push_back(maybeResult.value());
1297template <
typename Folder>
1299 bool foldDenseValues =
false) {
1303 if (!returnTy.hasRank() || !returnTy.hasStaticShape())
1309 if (
const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
1311 const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
1312 if (failed(maybeResult))
1318 if (foldDenseValues) {
1322 for (
auto const &v : val.
getValues<APInt>()) {
1323 const auto maybeResult = Folder::fold(v,
false);
1324 if (failed(maybeResult))
1326 resultValues.push_back(maybeResult.value());
1341 assert(dense.isSplat());
1342 APInt a = dense.getSplatValue<APInt>();
1343 return a.getSExtValue();
1348 const bool isUnsigned) {
1351 isUnsigned ?
lhs.uadd_ov(
rhs, overflow) :
lhs.sadd_ov(
rhs, overflow);
1357 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1364 const bool isUnsigned) {
1367 isUnsigned ?
lhs.usub_ov(
rhs, overflow) :
lhs.ssub_ov(
rhs, overflow);
1373 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1380 const bool isUnsigned) {
1382 const unsigned originalWidth =
lhs.getBitWidth();
1385 if (
lhs.getBitWidth() !=
rhs.getBitWidth()) {
1390 if (
lhs == 0 ||
rhs == 0)
1391 return APInt::getZero(originalWidth);
1393 bool overflow =
false;
1395 isUnsigned ?
lhs.umul_ov(
rhs, overflow) :
lhs.smul_ov(
rhs, overflow);
1400 return result.trunc(originalWidth);
1403 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1409 return a.isNegative() !=
b.isNegative();
1416 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1424 APInt::udivrem(
lhs,
rhs, q, r);
1425 if (!r.isZero() && Ceil) {
1432 bool overflow{
false};
1433 APInt
const q =
lhs.sdiv_ov(
rhs, overflow);
1436 APInt
const r =
lhs.srem(
rhs);
1446 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1454 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1456 if (
lhs.isNegative() || (!
rhs.isStrictlyPositive()))
1466 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1468 auto const r = t.mod(
rhs);
1469 if (llvm::APFloatBase::opStatus::opOK == r) {
1479 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1481 return lhs.getSExtValue() >=
rhs.getSExtValue() ?
lhs :
rhs;
1484 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1492 if (
lhs.getBitWidth() !=
rhs.getBitWidth())
1494 return lhs.getSExtValue() <=
rhs.getSExtValue() ?
lhs :
rhs;
1497 static FailureOr<APFloat>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1503 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1504 auto const numBits = value.getBitWidth();
1506 auto const zextv = value.getZExtValue();
1507 if (zextv >= numBits)
1509 return APInt::getOneBitSet(numBits, zextv);
1511 auto const sextv = value.getSExtValue();
1512 if (sextv < 0 || sextv >= numBits || (value.isNegative()))
1514 return APInt::getOneBitSet(numBits, sextv);
1519 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1520 if (!value.isStrictlyPositive())
1522 return APInt(value.getBitWidth(), value.ceilLogBase2());
1527 static FailureOr<APInt>
fold(
const APInt &value,
bool isUnsigned) {
1528 if (!value.isStrictlyPositive())
1530 return APInt(value.getBitWidth(), value.logBase2());
1536 const bool isUnsigned) {
1537 return isUnsigned ? APInt(1,
lhs.ugt(
rhs)) : APInt(1,
lhs.sgt(
rhs));
1540 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1541 return APInt(1,
lhs >
rhs);
1547 const bool isUnsigned) {
1548 return isUnsigned ? APInt(1,
lhs.uge(
rhs)) : APInt(1,
lhs.sge(
rhs));
1551 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1552 return APInt(1,
lhs >=
rhs);
1558 const bool isUnsigned) {
1559 return APInt(1,
lhs ==
rhs);
1562 static FailureOr<APInt>
fold(
const APFloat &
lhs,
const APFloat &
rhs) {
1563 return APInt(1,
lhs ==
rhs);
1568 if (llvm::isa<FloatType>(elemType))
1570 if (llvm::isa<IntegerType>(elemType))
1576 if (llvm::isa<FloatType>(elemType))
1577 return val && val.
isSplat() &&
1579 if (llvm::isa<IntegerType>(elemType)) {
1580 const int64_t shifted = 1LL << shift;
1581 return val && val.
isSplat() &&
1587OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1588 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1589 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1590 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1591 if (!lhsTy || !rhsTy || !resultTy)
1595 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1596 !rhsTy.getElementType().isIntOrIndexOrFloat())
1599 auto resultETy = resultTy.getElementType();
1601 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1603 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1606 lhsTy.getShape(), rhsTy.getShape());
1607 if (isBroadcastable && lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1609 if (isBroadcastable && rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
1612 if (!lhsAttr || !rhsAttr)
1618OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
1619 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
1620 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1621 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
1622 !outputTy.hasStaticShape())
1626 if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.
isInteger()) {
1627 const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1628 const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1635OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
1636 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1637 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1638 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1639 if (!lhsTy || !rhsTy || !resultTy)
1641 if (lhsTy.getElementType() != rhsTy.getElementType())
1646 auto resultETy = resultTy.getElementType();
1648 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1650 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1651 if (lhsAttr && lhsAttr.isSplat()) {
1652 if (llvm::isa<IntegerType>(resultETy) && resultTy.hasStaticShape() &&
1653 lhsAttr.getSplatValue<APInt>().isZero())
1654 return lhsAttr.resizeSplat(resultTy);
1657 if (rhsAttr && rhsAttr.isSplat()) {
1659 lhsTy.getShape(), rhsTy.getShape());
1660 if (isBroadcastable && lhsTy == resultTy &&
1661 llvm::isa<IntegerType>(resultETy) &&
1662 rhsAttr.getSplatValue<APInt>().isOne())
1666 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() &&
1667 llvm::isa<IntegerType>(resultETy)) {
1668 APInt l = lhsAttr.getSplatValue<APInt>();
1669 APInt r = rhsAttr.getSplatValue<APInt>();
1671 auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
1673 DivFoldAdaptor<
false>::fold(l, r, intTy.isUnsigned());
1686std::optional<APInt> mulInt(APInt
lhs, APInt
rhs, int32_t shift,
1687 unsigned bitwidth) {
1688 bool overflow =
false;
1689 APInt
result =
lhs.sext(64).smul_ov(
rhs.sext(64), overflow);
1692 return std::nullopt;
1695 auto round = APInt(64, 1) << (shift - 1);
1697 result.ashrInPlace(shift);
1700 if (!(
result.getSExtValue() >= INT32_MIN &&
1701 result.getSExtValue() <= INT32_MAX)) {
1703 return std::nullopt;
1707 return result.trunc(bitwidth);
1710DenseElementsAttr mulBinaryFolder(DenseElementsAttr
lhs, DenseElementsAttr
rhs,
1711 RankedTensorType ty, int32_t shift) {
1713 if (llvm::isa<IntegerType>(ty.getElementType())) {
1714 APInt l =
lhs.getSplatValue<APInt>();
1715 APInt r =
rhs.getSplatValue<APInt>();
1721 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
1722 const std::optional<APInt>
result = mulInt(l, r, shift, bitwidth);
1728 if (llvm::isa<FloatType>(ty.getElementType())) {
1729 APFloat l =
lhs.getSplatValue<APFloat>();
1730 APFloat r =
rhs.getSplatValue<APFloat>();
1740OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1741 auto lhs = getInput1();
1742 auto rhs = getInput2();
1743 auto lhsTy = llvm::dyn_cast<RankedTensorType>(
lhs.getType());
1744 auto rhsTy = llvm::dyn_cast<RankedTensorType>(
rhs.getType());
1745 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1746 if (!lhsTy || !rhsTy || !resultTy)
1749 auto resultETy = resultTy.getElementType();
1751 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1753 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1758 if (resultETy.isInteger(32)) {
1759 ElementsAttr shift_elem;
1760 if (getShift().getImpl()) {
1764 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1768 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr) &&
1769 resultTy.hasStaticShape())
1771 return lhsAttr.resizeSplat(resultTy);
1772 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr) &&
1773 resultTy.hasStaticShape())
1774 return rhsAttr.resizeSplat(resultTy);
1777 lhsTy.getShape(), rhsTy.getShape());
1778 if (isBroadcastable && rhsTy == resultTy &&
1781 if (isBroadcastable && lhsTy == resultTy &&
1785 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
1788OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1789 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1790 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
1791 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
1792 if (!lhsTy || !rhsTy || !resultTy)
1796 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
1797 !rhsTy.getElementType().isIntOrIndexOrFloat())
1800 auto resultETy = resultTy.getElementType();
1802 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1804 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1807 lhsTy.getShape(), rhsTy.getShape());
1808 if (isBroadcastable && lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
1811 if (!lhsAttr || !rhsAttr)
1817OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
1818 auto resultTy = llvm::cast<ShapedType>(
getType());
1820 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1822 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1824 if (!lhsAttr || !rhsAttr)
1830OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
1831 auto resultTy = llvm::cast<ShapedType>(
getType());
1833 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1835 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1837 if (!lhsAttr || !rhsAttr)
1843OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
1844 auto resultTy = llvm::cast<ShapedType>(
getType());
1846 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
1848 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
1849 Value
lhs = getInput1();
1850 Value
rhs = getInput2();
1851 auto lhsTy = llvm::cast<ShapedType>(
lhs.getType());
1855 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy.hasRank() &&
1856 resultTy.hasStaticShape() &&
lhs ==
rhs) {
1860 if (!lhsAttr || !rhsAttr)
1866OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
1870 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
1874 auto inTy = llvm::cast<ShapedType>(getInput().
getType());
1875 auto outTy = llvm::cast<ShapedType>(
getType());
1876 if (!outTy.hasRank() || !outTy.hasStaticShape())
1878 auto inETy = inTy.getElementType();
1879 auto outETy = outTy.getElementType();
1881 if (operand.isSplat()) {
1882 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
1884 auto splatVal = operand.getSplatValue<APFloat>();
1885 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
1886 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
1891 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
1892 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
1893 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
1894 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
1895 llvm::RoundingMode::NearestTiesToEven);
1899 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1900 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
1901 auto intVal = APSInt(
1902 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
1903 auto floatVal = operand.getSplatValue<APFloat>();
1905 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
1910 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
1911 const auto inIntType = llvm::cast<IntegerType>(inETy);
1912 auto unsignIn = inIntType.isUnsignedInteger();
1914 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
1915 auto intVal = operand.getSplatValue<APInt>();
1916 auto bitwidth = outETy.getIntOrFloatBitWidth();
1919 if (outETy.isInteger(1)) {
1920 intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
1922 intVal = intVal.trunc(bitwidth);
1923 }
else if (unsignIn || inIntType.isInteger(1)) {
1924 intVal = intVal.zext(bitwidth);
1926 intVal = intVal.sext(bitwidth);
1936OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1938OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) {
return getValuesAttr(); }
1940#define REDUCE_FOLDER(OP) \
1941 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
1942 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
1943 if (!inputTy.hasRank()) \
1945 if (inputTy != getType()) \
1947 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
1948 return getInput(); \
1961 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1962 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1964 if (!inputTy || !outputTy)
1970 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
1974 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
1975 getInput1().getDefiningOp())) {
1976 getInput1Mutable().assign(reshapeOp.getInput1());
1981 if (!inputTy.getElementType().isIntOrIndexOrFloat())
1986 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1988 if (!outputTy.hasStaticShape())
1992 if (operand.isSplat())
1997 if (!getInput1().hasOneUse())
2004 return operand.reshape(
2005 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
2011OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
2013 if (adaptor.getPadding() && getInput1().
getType() ==
getType()) {
2014 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
2015 if (densePad && densePad.isSplat() &&
2016 densePad.getSplatValue<APInt>().isZero()) {
2026OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
2028 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
2030 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
2032 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
2033 if (!scaleAttr || !offsetAttr || !borderAttr) {
2040 if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
2045 if (scale[0] != scale[1] || scale[2] != scale[3]) {
2050 if (offset[0] != 0 || offset[1] != 0) {
2055 if (border[0] != 0 || border[1] != 0) {
2059 return foldToInputIfTypeMatches(
getType(), getInput());
2062OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
2063 auto operand = getInput1();
2064 auto operandTy = llvm::cast<ShapedType>(operand.getType());
2065 auto axis = getAxis();
2067 const bool isSplatInput =
2068 llvm::isa_and_nonnull<SplatElementsAttr>(adaptor.getInput1());
2069 if (!operandTy.hasRank() ||
2070 (!isSplatInput && operandTy.getDimSize(axis) != 1))
2072 return foldToInputIfTypeMatches(
getType(), operand);
2075OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
2076 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
2077 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
2079 if (!inputTy || !outputTy)
2082 if (inputTy == outputTy && inputTy.hasStaticShape())
2087 DenseElementsAttr startElems;
2093 llvm::all_of(startElems.
getValues<APInt>(),
2094 [](
const APInt &val) { return val.isZero(); });
2099 DenseElementsAttr sizeElems;
2103 auto inputShape = inputTy.getShape();
2104 auto sizeValues = sizeElems.
getValues<APInt>();
2106 bool sizeMatchesInput =
true;
2107 for (
const auto &[i, sizeVal] : llvm::enumerate(sizeValues)) {
2108 int64_t size = sizeVal.getSExtValue();
2110 if (inputTy.isDynamicDim(i)) {
2114 sizeMatchesInput =
false;
2121 sizeMatchesInput =
false;
2127 if (sizeMatchesInput)
2132 if (!adaptor.getInput1())
2136 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
2137 !outputTy.getElementType().isIntOrIndexOrFloat())
2140 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
2141 if (operand.isSplat() && outputTy.hasStaticShape()) {
2145 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
2146 outputTy.getNumElements() == 1) {
2147 llvm::SmallVector<uint64_t>
indices =
2148 llvm::to_vector(startElems.
getValues<uint64_t>());
2149 if (
auto values = operand.tryGetValues<Attribute>())
2156OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
2157 const Value pred = getPred();
2158 const Value onTrue = getOnTrue();
2159 const Value onFalse = getOnFalse();
2161 const auto predTy = llvm::dyn_cast<RankedTensorType>(pred.
getType());
2162 const auto onTrueTy = llvm::dyn_cast<RankedTensorType>(onTrue.
getType());
2163 const auto onFalseTy = llvm::dyn_cast<RankedTensorType>(onFalse.
getType());
2164 if (!predTy || !onTrueTy || !onFalseTy)
2167 const Type resultTy =
getType();
2169 const ArrayRef<int64_t> predShape = predTy.getShape();
2170 const ArrayRef<int64_t> onTrueShape = onTrueTy.getShape();
2172 if (onTrue == onFalse && onTrueTy == resultTy &&
2177 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
2180 if (!predicate.isSplat())
2183 const bool predicateValue = predicate.getSplatValue<APInt>().getBoolValue();
2185 SmallVector<SmallVector<int64_t>, 3> shapes;
2186 shapes.emplace_back(predShape);
2187 shapes.emplace_back(onTrueShape);
2188 shapes.emplace_back(onFalseTy.getShape());
2189 const bool isBroadcastable =
2192 if (predicateValue ==
true && onTrueTy == resultTy && isBroadcastable)
2194 if (predicateValue ==
false && onFalseTy == resultTy && isBroadcastable)
2199OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
2201 if (
auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
2202 adaptor.getMultiples())) {
2203 if (multiples.isSplat() &&
2204 multiples.getSplatValue<APInt>().getSExtValue() == 1)
2206 if (
auto int_array_attr =
2207 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
2208 if (llvm::all_of(int_array_attr.getValues<APInt>(),
2209 [](APInt v) { return v.getSExtValue() == 1; }))
2217OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
2218 auto resultTy = llvm::cast<ShapedType>(
getType());
2222 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
2223 if (input.isSplat() && resultTy.hasRank() && resultTy.hasStaticShape() &&
2224 input.
getType().getElementType() == resultTy.getElementType())
2225 return input.reshape(resultTy);
2229 const llvm::ArrayRef<int32_t> perms = getPerms();
2231 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
2234 return foldToInputIfTypeMatches(
getType(), getInput1());
2237OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
2240 auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
2246 if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2247 failed(maybeIZp) || *maybeIZp != 0) {
2251 if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2252 failed(maybeOZp) || *maybeOZp != 0) {
2256 if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
2257 failed(maybeIZp) || *maybeIZp != 0) {
2261 if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
2262 failed(maybeOZp) || *maybeOZp != 0) {
2267 return foldToInputIfTypeMatches(
getType(), definingOp.getInput1());
2270OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
2271 auto input = getInput1();
2274 return foldToInputIfTypeMatches(
getType(), input);
2279OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
2280 auto input = adaptor.getInput1();
2282 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
2284 if (!inputAttr || !inputAttr.isSplat())
2287 auto shapeType = llvm::cast<ShapedType>(
getType());
2288 if (!shapeType.hasRank() || !shapeType.hasStaticShape())
2290 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
2291 auto floatVal = inputAttr.getSplatValue<APFloat>();
2293 ReciprocalOp::calcOneElement(floatVal));
2299template <
typename Op,
typename OpFoldAdaptor>
2301 auto input1ConstShape =
2302 dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
2303 if (!input1ConstShape)
2306 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2312template <
typename Op,
typename OpFoldAdaptor>
2314 auto input1ConstShape =
2315 dyn_cast<tosa::ConstShapeOp>(op->getInput1().getDefiningOp());
2316 auto input2ConstShape =
2317 dyn_cast<tosa::ConstShapeOp>(op->getInput2().getDefiningOp());
2318 if (!input1ConstShape || !input2ConstShape)
2321 const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2322 const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
2325 input1Attr.getType(),
2329OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
2330 const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().
getType());
2331 if (!inputTy || !inputTy.hasRank())
2333 const int32_t axis = getAxis();
2334 const int64_t dimSize = inputTy.getDimSize(axis);
2335 if (ShapedType::isDynamic(dimSize))
2339 const auto resultAttrTy =
2340 RankedTensorType::get(1, builder.getIndexType());
2345 auto const inputs = op->getInput();
2351 concatDims.reserve( 64);
2352 for (
auto const &v : inputs) {
2353 auto vConstShape = dyn_cast<tosa::ConstShapeOp>(v.getDefiningOp());
2357 const auto vAttr = cast<DenseElementsAttr>(vConstShape.getValues());
2360 auto const vAttrVals = vAttr.getValues<APInt>();
2361 for (
auto const &v : vAttrVals) {
2362 concatDims.push_back(v);
2366 auto *ctx = op->getContext();
2367 assert(ctx !=
nullptr &&
"ctx is nullptr");
2368 auto const rankedTy = RankedTensorType::get(
2369 {
static_cast<int64_t>(concatDims.size())}, IndexType::get(ctx));
2375 auto const input1 = op->getInput();
2376 auto const input2 = op->getStart();
2377 auto const input3 = op->getSize();
2379 auto input1ConstShape = dyn_cast<tosa::ConstShapeOp>(input1.getDefiningOp());
2381 if (!input1ConstShape)
2384 auto const input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
2388 auto const input1Vals = input1Attr.getValues<APInt>();
2389 auto const totalInput1 = input1Vals.size();
2394 if (failed(start) || failed(size))
2397 auto const startV =
static_cast<int32_t
>(start.value());
2398 auto const sizeV =
static_cast<int32_t
>(size.value());
2400 if ((sizeV <= 0) || (startV < 0) ||
2401 (
static_cast<size_t>(startV + sizeV) > totalInput1))
2405 sliceOfInput.reserve(totalInput1);
2407 for (
auto i = startV; i < (startV + sizeV); i++) {
2408 sliceOfInput.push_back(input1Vals[i]);
2411 auto *ctx = op->getContext();
2412 assert(ctx !=
nullptr &&
"ctx is nullptr");
2414 auto const rankedTy = RankedTensorType::get(
2415 {
static_cast<int64_t>(sliceOfInput.size())}, IndexType::get(ctx));
2420OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
2424OpFoldResult tosa::SubShapeOp::fold(FoldAdaptor adaptor) {
2428OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
2432OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
2433 return binaryFold<DivCeilShapeOp, DivFoldAdaptor<
true>>(
this);
2436OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
2437 return binaryFold<DivFloorShapeOp, DivFoldAdaptor<
false>>(
this);
2440OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
2444OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
2448OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
2452OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
2456OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
2460OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
2464OpFoldResult tosa::ConcatShapeOp::fold(FoldAdaptor adaptor) {
2468OpFoldResult tosa::SliceShapeOp::fold(FoldAdaptor adaptor) {
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
#define REDUCE_FOLDER(OP)
OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op)
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, bool foldDenseValues=false)
static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy, bool foldDenseValues=false)
OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op)
static FailureOr< int64_t > getSingleI64From1ElementTensor(Value v)
OpFoldResult binaryFold(Op *op)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
OpFoldResult unaryShapeFold(Op *op)
static bool checkMatchingPadConstAndZp(Value padConst, Value zp)
static bool signsDiffer(const APInt &a, const APInt &b)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
int64_t size() const
Returns the number of elements held by this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
iterator begin() const
Iterator access to the integer element values.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
DynamicAPInt round(const Fraction &f)
TosaLevel getTosaLevelFromEnum(const Level level)
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
SmallVector< int64_t > convertFromIntAttr(const DenseElementsAttr &attr, const int rank)
TargetEnvAttr lookupTargetEnv(Operation *op)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::AvgPool2dAdaptiveOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::CastToBlockScaledOp castToBlockScaledOp, PatternRewriter &rewriter) const override
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APInt > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APInt > fold(const APInt &value, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(tosa::MaxPool2dAdaptiveOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
bool isNarrowingCast(const ShapedType inType, const ShapedType outType) const
LogicalResult matchAndRewrite(tosa::CastOp castOp, PatternRewriter &rewriter) const override
bool supportsInf(const llvm::fltSemantics &semantics) const
bool supportsNaN(const llvm::fltSemantics &semantics) const
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
static FailureOr< APFloat > fold(const APFloat &lhs, const APFloat &rhs)
static FailureOr< APInt > fold(const APInt &lhs, const APInt &rhs, const bool isUnsigned)
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
int32_t MAX_TENSOR_LIST_SIZE