MLIR  21.0.0git
TosaDecomposeDepthwise.cpp
Go to the documentation of this file.
1 //===- TosaDecomposeDepthwise.cpp -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA Depthwise operation to a series of TOSA Ops specifically
10 // (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/Pass/Pass.h"
19 
20 using namespace mlir;
21 using namespace mlir::tosa;
22 
23 namespace {
24 
25 struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
26  explicit DepthwiseConv2DIsMul(MLIRContext *context)
27  : OpRewritePattern(context) {}
28 
29  LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
30  PatternRewriter &rewriter) const override {
31  Value input = op.getInput();
32  Value weight = op.getWeight();
33  ShapedType inputType = cast<ShapedType>(input.getType());
34  ShapedType weightType = cast<ShapedType>(weight.getType());
35  ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
36 
37  if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
38  resultType.hasStaticShape())) {
39  return failure();
40  }
41 
42  if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
43  return failure();
44 
45  // Only works for a 1x1 kernel.
46  ArrayRef<int64_t> weightShape = weightType.getShape();
47  if (weightShape[0] != 1 || weightShape[1] != 1) {
48  return failure();
49  }
50 
51  Type inputETy = inputType.getElementType();
52  Type weightETy = weightType.getElementType();
53  if (!inputETy.isIntOrFloat() || !weightETy.isIntOrFloat())
54  return rewriter.notifyMatchFailure(op, "unsupported type");
55 
56  // Get and verify zero points.
57  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
58  if (failed(maybeIZp))
59  return rewriter.notifyMatchFailure(
60  op, "input zero point cannot be statically determined");
61 
62  FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
63  if (failed(maybeWZp))
64  return rewriter.notifyMatchFailure(
65  op, "weight zero point cannot be statically determined");
66 
67  int64_t iZp = *maybeIZp;
68  int64_t wZp = *maybeWZp;
69  if (op.verifyInputZeroPoint(iZp).failed())
70  return rewriter.notifyMatchFailure(
71  op, "input zero point must be zero for non-int8 integer types");
72  if (op.verifyWeightZeroPoint(wZp).failed())
73  return rewriter.notifyMatchFailure(
74  op, "weight zero point must be zero for non-int8 integer types");
75 
76  // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
77  ArrayRef<int64_t> inputShape = inputType.getShape();
78  llvm::SmallVector<int64_t, 2> revisedInputShape{
79  inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
80  inputType = RankedTensorType::get(
81  revisedInputShape,
82  dyn_cast<RankedTensorType>(input.getType()).getElementType());
83  auto revisedInputShapeValue =
84  getTosaConstShape(rewriter, op.getLoc(), revisedInputShape);
85  input = rewriter
86  .create<tosa::ReshapeOp>(op.getLoc(), inputType, input,
87  revisedInputShapeValue)
88  .getResult();
89 
90  Type resultETy = resultType.getElementType();
91 
92  if (inputETy != resultETy) {
93  inputType = inputType.clone(resultETy);
94  input = rewriter.create<tosa::CastOp>(op.getLoc(), inputType, input);
95  }
96 
97  if (weightETy != resultETy) {
98  weightType = weightType.clone(resultETy);
99  weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight);
100  }
101 
102  if (iZp != 0 || wZp != 0) {
103 
104  auto applyZp = [&](Value val, int64_t zp) -> Value {
105  if (zp == 0)
106  return val;
107  auto ety = cast<ShapedType>(val.getType()).getElementType();
108  std::vector<int64_t> shape(cast<ShapedType>(val.getType()).getRank(),
109  1);
110  auto zpTy = RankedTensorType::get(shape, ety);
111  auto zpAttr =
112  DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp));
113  auto zpVal = rewriter.create<tosa::ConstOp>(op.getLoc(), zpTy, zpAttr);
114  return rewriter.create<tosa::SubOp>(op.getLoc(), val.getType(), val,
115  zpVal);
116  };
117 
118  input = applyZp(input, iZp);
119  weight = applyZp(weight, wZp);
120  }
121 
122  ArrayRef<int64_t> padAttr = op.getPad();
123  llvm::SmallVector<int64_t> pad(10, 0);
124  for (const auto &it : llvm::enumerate(padAttr))
125  pad[it.index() + 2] = it.value();
126 
127  if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
128  Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
129 
130  llvm::SmallVector<int64_t> newShape(inputType.getShape());
131  for (int i = 0, s = pad.size(); i < s; ++i) {
132  if (newShape[i / 2] != ShapedType::kDynamic) {
133  newShape[i / 2] += pad[i];
134  }
135  }
136 
137  Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
138 
139  auto padTy = RankedTensorType::get({1}, inputETy);
140  auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
141  Value padVal =
142  rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
143  inputType = RankedTensorType::get(newShape, inputETy);
144  input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input,
145  padSizeVal, padVal);
146  }
147 
148  // Perform an elementwise mul over the reshaped input and weight.
150  inputType.getDimSize(0), inputType.getDimSize(1),
151  inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]};
152  auto mulShapeType = RankedTensorType::get(
153  mulShape,
154  dyn_cast<RankedTensorType>(weight.getType()).getElementType());
155 
156  if (EqualizeRanks(rewriter, op.getLoc(), input, weight).failed()) {
157  return failure();
158  }
159 
160  auto shiftElementType = IntegerType::get(rewriter.getContext(), 8);
161  auto shiftType = RankedTensorType::get({1}, shiftElementType);
162  auto shiftZeroAttr = DenseElementsAttr::get(
163  shiftType, rewriter.getIntegerAttr(shiftElementType, 0));
164  Value constZero =
165  rewriter.create<tosa::ConstOp>(op.getLoc(), shiftType, shiftZeroAttr);
166  Value mulValue = rewriter
167  .create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
168  weight, constZero)
169  .getResult();
170 
171  // Reshape output to [N, H, W, C * M].
172  auto outputShape = cast<ShapedType>(op.getOutput().getType()).getShape();
173  auto outputShapeType = RankedTensorType::get(
174  outputShape,
175  dyn_cast<RankedTensorType>(input.getType()).getElementType());
176  auto outputShapeValue =
177  getTosaConstShape(rewriter, op->getLoc(), outputShape);
178  Value outputValue = rewriter.create<tosa::ReshapeOp>(
179  op.getLoc(), outputShapeType, mulValue, outputShapeValue);
180 
181  Value bias = op.getBias();
182  if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
183  return failure();
184  }
185 
186  // Add in the bias.
187  rewriter
188  .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, bias)
189  .getResult();
190  return success();
191  }
192 };
193 
194 } // namespace
195 
198  patterns.add<DepthwiseConv2DIsMul>(ctx);
199 }
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
MLIRContext * getContext() const
Definition: Builders.h:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:719
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns)
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314