MLIR  22.0.0git
TosaDecomposeDepthwise.cpp
Go to the documentation of this file.
1 //===- TosaDecomposeDepthwise.cpp -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA Depthwise operation to a series of TOSA Ops specifically
10 // (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/IR/BuiltinTypes.h"
18 
19 using namespace mlir;
20 using namespace mlir::tosa;
21 
22 namespace {
23 
24 struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
25  explicit DepthwiseConv2DIsMul(MLIRContext *context)
26  : OpRewritePattern(context) {}
27 
28  LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
29  PatternRewriter &rewriter) const override {
30  Value input = op.getInput();
31  Value weight = op.getWeight();
32  ShapedType inputType = cast<ShapedType>(input.getType());
33  ShapedType weightType = cast<ShapedType>(weight.getType());
34  ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
35 
36  if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
37  resultType.hasStaticShape())) {
38  return failure();
39  }
40 
41  if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
42  return failure();
43 
44  // Only works for a 1x1 kernel.
45  ArrayRef<int64_t> weightShape = weightType.getShape();
46  if (weightShape[0] != 1 || weightShape[1] != 1) {
47  return failure();
48  }
49 
50  Type inputETy = inputType.getElementType();
51  Type weightETy = weightType.getElementType();
52  if (!inputETy.isIntOrFloat() || !weightETy.isIntOrFloat())
53  return rewriter.notifyMatchFailure(op, "unsupported type");
54 
55  // Get and verify zero points.
56  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
57  if (failed(maybeIZp))
58  return rewriter.notifyMatchFailure(
59  op, "input zero point cannot be statically determined");
60 
61  FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
62  if (failed(maybeWZp))
63  return rewriter.notifyMatchFailure(
64  op, "weight zero point cannot be statically determined");
65 
66  int64_t iZp = *maybeIZp;
67  int64_t wZp = *maybeWZp;
68  if (op.verifyInputZeroPoint(iZp).failed())
69  return rewriter.notifyMatchFailure(
70  op, "input zero point must be zero for non-int8 integer types");
71  if (op.verifyWeightZeroPoint(wZp).failed())
72  return rewriter.notifyMatchFailure(
73  op, "weight zero point must be zero for non-int8 integer types");
74 
75  // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
76  ArrayRef<int64_t> inputShape = inputType.getShape();
77  llvm::SmallVector<int64_t, 2> revisedInputShape{
78  inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
79  inputType = RankedTensorType::get(
80  revisedInputShape,
81  dyn_cast<RankedTensorType>(input.getType()).getElementType());
82  auto revisedInputShapeValue =
83  getTosaConstShape(rewriter, op.getLoc(), revisedInputShape);
84  input = tosa::ReshapeOp::create(rewriter, op.getLoc(), inputType, input,
85  revisedInputShapeValue)
86  .getResult();
87 
88  Type resultETy = resultType.getElementType();
89 
90  if (inputETy != resultETy) {
91  inputType = inputType.clone(resultETy);
92  input = tosa::CastOp::create(rewriter, op.getLoc(), inputType, input);
93  }
94 
95  if (weightETy != resultETy) {
96  weightType = weightType.clone(resultETy);
97  weight = tosa::CastOp::create(rewriter, op.getLoc(), weightType, weight);
98  }
99 
100  if (iZp != 0 || wZp != 0) {
101 
102  auto applyZp = [&](Value val, int64_t zp) -> Value {
103  if (zp == 0)
104  return val;
105  auto ety = cast<ShapedType>(val.getType()).getElementType();
106  std::vector<int64_t> shape(cast<ShapedType>(val.getType()).getRank(),
107  1);
108  auto zpTy = RankedTensorType::get(shape, ety);
109  auto zpAttr =
110  DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp));
111  auto zpVal = tosa::ConstOp::create(rewriter, op.getLoc(), zpTy, zpAttr);
112  return tosa::SubOp::create(rewriter, op.getLoc(), val.getType(), val,
113  zpVal);
114  };
115 
116  input = applyZp(input, iZp);
117  weight = applyZp(weight, wZp);
118  }
119 
120  ArrayRef<int64_t> padAttr = op.getPad();
121  llvm::SmallVector<int64_t> pad(10, 0);
122  for (const auto &it : llvm::enumerate(padAttr))
123  pad[it.index() + 2] = it.value();
124 
125  if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
126  Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
127 
128  llvm::SmallVector<int64_t> newShape(inputType.getShape());
129  for (int i = 0, s = pad.size(); i < s; ++i) {
130  if (newShape[i / 2] != ShapedType::kDynamic) {
131  newShape[i / 2] += pad[i];
132  }
133  }
134 
135  Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
136 
137  auto padTy = RankedTensorType::get({1}, inputETy);
138  auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
139  Value padVal =
140  tosa::ConstOp::create(rewriter, op->getLoc(), padTy, padAttr);
141  inputType = RankedTensorType::get(newShape, inputETy);
142  input = tosa::PadOp::create(rewriter, op->getLoc(), inputType, input,
143  padSizeVal, padVal);
144  }
145 
146  // Perform an elementwise mul over the reshaped input and weight.
148  inputType.getDimSize(0), inputType.getDimSize(1),
149  inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]};
150  auto mulShapeType = RankedTensorType::get(
151  mulShape,
152  dyn_cast<RankedTensorType>(weight.getType()).getElementType());
153 
154  if (EqualizeRanks(rewriter, op.getLoc(), input, weight).failed()) {
155  return failure();
156  }
157 
158  auto shiftElementType = IntegerType::get(rewriter.getContext(), 8);
159  auto shiftType = RankedTensorType::get({1}, shiftElementType);
160  auto shiftZeroAttr = DenseElementsAttr::get(
161  shiftType, rewriter.getIntegerAttr(shiftElementType, 0));
162  Value constZero =
163  tosa::ConstOp::create(rewriter, op.getLoc(), shiftType, shiftZeroAttr);
164  Value mulValue = tosa::MulOp::create(rewriter, op.getLoc(), mulShapeType,
165  input, weight, constZero)
166  .getResult();
167 
168  // Reshape output to [N, H, W, C * M].
169  auto outputShape = cast<ShapedType>(op.getOutput().getType()).getShape();
170  auto outputShapeType = RankedTensorType::get(
171  outputShape,
172  dyn_cast<RankedTensorType>(input.getType()).getElementType());
173  auto outputShapeValue =
174  getTosaConstShape(rewriter, op->getLoc(), outputShape);
175  Value outputValue = tosa::ReshapeOp::create(
176  rewriter, op.getLoc(), outputShapeType, mulValue, outputShapeValue);
177 
178  Value bias = op.getBias();
179  if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
180  return failure();
181  }
182 
183  // Add in the bias.
184  rewriter
185  .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, bias)
186  .getResult();
187  return success();
188  }
189 };
190 
191 } // namespace
192 
195  patterns.add<DepthwiseConv2DIsMul>(ctx);
196 }
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:323
MLIRContext * getContext() const
Definition: Builders.h:56
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns)
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314