MLIR  19.0.0git
TosaDecomposeDepthwise.cpp
Go to the documentation of this file.
1 //===- TosaDecomposeDepthwise.cpp -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA Depthwise operation to a series of TOSA Ops specifically
10 // (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/Pass/Pass.h"
18 
19 using namespace mlir;
20 using namespace mlir::tosa;
21 
22 namespace {
23 
24 struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
25  explicit DepthwiseConv2DIsMul(MLIRContext *context)
26  : OpRewritePattern(context) {}
27 
28  LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
29  PatternRewriter &rewriter) const override {
30  Value input = op.getInput();
31  Value weight = op.getWeight();
32  ShapedType inputType = cast<ShapedType>(input.getType());
33  ShapedType weightType = cast<ShapedType>(weight.getType());
34  ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
35 
36  if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
37  resultType.hasStaticShape())) {
38  return failure();
39  }
40 
41  if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
42  return failure();
43 
44  // Only works for a 1x1 kernel.
45  ArrayRef<int64_t> weightShape = weightType.getShape();
46  if (weightShape[0] != 1 || weightShape[1] != 1) {
47  return failure();
48  }
49 
50  // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
51  ArrayRef<int64_t> inputShape = inputType.getShape();
52  llvm::SmallVector<int64_t, 2> revisedInputShape{
53  inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
54  inputType = RankedTensorType::get(
55  revisedInputShape,
56  dyn_cast<RankedTensorType>(input.getType()).getElementType());
57  input = rewriter
58  .create<tosa::ReshapeOp>(
59  op.getLoc(), inputType, input,
60  rewriter.getDenseI64ArrayAttr(revisedInputShape))
61  .getResult();
62 
63  if (inputType.getElementType() != resultType.getElementType()) {
64  inputType = inputType.clone(resultType.getElementType());
65  input = rewriter.create<tosa::CastOp>(op.getLoc(), inputType, input);
66  }
67 
68  if (weightType.getElementType() != resultType.getElementType()) {
69  weightType = weightType.clone(resultType.getElementType());
70  weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight);
71  }
72 
73  if (auto quantizationInfo = op.getQuantizationInfo()) {
74  auto iZp = quantizationInfo->getInputZp();
75  auto wZp = quantizationInfo->getWeightZp();
76 
77  auto applyZp = [&](Value val, int64_t zp) -> Value {
78  if (zp == 0)
79  return val;
80  auto ety = cast<ShapedType>(val.getType()).getElementType();
81  std::vector<int64_t> shape(cast<ShapedType>(val.getType()).getRank(),
82  1);
83  auto zpTy = RankedTensorType::get(shape, ety);
84  auto zpAttr =
85  DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp));
86  auto zpVal = rewriter.create<tosa::ConstOp>(op.getLoc(), zpTy, zpAttr);
87  return rewriter.create<tosa::SubOp>(op.getLoc(), val.getType(), val,
88  zpVal);
89  };
90 
91  input = applyZp(input, iZp);
92  weight = applyZp(weight, wZp);
93  }
94 
95  ArrayRef<int64_t> padAttr = op.getPad();
96  llvm::SmallVector<int64_t> pad(10, 0);
97  for (const auto &it : llvm::enumerate(padAttr))
98  pad[it.index() + 2] = it.value();
99 
100  if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
101  Type inputETy = inputType.getElementType();
102  Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
103 
104  llvm::SmallVector<int64_t> newShape(inputType.getShape());
105  for (int i = 0, s = pad.size(); i < s; ++i) {
106  if (newShape[i / 2] != ShapedType::kDynamic) {
107  newShape[i / 2] += pad[i];
108  }
109  }
110 
111  auto padSizeTy = RankedTensorType::get({5, 2}, rewriter.getI64Type());
112  auto padSize =
114  Value padSizeVal =
115  rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
116 
117  auto padTy = RankedTensorType::get({}, inputETy);
118  auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
119  Value padVal =
120  rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
121  inputType = RankedTensorType::get(newShape, inputETy);
122  input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input,
123  padSizeVal, padVal);
124  }
125 
126  // Perform an elementwise mul over the reshaped input and weight.
128  inputType.getDimSize(0), inputType.getDimSize(1),
129  inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]};
130  auto mulShapeType = RankedTensorType::get(
131  mulShape,
132  dyn_cast<RankedTensorType>(weight.getType()).getElementType());
133 
134  if (EqualizeRanks(rewriter, op.getLoc(), input, weight).failed()) {
135  return failure();
136  }
137 
138  Value mulValue = rewriter
139  .create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
140  weight, /*shift=*/0)
141  .getResult();
142 
143  // Reshape output to [N, H, W, C * M].
144  auto outputShape = cast<ShapedType>(op.getOutput().getType()).getShape();
145  auto outputShapeType = RankedTensorType::get(
146  outputShape,
147  dyn_cast<RankedTensorType>(input.getType()).getElementType());
148  Value outputValue = rewriter.create<tosa::ReshapeOp>(
149  op.getLoc(), outputShapeType, mulValue,
150  rewriter.getDenseI64ArrayAttr(outputShape));
151 
152  Value bias = op.getBias();
153  if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
154  return failure();
155  }
156 
157  // Add in the bias.
158  rewriter
159  .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, bias)
160  .getResult();
161  return success();
162  }
163 };
164 
165 } // namespace
166 
168  RewritePatternSet &patterns) {
169  patterns.add<DepthwiseConv2DIsMul>(ctx);
170 }
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
IntegerType getI64Type()
Definition: Builders.cpp:85
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns)
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool failed() const
Returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:44
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358