MLIR  22.0.0git
NamedOpConversions.cpp
Go to the documentation of this file.
1 //===- NamedOpConversions.cpp - Implements conversions between named ops --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements conversions between named ops that can be seens as
10 // canonicalizations of named ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
18 #include "mlir/IR/PatternMatch.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_LINALGNAMEDOPCONVERSIONPASS
25 #include "mlir/Dialect/Linalg/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::linalg;
30 
31 static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
32  return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
33 }
34 
35 static LogicalResult
37  Value iZp, Value kZp, Value init, Attribute stride,
38  Attribute dilation, PatternRewriter &rewriter) {
39  Location loc = operation->getLoc();
40  auto linalgOp = dyn_cast<LinalgOp>(operation);
41  // Exit out on the memref version of this operation.
42  if (!linalgOp || !linalgOp.hasPureTensorSemantics())
43  return failure();
44 
45  auto result = operation->getResult(0);
46 
47  auto kernelTy = dyn_cast<RankedTensorType>(kernel.getType());
48  auto initTy = dyn_cast<RankedTensorType>(init.getType());
49  auto resultTy = dyn_cast<RankedTensorType>(result.getType());
50  if (!kernelTy || !initTy || !resultTy)
51  return failure();
52 
53  if (kernelTy.getDimSize(3) != 1)
54  return failure();
55 
56  // Collapse kernel dims.
57  SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
59  auto newKernelTy = RankedTensorType::get(
60  {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
61  kernelTy.getElementType());
62  auto collapsedKernel = tensor::CollapseShapeOp::create(
63  rewriter, loc, newKernelTy, kernel, collapsedKernelDims);
64 
65  // Collapse init dims.
66  SmallVector<ReassociationIndices, 4> collapsedInitDims = {
68  getIndicesVector(3, 5)};
69  auto newInitTy =
70  RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
71  initTy.getDimSize(2), initTy.getDimSize(3)},
72  initTy.getElementType());
73  auto collapsedInit = tensor::CollapseShapeOp::create(rewriter, loc, newInitTy,
74  init, collapsedInitDims);
75 
76  SmallVector<NamedAttribute> preservedAttrs;
77  Operation *newConv =
79  .Case<DepthwiseConv2DNhwcHwcmOp>([&](auto op) {
80  preservedAttrs = getPrunedAttributeList(op);
81  return DepthwiseConv2DNhwcHwcOp::create(
82  rewriter, loc, newInitTy, ValueRange{input, collapsedKernel},
83  ValueRange{collapsedInit}, stride, dilation);
84  })
85  .Case<DepthwiseConv2DNhwcHwcmQOp>([&](auto op) {
86  preservedAttrs = getPrunedAttributeList(op);
87  return DepthwiseConv2DNhwcHwcQOp::create(
88  rewriter, loc, newInitTy,
89  ValueRange{input, collapsedKernel, iZp, kZp},
90  ValueRange{collapsedInit}, stride, dilation);
91  })
92  .Default([](Operation *op) { return nullptr; });
93  if (!newConv)
94  return failure();
95  for (auto attr : preservedAttrs)
96  newConv->setAttr(attr.getName(), attr.getValue());
97 
98  // Expand dimensions back out to
99  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
100  operation, resultTy, newConv->getResult(0), collapsedInitDims);
101  return success();
102 }
103 
104 namespace {
105 struct SimplifyDepthwiseConvOp
106  : public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
108 
109  LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
110  PatternRewriter &rewriter) const override {
111  Operation *operation = op.getOperation();
112  Value input = op.getDpsInputOperand(0)->get();
113  Value kernel = op.getDpsInputOperand(1)->get();
114  Value init = op.getDpsInitOperand(0)->get();
115 
116  auto stride = op.getStrides();
117  auto dilation = op.getDilations();
118 
119  return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
120  nullptr, init, stride, dilation,
121  rewriter);
122  }
123 };
124 
125 struct SimplifyDepthwiseConvQOp
126  : public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
128 
129  LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
130  PatternRewriter &rewriter) const override {
131  Operation *operation = op.getOperation();
132  Value input = op.getDpsInputOperand(0)->get();
133  Value kernel = op.getDpsInputOperand(1)->get();
134  Value iZp = op.getDpsInputOperand(2)->get();
135  Value kZp = op.getDpsInputOperand(3)->get();
136  Value init = op.getDpsInitOperand(0)->get();
137 
138  auto stride = op.getStrides();
139  auto dilation = op.getDilations();
140 
141  return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
142  init, stride, dilation, rewriter);
143  }
144 };
145 
146 struct LinalgNamedOpConversionPass
147  : public impl::LinalgNamedOpConversionPassBase<
148  LinalgNamedOpConversionPass> {
149  using impl::LinalgNamedOpConversionPassBase<
150  LinalgNamedOpConversionPass>::LinalgNamedOpConversionPassBase;
151 
152  void runOnOperation() override {
153  Operation *op = getOperation();
156  if (failed(applyPatternsGreedily(op, std::move(patterns))))
157  return signalPassFailure();
158  }
159 };
160 } // namespace
161 
164  patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
165  patterns.getContext());
166 }
static llvm::SmallVector< int64_t > getIndicesVector(int start, int end)
static LogicalResult matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, Value iZp, Value kZp, Value init, Attribute stride, Attribute dilation, PatternRewriter &rewriter)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns)
Patterns to convert from one named op to another.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition: Utils.h:385
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314