MLIR  19.0.0git
TosaDecomposeTransposeConv.cpp
Go to the documentation of this file.
1 //===- TosaDecomposeTransposeConv.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
10 // (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
11 // etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
12 // including transposing/reversing/reshaping etc..
13 // of the weights and input/output tenors and reversing/reshaping etc .. of
14 // the weights
15 //
16 //===----------------------------------------------------------------------===//
17 
22 #include "mlir/Pass/Pass.h"
23 
24 using namespace mlir;
25 using namespace mlir::tosa;
26 
27 namespace {
28 
29 template <typename TosaOp, typename... Args>
30 TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
31  Args &&...args) {
32  auto op = rewriter.create<TosaOp>(loc, resultTy, args...);
33 
34  InferShapedTypeOpInterface shapeInterface =
35  dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
36  if (!shapeInterface)
37  return op;
38 
39  SmallVector<ShapedTypeComponents> returnedShapes;
40  if (shapeInterface
41  .inferReturnTypeComponents(
42  op.getContext(), op.getLoc(), op->getOperands(),
44  op->getRegions(), returnedShapes)
45  .failed())
46  return op;
47 
48  // We need to use the element type of the existing result type to generate
49  // the new result shaped type. This is because rescale can include a cast to
50  // different bit-width types and does not have a TypeAttr to define the
51  // target type.
52  auto result = op->getResult(0);
53  auto predictedShape = returnedShapes[0];
54  auto currentKnowledge =
56 
57  // Compute the knowledge based on the inferred type.
58  auto inferredKnowledge =
60  inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
61  inferredKnowledge.hasRank = predictedShape.hasRank();
62  if (predictedShape.hasRank()) {
63  for (auto dim : predictedShape.getDims()) {
64  inferredKnowledge.sizes.push_back(dim);
65  }
66  }
67 
68  // Compute the new type based on the joined version.
69  auto newKnowledge =
70  mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge);
71  auto newTy = newKnowledge.getType();
72  result.setType(newTy);
73  return op;
74 }
75 
76 class TransposeConvNonStridedConverter
77  : public OpRewritePattern<tosa::TransposeConv2DOp> {
78 public:
80  LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
81  PatternRewriter &rewriter) const final {
82  Location loc = op->getLoc();
83  Value input = op->getOperand(0);
84  Value weight = op->getOperand(1);
85  Value bias = op->getOperand(2);
86 
87  ShapedType inputTy = cast<ShapedType>(input.getType());
88  ShapedType weightTy = cast<ShapedType>(weight.getType());
89  ShapedType biasTy = cast<ShapedType>(bias.getType());
90  ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
91 
92  llvm::ArrayRef<int64_t> stride = op.getStride();
93  llvm::ArrayRef<int64_t> pad = op.getOutPad();
94 
95  // If striding is all 1 we can modify padding and reverse the kernel along
96  // the x/y direction to make it a regular convolution. This is much simpler
97  // then handling striding....
98  if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
99  return failure();
100 
101  if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
102  !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
103  return failure();
104 
105  int64_t kernelHeight = weightTy.getDimSize(1);
106  int64_t kernelWidth = weightTy.getDimSize(2);
107 
108  llvm::SmallVector<int64_t> convPad(4, 0);
109  convPad[0] = kernelHeight - 1 + pad[0];
110  convPad[1] = kernelHeight - 1 + pad[1];
111  convPad[2] = kernelWidth - 1 + pad[2];
112  convPad[3] = kernelWidth - 1 + pad[3];
113 
114  auto reverse1 = rewriter.create<tosa::ReverseOp>(
115  loc, weightTy, weight, /* axis = */ rewriter.getI32IntegerAttr(1));
116  auto reverse2 = rewriter.create<tosa::ReverseOp>(
117  loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2));
118 
119  Value conv2d;
120  if (op.getQuantizationInfo()) {
121  conv2d = rewriter.create<tosa::Conv2DOp>(
122  loc, resultTy, input, reverse2, bias,
123  rewriter.getDenseI64ArrayAttr(convPad),
124  rewriter.getDenseI64ArrayAttr(stride),
125  rewriter.getDenseI64ArrayAttr({1, 1}), *op.getQuantizationInfo());
126  } else {
127  conv2d = rewriter.create<tosa::Conv2DOp>(
128  loc, resultTy, input, reverse2, bias,
129  rewriter.getDenseI64ArrayAttr(convPad),
130  rewriter.getDenseI64ArrayAttr(stride),
131  rewriter.getDenseI64ArrayAttr({1, 1}));
132  }
133 
134  rewriter.replaceOp(op, conv2d);
135  return success();
136  }
137 };
138 
139 class TransposeConvStridedConverter
140  : public OpRewritePattern<tosa::TransposeConv2DOp> {
141 public:
143  LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
144  PatternRewriter &rewriter) const final {
145  Location loc = op->getLoc();
146  Value input = op->getOperand(0);
147  Value weight = op->getOperand(1);
148  Value bias = op->getOperand(2);
149 
150  ShapedType inputTy = cast<ShapedType>(input.getType());
151  ShapedType weightTy = cast<ShapedType>(weight.getType());
152  ShapedType biasTy = cast<ShapedType>(bias.getType());
153  ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
154 
155  Type inputETy = inputTy.getElementType();
156  Type weightETy = weightTy.getElementType();
157  Type biasETy = biasTy.getElementType();
158  Type resultETy = resultTy.getElementType();
159 
160  llvm::ArrayRef<int64_t> pad = op.getOutPad();
161  llvm::ArrayRef<int64_t> stride = op.getStride();
162 
163  // If striding is all 1 we can modify padding and reverse the kernel along
164  // the x/y direction to make it a regular convolution. This is much simpler
165  // then handling striding....
166 
167  // If strides are all 1 we dont need to use this one.
168  if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
169  return rewriter.notifyMatchFailure(op, "non-one stride found.");
170 
171  if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
172  !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
173  return failure();
174 
175  int64_t batch = inputTy.getDimSize(0);
176 
177  int64_t outputChannels = weightTy.getDimSize(0);
178  int64_t weightHeight = weightTy.getDimSize(1);
179  int64_t weightWidth = weightTy.getDimSize(2);
180  int64_t inputChannels = weightTy.getDimSize(3);
181 
182  // Pad the weight so that it is modulo of the striding.
183  llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
184  weightPadding[3] =
185  weightHeight % stride[0] ? stride[0] - weightHeight % stride[0] : 0;
186  weightPadding[5] =
187  weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
188  DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
189  RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
190  Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>(
191  rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
192 
193  if (op.getQuantizationInfo().has_value()) {
194  auto quantInfo = op.getQuantizationInfo().value();
195  weight = createOpAndInfer<tosa::PadOp>(
196  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
197  weightPaddingVal, nullptr,
198  rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
199 
200  } else {
201  weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
202  UnrankedTensorType::get(weightETy),
203  weight, weightPaddingVal);
204  }
205 
206  weightTy = cast<ShapedType>(weight.getType());
207  weightHeight = weightTy.getDimSize(1);
208  weightWidth = weightTy.getDimSize(2);
209 
210  // Split out the width / height by the stride dimensions.
211  llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
212  outputChannels, weightHeight / stride[0],
213  stride[0], weightWidth / stride[1],
214  stride[1], inputChannels};
215  weight = createOpAndInfer<tosa::ReshapeOp>(
216  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
217  rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
218 
219  // Transpose the factored-out stride to the output channels.
220  Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
221  loc, RankedTensorType::get({6}, rewriter.getI32Type()),
222  rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
223 
224  weight = createOpAndInfer<tosa::TransposeOp>(
225  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
226  transposeWeightVal);
227 
228  // Collapse the strides and output channels into a single dimension.
229  llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
230  outputChannels * stride[0] * stride[1], weightHeight / stride[0],
231  weightWidth / stride[1], inputChannels};
232  weight = createOpAndInfer<tosa::ReshapeOp>(
233  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
234  rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
235  ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
236 
237  weight = createOpAndInfer<tosa::ReverseOp>(
238  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
239  /* axis = */ rewriter.getI32IntegerAttr(1));
240  weight = createOpAndInfer<tosa::ReverseOp>(
241  rewriter, loc, UnrankedTensorType::get(weightETy), weight,
242  /* axis = */ rewriter.getI32IntegerAttr(2));
243 
244  // We need to pad the input far enough that we can pull all values.
245  llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
246  inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
247  inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
248  inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
249  inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
250 
251  DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
252  RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
253 
254  Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>(
255  rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
256 
257  if (op.getQuantizationInfo().has_value()) {
258  auto quantInfo = op.getQuantizationInfo().value();
259  input = createOpAndInfer<tosa::PadOp>(
260  rewriter, loc, UnrankedTensorType::get(inputETy), input,
261  inputPaddingVal, nullptr,
262  rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
263  } else {
264  input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
265  UnrankedTensorType::get(inputETy),
266  input, inputPaddingVal);
267  }
268 
269  // We use a zero bias as we need to broadcast the bias.
270  auto zeroBias = rewriter.create<tosa::ConstOp>(
271  loc,
272  RankedTensorType::get({outputChannels * stride[0] * stride[1]},
273  biasETy),
275  RankedTensorType::get({outputChannels * stride[0] * stride[1]},
276  biasETy),
277  rewriter.getZeroAttr(biasETy)));
278 
279  // Perform the convolution using the zero bias.
280  Value conv2d;
281  if (op.getQuantizationInfo()) {
282  conv2d = createOpAndInfer<tosa::Conv2DOp>(
283  rewriter, loc, UnrankedTensorType::get(resultETy), input,
284  weight, zeroBias,
285  /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
286  /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
287  /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
288  *op.getQuantizationInfo())
289  .getResult();
290  } else {
291  conv2d = createOpAndInfer<tosa::Conv2DOp>(
292  rewriter, loc, UnrankedTensorType::get(resultETy), input,
293  weight, zeroBias,
294  /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
295  /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
296  /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}))
297  .getResult();
298  }
299 
300  // Factor the resulting width / height.
301  ShapedType convTy = cast<ShapedType>(conv2d.getType());
302  Type convETy = convTy.getElementType();
303 
304  int64_t convHeight = convTy.getDimSize(1);
305  int64_t convWidth = convTy.getDimSize(2);
306 
307  // Factor striding out of the convolution result.
308  llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
309  batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
310  conv2d = createOpAndInfer<tosa::ReshapeOp>(
311  rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
312  rewriter.getDenseI64ArrayAttr(convReshapeDims0));
313 
314  // Transpose the factored-out stride to the output channels.
315  Value transposeConvVal = rewriter.create<tosa::ConstOp>(
316  loc, RankedTensorType::get({6}, rewriter.getI32Type()),
317  rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
318 
319  conv2d = createOpAndInfer<tosa::TransposeOp>(
320  rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
321  transposeConvVal);
322 
323  // Fuse striding behavior back into width / height.
324  llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
325  batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
326  conv2d = createOpAndInfer<tosa::ReshapeOp>(
327  rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
328  rewriter.getDenseI64ArrayAttr(convReshapeDims1));
329 
330  // Determine the amount to slice / pad from the result start.
331  int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
332  int64_t resultSliceLeft = std::max<int64_t>(0, -pad[2]);
333  int64_t resultPadTop = std::max<int64_t>(0, pad[0]);
334  int64_t resultPadLeft = std::max<int64_t>(0, pad[2]);
335 
336  // Try to slice the targetted result size, cap to the convolutions width.
337  int64_t resultSliceHeight =
338  std::min<int64_t>(convReshapeDims1[1] - resultSliceTop,
339  resultTy.getDimSize(1) - resultPadTop);
340  int64_t resultSliceWidth =
341  std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft,
342  resultTy.getDimSize(2) - resultPadLeft);
343 
344  llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop,
345  resultSliceLeft, 0};
346  llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(),
347  convReshapeDims1.end());
348  sliceSize[1] = resultSliceHeight;
349  sliceSize[2] = resultSliceWidth;
350 
351  auto slice = createOpAndInfer<tosa::SliceOp>(
352  rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
353  rewriter.getDenseI64ArrayAttr(sliceBegin),
354  rewriter.getDenseI64ArrayAttr(sliceSize))
355  .getResult();
356 
357  llvm::SmallVector<int32_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
358  resultPadding[2] = resultPadTop;
359  resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
360  resultPadding[4] = resultPadLeft;
361  resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
362 
363  DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
364  RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding);
365 
366  Value resultPaddingVal = createOpAndInfer<tosa::ConstOp>(
367  rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
368 
369  Value resultPad = createOpAndInfer<tosa::PadOp>(
370  rewriter, loc, UnrankedTensorType::get(resultETy), slice,
371  resultPaddingVal);
372 
373  if (EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) {
374  return failure();
375  }
376 
377  rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias);
378  return success();
379  }
380 };
381 
382 } // namespace
383 
385  MLIRContext *ctx, RewritePatternSet &patterns) {
386  patterns.add<TransposeConvNonStridedConverter>(ctx);
387  patterns.add<TransposeConvStridedConverter>(ctx);
388 }
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
Definition: Operation.h:496
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
Definition: Operation.h:896
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
void populateTosaDecomposeTransposeConv(MLIRContext *ctx, RewritePatternSet &patterns)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool failed() const
Returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:44
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:81
static ValueKnowledge getPessimisticValueState()
Definition: ShapeUtils.h:61
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45