MLIR  20.0.0git
TosaDecomposeConv2D.cpp
Go to the documentation of this file.
1 //===- TosaDecomposeConv2D.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA Conv2D operation to a series of TOSA Ops specifically
10 // (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape
11 //
12 //===----------------------------------------------------------------------===//
13 
17 
18 using namespace mlir;
19 using namespace mlir::tosa;
20 
21 namespace {
22 
23 SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
24  return to_vector(llvm::map_range(shape, [](int64_t dim) {
25  return ShapedType::isDynamic(dim) ? -1 : dim;
26  }));
27 }
28 
29 struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
30  explicit Conv2DIsFullyConnected(MLIRContext *context)
31  : OpRewritePattern(context) {}
32 
33  LogicalResult matchAndRewrite(tosa::Conv2DOp op,
34  PatternRewriter &rewriter) const override {
35  Value input = op.getInput();
36  Value weight = op.getWeight();
37  ShapedType inputType = cast<ShapedType>(input.getType());
38  ShapedType weightType = cast<ShapedType>(weight.getType());
39  ShapedType resultType = cast<ShapedType>(op.getType());
40 
41  auto numDynamic =
42  llvm::count_if(inputType.getShape(), ShapedType::isDynamic);
43  if (numDynamic > 1)
44  return rewriter.notifyMatchFailure(
45  op, "at most one dim in input may be dynamic");
46  if (!weightType.hasRank())
47  return rewriter.notifyMatchFailure(op, "unranked weight input");
48 
49  if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
50  return failure();
51 
52  // Only works for a 1x1 kernel.
53  ArrayRef<int64_t> weightShape = weightType.getShape();
54  if (weightShape[1] != 1 || weightShape[2] != 1)
55  return failure();
56 
57  llvm::ArrayRef<int64_t> padAttr = op.getPad();
59  for (const auto &it : llvm::enumerate(padAttr))
60  pad[it.index() + 2] = it.value();
61 
62  if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
63  Type inputETy = inputType.getElementType();
64  Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
65  if (op.getQuantizationInfo()) {
66  auto quantizationInfo = op.getQuantizationInfo();
67  int64_t iZp = quantizationInfo->getInputZp();
68 
69  if (!validIntegerRange(cast<IntegerType>(inputETy), iZp))
70  return rewriter.notifyMatchFailure(
71  op, "tosa.conv op quantization has zp outside of input range");
72 
73  zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
74  }
75 
76  llvm::SmallVector<int64_t> newShape(inputType.getShape());
77 
78  for (int i = 0, s = newShape.size(); i < s; ++i) {
79  if (newShape[i] != ShapedType::kDynamic) {
80  newShape[i] += pad[i * 2] + pad[i * 2 + 1];
81  }
82  }
83 
84  auto padSizeTy = RankedTensorType::get({4, 2}, rewriter.getI64Type());
85  auto padSize =
87  Value padSizeVal =
88  rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
89 
90  auto padTy = RankedTensorType::get({}, inputETy);
91  auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
92  Value padVal =
93  rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
94  inputType = RankedTensorType::get(newShape, inputETy);
95  input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input,
96  padSizeVal, padVal);
97  }
98 
99  // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
100  ArrayRef<int64_t> inputShape = inputType.getShape();
101  int64_t combined = ShapedType::kDynamic;
102  if (numDynamic == 0)
103  combined = inputShape[0] * inputShape[1] * inputShape[2];
104  llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
105  auto revisedInputShapeType =
106  RankedTensorType::get(revisedInputShape, inputType.getElementType());
107  auto reshapedInput = rewriter
108  .create<tosa::ReshapeOp>(
109  op.getLoc(), revisedInputShapeType, input,
110  rewriter.getDenseI64ArrayAttr(
111  convertFromMlirShape(revisedInputShape)))
112  .getResult();
113 
114  // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
115  llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
116  weightShape[3]};
117  auto revisedWeightShapeType = RankedTensorType::get(
118  revisedWeightShape,
119  dyn_cast<RankedTensorType>(weight.getType()).getElementType());
120  auto reshapedWeight = rewriter
121  .create<tosa::ReshapeOp>(
122  op.getLoc(), revisedWeightShapeType, weight,
123  rewriter.getDenseI64ArrayAttr(
124  convertFromMlirShape(revisedWeightShape)))
125  .getResult();
126 
127  // Perform a fully connected network over the reshaped input and weight.
128  llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
129  auto fullyConnectedShapeType =
130  RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
131 
132  Value fullyConnectedValue;
133  if (op.getQuantizationInfo()) {
134  fullyConnectedValue =
135  rewriter
136  .create<tosa::FullyConnectedOp>(
137  op.getLoc(), fullyConnectedShapeType, reshapedInput,
138  reshapedWeight, op.getBias(), *op.getQuantizationInfo())
139  .getResult();
140  } else {
141  fullyConnectedValue = rewriter
142  .create<tosa::FullyConnectedOp>(
143  op.getLoc(), fullyConnectedShapeType,
144  reshapedInput, reshapedWeight, op.getBias())
145  .getResult();
146  }
147 
148  // Reshape output to [N, IH, IW, OC].
149  llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
150  inputShape[2], weightShape[0]};
151  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
152  op, resultType, fullyConnectedValue,
153  rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape)));
154  return success();
155  }
156 };
157 
158 } // namespace
159 
161  RewritePatternSet &patterns) {
162  patterns.add<Conv2DIsFullyConnected>(ctx);
163 }
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
IntegerType getI64Type()
Definition: Builders.cpp:109
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns)
bool validIntegerRange(IntegerType ty, int64_t value)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358