25 struct DepthwiseConv2DIsMul :
public OpRewritePattern<tosa::DepthwiseConv2DOp> {
29 LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
31 Value input = op.getInput();
32 Value weight = op.getWeight();
33 ShapedType inputType = cast<ShapedType>(input.
getType());
34 ShapedType weightType = cast<ShapedType>(weight.
getType());
35 ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
37 if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
38 resultType.hasStaticShape())) {
42 if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
47 if (weightShape[0] != 1 || weightShape[1] != 1) {
51 Type inputETy = inputType.getElementType();
52 Type weightETy = weightType.getElementType();
57 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
60 op,
"input zero point cannot be statically determined");
62 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
65 op,
"weight zero point cannot be statically determined");
67 int64_t iZp = *maybeIZp;
68 int64_t wZp = *maybeWZp;
69 if (op.verifyInputZeroPoint(iZp).failed())
71 op,
"input zero point must be zero for non-int8 integer types");
72 if (op.verifyWeightZeroPoint(wZp).failed())
74 op,
"weight zero point must be zero for non-int8 integer types");
79 inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
82 dyn_cast<RankedTensorType>(input.
getType()).getElementType());
83 auto revisedInputShapeValue =
86 .
create<tosa::ReshapeOp>(op.getLoc(), inputType, input,
87 revisedInputShapeValue)
90 Type resultETy = resultType.getElementType();
92 if (inputETy != resultETy) {
93 inputType = inputType.clone(resultETy);
94 input = rewriter.
create<tosa::CastOp>(op.getLoc(), inputType, input);
97 if (weightETy != resultETy) {
98 weightType = weightType.
clone(resultETy);
99 weight = rewriter.
create<tosa::CastOp>(op.getLoc(), weightType, weight);
102 if (iZp != 0 || wZp != 0) {
104 auto applyZp = [&](
Value val, int64_t zp) ->
Value {
107 auto ety = cast<ShapedType>(val.
getType()).getElementType();
108 std::vector<int64_t> shape(cast<ShapedType>(val.
getType()).getRank(),
113 auto zpVal = rewriter.
create<tosa::ConstOp>(op.getLoc(), zpTy, zpAttr);
114 return rewriter.
create<tosa::SubOp>(op.getLoc(), val.
getType(), val,
118 input = applyZp(input, iZp);
119 weight = applyZp(weight, wZp);
125 pad[it.index() + 2] = it.value();
127 if (llvm::any_of(pad, [](int64_t p) {
return p != 0; })) {
131 for (
int i = 0, s = pad.size(); i < s; ++i) {
132 if (newShape[i / 2] != ShapedType::kDynamic) {
133 newShape[i / 2] += pad[i];
142 rewriter.
create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
144 input = rewriter.
create<tosa::PadOp>(op->getLoc(), inputType, input,
150 inputType.getDimSize(0), inputType.getDimSize(1),
151 inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]};
154 dyn_cast<RankedTensorType>(weight.
getType()).getElementType());
156 if (
EqualizeRanks(rewriter, op.getLoc(), input, weight).failed()) {
165 rewriter.
create<tosa::ConstOp>(op.getLoc(), shiftType, shiftZeroAttr);
166 Value mulValue = rewriter
167 .
create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
172 auto outputShape = cast<ShapedType>(op.getOutput().getType()).getShape();
175 dyn_cast<RankedTensorType>(input.getType()).getElementType());
176 auto outputShapeValue =
178 Value outputValue = rewriter.
create<tosa::ReshapeOp>(
179 op.getLoc(), outputShapeType, mulValue, outputShapeValue);
181 Value bias = op.getBias();
182 if (
EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
198 patterns.add<DepthwiseConv2DIsMul>(ctx);
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns)
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...