27 #include "llvm/ADT/APFloat.h"
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/TypeSwitch.h"
45 if (op.getInput1().size() != 1)
47 if (op.getInput1().front().getType() != op.getType()) {
50 op.getInput1().front())
55 rewriter.
replaceOp(op, op.getInput1().front());
66 auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
70 op.getOperation()->setOperands(
71 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
84 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
87 "input must be transpose operation");
90 if (transposeOp.getConstantPerms(transposePerms).failed())
92 "transpose perms must be constant");
93 if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
95 transposeOp,
"inner transpose perms must be constant");
96 if (transposePerms.size() != innerTransposePerms.size())
99 "transpose and inner transpose perms sizes must be equal");
100 if (transposePerms.empty())
102 transposeOp,
"transpose perms sizes must be positive");
106 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
107 perms[i] = innerTransposePerms[transposePerms[i]];
113 rewriter.
create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
116 transposeOp, transposeOp.getResult().getType(),
117 innerTranspose.getInput1(), permsValue);
133 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
135 op,
"Src is from transpose, can compose transposes");
139 if (dyn_cast_or_null<tosa::TransposeOp>(subop))
141 op,
"Dest is used by transpose, can compose transposes");
144 auto input = op.getInput1();
145 auto inputTy = llvm::cast<ShapedType>(input.getType());
146 if (!inputTy.hasRank())
149 int64_t numDynDims = 0;
150 for (
int i = 0; i < inputTy.getRank(); ++i)
151 if (inputTy.isDynamicDim(i))
158 llvm::map_range(permAttr.getValues<APInt>(),
159 [](
const APInt &val) { return val.getSExtValue(); }));
162 nonZeroPerms.reserve(permValues.size());
163 for (
auto idx : permValues) {
164 auto sz = inputTy.getDimSize(idx);
166 nonZeroPerms.push_back(idx);
169 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
170 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
172 "Transpose changes memory layout.");
175 newShape.reserve(inputTy.getRank());
176 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
177 newShape.push_back(inputTy.getDimSize(permValues[i]));
180 op, op.getType(), op.getInput1(),
196 if (op.getPadConst())
199 auto input = op.getInput1();
200 auto padding = op.getPadding();
202 ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
203 Type elementTy = inputTy.getElementType();
206 if (llvm::isa<FloatType>(elementTy)) {
208 }
else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
210 }
else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
211 auto value = op.getQuantizationInfo()->getInputZp();
218 "tosa.pad to linalg lowering encountered an unknown element type");
223 auto constantVal = rewriter.
create<tosa::ConstOp>(
224 op.
getLoc(), denseAttr.getType(), denseAttr);
227 op, op.getType(),
ValueRange{input, padding, constantVal},
243 Value input = op.getInput();
244 Value output = op.getOutput();
245 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
246 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
248 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
254 if (outputShape[1] != 1 || outputShape[2] != 1) {
259 if (inputShape[1] != 1 || inputShape[2] != 1) {
278 Value input = op.getInput();
279 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
280 auto inputElementType = inputType.getElementType();
282 if (!inputType.hasStaticShape()) {
288 auto minClamp = op.getMinFp();
289 auto maxClamp = op.getMaxFp();
290 bool isMin = minClamp.isInfinity() && minClamp.isNegative();
291 bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
293 if (isMin && isMax) {
300 if (inputElementType.isUnsignedInteger()) {
301 int64_t minClamp = op.getMinInt();
302 int64_t maxClamp = op.getMaxInt();
305 APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
308 APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
311 if (minClamp <= intMin && maxClamp >= intMax) {
318 if (llvm::isa<IntegerType>(inputElementType)) {
319 int64_t minClamp = op.getMinInt();
320 int64_t maxClamp = op.getMaxInt();
323 APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
326 APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
329 if (minClamp <= intMin && maxClamp >= intMax) {
345 Value input = op.getInput();
351 if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
352 auto minFp =
std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
353 auto maxFp =
std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
355 auto minInt =
std::max(op.getMinInt(), clampOp.getMinInt());
356 auto maxInt =
std::min(op.getMaxInt(), clampOp.getMaxInt());
359 op, op.getType(), clampOp.getInput(),
381 Value sliceInput = sliceOp.getInput();
385 sliceOp,
"slice input must be concat operation");
388 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
389 if (!concatType || !concatType.hasStaticShape())
391 sliceOp,
"slice input must be a static ranked tensor");
392 int32_t axis = concatOp.getAxis();
400 std::optional<Value> replaceWithSlice;
401 for (
auto input : inputs) {
402 auto inputType = dyn_cast<RankedTensorType>(input.getType());
403 if (!inputType || !inputType.hasStaticShape())
405 sliceOp,
"concat input must be a static ranked tensor");
407 if (sliceStart[axis] >= 0 &&
408 (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
409 replaceWithSlice = rewriter
411 sliceOp.getLoc(), sliceOp.getType(), input,
417 sliceStart[axis] -= inputType.getDimSize(axis);
420 if (!replaceWithSlice)
422 sliceOp,
"corresponding concat input not found for slice");
424 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
438 template <
typename IntFolder,
typename FloatFolder>
440 RankedTensorType returnTy) {
443 auto rETy = llvm::cast<ShapedType>(rhs.
getType()).getElementType();
447 if (llvm::isa<IntegerType>(lETy)) {
450 auto result = IntFolder()(l, r);
454 if (llvm::isa<FloatType>(lETy)) {
457 auto result = FloatFolder()(l, r);
466 if (llvm::isa<FloatType>(elemType))
468 if (llvm::isa<IntegerType>(elemType))
474 if (llvm::isa<FloatType>(elemType))
477 if (llvm::isa<IntegerType>(elemType)) {
478 const int64_t shifted = 1LL << shift;
486 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
487 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
488 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
489 if (!lhsTy || !rhsTy || !resultTy)
492 auto resultETy = resultTy.getElementType();
493 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
494 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
496 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
498 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
501 if (!lhsAttr || !rhsAttr)
504 return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
509 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
510 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
511 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
512 if (!lhsTy || !rhsTy || !resultTy)
517 auto resultETy = resultTy.getElementType();
518 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
519 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
520 if (lhsAttr && lhsAttr.isSplat()) {
521 if (llvm::isa<IntegerType>(resultETy) &&
522 lhsAttr.getSplatValue<APInt>().isZero())
526 if (rhsAttr && rhsAttr.isSplat()) {
527 if (llvm::isa<IntegerType>(resultETy) &&
528 rhsAttr.getSplatValue<APInt>().isOne())
532 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) {
533 if (llvm::isa<IntegerType>(resultETy)) {
534 APInt l = lhsAttr.getSplatValue<APInt>();
535 APInt r = rhsAttr.getSplatValue<APInt>();
536 APInt result = l.sdiv(r);
546 RankedTensorType ty, int32_t shift) {
548 if (llvm::isa<IntegerType>(ty.getElementType())) {
556 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
557 l = l.sext(bitwidth * 2);
558 r = r.sext(bitwidth * 2);
560 result.lshrInPlace(shift);
561 result = result.trunc(bitwidth);
565 if (llvm::isa<FloatType>(ty.getElementType())) {
568 APFloat result = l * r;
578 auto lhs = getInput1();
579 auto rhs = getInput2();
580 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.
getType());
581 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.
getType());
582 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
583 if (!lhsTy || !rhsTy || !resultTy)
586 auto resultETy = resultTy.getElementType();
587 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
588 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
590 const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
591 if (rhsTy == resultTy) {
593 return lhsAttr.resizeSplat(resultTy);
597 if (lhsTy == resultTy) {
604 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
608 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
609 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
610 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
611 if (!lhsTy || !rhsTy || !resultTy)
614 auto resultETy = resultTy.getElementType();
615 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
616 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
618 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
621 if (!lhsAttr || !rhsAttr)
624 return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
629 template <
typename Cmp>
630 struct ComparisonFold {
631 ComparisonFold() =
default;
632 APInt operator()(
const APInt &l,
const APInt &r) {
633 return APInt(1, Cmp()(l, r));
636 APInt operator()(
const APFloat &l,
const APFloat &r) {
637 return APInt(1, Cmp()(l, r));
641 struct APIntFoldGreater {
642 APIntFoldGreater() =
default;
643 APInt operator()(
const APInt &l,
const APInt &r) {
644 return APInt(1, l.sgt(r));
648 struct APIntFoldGreaterEqual {
649 APIntFoldGreaterEqual() =
default;
650 APInt operator()(
const APInt &l,
const APInt &r) {
651 return APInt(1, l.sge(r));
657 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
658 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
659 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
661 if (!lhsAttr || !rhsAttr)
664 return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
665 lhsAttr, rhsAttr, resultTy);
668 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
669 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
670 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
671 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
673 if (!lhsAttr || !rhsAttr)
677 ComparisonFold<std::greater_equal<APFloat>>>(
678 lhsAttr, rhsAttr, resultTy);
682 auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
683 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
684 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
685 Value lhs = getInput1();
686 Value rhs = getInput2();
687 auto lhsTy = llvm::cast<ShapedType>(lhs.
getType());
691 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
692 resultTy.hasStaticShape() && lhs == rhs) {
696 if (!lhsAttr || !rhsAttr)
699 return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
700 ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
705 if (getInput().getType() == getType())
708 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
712 auto inTy = llvm::cast<ShapedType>(getInput().getType());
713 auto outTy = llvm::cast<ShapedType>(getType());
714 auto inETy = inTy.getElementType();
715 auto outETy = outTy.getElementType();
717 if (operand.isSplat()) {
718 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
720 auto splatVal = operand.getSplatValue<APFloat>();
721 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
722 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
727 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
728 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
729 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
730 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
731 llvm::RoundingMode::NearestTiesToEven);
735 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
736 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
737 auto intVal = APSInt(
738 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
739 auto floatVal = operand.getSplatValue<APFloat>();
741 floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact);
745 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
746 auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
748 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
749 auto intVal = operand.getSplatValue<APInt>();
750 auto bitwidth = outETy.getIntOrFloatBitWidth();
753 intVal = intVal.trunc(bitwidth);
754 }
else if (unsignIn) {
755 intVal = intVal.zext(bitwidth);
757 intVal = intVal.sext(bitwidth);
767 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValueAttr(); }
769 #define REDUCE_FOLDER(OP) \
770 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
771 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
772 if (!inputTy.hasRank()) \
774 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
788 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
789 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
791 if (!inputTy || !outputTy)
794 if (inputTy == outputTy)
798 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
799 getInput1().getDefiningOp())) {
800 getInput1Mutable().assign(reshapeOp.getInput1());
805 if (
auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
807 if (!outputTy.hasStaticShape())
811 if (operand.isSplat())
815 if (!getInput1().hasOneUse())
818 return operand.reshape(
819 llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
827 if (adaptor.getPadding()) {
828 auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
829 if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
845 if (scale[0] != scale[1] || scale[2] != scale[3]) {
850 if (offset[0] != 0 || offset[1] != 0) {
855 if (border[0] != 0 || border[1] != 0) {
859 auto input = getInput();
860 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
861 auto resultTy = llvm::cast<RankedTensorType>(getType());
862 if (inputTy != resultTy)
869 auto operand = getInput();
870 auto operandTy = llvm::cast<ShapedType>(operand.getType());
871 auto axis = getAxis();
872 auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
877 if (operandTy.hasRank() &&
878 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
885 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
886 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
888 if (!inputTy || !outputTy)
891 if (inputTy == outputTy && inputTy.hasStaticShape())
894 if (!adaptor.getInput())
898 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
899 !outputTy.getElementType().isIntOrIndexOrFloat())
902 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
903 if (operand.isSplat() && outputTy.hasStaticShape()) {
907 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
908 outputTy.getNumElements() == 1) {
910 auto value = operand.getValues<
Attribute>()[indices];
917 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
918 if (getOnTrue() == getOnFalse())
921 auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
925 if (!predicate.isSplat())
927 return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
932 bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) {
return v == 1; });
933 if (allOnes && getInput1().getType() == getType())
939 auto inputTy = llvm::cast<ShapedType>(getInput1().getType());
940 auto resultTy = llvm::cast<ShapedType>(getType());
943 if (
auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
944 if (input.isSplat() && resultTy.hasStaticShape() &&
945 inputTy.getElementType() == resultTy.getElementType())
946 return input.reshape(resultTy);
950 if (getInput1().getType() != getType())
955 if (getConstantPerms(perms).
failed())
958 if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms))
965 auto input = getInput1();
967 if (
auto op = input.getDefiningOp<tosa::ExpOp>()) {
968 return op.getInput1();
975 auto input = getInput1();
977 if (
auto op = input.getDefiningOp<tosa::LogOp>()) {
978 return op.getInput1();
984 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
985 auto input = getInput1();
987 if (
auto op = input.getDefiningOp<tosa::NegateOp>()) {
988 return op.getInput1();
995 auto input = getInput1();
997 if (
auto op = input.getDefiningOp<tosa::AbsOp>()) {
1010 concatOperands.reserve(2 * getNumOperands());
1013 bool foundFoldableConcat =
false;
1014 for (
Value operand : getOperands()) {
1015 concatOperands.emplace_back(operand);
1017 auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1022 if (getAxis() != producer.getAxis())
1026 foundFoldableConcat =
true;
1027 concatOperands.pop_back();
1028 llvm::append_range(concatOperands, producer->getOperands());
1031 if (!foundFoldableConcat)
1034 getOperation()->setOperands(concatOperands);
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
#define REDUCE_FOLDER(OP)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
IntegerAttr getI64IntegerAttr(int64_t value)
FloatAttr getF32FloatAttr(float value)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::PadOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...