MLIR 22.0.0git
TosaDecomposeTransposeConv.cpp
Go to the documentation of this file.
1//===- TosaDecomposeTransposeConv.cpp -------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
10// (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
11// etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
12// including transposing/reversing/reshaping etc..
13// of the weights and input/output tenors and reversing/reshaping etc .. of
14// the weights
15//
16//===----------------------------------------------------------------------===//
17
21
22using namespace mlir;
23using namespace mlir::tosa;
24
25namespace {
26
27class TransposeConvNonStridedConverter
28 : public OpRewritePattern<tosa::TransposeConv2DOp> {
29public:
30 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
31 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
32 PatternRewriter &rewriter) const final {
33 Location loc = op->getLoc();
34 Value input = op->getOperand(0);
35 Value weight = op->getOperand(1);
36 Value bias = op->getOperand(2);
37
38 ShapedType inputTy = cast<ShapedType>(input.getType());
39 ShapedType weightTy = cast<ShapedType>(weight.getType());
40 ShapedType biasTy = cast<ShapedType>(bias.getType());
41 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
42
43 llvm::ArrayRef<int64_t> stride = op.getStride();
44 llvm::ArrayRef<int64_t> pad = op.getOutPad();
45
46 // If striding is all 1 we can modify padding and reverse the kernel along
47 // the x/y direction to make it a regular convolution. This is much simpler
48 // then handling striding....
49 if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
50 return failure();
51
52 // Any dimensions other than batchSize cannot be dynamic for input/output
53 for (unsigned int i = 1; i < 4; ++i) {
54 if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
55 return failure();
56 }
57
58 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
59 return failure();
60
61 int64_t kernelHeight = weightTy.getDimSize(1);
62 int64_t kernelWidth = weightTy.getDimSize(2);
63
64 llvm::SmallVector<int64_t> convPad(4, 0);
65 convPad[0] = kernelHeight - 1 + pad[0];
66 convPad[1] = kernelHeight - 1 + pad[1];
67 convPad[2] = kernelWidth - 1 + pad[2];
68 convPad[3] = kernelWidth - 1 + pad[3];
69
70 auto reverse1 =
71 tosa::ReverseOp::create(rewriter, loc, weightTy, weight,
72 /* axis = */ rewriter.getI32IntegerAttr(1));
73 auto reverse2 =
74 tosa::ReverseOp::create(rewriter, loc, weightTy, reverse1,
75 /* axis = */ rewriter.getI32IntegerAttr(2));
76
77 Value conv2d = tosa::Conv2DOp::create(
78 rewriter, loc, resultTy, input, reverse2, bias, op.getInputZp(),
79 op.getWeightZp(), rewriter.getDenseI64ArrayAttr(convPad),
80 rewriter.getDenseI64ArrayAttr(stride),
81 rewriter.getDenseI64ArrayAttr({1, 1}),
82 /* acc_type = */ op.getAccType());
83
84 rewriter.replaceOp(op, conv2d);
85 return success();
86 }
87};
88
89class TransposeConvStridedConverter
90 : public OpRewritePattern<tosa::TransposeConv2DOp> {
91public:
92 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
93 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
94 PatternRewriter &rewriter) const final {
95 Location loc = op->getLoc();
96 Value input = op->getOperand(0);
97 Value weight = op->getOperand(1);
98 Value bias = op->getOperand(2);
99
100 ShapedType inputTy = cast<ShapedType>(input.getType());
101 ShapedType weightTy = cast<ShapedType>(weight.getType());
102 ShapedType biasTy = cast<ShapedType>(bias.getType());
103 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
104
105 Type inputETy = inputTy.getElementType();
106 Type weightETy = weightTy.getElementType();
107 Type biasETy = biasTy.getElementType();
108 Type resultETy = resultTy.getElementType();
109
110 llvm::ArrayRef<int64_t> pad = op.getOutPad();
111 llvm::ArrayRef<int64_t> stride = op.getStride();
112
113 // If striding is all 1 we can modify padding and reverse the kernel along
114 // the x/y direction to make it a regular convolution. This is much simpler
115 // then handling striding....
116
117 // If strides are all 1 we dont need to use this one.
118 if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
119 return rewriter.notifyMatchFailure(op, "non-one stride found.");
120
121 // Any dimensions other than batchSize cannot be dynamic for input/output
122 for (unsigned int i = 1; i < 4; ++i) {
123 if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
124 return failure();
125 }
126
127 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
128 return failure();
129
130 int64_t batch = inputTy.getDimSize(0);
131
132 int64_t outputChannels = weightTy.getDimSize(0);
133 int64_t weightHeight = weightTy.getDimSize(1);
134 int64_t weightWidth = weightTy.getDimSize(2);
135 int64_t inputChannels = weightTy.getDimSize(3);
136
137 // Pad the weight so that it is modulo of the striding.
138 llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
139 weightPadding[3] =
140 (weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
141 weightPadding[5] =
142 (weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
143
144 Value weightPaddingVal =
145 getTosaConstShape(rewriter, op->getLoc(), weightPadding);
146
147 // Get and verify zero points.
148 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
149 if (failed(maybeIZp))
150 return rewriter.notifyMatchFailure(
151 op, "input zero point cannot be statically determined");
152
153 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
154 if (failed(maybeWZp))
155 return rewriter.notifyMatchFailure(
156 op, "weight zero point cannot be statically determined");
157
158 int64_t inputZpVal = *maybeIZp;
159 int64_t weightZpVal = *maybeWZp;
160
161 if (op.verifyInputZeroPoint(inputZpVal).failed())
162 return rewriter.notifyMatchFailure(
163 op, "input zero point must be zero for non-int8 integer types");
164
165 if (op.verifyWeightZeroPoint(weightZpVal).failed())
166 return rewriter.notifyMatchFailure(
167 op, "weight zero point must be zero for non-int8 integer types");
168
169 // construct pad_const values from zp values
170 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
171 const Value inputPadConst =
172 createPadConstTensor(builder, op->getLoc(), input, inputZpVal);
173 const Value weightPadConst =
174 createPadConstTensor(builder, op->getLoc(), input, weightZpVal);
175
177 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
178 weightPaddingVal, weightPadConst);
179
180 weightTy = cast<ShapedType>(weight.getType());
181 weightHeight = weightTy.getDimSize(1);
182 weightWidth = weightTy.getDimSize(2);
183
184 // Split out the width / height by the stride dimensions.
185 llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
186 outputChannels, weightHeight / stride[0],
187 stride[0], weightWidth / stride[1],
188 stride[1], inputChannels};
189
191 builder, UnrankedTensorType::get(weightETy), weight,
192 getTosaConstShape(rewriter, loc, weightReshapeDims0));
193
194 // Transpose the factored-out stride to the output channels.
196 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
197 rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));
198
199 // Collapse the strides and output channels into a single dimension.
200 llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
201 outputChannels * stride[0] * stride[1], weightHeight / stride[0],
202 weightWidth / stride[1], inputChannels};
203
205 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
206 getTosaConstShape(rewriter, loc, weightReshapeDims1));
207 ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
208
210 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
211 /* axis = */ rewriter.getI32IntegerAttr(1));
213 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
214 /* axis = */ rewriter.getI32IntegerAttr(2));
215
216 // We need to pad the input far enough that we can pull all values.
217 llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
218 inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
219 inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
220 inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
221 inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
222
223 Value inputPaddingVal =
224 getTosaConstShape(rewriter, op->getLoc(), inputPadding);
225
227 rewriter, loc, UnrankedTensorType::get(inputETy), input,
228 inputPaddingVal, inputPadConst);
229
230 // We use a zero bias as we need to broadcast the bias.
231 auto zeroBias = tosa::ConstOp::create(
232 rewriter, loc,
233 RankedTensorType::get({outputChannels * stride[0] * stride[1]},
234 biasETy),
236 RankedTensorType::get({outputChannels * stride[0] * stride[1]},
237 biasETy),
238 rewriter.getZeroAttr(biasETy)));
239
240 auto inputZp =
241 createZeroPointTensor(rewriter, loc, input.getType(), inputZpVal);
242 auto weightZp =
243 createZeroPointTensor(rewriter, loc, weight.getType(), weightZpVal);
244
245 if (!inputZp.has_value() || !weightZp.has_value()) {
246 return rewriter.notifyMatchFailure(
247 op, "fail to create a const zero point tensor");
248 }
249
250 // Perform the convolution using the zero bias.
252 rewriter, loc, UnrankedTensorType::get(resultETy), input,
253 weight, zeroBias, inputZp.value(), weightZp.value(),
254 /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
255 /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
256 /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
257 /* acc_type = */ op.getAccType())
258 .getResult();
259
260 // Factor the resulting width / height.
261 ShapedType convTy = cast<ShapedType>(conv2d.getType());
262 Type convETy = convTy.getElementType();
263
264 int64_t convHeight = convTy.getDimSize(1);
265 int64_t convWidth = convTy.getDimSize(2);
266
267 // Factor striding out of the convolution result.
268 llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
269 batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
270
271 auto convReshapeDims0Value =
272 getTosaConstShape(rewriter, loc, convReshapeDims0);
273
275 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
276 convReshapeDims0Value);
277
278 // Transpose the factored-out stride to the output channels.
280 rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
281 rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));
282
283 // Fuse striding behavior back into width / height.
284 llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
285 batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
286
287 auto convReshapeDims1Value =
288 getTosaConstShape(rewriter, loc, convReshapeDims1);
289
291 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
292 convReshapeDims1Value);
293
294 // Determine the amount to slice / pad from the result start.
295 int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
296 int64_t resultSliceLeft = std::max<int64_t>(0, -pad[2]);
297 int64_t resultPadTop = std::max<int64_t>(0, pad[0]);
298 int64_t resultPadLeft = std::max<int64_t>(0, pad[2]);
299
300 // Try to slice the targetted result size, cap to the convolutions width.
301 int64_t resultSliceHeight =
302 std::min<int64_t>(convReshapeDims1[1] - resultSliceTop,
303 resultTy.getDimSize(1) - resultPadTop);
304 int64_t resultSliceWidth =
305 std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft,
306 resultTy.getDimSize(2) - resultPadLeft);
307
308 llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop,
309 resultSliceLeft, 0};
310 llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(),
311 convReshapeDims1.end());
312 sliceSize[1] = resultSliceHeight;
313 sliceSize[2] = resultSliceWidth;
314
316 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
317 getTosaConstShape(rewriter, loc, sliceBegin),
318 getTosaConstShape(rewriter, loc, sliceSize))
319 .getResult();
320
321 llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
322 resultPadding[2] = resultPadTop;
323 resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
324 resultPadding[4] = resultPadLeft;
325 resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
326
327 Value resultPaddingVal =
328 getTosaConstShape(rewriter, op->getLoc(), resultPadding);
329
330 Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
331 rewriter, loc, UnrankedTensorType::get(resultETy), slice,
332 resultPaddingVal);
333
334 if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) {
335 return failure();
336 }
337
338 rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias);
339 return success();
340 }
341};
342
343} // namespace
344
347 patterns.add<TransposeConvNonStridedConverter>(ctx);
348 patterns.add<TransposeConvStridedConverter>(ctx);
349}
return success()
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Type getType() const
Return the type of this value.
Definition Value.h:105
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
TosaOp CreateOpAndInferShape(ImplicitLocOpBuilder &builder, Type resultTy, Args &&...args)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition TosaOps.cpp:4558
void populateTosaDecomposeTransposeConv(MLIRContext *ctx, RewritePatternSet &patterns)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition TosaOps.cpp:594
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...