28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
47 if (op.getInput1().size() != 1)
49 if (op.getInput1().front().getType() != op.getType()) {
52 op.getInput1().front())
57 rewriter.
replaceOp(op, op.getInput1().front());
67 LogicalResult SelectOp::canonicalize(SelectOp op,
PatternRewriter &rewriter) {
68 auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
72 op.getOperation()->setOperands(
73 {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
86 transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
89 "input must be transpose operation");
92 if (transposeOp.getConstantPerms(transposePerms).failed())
94 "transpose perms must be constant");
95 if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
97 transposeOp,
"inner transpose perms must be constant");
98 if (transposePerms.size() != innerTransposePerms.size())
101 "transpose and inner transpose perms sizes must be equal");
102 if (transposePerms.empty())
104 transposeOp,
"transpose perms sizes must be positive");
108 for (
int i = 0, s = transposePerms.size(); i < s; ++i)
109 perms[i] = innerTransposePerms[transposePerms[i]];
115 rewriter.
create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
118 transposeOp, transposeOp.getResult().getType(),
119 innerTranspose.getInput1(), permsValue);
135 if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
137 op,
"Src is from transpose, can compose transposes");
141 if (dyn_cast_or_null<tosa::TransposeOp>(subop))
143 op,
"Dest is used by transpose, can compose transposes");
146 auto input = op.getInput1();
147 auto inputTy = llvm::cast<ShapedType>(input.getType());
148 if (!inputTy.hasRank())
151 int64_t numDynDims = 0;
152 for (
int i = 0; i < inputTy.getRank(); ++i)
153 if (inputTy.isDynamicDim(i))
160 llvm::map_range(permAttr.getValues<APInt>(),
161 [](
const APInt &val) { return val.getSExtValue(); }));
164 nonZeroPerms.reserve(permValues.size());
165 for (
auto idx : permValues) {
166 auto sz = inputTy.getDimSize(idx);
168 nonZeroPerms.push_back(idx);
171 for (
int i = 1, s = nonZeroPerms.size(); i < s; ++i)
172 if (nonZeroPerms[i - 1] > nonZeroPerms[i])
174 "Transpose changes memory layout.");
177 newShape.reserve(inputTy.getRank());
178 for (
int i = 0, s = inputTy.getRank(); i < s; ++i)
179 newShape.push_back(inputTy.getDimSize(permValues[i]));
182 op, op.getType(), op.getInput1(),
198 if (op.getPadConst())
201 auto input = op.getInput1();
202 auto padding = op.getPadding();
204 ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
205 Type elementTy = inputTy.getElementType();
208 if (llvm::isa<FloatType>(elementTy)) {
210 }
else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
212 }
else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
213 auto value = op.getQuantizationInfo()->getInputZp();
220 "tosa.pad to linalg lowering encountered an unknown element type");
225 auto constantVal = rewriter.
create<tosa::ConstOp>(
226 op.
getLoc(), denseAttr.getType(), denseAttr);
229 op, op.getType(),
ValueRange{input, padding, constantVal},
245 Value input = op.getInput();
246 Value output = op.getOutput();
247 ShapedType inputType = llvm::cast<ShapedType>(input.
getType());
248 ShapedType outputType = llvm::cast<ShapedType>(output.
getType());
250 if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
256 if (outputShape[1] != 1 || outputShape[2] != 1) {
261 if (inputShape[1] != 1 || inputShape[2] != 1) {
280 Value input = op.getInput();
281 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
282 auto inputElementType = inputType.getElementType();
284 if (!inputType.hasStaticShape()) {
288 if (isa<FloatType>(inputElementType)) {
290 auto minClamp = op.getMinFp();
291 auto maxClamp = op.getMaxFp();
292 bool isMin = minClamp.isInfinity() && minClamp.isNegative();
293 bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
295 if (isMin && isMax) {
302 if (inputElementType.isUnsignedInteger()) {
303 int64_t minClamp = op.getMinInt();
304 int64_t maxClamp = op.getMaxInt();
307 APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
310 APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
313 if (minClamp <= intMin && maxClamp >= intMax) {
320 if (llvm::isa<IntegerType>(inputElementType)) {
321 int64_t minClamp = op.getMinInt();
322 int64_t maxClamp = op.getMaxInt();
325 APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
328 APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
331 if (minClamp <= intMin && maxClamp >= intMax) {
347 Value input = op.getInput();
353 if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
354 auto minFp =
std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
355 auto maxFp =
std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
357 auto minInt =
std::max(op.getMinInt(), clampOp.getMinInt());
358 auto maxInt =
std::min(op.getMaxInt(), clampOp.getMaxInt());
361 op, op.getType(), clampOp.getInput(),
383 Value sliceInput = sliceOp.getInput();
387 sliceOp,
"slice input must be concat operation");
390 auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
391 if (!concatType || !concatType.hasStaticShape())
393 sliceOp,
"slice input must be a static ranked tensor");
394 int32_t axis = concatOp.getAxis();
402 std::optional<Value> replaceWithSlice;
403 for (
auto input : inputs) {
404 auto inputType = dyn_cast<RankedTensorType>(input.getType());
405 if (!inputType || !inputType.hasStaticShape())
407 sliceOp,
"concat input must be a static ranked tensor");
409 if (sliceStart[axis] >= 0 &&
410 (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
411 replaceWithSlice = rewriter
413 sliceOp.getLoc(), sliceOp.getType(), input,
419 sliceStart[axis] -= inputType.getDimSize(axis);
422 if (!replaceWithSlice)
424 sliceOp,
"corresponding concat input not found for slice");
426 rewriter.
replaceOp(sliceOp, replaceWithSlice.value());
440 template <
typename IntFolder,
typename FloatFolder>
442 RankedTensorType returnTy) {
445 auto rETy = llvm::cast<ShapedType>(rhs.
getType()).getElementType();
449 if (llvm::isa<IntegerType>(lETy)) {
452 auto result = IntFolder()(l, r);
456 if (llvm::isa<FloatType>(lETy)) {
459 auto result = FloatFolder()(l, r);
468 if (llvm::isa<FloatType>(elemType))
470 if (llvm::isa<IntegerType>(elemType))
476 if (llvm::isa<FloatType>(elemType))
479 if (llvm::isa<IntegerType>(elemType)) {
480 const int64_t shifted = 1LL << shift;
488 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
489 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
490 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
491 if (!lhsTy || !rhsTy || !resultTy)
495 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
496 !rhsTy.getElementType().isIntOrIndexOrFloat())
499 auto resultETy = resultTy.getElementType();
500 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
501 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
503 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
505 if (rhsTy == resultTy &&
isSplatZero(resultETy, lhsAttr))
508 if (!lhsAttr || !rhsAttr)
511 return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
516 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
517 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
518 if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
519 !outputTy.hasStaticShape())
522 if (inputTy.getDimSize(getAxis()) == 1)
529 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
530 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
531 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
532 if (!lhsTy || !rhsTy || !resultTy)
538 auto resultETy = resultTy.getElementType();
539 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
540 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
541 if (lhsAttr && lhsAttr.isSplat()) {
542 if (llvm::isa<IntegerType>(resultETy) &&
543 lhsAttr.getSplatValue<APInt>().isZero())
547 if (rhsAttr && rhsAttr.isSplat()) {
548 if (llvm::isa<IntegerType>(resultETy) &&
549 rhsAttr.getSplatValue<APInt>().isOne())
553 if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) {
554 if (llvm::isa<IntegerType>(resultETy)) {
555 APInt l = lhsAttr.getSplatValue<APInt>();
556 APInt r = rhsAttr.getSplatValue<APInt>();
557 APInt result = l.sdiv(r);
567 RankedTensorType ty, int32_t shift) {
569 if (llvm::isa<IntegerType>(ty.getElementType())) {
577 auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
578 l = l.sext(bitwidth * 2);
579 r = r.sext(bitwidth * 2);
581 result.lshrInPlace(shift);
582 result = result.trunc(bitwidth);
586 if (llvm::isa<FloatType>(ty.getElementType())) {
589 APFloat result = l * r;
599 auto lhs = getInput1();
600 auto rhs = getInput2();
601 auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.
getType());
602 auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.
getType());
603 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
604 if (!lhsTy || !rhsTy || !resultTy)
607 auto resultETy = resultTy.getElementType();
608 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
609 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
611 const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
612 if (rhsTy == resultTy) {
614 return lhsAttr.resizeSplat(resultTy);
618 if (lhsTy == resultTy) {
625 return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
629 auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
630 auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().
getType());
631 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
632 if (!lhsTy || !rhsTy || !resultTy)
636 if (!lhsTy.getElementType().isIntOrIndexOrFloat() ||
637 !rhsTy.getElementType().isIntOrIndexOrFloat())
640 auto resultETy = resultTy.getElementType();
641 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
642 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
644 if (lhsTy == resultTy &&
isSplatZero(resultETy, rhsAttr))
647 if (!lhsAttr || !rhsAttr)
650 return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
655 template <
typename Cmp>
656 struct ComparisonFold {
657 ComparisonFold() =
default;
658 APInt operator()(
const APInt &l,
const APInt &r) {
659 return APInt(1, Cmp()(l, r));
662 APInt operator()(
const APFloat &l,
const APFloat &r) {
663 return APInt(1, Cmp()(l, r));
667 struct APIntFoldGreater {
668 APIntFoldGreater() =
default;
669 APInt operator()(
const APInt &l,
const APInt &r) {
670 return APInt(1, l.sgt(r));
674 struct APIntFoldGreaterEqual {
675 APIntFoldGreaterEqual() =
default;
676 APInt operator()(
const APInt &l,
const APInt &r) {
677 return APInt(1, l.sge(r));
683 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
684 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
685 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
687 if (!lhsAttr || !rhsAttr)
690 return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
691 lhsAttr, rhsAttr, resultTy);
694 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
695 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
696 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
697 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
699 if (!lhsAttr || !rhsAttr)
703 ComparisonFold<std::greater_equal<APFloat>>>(
704 lhsAttr, rhsAttr, resultTy);
708 auto resultTy = llvm::dyn_cast<RankedTensorType>(
getType());
709 auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
710 auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
711 Value lhs = getInput1();
712 Value rhs = getInput2();
713 auto lhsTy = llvm::cast<ShapedType>(lhs.
getType());
717 if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
718 resultTy.hasStaticShape() && lhs == rhs) {
722 if (!lhsAttr || !rhsAttr)
725 return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
726 ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
734 auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
738 auto inTy = llvm::cast<ShapedType>(getInput().
getType());
739 auto outTy = llvm::cast<ShapedType>(
getType());
740 auto inETy = inTy.getElementType();
741 auto outETy = outTy.getElementType();
743 if (operand.isSplat()) {
744 if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
746 auto splatVal = operand.getSplatValue<APFloat>();
747 auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
748 splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
753 if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
754 auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
755 APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
756 splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
757 llvm::RoundingMode::NearestTiesToEven);
761 if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
762 auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
763 auto intVal = APSInt(
764 llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
765 auto floatVal = operand.getSplatValue<APFloat>();
767 floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
772 if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
773 auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
775 inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
776 auto intVal = operand.getSplatValue<APInt>();
777 auto bitwidth = outETy.getIntOrFloatBitWidth();
780 intVal = intVal.trunc(bitwidth);
781 }
else if (unsignIn) {
782 intVal = intVal.zext(bitwidth);
784 intVal = intVal.sext(bitwidth);
794 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) {
return getValueAttr(); }
796 #define REDUCE_FOLDER(OP) \
797 OpFoldResult OP::fold(FoldAdaptor adaptor) { \
798 ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
799 if (!inputTy.hasRank()) \
801 if (inputTy != getType()) \
803 if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
817 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
818 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
820 if (!inputTy || !outputTy)
826 if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
830 if (
auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
831 getInput1().getDefiningOp())) {
832 getInput1Mutable().assign(reshapeOp.getInput1());
837 if (!inputTy.getElementType().isIntOrIndexOrFloat())
841 if (
auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
843 if (!outputTy.hasStaticShape())
847 if (operand.isSplat())
851 if (!getInput1().hasOneUse())
854 return operand.reshape(
855 llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
864 auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
865 if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
881 if (scale[0] != scale[1] || scale[2] != scale[3]) {
886 if (offset[0] != 0 || offset[1] != 0) {
891 if (border[0] != 0 || border[1] != 0) {
895 auto input = getInput();
896 auto inputTy = llvm::cast<RankedTensorType>(input.getType());
897 auto resultTy = llvm::cast<RankedTensorType>(
getType());
898 if (inputTy != resultTy)
905 auto operand = getInput();
906 auto operandTy = llvm::cast<ShapedType>(operand.getType());
907 auto axis = getAxis();
908 auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
913 if (operandTy.hasRank() &&
914 (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
921 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
922 auto outputTy = llvm::dyn_cast<RankedTensorType>(
getType());
924 if (!inputTy || !outputTy)
927 if (inputTy == outputTy && inputTy.hasStaticShape())
930 if (!adaptor.getInput())
934 if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
935 !outputTy.getElementType().isIntOrIndexOrFloat())
938 auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
939 if (operand.isSplat() && outputTy.hasStaticShape()) {
943 if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
944 outputTy.getNumElements() == 1) {
946 auto value = operand.getValues<
Attribute>()[indices];
953 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
954 if (getOnTrue() == getOnFalse())
957 auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
961 if (!predicate.isSplat())
963 return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
968 bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) {
return v == 1; });
975 auto resultTy = llvm::cast<ShapedType>(
getType());
978 if (
auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
979 if (input.isSplat() && resultTy.hasStaticShape() &&
980 input.getType().getElementType() == resultTy.getElementType())
981 return input.reshape(resultTy);
990 if (getConstantPerms(perms).failed())
993 if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms))
1000 auto input = getInput1();
1002 if (
auto op = input.getDefiningOp<tosa::ExpOp>()) {
1003 return op.getInput1();
1010 auto input = getInput1();
1012 if (
auto op = input.getDefiningOp<tosa::LogOp>()) {
1013 return op.getInput1();
1019 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
1020 auto input = getInput1();
1022 if (
auto op = input.getDefiningOp<tosa::NegateOp>()) {
1023 return op.getInput1();
1030 auto input = getInput1();
1032 if (
auto op = input.getDefiningOp<tosa::AbsOp>()) {
1045 concatOperands.reserve(2 * getNumOperands());
1048 bool foundFoldableConcat =
false;
1049 for (
Value operand : getOperands()) {
1050 concatOperands.emplace_back(operand);
1052 auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1057 if (getAxis() != producer.getAxis())
1061 foundFoldableConcat =
true;
1062 concatOperands.pop_back();
1063 llvm::append_range(concatOperands, producer->getOperands());
1066 if (!foundFoldableConcat)
1069 getOperation()->setOperands(concatOperands);
1073 OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
1074 auto input = adaptor.getInput1();
1076 auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
1078 if (!inputAttr || !inputAttr.isSplat())
1081 auto shapeType = llvm::cast<ShapedType>(
getType());
1082 if (
auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
1083 auto floatVal = inputAttr.getSplatValue<APFloat>();
1085 ReciprocalOp::calcOneElement(floatVal));
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy)
#define REDUCE_FOLDER(OP)
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift)
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FloatAttr getFloatAttr(Type type, double value)
IntegerAttr getI64IntegerAttr(int64_t value)
FloatAttr getF32FloatAttr(float value)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::PadOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...