24struct DepthwiseConv2DIsMul :
public OpRewritePattern<tosa::DepthwiseConv2DOp> {
25 explicit DepthwiseConv2DIsMul(MLIRContext *context)
26 : OpRewritePattern(context) {}
28 LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
29 PatternRewriter &rewriter)
const override {
30 Value input = op.getInput();
31 Value weight = op.getWeight();
32 ShapedType inputType = cast<ShapedType>(input.
getType());
33 ShapedType weightType = cast<ShapedType>(weight.
getType());
34 ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
37 for (
unsigned int i = 1; i < 4; ++i) {
38 if (inputType.isDynamicDim(i) || resultType.isDynamicDim(i))
42 if (!weightType.hasStaticShape()) {
46 if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
50 ArrayRef<int64_t> weightShape = weightType.getShape();
51 if (weightShape[0] != 1 || weightShape[1] != 1) {
55 Type inputETy = inputType.getElementType();
56 Type weightETy = weightType.getElementType();
61 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
64 op,
"input zero point cannot be statically determined");
66 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
69 op,
"weight zero point cannot be statically determined");
71 int64_t iZp = *maybeIZp;
72 int64_t wZp = *maybeWZp;
73 if (op.verifyInputZeroPoint(iZp).failed())
75 op,
"input zero point must be zero for non-int8 integer types");
76 if (op.verifyWeightZeroPoint(wZp).failed())
78 op,
"weight zero point must be zero for non-int8 integer types");
81 ArrayRef<int64_t> inputShape = inputType.getShape();
82 llvm::SmallVector<int64_t, 2> revisedInputShape{
83 inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
84 inputType = RankedTensorType::get(
86 dyn_cast<RankedTensorType>(input.
getType()).getElementType());
87 auto revisedInputShapeValue =
89 input = tosa::ReshapeOp::create(rewriter, op.getLoc(), inputType, input,
90 revisedInputShapeValue)
93 Type resultETy = resultType.getElementType();
95 if (inputETy != resultETy) {
96 inputType = inputType.clone(resultETy);
97 input = tosa::CastOp::create(rewriter, op.getLoc(), inputType, input);
100 if (weightETy != resultETy) {
101 weightType = weightType.clone(resultETy);
102 weight = tosa::CastOp::create(rewriter, op.getLoc(), weightType, weight);
105 if (iZp != 0 || wZp != 0) {
107 auto applyZp = [&](Value val, int64_t zp) -> Value {
110 auto ety = cast<ShapedType>(val.
getType()).getElementType();
111 std::vector<int64_t> shape(cast<ShapedType>(val.
getType()).getRank(),
113 auto zpTy = RankedTensorType::get(shape, ety);
116 auto zpVal = tosa::ConstOp::create(rewriter, op.getLoc(), zpTy, zpAttr);
117 return tosa::SubOp::create(rewriter, op.getLoc(), val.
getType(), val,
121 input = applyZp(input, iZp);
122 weight = applyZp(weight, wZp);
125 ArrayRef<int64_t> padAttr = op.getPad();
126 llvm::SmallVector<int64_t> pad(10, 0);
127 for (
const auto &it : llvm::enumerate(padAttr))
128 pad[it.index() + 2] = it.value();
130 if (llvm::any_of(pad, [](int64_t p) {
return p != 0; })) {
131 Attribute zeroAttr = rewriter.
getZeroAttr(inputETy);
133 llvm::SmallVector<int64_t> newShape(inputType.getShape());
134 for (
int i = 0, s = pad.size(); i < s; ++i) {
135 if (newShape[i / 2] != ShapedType::kDynamic) {
136 newShape[i / 2] += pad[i];
142 auto padTy = RankedTensorType::get({1}, inputETy);
145 tosa::ConstOp::create(rewriter, op->getLoc(), padTy, padAttr);
146 inputType = RankedTensorType::get(newShape, inputETy);
147 input = tosa::PadOp::create(rewriter, op->getLoc(), inputType, input,
152 llvm::SmallVector<int64_t, 2> mulShape{
153 inputType.getDimSize(0), inputType.getDimSize(1),
154 inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]};
155 auto mulShapeType = RankedTensorType::get(
157 dyn_cast<RankedTensorType>(weight.
getType()).getElementType());
159 if (
EqualizeRanks(rewriter, op.getLoc(), input, weight).failed()) {
163 auto shiftElementType = IntegerType::get(rewriter.
getContext(), 8);
164 auto shiftType = RankedTensorType::get({1}, shiftElementType);
168 tosa::ConstOp::create(rewriter, op.getLoc(), shiftType, shiftZeroAttr);
169 Value mulValue = tosa::MulOp::create(rewriter, op.getLoc(), mulShapeType,
170 input, weight, constZero)
174 auto outputShape = cast<ShapedType>(op.getOutput().getType()).getShape();
175 auto outputShapeType = RankedTensorType::get(
177 dyn_cast<RankedTensorType>(input.
getType()).getElementType());
178 auto outputShapeValue =
180 Value outputValue = tosa::ReshapeOp::create(
181 rewriter, op.getLoc(), outputShapeType, mulValue, outputShapeValue);
183 Value bias = op.getBias();
184 if (
EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
200 patterns.add<DepthwiseConv2DIsMul>(ctx);
IntegerAttr getIntegerAttr(Type type, int64_t value)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Type getType() const
Return the type of this value.
void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns)
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...