MLIR 22.0.0git
TosaDecomposeDepthwise.cpp
Go to the documentation of this file.
1//===- TosaDecomposeDepthwise.cpp -----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Decompose TOSA Depthwise operation to a series of TOSA Ops specifically
10// (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add
11//
12//===----------------------------------------------------------------------===//
13
18
19using namespace mlir;
20using namespace mlir::tosa;
21
22namespace {
23
24struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
25 explicit DepthwiseConv2DIsMul(MLIRContext *context)
26 : OpRewritePattern(context) {}
27
28 LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
29 PatternRewriter &rewriter) const override {
30 Value input = op.getInput();
31 Value weight = op.getWeight();
32 ShapedType inputType = cast<ShapedType>(input.getType());
33 ShapedType weightType = cast<ShapedType>(weight.getType());
34 ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
35
36 // Any dimensions other than batchSize cannot be dynamic for input/output
37 for (unsigned int i = 1; i < 4; ++i) {
38 if (inputType.isDynamicDim(i) || resultType.isDynamicDim(i))
39 return failure();
40 }
41
42 if (!weightType.hasStaticShape()) {
43 return failure();
44 }
45
46 if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
47 return failure();
48
49 // Only works for a 1x1 kernel.
50 ArrayRef<int64_t> weightShape = weightType.getShape();
51 if (weightShape[0] != 1 || weightShape[1] != 1) {
52 return failure();
53 }
54
55 Type inputETy = inputType.getElementType();
56 Type weightETy = weightType.getElementType();
57 if (!inputETy.isIntOrFloat() || !weightETy.isIntOrFloat())
58 return rewriter.notifyMatchFailure(op, "unsupported type");
59
60 // Get and verify zero points.
61 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
62 if (failed(maybeIZp))
63 return rewriter.notifyMatchFailure(
64 op, "input zero point cannot be statically determined");
65
66 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
67 if (failed(maybeWZp))
68 return rewriter.notifyMatchFailure(
69 op, "weight zero point cannot be statically determined");
70
71 int64_t iZp = *maybeIZp;
72 int64_t wZp = *maybeWZp;
73 if (op.verifyInputZeroPoint(iZp).failed())
74 return rewriter.notifyMatchFailure(
75 op, "input zero point must be zero for non-int8 integer types");
76 if (op.verifyWeightZeroPoint(wZp).failed())
77 return rewriter.notifyMatchFailure(
78 op, "weight zero point must be zero for non-int8 integer types");
79
80 // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
81 ArrayRef<int64_t> inputShape = inputType.getShape();
82 llvm::SmallVector<int64_t, 2> revisedInputShape{
83 inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
84 inputType = RankedTensorType::get(
85 revisedInputShape,
86 dyn_cast<RankedTensorType>(input.getType()).getElementType());
87 auto revisedInputShapeValue =
88 getTosaConstShape(rewriter, op.getLoc(), revisedInputShape);
89 input = tosa::ReshapeOp::create(rewriter, op.getLoc(), inputType, input,
90 revisedInputShapeValue)
91 .getResult();
92
93 Type resultETy = resultType.getElementType();
94
95 if (inputETy != resultETy) {
96 inputType = inputType.clone(resultETy);
97 input = tosa::CastOp::create(rewriter, op.getLoc(), inputType, input);
98 }
99
100 if (weightETy != resultETy) {
101 weightType = weightType.clone(resultETy);
102 weight = tosa::CastOp::create(rewriter, op.getLoc(), weightType, weight);
103 }
104
105 if (iZp != 0 || wZp != 0) {
106
107 auto applyZp = [&](Value val, int64_t zp) -> Value {
108 if (zp == 0)
109 return val;
110 auto ety = cast<ShapedType>(val.getType()).getElementType();
111 std::vector<int64_t> shape(cast<ShapedType>(val.getType()).getRank(),
112 1);
113 auto zpTy = RankedTensorType::get(shape, ety);
114 auto zpAttr =
115 DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp));
116 auto zpVal = tosa::ConstOp::create(rewriter, op.getLoc(), zpTy, zpAttr);
117 return tosa::SubOp::create(rewriter, op.getLoc(), val.getType(), val,
118 zpVal);
119 };
120
121 input = applyZp(input, iZp);
122 weight = applyZp(weight, wZp);
123 }
124
125 ArrayRef<int64_t> padAttr = op.getPad();
126 llvm::SmallVector<int64_t> pad(10, 0);
127 for (const auto &it : llvm::enumerate(padAttr))
128 pad[it.index() + 2] = it.value();
129
130 if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
131 Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
132
133 llvm::SmallVector<int64_t> newShape(inputType.getShape());
134 for (int i = 0, s = pad.size(); i < s; ++i) {
135 if (newShape[i / 2] != ShapedType::kDynamic) {
136 newShape[i / 2] += pad[i];
137 }
138 }
139
140 Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
141
142 auto padTy = RankedTensorType::get({1}, inputETy);
143 auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
144 Value padVal =
145 tosa::ConstOp::create(rewriter, op->getLoc(), padTy, padAttr);
146 inputType = RankedTensorType::get(newShape, inputETy);
147 input = tosa::PadOp::create(rewriter, op->getLoc(), inputType, input,
148 padSizeVal, padVal);
149 }
150
151 // Perform an elementwise mul over the reshaped input and weight.
152 llvm::SmallVector<int64_t, 2> mulShape{
153 inputType.getDimSize(0), inputType.getDimSize(1),
154 inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]};
155 auto mulShapeType = RankedTensorType::get(
156 mulShape,
157 dyn_cast<RankedTensorType>(weight.getType()).getElementType());
158
159 if (EqualizeRanks(rewriter, op.getLoc(), input, weight).failed()) {
160 return failure();
161 }
162
163 auto shiftElementType = IntegerType::get(rewriter.getContext(), 8);
164 auto shiftType = RankedTensorType::get({1}, shiftElementType);
165 auto shiftZeroAttr = DenseElementsAttr::get(
166 shiftType, rewriter.getIntegerAttr(shiftElementType, 0));
167 Value constZero =
168 tosa::ConstOp::create(rewriter, op.getLoc(), shiftType, shiftZeroAttr);
169 Value mulValue = tosa::MulOp::create(rewriter, op.getLoc(), mulShapeType,
170 input, weight, constZero)
171 .getResult();
172
173 // Reshape output to [N, H, W, C * M].
174 auto outputShape = cast<ShapedType>(op.getOutput().getType()).getShape();
175 auto outputShapeType = RankedTensorType::get(
176 outputShape,
177 dyn_cast<RankedTensorType>(input.getType()).getElementType());
178 auto outputShapeValue =
179 getTosaConstShape(rewriter, op->getLoc(), outputShape);
180 Value outputValue = tosa::ReshapeOp::create(
181 rewriter, op.getLoc(), outputShapeType, mulValue, outputShapeValue);
182
183 Value bias = op.getBias();
184 if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
185 return failure();
186 }
187
188 // Add in the bias.
189 rewriter
190 .replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue, bias)
191 .getResult();
192 return success();
193 }
194};
195
196} // namespace
197
200 patterns.add<DepthwiseConv2DIsMul>(ctx);
201}
return success()
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
MLIRContext * getContext() const
Definition Builders.h:56
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
Type getType() const
Return the type of this value.
Definition Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns)
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...