29 template <
typename TosaOp,
typename... Args>
32 auto op = rewriter.
create<TosaOp>(loc, resultTy, args...);
34 InferShapedTypeOpInterface shapeInterface =
35 dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
41 .inferReturnTypeComponents(
53 auto predictedShape = returnedShapes[0];
54 auto currentKnowledge =
58 auto inferredKnowledge =
60 inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
61 inferredKnowledge.hasRank = predictedShape.hasRank();
62 if (predictedShape.hasRank()) {
63 for (
auto dim : predictedShape.getDims()) {
64 inferredKnowledge.sizes.push_back(dim);
71 auto newTy = newKnowledge.getType();
72 result.setType(newTy);
76 class TransposeConvNonStridedConverter
87 ShapedType inputTy = cast<ShapedType>(input.
getType());
88 ShapedType weightTy = cast<ShapedType>(weight.
getType());
89 ShapedType biasTy = cast<ShapedType>(bias.
getType());
98 if (llvm::any_of(stride, [](int64_t v) {
return v != 1; }))
101 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
102 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
105 int64_t kernelHeight = weightTy.getDimSize(1);
106 int64_t kernelWidth = weightTy.getDimSize(2);
109 convPad[0] = kernelHeight - 1 + pad[0];
110 convPad[1] = kernelHeight - 1 + pad[1];
111 convPad[2] = kernelWidth - 1 + pad[2];
112 convPad[3] = kernelWidth - 1 + pad[3];
114 auto reverse1 = rewriter.create<tosa::ReverseOp>(
115 loc, weightTy, weight, rewriter.getI32IntegerAttr(1));
116 auto reverse2 = rewriter.create<tosa::ReverseOp>(
117 loc, weightTy, reverse1, rewriter.getI32IntegerAttr(2));
120 if (op.getQuantizationInfo()) {
121 conv2d = rewriter.create<tosa::Conv2DOp>(
122 loc, resultTy, input, reverse2, bias,
123 rewriter.getDenseI64ArrayAttr(convPad),
124 rewriter.getDenseI64ArrayAttr(stride),
125 rewriter.getDenseI64ArrayAttr({1, 1}), *op.getQuantizationInfo());
127 conv2d = rewriter.create<tosa::Conv2DOp>(
128 loc, resultTy, input, reverse2, bias,
129 rewriter.getDenseI64ArrayAttr(convPad),
130 rewriter.getDenseI64ArrayAttr(stride),
131 rewriter.getDenseI64ArrayAttr({1, 1}));
134 rewriter.replaceOp(op, conv2d);
139 class TransposeConvStridedConverter
150 ShapedType inputTy = cast<ShapedType>(input.
getType());
151 ShapedType weightTy = cast<ShapedType>(weight.
getType());
152 ShapedType biasTy = cast<ShapedType>(bias.
getType());
155 Type inputETy = inputTy.getElementType();
156 Type weightETy = weightTy.getElementType();
157 Type biasETy = biasTy.getElementType();
158 Type resultETy = resultTy.getElementType();
168 if (llvm::all_of(stride, [](int64_t v) {
return v == 1; }))
169 return rewriter.notifyMatchFailure(op,
"non-one stride found.");
171 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
172 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
175 int64_t batch = inputTy.getDimSize(0);
177 int64_t outputChannels = weightTy.getDimSize(0);
178 int64_t weightHeight = weightTy.getDimSize(1);
179 int64_t weightWidth = weightTy.getDimSize(2);
180 int64_t inputChannels = weightTy.getDimSize(3);
185 weightHeight % stride[0] ? stride[0] - weightHeight % stride[0] : 0;
187 weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
190 Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>(
191 rewriter, loc, weightPaddingAttr.
getType(), weightPaddingAttr);
193 if (op.getQuantizationInfo().has_value()) {
194 auto quantInfo = op.getQuantizationInfo().value();
195 weight = createOpAndInfer<tosa::PadOp>(
197 weightPaddingVal,
nullptr,
198 rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
201 weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
203 weight, weightPaddingVal);
206 weightTy = cast<ShapedType>(weight.
getType());
207 weightHeight = weightTy.getDimSize(1);
208 weightWidth = weightTy.getDimSize(2);
212 outputChannels, weightHeight / stride[0],
213 stride[0], weightWidth / stride[1],
214 stride[1], inputChannels};
215 weight = createOpAndInfer<tosa::ReshapeOp>(
217 rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
220 Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
222 rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
224 weight = createOpAndInfer<tosa::TransposeOp>(
230 outputChannels * stride[0] * stride[1], weightHeight / stride[0],
231 weightWidth / stride[1], inputChannels};
232 weight = createOpAndInfer<tosa::ReshapeOp>(
234 rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
235 ShapedType restridedWeightTy = cast<ShapedType>(weight.
getType());
237 weight = createOpAndInfer<tosa::ReverseOp>(
239 rewriter.getI32IntegerAttr(1));
240 weight = createOpAndInfer<tosa::ReverseOp>(
242 rewriter.getI32IntegerAttr(2));
246 inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
247 inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
248 inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
249 inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
254 Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>(
255 rewriter, loc, inputPaddingAttr.
getType(), inputPaddingAttr);
257 if (op.getQuantizationInfo().has_value()) {
258 auto quantInfo = op.getQuantizationInfo().value();
259 input = createOpAndInfer<tosa::PadOp>(
261 inputPaddingVal,
nullptr,
262 rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
264 input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
266 input, inputPaddingVal);
270 auto zeroBias = rewriter.create<tosa::ConstOp>(
277 rewriter.getZeroAttr(biasETy)));
281 if (op.getQuantizationInfo()) {
282 conv2d = createOpAndInfer<tosa::Conv2DOp>(
285 rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
286 rewriter.getDenseI64ArrayAttr({1, 1}),
287 rewriter.getDenseI64ArrayAttr({1, 1}),
288 *op.getQuantizationInfo())
291 conv2d = createOpAndInfer<tosa::Conv2DOp>(
294 rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
295 rewriter.getDenseI64ArrayAttr({1, 1}),
296 rewriter.getDenseI64ArrayAttr({1, 1}))
301 ShapedType convTy = cast<ShapedType>(conv2d.getType());
302 Type convETy = convTy.getElementType();
304 int64_t convHeight = convTy.getDimSize(1);
305 int64_t convWidth = convTy.getDimSize(2);
309 batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
310 conv2d = createOpAndInfer<tosa::ReshapeOp>(
312 rewriter.getDenseI64ArrayAttr(convReshapeDims0));
315 Value transposeConvVal = rewriter.create<tosa::ConstOp>(
317 rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
319 conv2d = createOpAndInfer<tosa::TransposeOp>(
325 batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
326 conv2d = createOpAndInfer<tosa::ReshapeOp>(
328 rewriter.getDenseI64ArrayAttr(convReshapeDims1));
331 int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
332 int64_t resultSliceLeft = std::max<int64_t>(0, -pad[2]);
333 int64_t resultPadTop = std::max<int64_t>(0, pad[0]);
334 int64_t resultPadLeft = std::max<int64_t>(0, pad[2]);
337 int64_t resultSliceHeight =
338 std::min<int64_t>(convReshapeDims1[1] - resultSliceTop,
339 resultTy.getDimSize(1) - resultPadTop);
340 int64_t resultSliceWidth =
341 std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft,
342 resultTy.getDimSize(2) - resultPadLeft);
347 convReshapeDims1.end());
348 sliceSize[1] = resultSliceHeight;
349 sliceSize[2] = resultSliceWidth;
351 auto slice = createOpAndInfer<tosa::SliceOp>(
353 rewriter.getDenseI64ArrayAttr(sliceBegin),
354 rewriter.getDenseI64ArrayAttr(sliceSize))
358 resultPadding[2] = resultPadTop;
359 resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
360 resultPadding[4] = resultPadLeft;
361 resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
366 Value resultPaddingVal = createOpAndInfer<tosa::ConstOp>(
367 rewriter, loc, resultPaddingAttr.
getType(), resultPaddingAttr);
369 Value resultPad = createOpAndInfer<tosa::PadOp>(
377 rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias);
386 patterns.
add<TransposeConvNonStridedConverter>(ctx);
387 patterns.
add<TransposeConvStridedConverter>(ctx);
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
operand_range getOperands()
Returns an iterator on the underlying Value's.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, Value &input1, Value &input2)
Common code to create the reshape op where necessary to make the rank of two values equal.
void populateTosaDecomposeTransposeConv(MLIRContext *ctx, RewritePatternSet &patterns)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
bool failed() const
Returns true if the provided LogicalResult corresponds to a failure value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getPessimisticValueState()
static ValueKnowledge getKnowledgeFromType(Type type)