29 class TransposeConvNonStridedConverter
33 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
36 Value input = op->getOperand(0);
37 Value weight = op->getOperand(1);
38 Value bias = op->getOperand(2);
40 ShapedType inputTy = cast<ShapedType>(input.
getType());
41 ShapedType weightTy = cast<ShapedType>(weight.
getType());
42 ShapedType biasTy = cast<ShapedType>(bias.
getType());
43 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
51 if (llvm::any_of(stride, [](int64_t v) {
return v != 1; }))
54 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
55 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
58 int64_t kernelHeight = weightTy.getDimSize(1);
59 int64_t kernelWidth = weightTy.getDimSize(2);
62 convPad[0] = kernelHeight - 1 + pad[0];
63 convPad[1] = kernelHeight - 1 + pad[1];
64 convPad[2] = kernelWidth - 1 + pad[2];
65 convPad[3] = kernelWidth - 1 + pad[3];
67 auto reverse1 = rewriter.create<tosa::ReverseOp>(
68 loc, weightTy, weight, rewriter.getI32IntegerAttr(1));
69 auto reverse2 = rewriter.create<tosa::ReverseOp>(
70 loc, weightTy, reverse1, rewriter.getI32IntegerAttr(2));
73 if (op.getQuantizationInfo()) {
74 conv2d = rewriter.create<tosa::Conv2DOp>(
75 loc, resultTy, input, reverse2, bias,
76 rewriter.getDenseI64ArrayAttr(convPad),
77 rewriter.getDenseI64ArrayAttr(stride),
78 rewriter.getDenseI64ArrayAttr({1, 1}), *op.getQuantizationInfo());
80 conv2d = rewriter.create<tosa::Conv2DOp>(
81 loc, resultTy, input, reverse2, bias,
82 rewriter.getDenseI64ArrayAttr(convPad),
83 rewriter.getDenseI64ArrayAttr(stride),
84 rewriter.getDenseI64ArrayAttr({1, 1}));
87 rewriter.replaceOp(op, conv2d);
92 class TransposeConvStridedConverter
96 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
99 Value input = op->getOperand(0);
100 Value weight = op->getOperand(1);
101 Value bias = op->getOperand(2);
103 ShapedType inputTy = cast<ShapedType>(input.
getType());
104 ShapedType weightTy = cast<ShapedType>(weight.
getType());
105 ShapedType biasTy = cast<ShapedType>(bias.
getType());
106 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
108 Type inputETy = inputTy.getElementType();
109 Type weightETy = weightTy.getElementType();
110 Type biasETy = biasTy.getElementType();
111 Type resultETy = resultTy.getElementType();
121 if (llvm::all_of(stride, [](int64_t v) {
return v == 1; }))
122 return rewriter.notifyMatchFailure(op,
"non-one stride found.");
124 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
125 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
128 int64_t batch = inputTy.getDimSize(0);
130 int64_t outputChannels = weightTy.getDimSize(0);
131 int64_t weightHeight = weightTy.getDimSize(1);
132 int64_t weightWidth = weightTy.getDimSize(2);
133 int64_t inputChannels = weightTy.getDimSize(3);
138 (weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
140 (weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
143 Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
144 rewriter, loc, weightPaddingAttr.
getType(), weightPaddingAttr);
146 if (op.getQuantizationInfo().has_value()) {
147 auto quantInfo = op.getQuantizationInfo().value();
148 weight = CreateOpAndInferShape<tosa::PadOp>(
150 weightPaddingVal,
nullptr,
151 rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
154 weight = CreateOpAndInferShape<tosa::PadOp>(
159 weightTy = cast<ShapedType>(weight.
getType());
160 weightHeight = weightTy.getDimSize(1);
161 weightWidth = weightTy.getDimSize(2);
165 outputChannels, weightHeight / stride[0],
166 stride[0], weightWidth / stride[1],
167 stride[1], inputChannels};
168 weight = CreateOpAndInferShape<tosa::ReshapeOp>(
170 rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
173 Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
175 rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
177 weight = CreateOpAndInferShape<tosa::TransposeOp>(
183 outputChannels * stride[0] * stride[1], weightHeight / stride[0],
184 weightWidth / stride[1], inputChannels};
185 weight = CreateOpAndInferShape<tosa::ReshapeOp>(
187 rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
188 ShapedType restridedWeightTy = cast<ShapedType>(weight.
getType());
190 weight = CreateOpAndInferShape<tosa::ReverseOp>(
192 rewriter.getI32IntegerAttr(1));
193 weight = CreateOpAndInferShape<tosa::ReverseOp>(
195 rewriter.getI32IntegerAttr(2));
199 inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
200 inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
201 inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
202 inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
207 Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
208 rewriter, loc, inputPaddingAttr.
getType(), inputPaddingAttr);
210 if (op.getQuantizationInfo().has_value()) {
211 auto quantInfo = op.getQuantizationInfo().value();
212 input = CreateOpAndInferShape<tosa::PadOp>(
214 inputPaddingVal,
nullptr,
215 rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
217 input = CreateOpAndInferShape<tosa::PadOp>(
223 auto zeroBias = rewriter.create<tosa::ConstOp>(
230 rewriter.getZeroAttr(biasETy)));
234 if (op.getQuantizationInfo()) {
235 conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
238 rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
239 rewriter.getDenseI64ArrayAttr({1, 1}),
240 rewriter.getDenseI64ArrayAttr({1, 1}),
241 *op.getQuantizationInfo())
244 conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
247 rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
248 rewriter.getDenseI64ArrayAttr({1, 1}),
249 rewriter.getDenseI64ArrayAttr({1, 1}))
254 ShapedType convTy = cast<ShapedType>(conv2d.getType());
255 Type convETy = convTy.getElementType();
257 int64_t convHeight = convTy.getDimSize(1);
258 int64_t convWidth = convTy.getDimSize(2);
262 batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
263 conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
265 rewriter.getDenseI64ArrayAttr(convReshapeDims0));
268 Value transposeConvVal = rewriter.create<tosa::ConstOp>(
270 rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
272 conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
278 batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
279 conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
281 rewriter.getDenseI64ArrayAttr(convReshapeDims1));
284 int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
285 int64_t resultSliceLeft = std::max<int64_t>(0, -pad[2]);
286 int64_t resultPadTop = std::max<int64_t>(0, pad[0]);
287 int64_t resultPadLeft = std::max<int64_t>(0, pad[2]);
290 int64_t resultSliceHeight =
291 std::min<int64_t>(convReshapeDims1[1] - resultSliceTop,
292 resultTy.getDimSize(1) - resultPadTop);
293 int64_t resultSliceWidth =
294 std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft,
295 resultTy.getDimSize(2) - resultPadLeft);
300 convReshapeDims1.end());
301 sliceSize[1] = resultSliceHeight;
302 sliceSize[2] = resultSliceWidth;
304 auto slice = CreateOpAndInferShape<tosa::SliceOp>(
306 rewriter.getDenseI64ArrayAttr(sliceBegin),
307 rewriter.getDenseI64ArrayAttr(sliceSize))
311 resultPadding[2] = resultPadTop;
312 resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
313 resultPadding[4] = resultPadLeft;
314 resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
319 Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
320 rewriter, loc, resultPaddingAttr.
getType(), resultPaddingAttr);
322 Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
326 if (
EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) {
330 rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias);
339 patterns.
add<TransposeConvNonStridedConverter>(ctx);
340 patterns.
add<TransposeConvStridedConverter>(ctx);
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
void populateTosaDecomposeTransposeConv(MLIRContext *ctx, RewritePatternSet &patterns)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...