28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
47 if (op.getInput1().size() != 1)
49 if (op.getInput1().front().getType() != op.getType()) {
52 op.getInput1().front())
57 rewriter.
replaceOp(op, op.getInput1().front());
68 auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
72 op.getOperation()->setOperands(
73 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
86 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
89 "input must be transpose operation");
92 if (transposeOp.getConstantPerms(transposePerms).failed())
94 "transpose perms must be constant");
95 if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
97 transposeOp,
"inner transpose perms must be constant");
98 if (transposePerms.size() != innerTransposePerms.size())
101 "transpose and inner transpose perms sizes must be equal");
102 if (transposePerms.empty())
104 transposeOp,
"transpose perms sizes must be positive");
108 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
109 perms[i] = innerTransposePerms[transposePerms[i]];
115 rewriter.
create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
118 transposeOp, transposeOp.getResult().getType(),
119 innerTranspose.getInput1(), permsValue);
135 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
137 op,
"Src is from transpose, can compose transposes");
141 if (dyn_cast_or_null<tosa::TransposeOp>(subop))
143 op,
"Dest is used by transpose, can compose transposes");
146 auto input = op.getInput1();
147 auto inputTy = llvm::cast<ShapedType>(input.getType());
148 if (!inputTy.hasRank())
151 int64_t numDynDims = 0;
152 for (
int i = 0; i < inputTy.getRank(); ++i)
153 if (inputTy.isDynamicDim(i))
160 llvm::map_range(permAttr.getValues<APInt>(),
161 [](
const APInt &val) { return val.getSExtValue(); }));
164 nonZeroPerms.reserve(permValues.size());
165 for (
auto idx : permValues) {
166 auto sz = inputTy.getDimSize(idx);
168 nonZeroPerms.push_back(idx);
171 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
172 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
174 "Transpose changes memory layout.");
177 newShape.reserve(inputTy.getRank());
178 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
179 newShape.push_back(inputTy.getDimSize(permValues[i]));
182 op, op.getType(), op.getInput1(),
198 if (op.getPadConst())
201 auto input = op.getInput1();
202 auto padding = op.getPadding();
204 ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
205 Type elementTy = inputTy.getElementType();
208 if (llvm::isa<FloatType>(elementTy)) {
210 }
else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
212 }
else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
213 auto value = op.getQuantizationInfo()->getInputZp();
220 "tosa.pad to linalg lowering encountered an unknown element type");
225 auto constantVal = rewriter.
create<tosa::ConstOp>(
226 op.
getLoc(), denseAttr.getType(), denseAttr);
229 op, op.getType(),
ValueRange{input, padding, constantVal},
245 Value input = op.getInput();
246 Value output = op.getOutput();
247 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
248 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
250 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
256 if (outputShape[1] != 1 || outputShape[2] != 1) {
261 if (inputShape[1] != 1 || inputShape[2] != 1) {
280 Value input = op.getInput();
281 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
282 auto inputElementType = inputType.getElementType();
284 if (!inputType.hasStaticShape()) {
288 if (isa<FloatType>(inputElementType)) {
290 auto minClamp = op.getMinFp();
291 auto maxClamp = op.getMaxFp();
292 bool isMin = minClamp.isInfinity() && minClamp.isNegative();
293 bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
295 if (isMin && isMax) {
302 if (inputElementType.isUnsignedInteger()) {
303 int64_t minClamp = op.getMinInt();
304 int64_t maxClamp = op.getMaxInt();
307 APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
310 APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
313 if (minClamp <= intMin && maxClamp >= intMax) {
320 if (llvm::isa<IntegerType>(inputElementType)) {
321 int64_t minClamp = op.getMinInt();
322 int64_t maxClamp = op.getMaxInt();
325 APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
328 APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
331 if (minClamp <= intMin && maxClamp >= intMax) {
347 Value input = op.getInput();
353 if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
354 auto minFp =
std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
355 auto maxFp =
std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
357 auto minInt =
std::max(op.getMinInt(), clampOp.getMinInt());
358 auto maxInt =
std::min(op.getMaxInt(), clampOp.getMaxInt());
361 op, op.getType(), clampOp.getInput(),
383 Value sliceInput = sliceOp.getInput();
387 sliceOp,
"slice input must be concat operation");
390 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
391 if (!concatType || !concatType.hasStaticShape())
393 sliceOp,
"slice input must be a static ranked tensor");
394 int32_t axis = concatOp.getAxis();
402 std::optional<Value> replaceWithSlice;
403 for (
auto input : inputs) {
404 auto inputType = dyn_cast<RankedTensorType>(input.getType());
405 if (!inputType || !inputType.hasStaticShape())
407 sliceOp,
"concat input must be a static ranked tensor");
409 if (sliceStart[axis] >= 0 &&
410 (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
411 replaceWithSlice = rewriter
413 sliceOp.getLoc(), sliceOp.getType(), input,
419 sliceStart[axis] -= inputType.getDimSize(axis);
422 if (!replaceWithSlice)
424 sliceOp,
"corresponding concat input not found for slice");
426 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
440 template <
typename IntFolder,
typename FloatFolder>
442 RankedTensorType returnTy) {
445 auto rETy = llvm::cast<ShapedType>(rhs.
getType()).getElementType();
449 if (llvm::isa<IntegerType>(lETy)) {
452 auto result = IntFolder()(l, r);
456 if (llvm::isa<FloatType>(lETy)) {
459 auto result = FloatFolder()(l, r);
468 if (llvm::isa<FloatType>(elemType))
470 if (llvm::isa<IntegerType>(elemType))
476 if (llvm::isa<FloatType>(elemType))
479 if (llvm::isa<IntegerType>(elemType)) {
480 const int64_t shifted = 1LL << shift;
488 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
489 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
490 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
491 if (!lhsTy || !rhsTy || !resultTy)
494 auto resultETy = resultTy.getElementType();
495 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
496 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
498 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
500 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
503 if (!lhsAttr || !rhsAttr)
506 return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
511 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
512 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
513 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
514 !outputTy.hasStaticShape())
517 if (inputTy.getDimSize(getAxis()) == 1)
524 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
525 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
526 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
527 if (!lhsTy || !rhsTy || !resultTy)
532 auto resultETy = resultTy.getElementType();
533 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
534 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
535 if (lhsAttr && lhsAttr.isSplat()) {
536 if (llvm::isa<IntegerType>(resultETy) &&
537 lhsAttr.getSplatValue<APInt>().isZero())
541 if (rhsAttr && rhsAttr.isSplat()) {
542 if (llvm::isa<IntegerType>(resultETy) &&
543 rhsAttr.getSplatValue<APInt>().isOne())
547 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) {
548 if (llvm::isa<IntegerType>(resultETy)) {
549 APInt l = lhsAttr.getSplatValue<APInt>();
550 APInt r = rhsAttr.getSplatValue<APInt>();
551 APInt result = l.sdiv(r);
561 RankedTensorType ty, int32_t shift) {
563 if (llvm::isa<IntegerType>(ty.getElementType())) {
571 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
572 l = l.sext(bitwidth * 2);
573 r = r.sext(bitwidth * 2);
575 result.lshrInPlace(shift);
576 result = result.trunc(bitwidth);
580 if (llvm::isa<FloatType>(ty.getElementType())) {
583 APFloat result = l * r;
593 auto lhs = getInput1();
594 auto rhs = getInput2();
595 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.
getType());
596 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.
getType());
597 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
598 if (!lhsTy || !rhsTy || !resultTy)
601 auto resultETy = resultTy.getElementType();
602 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
603 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
605 const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
606 if (rhsTy == resultTy) {
608 return lhsAttr.resizeSplat(resultTy);
612 if (lhsTy == resultTy) {
619 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
623 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
624 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
625 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
626 if (!lhsTy || !rhsTy || !resultTy)
629 auto resultETy = resultTy.getElementType();
630 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
631 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
633 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
636 if (!lhsAttr || !rhsAttr)
639 return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
644 template <
typename Cmp>
645 struct ComparisonFold {
646 ComparisonFold() =
default;
647 APInt operator()(
const APInt &l,
const APInt &r) {
648 return APInt(1, Cmp()(l, r));
651 APInt operator()(
const APFloat &l,
const APFloat &r) {
652 return APInt(1, Cmp()(l, r));
656 struct APIntFoldGreater {
657 APIntFoldGreater() =
default;
658 APInt operator()(
const APInt &l,
const APInt &r) {
659 return APInt(1, l.sgt(r));
663 struct APIntFoldGreaterEqual {
664 APIntFoldGreaterEqual() =
default;
665 APInt operator()(
const APInt &l,
const APInt &r) {
666 return APInt(1, l.sge(r));
672 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
673 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
674 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
676 if (!lhsAttr || !rhsAttr)
679 return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
680 lhsAttr, rhsAttr, resultTy);
683 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
684 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
685 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
686 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
688 if (!lhsAttr || !rhsAttr)
692 ComparisonFold<std::greater_equal<APFloat>>>(
693 lhsAttr, rhsAttr, resultTy);
697 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
698 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
699 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
700 Value lhs = getInput1();
701 Value rhs = getInput2();
702 auto lhsTy = llvm::cast<ShapedType>(lhs.
getType());
706 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
707 resultTy.hasStaticShape() && lhs == rhs) {
711 if (!lhsAttr || !rhsAttr)
714 return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
715 ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
720 if (getInput().getType() == getType())
723 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
727 auto inTy = llvm::cast<ShapedType>(getInput().getType());
728 auto outTy = llvm::cast<ShapedType>(getType());
729 auto inETy = inTy.getElementType();
730 auto outETy = outTy.getElementType();
732 if (operand.isSplat()) {
733 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
735 auto splatVal = operand.getSplatValue<APFloat>();
736 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
737 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
742 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
743 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
744 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
745 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
746 llvm::RoundingMode::NearestTiesToEven);
750 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
751 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
752 auto intVal = APSInt(
753 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
754 auto floatVal = operand.getSplatValue<APFloat>();
756 floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact);
760 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
761 auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
763 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
764 auto intVal = operand.getSplatValue<APInt>();
765 auto bitwidth = outETy.getIntOrFloatBitWidth();
768 intVal = intVal.trunc(bitwidth);
769 }
else if (unsignIn) {
770 intVal = intVal.zext(bitwidth);
772 intVal = intVal.sext(bitwidth);
782 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValueAttr(); }
784 #define REDUCE_FOLDER(OP) \
785 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
786 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
787 if (!inputTy.hasRank()) \
789 if (inputTy != getType()) \
791 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
805 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
806 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
808 if (!inputTy || !outputTy)
814 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
818 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
819 getInput1().getDefiningOp())) {
820 getInput1Mutable().assign(reshapeOp.getInput1());
825 if (
auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
827 if (!outputTy.hasStaticShape())
831 if (operand.isSplat())
835 if (!getInput1().hasOneUse())
838 return operand.reshape(
839 llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
847 if (adaptor.getPadding()) {
848 auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
849 if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
865 if (scale[0] != scale[1] || scale[2] != scale[3]) {
870 if (offset[0] != 0 || offset[1] != 0) {
875 if (border[0] != 0 || border[1] != 0) {
879 auto input = getInput();
880 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
881 auto resultTy = llvm::cast<RankedTensorType>(getType());
882 if (inputTy != resultTy)
889 auto operand = getInput();
890 auto operandTy = llvm::cast<ShapedType>(operand.getType());
891 auto axis = getAxis();
892 auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
897 if (operandTy.hasRank() &&
898 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
905 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
906 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
908 if (!inputTy || !outputTy)
911 if (inputTy == outputTy && inputTy.hasStaticShape())
914 if (!adaptor.getInput())
918 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
919 !outputTy.getElementType().isIntOrIndexOrFloat())
922 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
923 if (operand.isSplat() && outputTy.hasStaticShape()) {
927 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
928 outputTy.getNumElements() == 1) {
930 auto value = operand.getValues<
Attribute>()[indices];
937 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
938 if (getOnTrue() == getOnFalse())
941 auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
945 if (!predicate.isSplat())
947 return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
952 bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) {
return v == 1; });
953 if (allOnes && getInput1().getType() == getType())
959 auto inputTy = llvm::cast<ShapedType>(getInput1().getType());
960 auto resultTy = llvm::cast<ShapedType>(getType());
963 if (
auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
964 if (input.isSplat() && resultTy.hasStaticShape() &&
965 inputTy.getElementType() == resultTy.getElementType())
966 return input.reshape(resultTy);
970 if (getInput1().getType() != getType())
975 if (getConstantPerms(perms).
failed())
978 if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms))
985 auto input = getInput1();
987 if (
auto op = input.getDefiningOp<tosa::ExpOp>()) {
988 return op.getInput1();
995 auto input = getInput1();
997 if (
auto op = input.getDefiningOp<tosa::LogOp>()) {
998 return op.getInput1();
1004 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1005 auto input = getInput1();
1007 if (
auto op = input.getDefiningOp<tosa::NegateOp>()) {
1008 return op.getInput1();
1015 auto input = getInput1();
1017 if (
auto op = input.getDefiningOp<tosa::AbsOp>()) {
1030 concatOperands.reserve(2 * getNumOperands());
1033 bool foundFoldableConcat =
false;
1034 for (
Value operand : getOperands()) {
1035 concatOperands.emplace_back(operand);
1037 auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1042 if (getAxis() != producer.getAxis())
1046 foundFoldableConcat =
true;
1047 concatOperands.pop_back();
1048 llvm::append_range(concatOperands, producer->getOperands());
1051 if (!foundFoldableConcat)
1054 getOperation()->setOperands(concatOperands);
1058 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1059 auto input = adaptor.getInput1();
1061 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1063 if (!inputAttr || !inputAttr.isSplat())
1066 auto shapeType = llvm::cast<ShapedType>(getType());
1067 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1068 auto floatVal = inputAttr.getSplatValue<APFloat>();
1070 ReciprocalOp::calcOneElement(floatVal));
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
#define REDUCE_FOLDER(OP)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
IntegerAttr getI64IntegerAttr(int64_t value)
FloatAttr getF32FloatAttr(float value)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::PadOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...