MLIR  20.0.0git
TosaDecomposeTransposeConv.cpp
Go to the documentation of this file.
1 //===- TosaDecomposeTransposeConv.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
10 // (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
11 // etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
12 // including transposing/reversing/reshaping etc..
13 // of the weights and input/output tenors and reversing/reshaping etc .. of
14 // the weights
15 //
16 //===----------------------------------------------------------------------===//
17 
22 #include "mlir/Pass/Pass.h"
23 
24 using namespace mlir;
25 using namespace mlir::tosa;
26 
27 namespace {
28 
29 class TransposeConvNonStridedConverter
30  : public OpRewritePattern<tosa::TransposeConv2DOp> {
31 public:
33  LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
34  PatternRewriter &rewriter) const final {
35  Location loc = op->getLoc();
36  Value input = op->getOperand(0);
37  Value weight = op->getOperand(1);
38  Value bias = op->getOperand(2);
39 
40  ShapedType inputTy = cast<ShapedType>(input.getType());
41  ShapedType weightTy = cast<ShapedType>(weight.getType());
42  ShapedType biasTy = cast<ShapedType>(bias.getType());
43  ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
44 
45  llvm::ArrayRef<int64_t> stride = op.getStride();
46  llvm::ArrayRef<int64_t> pad = op.getOutPad();
47 
48  // If striding is all 1 we can modify padding and reverse the kernel along
49  // the x/y direction to make it a regular convolution. This is much simpler
50  // then handling striding....
51  if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
52  return failure();
53 
54  if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
55  !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
56  return failure();
57 
58  int64_t kernelHeight = weightTy.getDimSize(1);
59  int64_t kernelWidth = weightTy.getDimSize(2);
60 
61  llvm::SmallVector<int64_t> convPad(4, 0);
62  convPad[0] = kernelHeight - 1 + pad[0];
63  convPad[1] = kernelHeight - 1 + pad[1];
64  convPad[2] = kernelWidth - 1 + pad[2];
65  convPad[3] = kernelWidth - 1 + pad[3];
66 
67  auto reverse1 = rewriter.create<tosa::ReverseOp>(
68  loc, weightTy, weight, /* axis = */ rewriter.getI32IntegerAttr(1));
69  auto reverse2 = rewriter.create<tosa::ReverseOp>(
70  loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2));
71 
72  Value conv2d;
73  if (op.getQuantizationInfo()) {
74  conv2d = rewriter.create<tosa::Conv2DOp>(
75  loc, resultTy, input, reverse2, bias,
76  rewriter.getDenseI64ArrayAttr(convPad),
77  rewriter.getDenseI64ArrayAttr(stride),
78  rewriter.getDenseI64ArrayAttr({1, 1}), *op.getQuantizationInfo());
79  } else {
80  conv2d = rewriter.create<tosa::Conv2DOp>(
81  loc, resultTy, input, reverse2, bias,
82  rewriter.getDenseI64ArrayAttr(convPad),
83  rewriter.getDenseI64ArrayAttr(stride),
84  rewriter.getDenseI64ArrayAttr({1, 1}));
85  }
86 
87  rewriter.replaceOp(op, conv2d);
88  return success();
89  }
90 };
91 
92 class TransposeConvStridedConverter
93  : public OpRewritePattern<tosa::TransposeConv2DOp> {
94 public:
96  LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
97  PatternRewriter &rewriter) const final {
98  Location loc = op->getLoc();
99  Value input = op->getOperand(0);
100  Value weight = op->getOperand(1);
101  Value bias = op->getOperand(2);
102 
103  ShapedType inputTy = cast<ShapedType>(input.getType());
104  ShapedType weightTy = cast<ShapedType>(weight.getType());
105  ShapedType biasTy = cast<ShapedType>(bias.getType());
106  ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
107 
108  Type inputETy = inputTy.getElementType();
109  Type weightETy = weightTy.getElementType();
110  Type biasETy = biasTy.getElementType();
111  Type resultETy = resultTy.getElementType();
112 
113  llvm::ArrayRef<int64_t> pad = op.getOutPad();
114  llvm::ArrayRef<int64_t> stride = op.getStride();
115 
116  // If striding is all 1 we can modify padding and reverse the kernel along
117  // the x/y direction to make it a regular convolution. This is much simpler
118  // then handling striding....
119 
120  // If strides are all 1 we dont need to use this one.
121  if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
122  return rewriter.notifyMatchFailure(op, "non-one stride found.");
123 
124  if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
125  !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
126  return failure();
127 
128  int64_t batch = inputTy.getDimSize(0);
129 
130  int64_t outputChannels = weightTy.getDimSize(0);
131  int64_t weightHeight = weightTy.getDimSize(1);
132  int64_t weightWidth = weightTy.getDimSize(2);
133  int64_t inputChannels = weightTy.getDimSize(3);
134 
135  // Pad the weight so that it is modulo of the striding.
136  llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
137  weightPadding[3] =
138  (weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
139  weightPadding[5] =
140  (weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
141  DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
142  RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
143  Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
144  rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
145 
146  if (op.getQuantizationInfo().has_value()) {
147  auto quantInfo = op.getQuantizationInfo().value();
148  weight = CreateOpAndInferShape<tosa::PadOp>(
149  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
150  weightPaddingVal, nullptr,
151  rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
152 
153  } else {
154  weight = CreateOpAndInferShape<tosa::PadOp>(
155  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
156  weightPaddingVal);
157  }
158 
159  weightTy = cast<ShapedType>(weight.getType());
160  weightHeight = weightTy.getDimSize(1);
161  weightWidth = weightTy.getDimSize(2);
162 
163  // Split out the width / height by the stride dimensions.
164  llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
165  outputChannels, weightHeight / stride[0],
166  stride[0], weightWidth / stride[1],
167  stride[1], inputChannels};
168  weight = CreateOpAndInferShape<tosa::ReshapeOp>(
169  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
170  rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
171 
172  // Transpose the factored-out stride to the output channels.
173  Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
174  loc, RankedTensorType::get({6}, rewriter.getI32Type()),
175  rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
176 
177  weight = CreateOpAndInferShape<tosa::TransposeOp>(
178  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
179  transposeWeightVal);
180 
181  // Collapse the strides and output channels into a single dimension.
182  llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
183  outputChannels * stride[0] * stride[1], weightHeight / stride[0],
184  weightWidth / stride[1], inputChannels};
185  weight = CreateOpAndInferShape<tosa::ReshapeOp>(
186  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
187  rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
188  ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
189 
190  weight = CreateOpAndInferShape<tosa::ReverseOp>(
191  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
192  /* axis = */ rewriter.getI32IntegerAttr(1));
193  weight = CreateOpAndInferShape<tosa::ReverseOp>(
194  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
195  /* axis = */ rewriter.getI32IntegerAttr(2));
196 
197  // We need to pad the input far enough that we can pull all values.
198  llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
199  inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
200  inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
201  inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
202  inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
203 
204  DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
205  RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
206 
207  Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
208  rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
209 
210  if (op.getQuantizationInfo().has_value()) {
211  auto quantInfo = op.getQuantizationInfo().value();
212  input = CreateOpAndInferShape<tosa::PadOp>(
213  rewriter, loc, UnrankedTensorType::get(inputETy), input,
214  inputPaddingVal, nullptr,
215  rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
216  } else {
217  input = CreateOpAndInferShape<tosa::PadOp>(
218  rewriter, loc, UnrankedTensorType::get(inputETy), input,
219  inputPaddingVal);
220  }
221 
222  // We use a zero bias as we need to broadcast the bias.
223  auto zeroBias = rewriter.create<tosa::ConstOp>(
224  loc,
225  RankedTensorType::get({outputChannels * stride[0] * stride[1]},
226  biasETy),
228  RankedTensorType::get({outputChannels * stride[0] * stride[1]},
229  biasETy),
230  rewriter.getZeroAttr(biasETy)));
231 
232  // Perform the convolution using the zero bias.
233  Value conv2d;
234  if (op.getQuantizationInfo()) {
235  conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
236  rewriter, loc, UnrankedTensorType::get(resultETy), input,
237  weight, zeroBias,
238  /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
239  /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
240  /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
241  *op.getQuantizationInfo())
242  .getResult();
243  } else {
244  conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
245  rewriter, loc, UnrankedTensorType::get(resultETy), input,
246  weight, zeroBias,
247  /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
248  /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
249  /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}))
250  .getResult();
251  }
252 
253  // Factor the resulting width / height.
254  ShapedType convTy = cast<ShapedType>(conv2d.getType());
255  Type convETy = convTy.getElementType();
256 
257  int64_t convHeight = convTy.getDimSize(1);
258  int64_t convWidth = convTy.getDimSize(2);
259 
260  // Factor striding out of the convolution result.
261  llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
262  batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
263  conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
264  rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
265  rewriter.getDenseI64ArrayAttr(convReshapeDims0));
266 
267  // Transpose the factored-out stride to the output channels.
268  Value transposeConvVal = rewriter.create<tosa::ConstOp>(
269  loc, RankedTensorType::get({6}, rewriter.getI32Type()),
270  rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
271 
272  conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
273  rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
274  transposeConvVal);
275 
276  // Fuse striding behavior back into width / height.
277  llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
278  batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
279  conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
280  rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
281  rewriter.getDenseI64ArrayAttr(convReshapeDims1));
282 
283  // Determine the amount to slice / pad from the result start.
284  int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
285  int64_t resultSliceLeft = std::max<int64_t>(0, -pad[2]);
286  int64_t resultPadTop = std::max<int64_t>(0, pad[0]);
287  int64_t resultPadLeft = std::max<int64_t>(0, pad[2]);
288 
289  // Try to slice the targetted result size, cap to the convolutions width.
290  int64_t resultSliceHeight =
291  std::min<int64_t>(convReshapeDims1[1] - resultSliceTop,
292  resultTy.getDimSize(1) - resultPadTop);
293  int64_t resultSliceWidth =
294  std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft,
295  resultTy.getDimSize(2) - resultPadLeft);
296 
297  llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop,
298  resultSliceLeft, 0};
299  llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(),
300  convReshapeDims1.end());
301  sliceSize[1] = resultSliceHeight;
302  sliceSize[2] = resultSliceWidth;
303 
304  auto slice = CreateOpAndInferShape<tosa::SliceOp>(
305  rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
306  rewriter.getDenseI64ArrayAttr(sliceBegin),
307  rewriter.getDenseI64ArrayAttr(sliceSize))
308  .getResult();
309 
310  llvm::SmallVector<int32_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
311  resultPadding[2] = resultPadTop;
312  resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
313  resultPadding[4] = resultPadLeft;
314  resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
315 
316  DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
317  RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding);
318 
319  Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
320  rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
321 
322  Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
323  rewriter, loc, UnrankedTensorType::get(resultETy), slice,
324  resultPaddingVal);
325 
326  if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) {
327  return failure();
328  }
329 
330  rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias);
331  return success();
332  }
333 };
334 
335 } // namespace
336 
338  MLIRContext *ctx, RewritePatternSet &patterns) {
339  patterns.add<TransposeConvNonStridedConverter>(ctx);
340  patterns.add<TransposeConvStridedConverter>(ctx);
341 }
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
void populateTosaDecomposeTransposeConv(MLIRContext *ctx, RewritePatternSet &patterns)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358