MLIR  19.0.0git
NamedOpConversions.cpp
Go to the documentation of this file.
1 //===- NamedOpConversions.cpp - Implements conversions between named ops --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements conversions between named ops that can be seens as
10 // canonicalizations of named ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
18 #include "mlir/IR/PatternMatch.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_LINALGNAMEDOPCONVERSIONPASS
25 #include "mlir/Dialect/Linalg/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::linalg;
30 
31 static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
32  return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
33 }
34 
35 static LogicalResult
37  Value iZp, Value kZp, Value init, Attribute stride,
38  Attribute dilation, PatternRewriter &rewriter) {
39  Location loc = operation->getLoc();
40  auto linalgOp = dyn_cast<LinalgOp>(operation);
41  // Exit out on the memref version of this operation.
42  if (!linalgOp || !linalgOp.hasPureTensorSemantics())
43  return failure();
44 
45  auto result = operation->getResult(0);
46 
47  auto kernelTy = dyn_cast<RankedTensorType>(kernel.getType());
48  auto initTy = dyn_cast<RankedTensorType>(init.getType());
49  auto resultTy = dyn_cast<RankedTensorType>(result.getType());
50  if (!kernelTy || !initTy || !resultTy)
51  return failure();
52 
53  if (kernelTy.getDimSize(3) != 1)
54  return failure();
55 
56  // Collapse kernel dims.
57  SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
59  auto newKernelTy = RankedTensorType::get(
60  {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
61  kernelTy.getElementType());
62  auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>(
63  loc, newKernelTy, kernel, collapsedKernelDims);
64 
65  // Collapse init dims.
66  SmallVector<ReassociationIndices, 4> collapsedInitDims = {
68  getIndicesVector(3, 5)};
69  auto newInitTy =
70  RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
71  initTy.getDimSize(2), initTy.getDimSize(3)},
72  initTy.getElementType());
73  auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
74  loc, newInitTy, init, collapsedInitDims);
75 
76  SmallVector<NamedAttribute> preservedAttrs;
77  Operation *newConv =
79  .Case<DepthwiseConv2DNhwcHwcmOp>([&](auto op) {
80  preservedAttrs = getPrunedAttributeList(op);
81  return rewriter.create<DepthwiseConv2DNhwcHwcOp>(
82  loc, newInitTy, ValueRange{input, collapsedKernel},
83  ValueRange{collapsedInit}, stride, dilation);
84  })
85  .Case<DepthwiseConv2DNhwcHwcmQOp>([&](auto op) {
86  preservedAttrs = getPrunedAttributeList(op);
87  return rewriter.create<DepthwiseConv2DNhwcHwcQOp>(
88  loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
89  ValueRange{collapsedInit}, stride, dilation);
90  })
91  .Default([](Operation *op) { return nullptr; });
92  if (!newConv)
93  return failure();
94  for (auto attr : preservedAttrs)
95  newConv->setAttr(attr.getName(), attr.getValue());
96 
97  // Expand dimensions back out to
98  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
99  operation, resultTy, newConv->getResult(0), collapsedInitDims);
100  return success();
101 }
102 
103 namespace {
104 struct SimplifyDepthwiseConvOp
105  : public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
107 
108  LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
109  PatternRewriter &rewriter) const override {
110  Operation *operation = op.getOperation();
111  Value input = op.getDpsInputOperand(0)->get();
112  Value kernel = op.getDpsInputOperand(1)->get();
113  Value init = op.getDpsInitOperand(0)->get();
114 
115  auto stride = op.getStrides();
116  auto dilation = op.getDilations();
117 
118  return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
119  nullptr, init, stride, dilation,
120  rewriter);
121  }
122 };
123 
124 struct SimplifyDepthwiseConvQOp
125  : public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
127 
128  LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
129  PatternRewriter &rewriter) const override {
130  Operation *operation = op.getOperation();
131  Value input = op.getDpsInputOperand(0)->get();
132  Value kernel = op.getDpsInputOperand(1)->get();
133  Value iZp = op.getDpsInputOperand(2)->get();
134  Value kZp = op.getDpsInputOperand(3)->get();
135  Value init = op.getDpsInitOperand(0)->get();
136 
137  auto stride = op.getStrides();
138  auto dilation = op.getDilations();
139 
140  return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
141  init, stride, dilation, rewriter);
142  }
143 };
144 
145 struct LinalgNamedOpConversionPass
146  : public impl::LinalgNamedOpConversionPassBase<
147  LinalgNamedOpConversionPass> {
148  using impl::LinalgNamedOpConversionPassBase<
149  LinalgNamedOpConversionPass>::LinalgNamedOpConversionPassBase;
150 
151  void runOnOperation() override {
152  Operation *op = getOperation();
153  RewritePatternSet patterns(op->getContext());
155  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
156  return signalPassFailure();
157  }
158 };
159 } // namespace
160 
162  RewritePatternSet &patterns) {
163  patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
164  patterns.getContext());
165 }
static llvm::SmallVector< int64_t > getIndicesVector(int start, int end)
static LogicalResult matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, Value iZp, Value kZp, Value init, Attribute stride, Attribute dilation, PatternRewriter &rewriter)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns)
Patterns to convert from one named op to another.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition: Utils.h:371
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358