24 struct DepthwiseConv2DIsMul :
public OpRewritePattern<tosa::DepthwiseConv2DOp> {
28 LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
30 Value input = op.getInput();
31 Value weight = op.getWeight();
32 ShapedType inputType = cast<ShapedType>(input.
getType());
33 ShapedType weightType = cast<ShapedType>(weight.
getType());
34 ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
36 if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
37 resultType.hasStaticShape())) {
41 if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
46 if (weightShape[0] != 1 || weightShape[1] != 1) {
53 inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
56 dyn_cast<RankedTensorType>(input.
getType()).getElementType());
59 op.getLoc(), inputType, input,
63 if (inputType.getElementType() != resultType.getElementType()) {
64 inputType = inputType.clone(resultType.getElementType());
65 input = rewriter.
create<tosa::CastOp>(op.getLoc(), inputType, input);
68 if (weightType.getElementType() != resultType.getElementType()) {
69 weightType = weightType.
clone(resultType.getElementType());
70 weight = rewriter.
create<tosa::CastOp>(op.getLoc(), weightType, weight);
73 if (
auto quantizationInfo = op.getQuantizationInfo()) {
74 auto iZp = quantizationInfo->getInputZp();
75 auto wZp = quantizationInfo->getWeightZp();
77 auto applyZp = [&](
Value val, int64_t zp) ->
Value {
80 auto ety = cast<ShapedType>(val.
getType()).getElementType();
81 std::vector<int64_t> shape(cast<ShapedType>(val.
getType()).getRank(),
86 auto zpVal = rewriter.
create<tosa::ConstOp>(op.getLoc(), zpTy, zpAttr);
87 return rewriter.
create<tosa::SubOp>(op.getLoc(), val.
getType(), val,
91 input = applyZp(input, iZp);
92 weight = applyZp(weight, wZp);
98 pad[it.index() + 2] = it.value();
100 if (llvm::any_of(pad, [](int64_t p) {
return p != 0; })) {
101 Type inputETy = inputType.getElementType();
105 for (
int i = 0, s = pad.size(); i < s; ++i) {
106 if (newShape[i / 2] != ShapedType::kDynamic) {
107 newShape[i / 2] += pad[i];
115 rewriter.
create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
120 rewriter.
create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
122 input = rewriter.
create<tosa::PadOp>(op->getLoc(), inputType, input,
128 inputType.getDimSize(0), inputType.getDimSize(1),
129 inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]};
132 dyn_cast<RankedTensorType>(weight.
getType()).getElementType());
134 if (
EqualizeRanks(rewriter, op.getLoc(), input, weight).failed()) {
138 Value mulValue = rewriter
139 .
create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
144 auto outputShape = cast<ShapedType>(op.getOutput().getType()).getShape();
147 dyn_cast<RankedTensorType>(input.getType()).getElementType());
148 Value outputValue = rewriter.
create<tosa::ReshapeOp>(
149 op.getLoc(), outputShapeType, mulValue,
152 Value bias = op.getBias();
153 if (
EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
169 patterns.
add<DepthwiseConv2DIsMul>(ctx);
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns)
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...