MLIR 22.0.0git
SimplifyDepthwiseConv.cpp
Go to the documentation of this file.
1//===- NamedOpConversions.cpp - Implements conversions between named ops --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements conversions between named ops that can be seens as
10// canonicalizations of named ops.
11//
12//===----------------------------------------------------------------------===//
13
15
20#include "llvm/ADT/SmallVector.h"
21#include "llvm/ADT/TypeSwitch.h"
22
23namespace mlir {
24#define GEN_PASS_DEF_SIMPLIFYDEPTHWISECONVPASS
25#include "mlir/Dialect/Linalg/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29using namespace mlir::linalg;
30
31static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
32 return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
33}
34
35static LogicalResult
37 Value iZp, Value kZp, Value init, Attribute stride,
38 Attribute dilation, PatternRewriter &rewriter) {
39 Location loc = operation->getLoc();
40 auto linalgOp = dyn_cast<LinalgOp>(operation);
41 // Exit out on the memref version of this operation.
42 if (!linalgOp || !linalgOp.hasPureTensorSemantics())
43 return failure();
44
45 auto result = operation->getResult(0);
46
47 auto kernelTy = dyn_cast<RankedTensorType>(kernel.getType());
48 auto initTy = dyn_cast<RankedTensorType>(init.getType());
49 auto resultTy = dyn_cast<RankedTensorType>(result.getType());
50 if (!kernelTy || !initTy || !resultTy)
51 return failure();
52
53 if (kernelTy.getDimSize(3) != 1)
54 return failure();
55
56 // Collapse kernel dims.
57 SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
59 auto newKernelTy = RankedTensorType::get(
60 {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
61 kernelTy.getElementType());
62 auto collapsedKernel = tensor::CollapseShapeOp::create(
63 rewriter, loc, newKernelTy, kernel, collapsedKernelDims);
64
65 // Collapse init dims.
66 SmallVector<ReassociationIndices, 4> collapsedInitDims = {
68 getIndicesVector(3, 5)};
69 auto newInitTy =
70 RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
71 initTy.getDimSize(2), initTy.getDimSize(3)},
72 initTy.getElementType());
73 auto collapsedInit = tensor::CollapseShapeOp::create(rewriter, loc, newInitTy,
74 init, collapsedInitDims);
75
76 SmallVector<NamedAttribute> preservedAttrs;
77 Operation *newConv =
79 .Case<DepthwiseConv2DNhwcHwcmOp>([&](auto op) {
80 preservedAttrs = getPrunedAttributeList(op);
81 return DepthwiseConv2DNhwcHwcOp::create(
82 rewriter, loc, newInitTy, ValueRange{input, collapsedKernel},
83 ValueRange{collapsedInit}, stride, dilation);
84 })
85 .Case<DepthwiseConv2DNhwcHwcmQOp>([&](auto op) {
86 preservedAttrs = getPrunedAttributeList(op);
87 return DepthwiseConv2DNhwcHwcQOp::create(
88 rewriter, loc, newInitTy,
89 ValueRange{input, collapsedKernel, iZp, kZp},
90 ValueRange{collapsedInit}, stride, dilation);
91 })
92 .Default(nullptr);
93 if (!newConv)
94 return failure();
95 for (auto attr : preservedAttrs)
96 newConv->setAttr(attr.getName(), attr.getValue());
97
98 // Expand dimensions back out to
99 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
100 operation, resultTy, newConv->getResult(0), collapsedInitDims);
101 return success();
102}
103
104namespace {
105struct SimplifyDepthwiseConvOp
106 : public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
107 using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern;
108
109 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
110 PatternRewriter &rewriter) const override {
111 Operation *operation = op.getOperation();
112 Value input = op.getDpsInputOperand(0)->get();
113 Value kernel = op.getDpsInputOperand(1)->get();
114 Value init = op.getDpsInitOperand(0)->get();
115
116 auto stride = op.getStrides();
117 auto dilation = op.getDilations();
118
119 return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
120 nullptr, init, stride, dilation,
121 rewriter);
122 }
123};
124
125struct SimplifyDepthwiseConvQOp
126 : public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
127 using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern;
128
129 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
130 PatternRewriter &rewriter) const override {
131 Operation *operation = op.getOperation();
132 Value input = op.getDpsInputOperand(0)->get();
133 Value kernel = op.getDpsInputOperand(1)->get();
134 Value iZp = op.getDpsInputOperand(2)->get();
135 Value kZp = op.getDpsInputOperand(3)->get();
136 Value init = op.getDpsInitOperand(0)->get();
137
138 auto stride = op.getStrides();
139 auto dilation = op.getDilations();
140
141 return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
142 init, stride, dilation, rewriter);
143 }
144};
145
146struct SimplifyDepthwiseConvPass
147 : public impl::SimplifyDepthwiseConvPassBase<SimplifyDepthwiseConvPass> {
148 using impl::SimplifyDepthwiseConvPassBase<
149 SimplifyDepthwiseConvPass>::SimplifyDepthwiseConvPassBase;
150
151 void runOnOperation() override {
152 Operation *op = getOperation();
153 RewritePatternSet patterns(op->getContext());
155 if (failed(applyPatternsGreedily(op, std::move(patterns))))
156 return signalPassFailure();
157 }
158};
159} // namespace
160
163 patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
164 patterns.getContext());
165}
return success()
static LogicalResult matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, Value iZp, Value kZp, Value init, Attribute stride, Attribute dilation, PatternRewriter &rewriter)
static llvm::SmallVector< int64_t > getIndicesVector(int start, int end)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void populateSimplifyDepthwiseConvPatterns(RewritePatternSet &patterns)
Patterns to simplify depthwise convolutions.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition Utils.h:377
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...