24 return to_vector(llvm::map_range(shape, [](int64_t dim) {
25 return ShapedType::isDynamic(dim) ? -1 : dim;
30 explicit Conv2DIsFullyConnected(
MLIRContext *context)
33 LogicalResult matchAndRewrite(tosa::Conv2DOp op,
35 Value input = op.getInput();
36 Value weight = op.getWeight();
37 ShapedType inputType = cast<ShapedType>(input.
getType());
38 ShapedType weightType = cast<ShapedType>(weight.
getType());
39 ShapedType resultType = cast<ShapedType>(op.getType());
42 llvm::count_if(inputType.getShape(), ShapedType::isDynamic);
45 op,
"at most one dim in input may be dynamic");
46 if (!weightType.hasRank())
49 if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
54 if (weightShape[1] != 1 || weightShape[2] != 1)
60 pad[it.index() + 2] = it.value();
62 if (llvm::any_of(pad, [](int64_t p) {
return p != 0; })) {
63 Type inputETy = inputType.getElementType();
65 if (op.getQuantizationInfo()) {
66 auto quantizationInfo = op.getQuantizationInfo();
67 int64_t iZp = quantizationInfo->getInputZp();
71 op,
"tosa.conv op quantization has zp outside of input range");
78 for (
int i = 0, s = newShape.size(); i < s; ++i) {
79 if (newShape[i] != ShapedType::kDynamic) {
80 newShape[i] += pad[i * 2] + pad[i * 2 + 1];
88 rewriter.
create<tosa::ConstOp>(op->
getLoc(), padSizeTy, padSize);
93 rewriter.
create<tosa::ConstOp>(op->
getLoc(), padTy, padAttr);
95 input = rewriter.
create<tosa::PadOp>(op->
getLoc(), inputType, input,
101 int64_t combined = ShapedType::kDynamic;
103 combined = inputShape[0] * inputShape[1] * inputShape[2];
105 auto revisedInputShapeType =
107 auto reshapedInput = rewriter
109 op.
getLoc(), revisedInputShapeType, input,
111 convertFromMlirShape(revisedInputShape)))
119 dyn_cast<RankedTensorType>(weight.
getType()).getElementType());
120 auto reshapedWeight = rewriter
122 op.
getLoc(), revisedWeightShapeType, weight,
124 convertFromMlirShape(revisedWeightShape)))
129 auto fullyConnectedShapeType =
132 Value fullyConnectedValue;
133 if (op.getQuantizationInfo()) {
134 fullyConnectedValue =
136 .
create<tosa::FullyConnectedOp>(
137 op.
getLoc(), fullyConnectedShapeType, reshapedInput,
138 reshapedWeight, op.getBias(), *op.getQuantizationInfo())
141 fullyConnectedValue = rewriter
142 .
create<tosa::FullyConnectedOp>(
143 op.
getLoc(), fullyConnectedShapeType,
144 reshapedInput, reshapedWeight, op.getBias())
150 inputShape[2], weightShape[0]};
152 op, resultType, fullyConnectedValue,
162 patterns.
add<Conv2DIsFullyConnected>(ctx);
Attributes are known-constant values of operations.
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns)
bool validIntegerRange(IntegerType ty, int64_t value)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...