28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
47 if (op.getInput1().size() != 1)
49 if (op.getInput1().front().getType() != op.getType()) {
52 op.getInput1().front())
57 rewriter.
replaceOp(op, op.getInput1().front());
67 LogicalResult SelectOp::canonicalize(SelectOp op,
PatternRewriter &rewriter) {
68 auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
72 op.getOperation()->setOperands(
73 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
86 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
89 "input must be transpose operation");
92 if (transposeOp.getConstantPerms(transposePerms).failed())
94 "transpose perms must be constant");
95 if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
97 transposeOp,
"inner transpose perms must be constant");
98 if (transposePerms.size() != innerTransposePerms.size())
101 "transpose and inner transpose perms sizes must be equal");
102 if (transposePerms.empty())
104 transposeOp,
"transpose perms sizes must be positive");
108 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
109 perms[i] = innerTransposePerms[transposePerms[i]];
114 Value permsValue = rewriter.
create<tosa::ConstOp>(transposeOp.getLoc(),
118 transposeOp, transposeOp.getResult().getType(),
119 innerTranspose.getInput1(), permsValue);
135 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
137 op,
"Src is from transpose, can compose transposes");
139 Value result = op.getResult();
141 if (dyn_cast_or_null<tosa::TransposeOp>(subop))
143 op,
"Dest is used by transpose, can compose transposes");
146 auto input = op.getInput1();
147 auto inputTy = llvm::cast<ShapedType>(input.getType());
148 if (!inputTy.hasRank())
151 int64_t numDynDims = 0;
152 for (
int i = 0; i < inputTy.getRank(); ++i)
153 if (inputTy.isDynamicDim(i))
160 llvm::map_range(permAttr.getValues<APInt>(),
161 [](
const APInt &val) { return val.getSExtValue(); }));
164 nonZeroPerms.reserve(permValues.size());
165 for (
auto idx : permValues) {
166 auto sz = inputTy.getDimSize(idx);
168 nonZeroPerms.push_back(idx);
171 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
172 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
174 "Transpose changes memory layout.");
177 newShape.reserve(inputTy.getRank());
178 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
179 newShape.push_back(inputTy.getDimSize(permValues[i]));
182 op, op.getType(), op.getInput1(),
198 if (op.getPadConst())
201 auto input = op.getInput1();
202 auto padding = op.getPadding();
204 ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
205 Type elementTy = inputTy.getElementType();
208 if (llvm::isa<FloatType>(elementTy)) {
210 }
else if (llvm::isa<IntegerType>(elementTy) && !op.getInputZpAttr()) {
212 }
else if (llvm::isa<IntegerType>(elementTy) && op.getInputZpAttr()) {
213 int64_t value = op.getInputZpAttr().getInt();
220 "tosa.pad to linalg lowering encountered an unknown element type");
225 auto constantVal = rewriter.
create<tosa::ConstOp>(
226 op.getLoc(), denseAttr.getType(), denseAttr);
229 op, op.getType(),
ValueRange{input, padding, constantVal},
245 Value input = op.getInput();
246 Value output = op.getOutput();
247 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
248 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
250 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
256 if (outputShape[1] != 1 || outputShape[2] != 1) {
261 if (inputShape[1] != 1 || inputShape[2] != 1) {
280 Value input = op.getInput();
281 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
282 auto inputElementType = inputType.getElementType();
284 if (!inputType.hasStaticShape()) {
288 if (isa<FloatType>(inputElementType)) {
290 auto minClamp = op.getMinFp();
291 auto maxClamp = op.getMaxFp();
292 bool isMin = minClamp.isInfinity() && minClamp.isNegative();
293 bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
295 if (isMin && isMax) {
302 if (inputElementType.isUnsignedInteger()) {
303 int64_t minClamp = op.getMinInt();
304 int64_t maxClamp = op.getMaxInt();
307 APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
310 APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
313 if (minClamp <= intMin && maxClamp >= intMax) {
320 if (llvm::isa<IntegerType>(inputElementType)) {
321 int64_t minClamp = op.getMinInt();
322 int64_t maxClamp = op.getMaxInt();
325 APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
328 APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
331 if (minClamp <= intMin && maxClamp >= intMax) {
363 template <
typename T>
365 ClampRange(
const T &start,
const T &end) : start(start), end(end) {}
371 return start < otherRange.
end && otherRange.
start < end;
379 dyn_cast_if_present<tosa::ClampOp>(op.getInput().getDefiningOp());
384 const auto opNanMode = op.getNanMode();
385 const auto clampNanMode = clampOp.getNanMode();
386 if (opNanMode ==
"IGNORE" && clampNanMode ==
"PROPAGATE")
390 const auto opMinInt = op.getMinInt();
391 const auto opMaxInt = op.getMaxInt();
392 const auto clampOpMinInt = clampOp.getMinInt();
393 const auto clampOpMaxInt = clampOp.getMaxInt();
396 if (!opRangeIntRange.
intersects(clampRangeIntRange))
399 const auto opMinFloat = op.getMinFp();
400 const auto opMaxFloat = op.getMaxFp();
401 const auto clampOpMinFloat = clampOp.getMinFp();
402 const auto clampOpMaxFloat = clampOp.getMaxFp();
405 if (!opRangeFloatRange.
intersects(clampRangeFloatRange))
409 const auto minFp =
std::max(opMinFloat, clampOpMinFloat).convertToFloat();
410 const auto maxFp =
std::min(opMaxFloat, clampOpMaxFloat).convertToFloat();
411 const auto minInt =
std::max(opMinInt, clampOpMinInt);
412 const auto maxInt =
std::min(opMaxInt, clampOpMaxInt);
414 op, op.getType(), clampOp.getInput(),
417 rewriter.
getStringAttr((opNanMode != clampNanMode) ?
"IGNORE"
434 Value sliceInput = sliceOp.getInput1();
438 sliceOp,
"slice input must be concat operation");
441 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
442 if (!concatType || !concatType.hasStaticShape())
444 sliceOp,
"slice input must be a static ranked tensor");
445 int32_t axis = concatOp.getAxis();
452 sliceOp,
"start of slice must be a static ranked shape");
456 sliceOp,
"size of slice must be a static ranked shape");
459 llvm::to_vector(startElems.
getValues<int64_t>());
461 llvm::to_vector(sizeElems.
getValues<int64_t>());
466 std::optional<Value> replaceWithSlice;
467 for (
auto input : inputs) {
468 auto inputType = dyn_cast<RankedTensorType>(input.getType());
469 if (!inputType || !inputType.hasStaticShape())
471 sliceOp,
"concat input must be a static ranked tensor");
473 if (sliceStarts[axis] >= 0 && (sliceStarts[axis] + sliceSizes[axis]) <=
474 inputType.getDimSize(axis)) {
481 .
create<tosa::SliceOp>(sliceOp.getLoc(), sliceOp.getType(),
482 input, start_op, size_op)
486 sliceStarts[axis] -= inputType.getDimSize(axis);
489 if (!replaceWithSlice)
491 sliceOp,
"corresponding concat input not found for slice");
493 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
507 template <
typename IntFolder,
typename FloatFolder>
509 RankedTensorType returnTy) {
512 auto rETy = llvm::cast<ShapedType>(rhs.
getType()).getElementType();
516 if (llvm::isa<IntegerType>(lETy)) {
519 auto result = IntFolder()(l, r);
523 if (llvm::isa<FloatType>(lETy)) {
526 auto result = FloatFolder()(l, r);
535 if (llvm::isa<FloatType>(elemType))
537 if (llvm::isa<IntegerType>(elemType))
543 if (llvm::isa<FloatType>(elemType))
546 if (llvm::isa<IntegerType>(elemType)) {
547 const int64_t shifted = 1LL << shift;
555 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
556 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
557 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
558 if (!lhsTy || !rhsTy || !resultTy)
562 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
563 !rhsTy.getElementType().isIntOrIndexOrFloat())
566 auto resultETy = resultTy.getElementType();
568 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
570 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
572 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
574 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
577 if (!lhsAttr || !rhsAttr)
580 return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
585 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
586 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
587 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
588 !outputTy.hasStaticShape())
591 if (inputTy.getDimSize(getAxis()) == 1)
598 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
599 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
600 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
601 if (!lhsTy || !rhsTy || !resultTy)
607 auto resultETy = resultTy.getElementType();
609 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
611 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
612 if (lhsAttr && lhsAttr.isSplat()) {
613 if (llvm::isa<IntegerType>(resultETy) &&
614 lhsAttr.getSplatValue<APInt>().isZero())
618 if (rhsAttr && rhsAttr.isSplat()) {
619 if (llvm::isa<IntegerType>(resultETy) &&
620 rhsAttr.getSplatValue<APInt>().isOne())
624 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) {
625 if (llvm::isa<IntegerType>(resultETy)) {
626 APInt l = lhsAttr.getSplatValue<APInt>();
627 APInt r = rhsAttr.getSplatValue<APInt>();
628 APInt result = l.sdiv(r);
638 RankedTensorType ty, int32_t shift) {
640 if (llvm::isa<IntegerType>(ty.getElementType())) {
648 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
649 l = l.sext(bitwidth * 2);
650 r = r.sext(bitwidth * 2);
652 result.lshrInPlace(shift);
653 result = result.trunc(bitwidth);
657 if (llvm::isa<FloatType>(ty.getElementType())) {
660 APFloat result = l * r;
670 auto lhs = getInput1();
671 auto rhs = getInput2();
672 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.
getType());
673 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.
getType());
674 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
675 if (!lhsTy || !rhsTy || !resultTy)
678 auto resultETy = resultTy.getElementType();
680 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
682 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
687 if (resultETy.isInteger(32)) {
688 ElementsAttr shift_elem;
689 if (getShift().getImpl()) {
693 shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
697 if (rhsTy == resultTy) {
699 return lhsAttr.resizeSplat(resultTy);
703 if (lhsTy == resultTy) {
710 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
714 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
715 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
716 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
717 if (!lhsTy || !rhsTy || !resultTy)
721 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
722 !rhsTy.getElementType().isIntOrIndexOrFloat())
725 auto resultETy = resultTy.getElementType();
727 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
729 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
731 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
734 if (!lhsAttr || !rhsAttr)
737 return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
742 template <
typename Cmp>
743 struct ComparisonFold {
744 ComparisonFold() =
default;
745 APInt operator()(
const APInt &l,
const APInt &r) {
746 return APInt(1, Cmp()(l, r));
749 APInt operator()(
const APFloat &l,
const APFloat &r) {
750 return APInt(1, Cmp()(l, r));
754 struct APIntFoldGreater {
755 APIntFoldGreater() =
default;
756 APInt operator()(
const APInt &l,
const APInt &r) {
757 return APInt(1, l.sgt(r));
761 struct APIntFoldGreaterEqual {
762 APIntFoldGreaterEqual() =
default;
763 APInt operator()(
const APInt &l,
const APInt &r) {
764 return APInt(1, l.sge(r));
770 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
772 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
774 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
776 if (!lhsAttr || !rhsAttr)
779 return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
780 lhsAttr, rhsAttr, resultTy);
783 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
784 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
786 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
788 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
790 if (!lhsAttr || !rhsAttr)
794 ComparisonFold<std::greater_equal<APFloat>>>(
795 lhsAttr, rhsAttr, resultTy);
799 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
801 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
803 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
804 Value lhs = getInput1();
805 Value rhs = getInput2();
806 auto lhsTy = llvm::cast<ShapedType>(lhs.
getType());
810 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
811 resultTy.hasStaticShape() && lhs == rhs) {
815 if (!lhsAttr || !rhsAttr)
818 return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
819 ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
827 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
831 auto inTy = llvm::cast<ShapedType>(getInput().
getType());
832 auto outTy = llvm::cast<ShapedType>(
getType());
833 auto inETy = inTy.getElementType();
834 auto outETy = outTy.getElementType();
836 if (operand.isSplat()) {
837 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
839 auto splatVal = operand.getSplatValue<APFloat>();
840 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
841 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
846 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
847 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
848 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
849 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
850 llvm::RoundingMode::NearestTiesToEven);
854 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
855 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
856 auto intVal = APSInt(
857 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
858 auto floatVal = operand.getSplatValue<APFloat>();
860 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
865 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
866 auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
868 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
869 auto intVal = operand.getSplatValue<APInt>();
870 auto bitwidth = outETy.getIntOrFloatBitWidth();
873 intVal = intVal.trunc(bitwidth);
874 }
else if (unsignIn) {
875 intVal = intVal.zext(bitwidth);
877 intVal = intVal.sext(bitwidth);
887 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValueAttr(); }
889 OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) {
return getValueAttr(); }
891 #define REDUCE_FOLDER(OP) \
892 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
893 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
894 if (!inputTy.hasRank()) \
896 if (inputTy != getType()) \
898 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
912 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
913 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
915 if (!inputTy || !outputTy)
921 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
925 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
926 getInput1().getDefiningOp())) {
927 getInput1Mutable().assign(reshapeOp.getInput1());
932 if (!inputTy.getElementType().isIntOrIndexOrFloat())
937 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
939 if (!outputTy.hasStaticShape())
943 if (operand.isSplat())
948 if (!getInput1().hasOneUse())
955 return operand.reshape(
956 llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
965 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
966 if (densePad && densePad.isSplat() &&
967 densePad.getSplatValue<APInt>().isZero()) {
983 if (scale[0] != scale[1] || scale[2] != scale[3]) {
988 if (offset[0] != 0 || offset[1] != 0) {
993 if (border[0] != 0 || border[1] != 0) {
997 auto input = getInput();
998 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
999 auto resultTy = llvm::cast<RankedTensorType>(
getType());
1000 if (inputTy != resultTy)
1007 auto operand = getInput1();
1008 auto operandTy = llvm::cast<ShapedType>(operand.getType());
1009 auto axis = getAxis();
1011 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
1016 if (operandTy.hasRank() &&
1017 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
1024 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1025 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
1027 if (!inputTy || !outputTy)
1030 if (inputTy == outputTy && inputTy.hasStaticShape())
1033 if (!adaptor.getInput1())
1037 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
1038 !outputTy.getElementType().isIntOrIndexOrFloat())
1041 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
1042 if (operand.isSplat() && outputTy.hasStaticShape()) {
1046 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
1047 outputTy.getNumElements() == 1) {
1053 llvm::to_vector(startElems.
getValues<uint64_t>());
1054 auto value = operand.getValues<
Attribute>()[indices];
1061 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1062 if (getOnTrue() == getOnFalse())
1066 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
1070 if (!predicate.isSplat())
1072 return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1078 if (
auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
1079 adaptor.getMultiples())) {
1080 if (multiples.isSplat() &&
1081 multiples.getSplatValue<APInt>().getSExtValue() == 1)
1083 if (
auto int_array_attr =
1084 llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
1085 if (llvm::all_of(int_array_attr.getValues<APInt>(),
1086 [](APInt v) { return v.getSExtValue() == 1; }))
1095 auto resultTy = llvm::cast<ShapedType>(
getType());
1099 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1100 if (input.isSplat() && resultTy.hasStaticShape() &&
1101 input.getType().getElementType() == resultTy.getElementType())
1102 return input.reshape(resultTy);
1107 if (getConstantPerms(perms).failed())
1110 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1117 auto input = getInput1();
1119 if (
auto op = input.getDefiningOp<tosa::ExpOp>()) {
1120 return op.getInput1();
1127 auto input = getInput1();
1129 if (
auto op = input.getDefiningOp<tosa::LogOp>()) {
1130 return op.getInput1();
1136 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1137 auto input = getInput1();
1139 if (
auto op = input.getDefiningOp<tosa::NegateOp>()) {
1140 return op.getInput1();
1147 auto input = getInput1();
1149 if (
auto op = input.getDefiningOp<tosa::AbsOp>()) {
1162 concatOperands.reserve(2 * getNumOperands());
1165 bool foundFoldableConcat =
false;
1166 for (
Value operand : getOperands()) {
1167 concatOperands.emplace_back(operand);
1169 auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1174 if (getAxis() != producer.getAxis())
1178 foundFoldableConcat =
true;
1179 concatOperands.pop_back();
1180 llvm::append_range(concatOperands, producer->getOperands());
1183 if (!foundFoldableConcat)
1186 getOperation()->setOperands(concatOperands);
1190 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1191 auto input = adaptor.getInput1();
1193 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1195 if (!inputAttr || !inputAttr.isSplat())
1198 auto shapeType = llvm::cast<ShapedType>(
getType());
1199 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1200 auto floatVal = inputAttr.getSplatValue<APFloat>();
1202 ReciprocalOp::calcOneElement(floatVal));
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
#define REDUCE_FOLDER(OP)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
FloatAttr getF32FloatAttr(float value)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
auto getValues() const
Return the held element values as a range of the given type.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValue(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool intersects(const ClampRange< T > &otherRange)
ClampRange(const T &start, const T &end)
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::PadOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...