28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
47 if (op.getInput1().size() != 1)
49 if (op.getInput1().front().getType() != op.getType()) {
52 op.getInput1().front())
57 rewriter.
replaceOp(op, op.getInput1().front());
67 LogicalResult SelectOp::canonicalize(SelectOp op,
PatternRewriter &rewriter) {
68 auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
72 op.getOperation()->setOperands(
73 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
86 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
89 "input must be transpose operation");
92 if (transposeOp.getConstantPerms(transposePerms).failed())
94 "transpose perms must be constant");
95 if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
97 transposeOp,
"inner transpose perms must be constant");
98 if (transposePerms.size() != innerTransposePerms.size())
101 "transpose and inner transpose perms sizes must be equal");
102 if (transposePerms.empty())
104 transposeOp,
"transpose perms sizes must be positive");
108 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
109 perms[i] = innerTransposePerms[transposePerms[i]];
115 rewriter.
create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
118 transposeOp, transposeOp.getResult().getType(),
119 innerTranspose.getInput1(), permsValue);
135 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
137 op,
"Src is from transpose, can compose transposes");
139 Value result = op.getResult();
141 if (dyn_cast_or_null<tosa::TransposeOp>(subop))
143 op,
"Dest is used by transpose, can compose transposes");
146 auto input = op.getInput1();
147 auto inputTy = llvm::cast<ShapedType>(input.getType());
148 if (!inputTy.hasRank())
151 int64_t numDynDims = 0;
152 for (
int i = 0; i < inputTy.getRank(); ++i)
153 if (inputTy.isDynamicDim(i))
160 llvm::map_range(permAttr.getValues<APInt>(),
161 [](
const APInt &val) { return val.getSExtValue(); }));
164 nonZeroPerms.reserve(permValues.size());
165 for (
auto idx : permValues) {
166 auto sz = inputTy.getDimSize(idx);
168 nonZeroPerms.push_back(idx);
171 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
172 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
174 "Transpose changes memory layout.");
177 newShape.reserve(inputTy.getRank());
178 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
179 newShape.push_back(inputTy.getDimSize(permValues[i]));
182 op, op.getType(), op.getInput1(),
198 if (op.getPadConst())
201 auto input = op.getInput1();
202 auto padding = op.getPadding();
204 ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
205 Type elementTy = inputTy.getElementType();
208 if (llvm::isa<FloatType>(elementTy)) {
210 }
else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
212 }
else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
213 auto value = op.getQuantizationInfo()->getInputZp();
220 "tosa.pad to linalg lowering encountered an unknown element type");
225 auto constantVal = rewriter.
create<tosa::ConstOp>(
226 op.getLoc(), denseAttr.getType(), denseAttr);
229 op, op.getType(),
ValueRange{input, padding, constantVal},
245 Value input = op.getInput();
246 Value output = op.getOutput();
247 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
248 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
250 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
256 if (outputShape[1] != 1 || outputShape[2] != 1) {
261 if (inputShape[1] != 1 || inputShape[2] != 1) {
280 Value input = op.getInput();
281 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
282 auto inputElementType = inputType.getElementType();
284 if (!inputType.hasStaticShape()) {
288 if (isa<FloatType>(inputElementType)) {
290 auto minClamp = op.getMinFp();
291 auto maxClamp = op.getMaxFp();
292 bool isMin = minClamp.isInfinity() && minClamp.isNegative();
293 bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
295 if (isMin && isMax) {
302 if (inputElementType.isUnsignedInteger()) {
303 int64_t minClamp = op.getMinInt();
304 int64_t maxClamp = op.getMaxInt();
307 APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
310 APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
313 if (minClamp <= intMin && maxClamp >= intMax) {
320 if (llvm::isa<IntegerType>(inputElementType)) {
321 int64_t minClamp = op.getMinInt();
322 int64_t maxClamp = op.getMaxInt();
325 APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
328 APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
331 if (minClamp <= intMin && maxClamp >= intMax) {
347 Value input = op.getInput();
353 if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
354 auto minFp =
std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
355 auto maxFp =
std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
357 auto minInt =
std::max(op.getMinInt(), clampOp.getMinInt());
358 auto maxInt =
std::min(op.getMaxInt(), clampOp.getMaxInt());
361 op, op.getType(), clampOp.getInput(),
383 Value sliceInput = sliceOp.getInput1();
387 sliceOp,
"slice input must be concat operation");
390 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
391 if (!concatType || !concatType.hasStaticShape())
393 sliceOp,
"slice input must be a static ranked tensor");
394 int32_t axis = concatOp.getAxis();
402 std::optional<Value> replaceWithSlice;
403 for (
auto input : inputs) {
404 auto inputType = dyn_cast<RankedTensorType>(input.getType());
405 if (!inputType || !inputType.hasStaticShape())
407 sliceOp,
"concat input must be a static ranked tensor");
409 if (sliceStart[axis] >= 0 &&
410 (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
411 replaceWithSlice = rewriter
413 sliceOp.getLoc(), sliceOp.getType(), input,
419 sliceStart[axis] -= inputType.getDimSize(axis);
422 if (!replaceWithSlice)
424 sliceOp,
"corresponding concat input not found for slice");
426 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
440 template <
typename IntFolder,
typename FloatFolder>
442 RankedTensorType returnTy) {
445 auto rETy = llvm::cast<ShapedType>(rhs.
getType()).getElementType();
449 if (llvm::isa<IntegerType>(lETy)) {
452 auto result = IntFolder()(l, r);
456 if (llvm::isa<FloatType>(lETy)) {
459 auto result = FloatFolder()(l, r);
468 if (llvm::isa<FloatType>(elemType))
470 if (llvm::isa<IntegerType>(elemType))
476 if (llvm::isa<FloatType>(elemType))
479 if (llvm::isa<IntegerType>(elemType)) {
480 const int64_t shifted = 1LL << shift;
488 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
489 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
490 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
491 if (!lhsTy || !rhsTy || !resultTy)
495 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
496 !rhsTy.getElementType().isIntOrIndexOrFloat())
499 auto resultETy = resultTy.getElementType();
501 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
503 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
505 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
507 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
510 if (!lhsAttr || !rhsAttr)
513 return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
518 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
519 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
520 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
521 !outputTy.hasStaticShape())
524 if (inputTy.getDimSize(getAxis()) == 1)
531 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
532 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
533 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
534 if (!lhsTy || !rhsTy || !resultTy)
540 auto resultETy = resultTy.getElementType();
542 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
544 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
545 if (lhsAttr && lhsAttr.isSplat()) {
546 if (llvm::isa<IntegerType>(resultETy) &&
547 lhsAttr.getSplatValue<APInt>().isZero())
551 if (rhsAttr && rhsAttr.isSplat()) {
552 if (llvm::isa<IntegerType>(resultETy) &&
553 rhsAttr.getSplatValue<APInt>().isOne())
557 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) {
558 if (llvm::isa<IntegerType>(resultETy)) {
559 APInt l = lhsAttr.getSplatValue<APInt>();
560 APInt r = rhsAttr.getSplatValue<APInt>();
561 APInt result = l.sdiv(r);
571 RankedTensorType ty, int32_t shift) {
573 if (llvm::isa<IntegerType>(ty.getElementType())) {
581 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
582 l = l.sext(bitwidth * 2);
583 r = r.sext(bitwidth * 2);
585 result.lshrInPlace(shift);
586 result = result.trunc(bitwidth);
590 if (llvm::isa<FloatType>(ty.getElementType())) {
593 APFloat result = l * r;
603 auto lhs = getInput1();
604 auto rhs = getInput2();
605 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.
getType());
606 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.
getType());
607 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
608 if (!lhsTy || !rhsTy || !resultTy)
611 auto resultETy = resultTy.getElementType();
613 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
615 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
617 const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
619 if (rhsTy == resultTy) {
621 return lhsAttr.resizeSplat(resultTy);
625 if (lhsTy == resultTy) {
632 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
636 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
637 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
638 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
639 if (!lhsTy || !rhsTy || !resultTy)
643 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
644 !rhsTy.getElementType().isIntOrIndexOrFloat())
647 auto resultETy = resultTy.getElementType();
649 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
651 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
653 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
656 if (!lhsAttr || !rhsAttr)
659 return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
664 template <
typename Cmp>
665 struct ComparisonFold {
666 ComparisonFold() =
default;
667 APInt operator()(
const APInt &l,
const APInt &r) {
668 return APInt(1, Cmp()(l, r));
671 APInt operator()(
const APFloat &l,
const APFloat &r) {
672 return APInt(1, Cmp()(l, r));
676 struct APIntFoldGreater {
677 APIntFoldGreater() =
default;
678 APInt operator()(
const APInt &l,
const APInt &r) {
679 return APInt(1, l.sgt(r));
683 struct APIntFoldGreaterEqual {
684 APIntFoldGreaterEqual() =
default;
685 APInt operator()(
const APInt &l,
const APInt &r) {
686 return APInt(1, l.sge(r));
692 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
694 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
696 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
698 if (!lhsAttr || !rhsAttr)
701 return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
702 lhsAttr, rhsAttr, resultTy);
705 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
706 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
708 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
710 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
712 if (!lhsAttr || !rhsAttr)
716 ComparisonFold<std::greater_equal<APFloat>>>(
717 lhsAttr, rhsAttr, resultTy);
721 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
723 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
725 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
726 Value lhs = getInput1();
727 Value rhs = getInput2();
728 auto lhsTy = llvm::cast<ShapedType>(lhs.
getType());
732 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
733 resultTy.hasStaticShape() && lhs == rhs) {
737 if (!lhsAttr || !rhsAttr)
740 return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
741 ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
749 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
753 auto inTy = llvm::cast<ShapedType>(getInput().
getType());
754 auto outTy = llvm::cast<ShapedType>(
getType());
755 auto inETy = inTy.getElementType();
756 auto outETy = outTy.getElementType();
758 if (operand.isSplat()) {
759 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
761 auto splatVal = operand.getSplatValue<APFloat>();
762 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
763 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
768 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
769 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
770 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
771 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
772 llvm::RoundingMode::NearestTiesToEven);
776 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
777 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
778 auto intVal = APSInt(
779 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
780 auto floatVal = operand.getSplatValue<APFloat>();
782 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
787 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
788 auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
790 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
791 auto intVal = operand.getSplatValue<APInt>();
792 auto bitwidth = outETy.getIntOrFloatBitWidth();
795 intVal = intVal.trunc(bitwidth);
796 }
else if (unsignIn) {
797 intVal = intVal.zext(bitwidth);
799 intVal = intVal.sext(bitwidth);
809 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValueAttr(); }
811 #define REDUCE_FOLDER(OP) \
812 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
813 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
814 if (!inputTy.hasRank()) \
816 if (inputTy != getType()) \
818 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
832 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
833 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
835 if (!inputTy || !outputTy)
841 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
845 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
846 getInput1().getDefiningOp())) {
847 getInput1Mutable().assign(reshapeOp.getInput1());
852 if (!inputTy.getElementType().isIntOrIndexOrFloat())
857 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
859 if (!outputTy.hasStaticShape())
863 if (operand.isSplat())
868 if (!getInput1().hasOneUse())
871 return operand.reshape(
872 llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
881 auto densePad = llvm::dyn_cast<DenseElementsAttr>(adaptor.getPadding());
882 if (densePad && densePad.isSplat() &&
883 densePad.getSplatValue<APInt>().isZero()) {
899 if (scale[0] != scale[1] || scale[2] != scale[3]) {
904 if (offset[0] != 0 || offset[1] != 0) {
909 if (border[0] != 0 || border[1] != 0) {
913 auto input = getInput();
914 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
915 auto resultTy = llvm::cast<RankedTensorType>(
getType());
916 if (inputTy != resultTy)
923 auto operand = getInput1();
924 auto operandTy = llvm::cast<ShapedType>(operand.getType());
925 auto axis = getAxis();
927 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
932 if (operandTy.hasRank() &&
933 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
940 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
941 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
943 if (!inputTy || !outputTy)
946 if (inputTy == outputTy && inputTy.hasStaticShape())
949 if (!adaptor.getInput1())
953 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
954 !outputTy.getElementType().isIntOrIndexOrFloat())
957 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput1());
958 if (operand.isSplat() && outputTy.hasStaticShape()) {
962 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
963 outputTy.getNumElements() == 1) {
965 auto value = operand.getValues<
Attribute>()[indices];
972 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
973 if (getOnTrue() == getOnFalse())
977 llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
981 if (!predicate.isSplat())
983 return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
988 bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) {
return v == 1; });
995 auto resultTy = llvm::cast<ShapedType>(
getType());
999 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
1000 if (input.isSplat() && resultTy.hasStaticShape() &&
1001 input.getType().getElementType() == resultTy.getElementType())
1002 return input.reshape(resultTy);
1011 if (getConstantPerms(perms).failed())
1014 if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
1021 auto input = getInput1();
1023 if (
auto op = input.getDefiningOp<tosa::ExpOp>()) {
1024 return op.getInput1();
1031 auto input = getInput1();
1033 if (
auto op = input.getDefiningOp<tosa::LogOp>()) {
1034 return op.getInput1();
1040 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1041 auto input = getInput1();
1043 if (
auto op = input.getDefiningOp<tosa::NegateOp>()) {
1044 return op.getInput1();
1051 auto input = getInput1();
1053 if (
auto op = input.getDefiningOp<tosa::AbsOp>()) {
1066 concatOperands.reserve(2 * getNumOperands());
1069 bool foundFoldableConcat =
false;
1070 for (
Value operand : getOperands()) {
1071 concatOperands.emplace_back(operand);
1073 auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1078 if (getAxis() != producer.getAxis())
1082 foundFoldableConcat =
true;
1083 concatOperands.pop_back();
1084 llvm::append_range(concatOperands, producer->getOperands());
1087 if (!foundFoldableConcat)
1090 getOperation()->setOperands(concatOperands);
1094 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1095 auto input = adaptor.getInput1();
1097 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1099 if (!inputAttr || !inputAttr.isSplat())
1102 auto shapeType = llvm::cast<ShapedType>(
getType());
1103 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1104 auto floatVal = inputAttr.getSplatValue<APFloat>();
1106 ReciprocalOp::calcOneElement(floatVal));
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
#define REDUCE_FOLDER(OP)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
IntegerAttr getI64IntegerAttr(int64_t value)
FloatAttr getF32FloatAttr(float value)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::PadOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...