29 class TransposeConvNonStridedConverter
33 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
36 Value input = op->getOperand(0);
37 Value weight = op->getOperand(1);
38 Value bias = op->getOperand(2);
40 ShapedType inputTy = cast<ShapedType>(input.
getType());
41 ShapedType weightTy = cast<ShapedType>(weight.
getType());
42 ShapedType biasTy = cast<ShapedType>(bias.
getType());
43 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
51 if (llvm::any_of(stride, [](int64_t v) {
return v != 1; }))
54 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
55 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
58 int64_t kernelHeight = weightTy.getDimSize(1);
59 int64_t kernelWidth = weightTy.getDimSize(2);
62 convPad[0] = kernelHeight - 1 + pad[0];
63 convPad[1] = kernelHeight - 1 + pad[1];
64 convPad[2] = kernelWidth - 1 + pad[2];
65 convPad[3] = kernelWidth - 1 + pad[3];
67 auto reverse1 = rewriter.create<tosa::ReverseOp>(
68 loc, weightTy, weight, rewriter.getI32IntegerAttr(1));
69 auto reverse2 = rewriter.create<tosa::ReverseOp>(
70 loc, weightTy, reverse1, rewriter.getI32IntegerAttr(2));
72 Value conv2d = rewriter.create<tosa::Conv2DOp>(
73 loc, resultTy, input, reverse2, bias, op.getInputZp(), op.getWeightZp(),
74 rewriter.getDenseI64ArrayAttr(convPad),
75 rewriter.getDenseI64ArrayAttr(stride),
76 rewriter.getDenseI64ArrayAttr({1, 1}),
79 rewriter.replaceOp(op, conv2d);
84 class TransposeConvStridedConverter
88 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
91 Value input = op->getOperand(0);
92 Value weight = op->getOperand(1);
93 Value bias = op->getOperand(2);
95 ShapedType inputTy = cast<ShapedType>(input.
getType());
96 ShapedType weightTy = cast<ShapedType>(weight.
getType());
97 ShapedType biasTy = cast<ShapedType>(bias.
getType());
98 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
100 Type inputETy = inputTy.getElementType();
101 Type weightETy = weightTy.getElementType();
102 Type biasETy = biasTy.getElementType();
103 Type resultETy = resultTy.getElementType();
113 if (llvm::all_of(stride, [](int64_t v) {
return v == 1; }))
114 return rewriter.notifyMatchFailure(op,
"non-one stride found.");
116 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
117 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
120 int64_t batch = inputTy.getDimSize(0);
122 int64_t outputChannels = weightTy.getDimSize(0);
123 int64_t weightHeight = weightTy.getDimSize(1);
124 int64_t weightWidth = weightTy.getDimSize(2);
125 int64_t inputChannels = weightTy.getDimSize(3);
130 (weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
132 weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
134 Value weightPaddingVal =
138 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
139 if (failed(maybeIZp))
140 return rewriter.notifyMatchFailure(
141 op,
"input zero point cannot be statically determined");
143 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
144 if (failed(maybeWZp))
145 return rewriter.notifyMatchFailure(
146 op,
"weight zero point cannot be statically determined");
148 int64_t inputZpVal = *maybeIZp;
149 int64_t weightZpVal = *maybeWZp;
151 if (op.verifyInputZeroPoint(inputZpVal).failed())
152 return rewriter.notifyMatchFailure(
153 op,
"input zero point must be zero for non-int8 integer types");
155 if (op.verifyWeightZeroPoint(weightZpVal).failed())
156 return rewriter.notifyMatchFailure(
157 op,
"weight zero point must be zero for non-int8 integer types");
161 const Value inputPadConst =
163 const Value weightPadConst =
166 weight = CreateOpAndInferShape<tosa::PadOp>(
168 weightPaddingVal, weightPadConst);
170 weightTy = cast<ShapedType>(weight.
getType());
171 weightHeight = weightTy.getDimSize(1);
172 weightWidth = weightTy.getDimSize(2);
176 outputChannels, weightHeight / stride[0],
177 stride[0], weightWidth / stride[1],
178 stride[1], inputChannels};
180 weight = CreateOpAndInferShape<tosa::ReshapeOp>(
185 weight = CreateOpAndInferShape<tosa::TransposeOp>(
187 rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));
191 outputChannels * stride[0] * stride[1], weightHeight / stride[0],
192 weightWidth / stride[1], inputChannels};
194 weight = CreateOpAndInferShape<tosa::ReshapeOp>(
197 ShapedType restridedWeightTy = cast<ShapedType>(weight.
getType());
199 weight = CreateOpAndInferShape<tosa::ReverseOp>(
201 rewriter.getI32IntegerAttr(1));
202 weight = CreateOpAndInferShape<tosa::ReverseOp>(
204 rewriter.getI32IntegerAttr(2));
208 inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
209 inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
210 inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
211 inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
213 Value inputPaddingVal =
216 input = CreateOpAndInferShape<tosa::PadOp>(
218 inputPaddingVal, inputPadConst);
221 auto zeroBias = rewriter.create<tosa::ConstOp>(
228 rewriter.getZeroAttr(biasETy)));
235 if (!inputZp.has_value() || !weightZp.has_value()) {
236 return rewriter.notifyMatchFailure(
237 op,
"fail to create a const zero point tensor");
241 Value conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
243 weight, zeroBias, inputZp.value(), weightZp.value(),
244 rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
245 rewriter.getDenseI64ArrayAttr({1, 1}),
246 rewriter.getDenseI64ArrayAttr({1, 1}),
251 ShapedType convTy = cast<ShapedType>(conv2d.
getType());
252 Type convETy = convTy.getElementType();
254 int64_t convHeight = convTy.getDimSize(1);
255 int64_t convWidth = convTy.getDimSize(2);
259 batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
261 auto convReshapeDims0Value =
264 conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
266 convReshapeDims0Value);
269 conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
271 rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));
275 batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
277 auto convReshapeDims1Value =
280 conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
282 convReshapeDims1Value);
285 int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
286 int64_t resultSliceLeft = std::max<int64_t>(0, -pad[2]);
287 int64_t resultPadTop = std::max<int64_t>(0, pad[0]);
288 int64_t resultPadLeft = std::max<int64_t>(0, pad[2]);
291 int64_t resultSliceHeight =
292 std::min<int64_t>(convReshapeDims1[1] - resultSliceTop,
293 resultTy.getDimSize(1) - resultPadTop);
294 int64_t resultSliceWidth =
295 std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft,
296 resultTy.getDimSize(2) - resultPadLeft);
301 convReshapeDims1.end());
302 sliceSize[1] = resultSliceHeight;
303 sliceSize[2] = resultSliceWidth;
305 auto slice = CreateOpAndInferShape<tosa::SliceOp>(
312 resultPadding[2] = resultPadTop;
313 resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
314 resultPadding[4] = resultPadLeft;
315 resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
317 Value resultPaddingVal =
320 Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
324 if (
EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) {
328 rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias);
337 patterns.add<TransposeConvNonStridedConverter>(ctx);
338 patterns.add<TransposeConvStridedConverter>(ctx);
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
void populateTosaDecomposeTransposeConv(MLIRContext *ctx, RewritePatternSet &patterns)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...