27class TransposeConvNonStridedConverter
30 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
31 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
32 PatternRewriter &rewriter)
const final {
33 Location loc = op->getLoc();
34 Value input = op->getOperand(0);
35 Value weight = op->getOperand(1);
36 Value bias = op->getOperand(2);
38 ShapedType inputTy = cast<ShapedType>(input.
getType());
39 ShapedType weightTy = cast<ShapedType>(weight.
getType());
40 ShapedType biasTy = cast<ShapedType>(bias.
getType());
41 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
43 llvm::ArrayRef<int64_t> stride = op.getStride();
44 llvm::ArrayRef<int64_t> pad = op.getOutPad();
49 if (llvm::any_of(stride, [](int64_t v) {
return v != 1; }))
53 for (
unsigned int i = 1; i < 4; ++i) {
54 if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
58 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
61 int64_t kernelHeight = weightTy.getDimSize(1);
62 int64_t kernelWidth = weightTy.getDimSize(2);
64 llvm::SmallVector<int64_t> convPad(4, 0);
65 convPad[0] = kernelHeight - 1 + pad[0];
66 convPad[1] = kernelHeight - 1 + pad[1];
67 convPad[2] = kernelWidth - 1 + pad[2];
68 convPad[3] = kernelWidth - 1 + pad[3];
71 tosa::ReverseOp::create(rewriter, loc, weightTy, weight,
72 rewriter.getI32IntegerAttr(1));
74 tosa::ReverseOp::create(rewriter, loc, weightTy, reverse1,
75 rewriter.getI32IntegerAttr(2));
77 Value conv2d = tosa::Conv2DOp::create(
78 rewriter, loc, resultTy, input, reverse2, bias, op.getInputZp(),
79 op.getWeightZp(), rewriter.getDenseI64ArrayAttr(convPad),
80 rewriter.getDenseI64ArrayAttr(stride),
81 rewriter.getDenseI64ArrayAttr({1, 1}),
84 rewriter.replaceOp(op, conv2d);
89class TransposeConvStridedConverter
92 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
93 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
94 PatternRewriter &rewriter)
const final {
95 Location loc = op->getLoc();
96 Value input = op->getOperand(0);
97 Value weight = op->getOperand(1);
98 Value bias = op->getOperand(2);
100 ShapedType inputTy = cast<ShapedType>(input.
getType());
101 ShapedType weightTy = cast<ShapedType>(weight.
getType());
102 ShapedType biasTy = cast<ShapedType>(bias.
getType());
103 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
105 Type inputETy = inputTy.getElementType();
106 Type weightETy = weightTy.getElementType();
107 Type biasETy = biasTy.getElementType();
108 Type resultETy = resultTy.getElementType();
110 llvm::ArrayRef<int64_t> pad = op.getOutPad();
111 llvm::ArrayRef<int64_t> stride = op.getStride();
118 if (llvm::all_of(stride, [](int64_t v) {
return v == 1; }))
119 return rewriter.notifyMatchFailure(op,
"non-one stride found.");
122 for (
unsigned int i = 1; i < 4; ++i) {
123 if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
127 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
130 int64_t batch = inputTy.getDimSize(0);
132 int64_t outputChannels = weightTy.getDimSize(0);
133 int64_t weightHeight = weightTy.getDimSize(1);
134 int64_t weightWidth = weightTy.getDimSize(2);
135 int64_t inputChannels = weightTy.getDimSize(3);
138 llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
140 (weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
142 (weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
144 Value weightPaddingVal =
148 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
150 return rewriter.notifyMatchFailure(
151 op,
"input zero point cannot be statically determined");
153 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
155 return rewriter.notifyMatchFailure(
156 op,
"weight zero point cannot be statically determined");
158 int64_t inputZpVal = *maybeIZp;
159 int64_t weightZpVal = *maybeWZp;
161 if (op.verifyInputZeroPoint(inputZpVal).failed())
162 return rewriter.notifyMatchFailure(
163 op,
"input zero point must be zero for non-int8 integer types");
165 if (op.verifyWeightZeroPoint(weightZpVal).failed())
166 return rewriter.notifyMatchFailure(
167 op,
"weight zero point must be zero for non-int8 integer types");
170 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
171 const Value inputPadConst =
173 const Value weightPadConst =
177 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
178 weightPaddingVal, weightPadConst);
180 weightTy = cast<ShapedType>(weight.
getType());
181 weightHeight = weightTy.getDimSize(1);
182 weightWidth = weightTy.getDimSize(2);
185 llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
186 outputChannels, weightHeight / stride[0],
187 stride[0], weightWidth / stride[1],
188 stride[1], inputChannels};
191 builder, UnrankedTensorType::get(weightETy), weight,
196 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
197 rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));
200 llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
201 outputChannels * stride[0] * stride[1], weightHeight / stride[0],
202 weightWidth / stride[1], inputChannels};
205 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
207 ShapedType restridedWeightTy = cast<ShapedType>(weight.
getType());
210 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
211 rewriter.getI32IntegerAttr(1));
213 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
214 rewriter.getI32IntegerAttr(2));
217 llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
218 inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
219 inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
220 inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
221 inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
223 Value inputPaddingVal =
227 rewriter, loc, UnrankedTensorType::get(inputETy), input,
228 inputPaddingVal, inputPadConst);
231 auto zeroBias = tosa::ConstOp::create(
233 RankedTensorType::get({outputChannels * stride[0] * stride[1]},
236 RankedTensorType::get({outputChannels * stride[0] * stride[1]},
238 rewriter.getZeroAttr(biasETy)));
245 if (!inputZp.has_value() || !weightZp.has_value()) {
246 return rewriter.notifyMatchFailure(
247 op,
"fail to create a const zero point tensor");
252 rewriter, loc, UnrankedTensorType::get(resultETy), input,
253 weight, zeroBias, inputZp.value(), weightZp.value(),
254 rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
255 rewriter.getDenseI64ArrayAttr({1, 1}),
256 rewriter.getDenseI64ArrayAttr({1, 1}),
261 ShapedType convTy = cast<ShapedType>(conv2d.
getType());
262 Type convETy = convTy.getElementType();
264 int64_t convHeight = convTy.getDimSize(1);
265 int64_t convWidth = convTy.getDimSize(2);
268 llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
269 batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
271 auto convReshapeDims0Value =
275 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
276 convReshapeDims0Value);
280 rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
281 rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));
284 llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
285 batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
287 auto convReshapeDims1Value =
291 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
292 convReshapeDims1Value);
295 int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
296 int64_t resultSliceLeft = std::max<int64_t>(0, -pad[2]);
297 int64_t resultPadTop = std::max<int64_t>(0, pad[0]);
298 int64_t resultPadLeft = std::max<int64_t>(0, pad[2]);
301 int64_t resultSliceHeight =
302 std::min<int64_t>(convReshapeDims1[1] - resultSliceTop,
303 resultTy.getDimSize(1) - resultPadTop);
304 int64_t resultSliceWidth =
305 std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft,
306 resultTy.getDimSize(2) - resultPadLeft);
308 llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop,
310 llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(),
311 convReshapeDims1.end());
312 sliceSize[1] = resultSliceHeight;
313 sliceSize[2] = resultSliceWidth;
316 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
321 llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
322 resultPadding[2] = resultPadTop;
323 resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
324 resultPadding[4] = resultPadLeft;
325 resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
327 Value resultPaddingVal =
331 rewriter, loc, UnrankedTensorType::get(resultETy), slice,
334 if (
EqualizeRanks(rewriter, op.getLoc(), resultPad, bias).failed()) {
338 rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias);
347 patterns.add<TransposeConvNonStridedConverter>(ctx);
348 patterns.add<TransposeConvStridedConverter>(ctx);
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Type getType() const
Return the type of this value.
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
TosaOp CreateOpAndInferShape(ImplicitLocOpBuilder &builder, Type resultTy, Args &&...args)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
void populateTosaDecomposeTransposeConv(MLIRContext *ctx, RewritePatternSet &patterns)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...