MLIR  21.0.0git
TosaDecomposeTransposeConv.cpp
Go to the documentation of this file.
1 //===- TosaDecomposeTransposeConv.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
10 // (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
11 // etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
12 // including transposing/reversing/reshaping etc..
13 // of the weights and input/output tenors and reversing/reshaping etc .. of
14 // the weights
15 //
16 //===----------------------------------------------------------------------===//
17 
22 #include "mlir/Pass/Pass.h"
23 
24 using namespace mlir;
25 using namespace mlir::tosa;
26 
27 namespace {
28 
29 class TransposeConvNonStridedConverter
30  : public OpRewritePattern<tosa::TransposeConv2DOp> {
31 public:
33  LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
34  PatternRewriter &rewriter) const final {
35  Location loc = op->getLoc();
36  Value input = op->getOperand(0);
37  Value weight = op->getOperand(1);
38  Value bias = op->getOperand(2);
39 
40  ShapedType inputTy = cast<ShapedType>(input.getType());
41  ShapedType weightTy = cast<ShapedType>(weight.getType());
42  ShapedType biasTy = cast<ShapedType>(bias.getType());
43  ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
44 
45  llvm::ArrayRef<int64_t> stride = op.getStride();
46  llvm::ArrayRef<int64_t> pad = op.getOutPad();
47 
48  // If striding is all 1 we can modify padding and reverse the kernel along
49  // the x/y direction to make it a regular convolution. This is much simpler
50  // then handling striding....
51  if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
52  return failure();
53 
54  if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
55  !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
56  return failure();
57 
58  int64_t kernelHeight = weightTy.getDimSize(1);
59  int64_t kernelWidth = weightTy.getDimSize(2);
60 
61  llvm::SmallVector<int64_t> convPad(4, 0);
62  convPad[0] = kernelHeight - 1 + pad[0];
63  convPad[1] = kernelHeight - 1 + pad[1];
64  convPad[2] = kernelWidth - 1 + pad[2];
65  convPad[3] = kernelWidth - 1 + pad[3];
66 
67  auto reverse1 = rewriter.create<tosa::ReverseOp>(
68  loc, weightTy, weight, /* axis = */ rewriter.getI32IntegerAttr(1));
69  auto reverse2 = rewriter.create<tosa::ReverseOp>(
70  loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2));
71 
72  Value conv2d = rewriter.create<tosa::Conv2DOp>(
73  loc, resultTy, input, reverse2, bias, op.getInputZp(), op.getWeightZp(),
74  rewriter.getDenseI64ArrayAttr(convPad),
75  rewriter.getDenseI64ArrayAttr(stride),
76  rewriter.getDenseI64ArrayAttr({1, 1}),
77  /* acc_type = */ op.getAccType());
78 
79  rewriter.replaceOp(op, conv2d);
80  return success();
81  }
82 };
83 
84 class TransposeConvStridedConverter
85  : public OpRewritePattern<tosa::TransposeConv2DOp> {
86 public:
88  LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
89  PatternRewriter &rewriter) const final {
90  Location loc = op->getLoc();
91  Value input = op->getOperand(0);
92  Value weight = op->getOperand(1);
93  Value bias = op->getOperand(2);
94 
95  ShapedType inputTy = cast<ShapedType>(input.getType());
96  ShapedType weightTy = cast<ShapedType>(weight.getType());
97  ShapedType biasTy = cast<ShapedType>(bias.getType());
98  ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
99 
100  Type inputETy = inputTy.getElementType();
101  Type weightETy = weightTy.getElementType();
102  Type biasETy = biasTy.getElementType();
103  Type resultETy = resultTy.getElementType();
104 
105  llvm::ArrayRef<int64_t> pad = op.getOutPad();
106  llvm::ArrayRef<int64_t> stride = op.getStride();
107 
108  // If striding is all 1 we can modify padding and reverse the kernel along
109  // the x/y direction to make it a regular convolution. This is much simpler
110  // then handling striding....
111 
112  // If strides are all 1 we dont need to use this one.
113  if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
114  return rewriter.notifyMatchFailure(op, "non-one stride found.");
115 
116  if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
117  !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
118  return failure();
119 
120  int64_t batch = inputTy.getDimSize(0);
121 
122  int64_t outputChannels = weightTy.getDimSize(0);
123  int64_t weightHeight = weightTy.getDimSize(1);
124  int64_t weightWidth = weightTy.getDimSize(2);
125  int64_t inputChannels = weightTy.getDimSize(3);
126 
127  // Pad the weight so that it is modulo of the striding.
128  llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
129  weightPadding[3] =
130  (weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
131  weightPadding[5] =
132  weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
133 
134  Value weightPaddingVal =
135  getTosaConstShape(rewriter, op->getLoc(), weightPadding);
136 
137  // Get and verify zero points.
138  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
139  if (failed(maybeIZp))
140  return rewriter.notifyMatchFailure(
141  op, "input zero point cannot be statically determined");
142 
143  FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
144  if (failed(maybeWZp))
145  return rewriter.notifyMatchFailure(
146  op, "weight zero point cannot be statically determined");
147 
148  int64_t inputZpVal = *maybeIZp;
149  int64_t weightZpVal = *maybeWZp;
150 
151  if (op.verifyInputZeroPoint(inputZpVal).failed())
152  return rewriter.notifyMatchFailure(
153  op, "input zero point must be zero for non-int8 integer types");
154 
155  if (op.verifyWeightZeroPoint(weightZpVal).failed())
156  return rewriter.notifyMatchFailure(
157  op, "weight zero point must be zero for non-int8 integer types");
158 
159  // construct pad_const values from zp values
160  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
161  const Value inputPadConst =
162  createPadConstTensor(builder, op->getLoc(), input, inputZpVal);
163  const Value weightPadConst =
164  createPadConstTensor(builder, op->getLoc(), input, weightZpVal);
165 
166  weight = CreateOpAndInferShape<tosa::PadOp>(
167  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
168  weightPaddingVal, weightPadConst);
169 
170  weightTy = cast<ShapedType>(weight.getType());
171  weightHeight = weightTy.getDimSize(1);
172  weightWidth = weightTy.getDimSize(2);
173 
174  // Split out the width / height by the stride dimensions.
175  llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
176  outputChannels, weightHeight / stride[0],
177  stride[0], weightWidth / stride[1],
178  stride[1], inputChannels};
179 
180  weight = CreateOpAndInferShape<tosa::ReshapeOp>(
181  builder, UnrankedTensorType::get(weightETy), weight,
182  getTosaConstShape(rewriter, loc, weightReshapeDims0));
183 
184  // Transpose the factored-out stride to the output channels.
185  weight = CreateOpAndInferShape<tosa::TransposeOp>(
186  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
187  rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));
188 
189  // Collapse the strides and output channels into a single dimension.
190  llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
191  outputChannels * stride[0] * stride[1], weightHeight / stride[0],
192  weightWidth / stride[1], inputChannels};
193 
194  weight = CreateOpAndInferShape<tosa::ReshapeOp>(
195  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
196  getTosaConstShape(rewriter, loc, weightReshapeDims1));
197  ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
198 
199  weight = CreateOpAndInferShape<tosa::ReverseOp>(
200  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
201  /* axis = */ rewriter.getI32IntegerAttr(1));
202  weight = CreateOpAndInferShape<tosa::ReverseOp>(
203  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
204  /* axis = */ rewriter.getI32IntegerAttr(2));
205 
206  // We need to pad the input far enough that we can pull all values.
207  llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
208  inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
209  inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
210  inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
211  inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
212 
213  Value inputPaddingVal =
214  getTosaConstShape(rewriter, op->getLoc(), inputPadding);
215 
216  input = CreateOpAndInferShape<tosa::PadOp>(
217  rewriter, loc, UnrankedTensorType::get(inputETy), input,
218  inputPaddingVal, inputPadConst);
219 
220  // We use a zero bias as we need to broadcast the bias.
221  auto zeroBias = rewriter.create<tosa::ConstOp>(
222  loc,
223  RankedTensorType::get({outputChannels * stride[0] * stride[1]},
224  biasETy),
226  RankedTensorType::get({outputChannels * stride[0] * stride[1]},
227  biasETy),
228  rewriter.getZeroAttr(biasETy)));
229 
230  auto inputZp =
231  createZeroPointTensor(rewriter, loc, input.getType(), inputZpVal);
232  auto weightZp =
233  createZeroPointTensor(rewriter, loc, weight.getType(), weightZpVal);
234 
235  if (!inputZp.has_value() || !weightZp.has_value()) {
236  return rewriter.notifyMatchFailure(
237  op, "fail to create a const zero point tensor");
238  }
239 
240  // Perform the convolution using the zero bias.
241  Value conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
242  rewriter, loc, UnrankedTensorType::get(resultETy), input,
243  weight, zeroBias, inputZp.value(), weightZp.value(),
244  /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
245  /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
246  /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
247  /* acc_type = */ op.getAccType())
248  .getResult();
249 
250  // Factor the resulting width / height.
251  ShapedType convTy = cast<ShapedType>(conv2d.getType());
252  Type convETy = convTy.getElementType();
253 
254  int64_t convHeight = convTy.getDimSize(1);
255  int64_t convWidth = convTy.getDimSize(2);
256 
257  // Factor striding out of the convolution result.
258  llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
259  batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
260 
261  auto convReshapeDims0Value =
262  getTosaConstShape(rewriter, loc, convReshapeDims0);
263 
264  conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
265  rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
266  convReshapeDims0Value);
267 
268  // Transpose the factored-out stride to the output channels.
269  conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
270  rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
271  rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));
272 
273  // Fuse striding behavior back into width / height.
274  llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
275  batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
276 
277  auto convReshapeDims1Value =
278  getTosaConstShape(rewriter, loc, convReshapeDims1);
279 
280  conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
281  rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
282  convReshapeDims1Value);
283 
284  // Determine the amount to slice / pad from the result start.
285  int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
286  int64_t resultSliceLeft = std::max<int64_t>(0, -pad[2]);
287  int64_t resultPadTop = std::max<int64_t>(0, pad[0]);
288  int64_t resultPadLeft = std::max<int64_t>(0, pad[2]);
289 
290  // Try to slice the targetted result size, cap to the convolutions width.
291  int64_t resultSliceHeight =
292  std::min<int64_t>(convReshapeDims1[1] - resultSliceTop,
293  resultTy.getDimSize(1) - resultPadTop);
294  int64_t resultSliceWidth =
295  std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft,
296  resultTy.getDimSize(2) - resultPadLeft);
297 
298  llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop,
299  resultSliceLeft, 0};
300  llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(),
301  convReshapeDims1.end());
302  sliceSize[1] = resultSliceHeight;
303  sliceSize[2] = resultSliceWidth;
304 
305  auto slice = CreateOpAndInferShape<tosa::SliceOp>(
306  rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
307  getTosaConstShape(rewriter, loc, sliceBegin),
308  getTosaConstShape(rewriter, loc, sliceSize))
309  .getResult();
310 
311  llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
312  resultPadding[2] = resultPadTop;
313  resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
314  resultPadding[4] = resultPadLeft;
315  resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
316 
317  Value resultPaddingVal =
318  getTosaConstShape(rewriter, op->getLoc(), resultPadding);
319 
320  Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
321  rewriter, loc, UnrankedTensorType::get(resultETy), slice,
322  resultPaddingVal);
323 
324  if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) {
325  return failure();
326  }
327 
328  rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias);
329  return success();
330  }
331 };
332 
333 } // namespace
334 
337  patterns.add<TransposeConvNonStridedConverter>(ctx);
338  patterns.add<TransposeConvStridedConverter>(ctx);
339 }
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition: TosaOps.cpp:3300
void populateTosaDecomposeTransposeConv(MLIRContext *ctx, RewritePatternSet &patterns)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition: TosaOps.cpp:220
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358