20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
24 #define GEN_PASS_DEF_LINALGNAMEDOPCONVERSIONPASS
25 #include "mlir/Dialect/Linalg/Passes.h.inc"
32 return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
40 auto linalgOp = dyn_cast<LinalgOp>(operation);
42 if (!linalgOp || !linalgOp.hasPureTensorSemantics())
47 auto kernelTy = dyn_cast<RankedTensorType>(kernel.
getType());
48 auto initTy = dyn_cast<RankedTensorType>(init.
getType());
49 auto resultTy = dyn_cast<RankedTensorType>(result.getType());
50 if (!kernelTy || !initTy || !resultTy)
53 if (kernelTy.getDimSize(3) != 1)
60 {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
61 kernelTy.getElementType());
62 auto collapsedKernel = tensor::CollapseShapeOp::create(
63 rewriter, loc, newKernelTy, kernel, collapsedKernelDims);
71 initTy.getDimSize(2), initTy.getDimSize(3)},
72 initTy.getElementType());
73 auto collapsedInit = tensor::CollapseShapeOp::create(rewriter, loc, newInitTy,
74 init, collapsedInitDims);
79 .Case<DepthwiseConv2DNhwcHwcmOp>([&](
auto op) {
81 return DepthwiseConv2DNhwcHwcOp::create(
82 rewriter, loc, newInitTy,
ValueRange{input, collapsedKernel},
85 .Case<DepthwiseConv2DNhwcHwcmQOp>([&](
auto op) {
87 return DepthwiseConv2DNhwcHwcQOp::create(
88 rewriter, loc, newInitTy,
92 .Default([](
Operation *op) {
return nullptr; });
95 for (
auto attr : preservedAttrs)
96 newConv->
setAttr(attr.getName(), attr.getValue());
100 operation, resultTy, newConv->
getResult(0), collapsedInitDims);
105 struct SimplifyDepthwiseConvOp
109 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
111 Operation *operation = op.getOperation();
112 Value input = op.getDpsInputOperand(0)->get();
113 Value kernel = op.getDpsInputOperand(1)->get();
114 Value init = op.getDpsInitOperand(0)->get();
116 auto stride = op.getStrides();
117 auto dilation = op.getDilations();
120 nullptr, init, stride, dilation,
125 struct SimplifyDepthwiseConvQOp
129 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
131 Operation *operation = op.getOperation();
132 Value input = op.getDpsInputOperand(0)->get();
133 Value kernel = op.getDpsInputOperand(1)->get();
134 Value iZp = op.getDpsInputOperand(2)->get();
135 Value kZp = op.getDpsInputOperand(3)->get();
136 Value init = op.getDpsInitOperand(0)->get();
138 auto stride = op.getStrides();
139 auto dilation = op.getDilations();
142 init, stride, dilation, rewriter);
146 struct LinalgNamedOpConversionPass
147 :
public impl::LinalgNamedOpConversionPassBase<
148 LinalgNamedOpConversionPass> {
149 using impl::LinalgNamedOpConversionPassBase<
150 LinalgNamedOpConversionPass>::LinalgNamedOpConversionPassBase;
152 void runOnOperation()
override {
157 return signalPassFailure();
164 patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
static llvm::SmallVector< int64_t > getIndicesVector(int start, int end)
static LogicalResult matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, Value iZp, Value kZp, Value init, Attribute stride, Attribute dilation, PatternRewriter &rewriter)
Attributes are known-constant values of operations.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns)
Patterns to convert from one named op to another.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...