MLIR  22.0.0git
TransposeConv2D.cpp
Go to the documentation of this file.
1 //===- TransposeConv2D.cpp - Convolution transposition -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/PatternMatch.h"
16 #include "llvm/ADT/SmallVector.h"
17 
18 namespace mlir {
19 namespace linalg {
20 namespace {
21 // clang-format off
22 /// Convolution converter that applies the following rewrite:
23 ///
24 /// Before:
25 ///
26 /// %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
27 /// strides = dense<2> : tensor<2xi64>}
28 /// ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
29 /// outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
30 ///
31 /// After:
32 ///
33 /// %cst = arith.constant 0.000000e+00 : f32
34 /// %0 = tensor.empty() : tensor<2x2x6x8xf32>
35 /// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
36 /// %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>)
37 /// permutation = [1, 2, 3, 0]
38 /// %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
39 /// ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>)
40 /// -> tensor<1x2x2x8xf32>
41 ///
42 /// with an analogous example for the quantized case.
43 // clang-format on
44 template <typename FHWCConvOp, typename HWCFConvOp>
45 FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter,
46  FHWCConvOp op) {
47  // Construct a permutation of the filter tensor dimensions. For a 2D
48  // convolution this will be known statically as [1, 2, 3, 0].
49  SmallVector<int64_t> filterPerm = {1, 2, 3, 0};
50 
51  // Create the type for the transposed filter tensor.
52  auto filter = op->getOperand(1);
53  auto filterTy = cast<ShapedType>(filter.getType());
54  SmallVector<int64_t> newFilterShape(filterPerm.size());
55  std::generate(std::begin(newFilterShape), std::end(newFilterShape),
56  [dim = 0, &filterTy, &filterPerm]() mutable {
57  return filterTy.getShape()[filterPerm[dim++]];
58  });
59 
60  // Because linalg.transpose expects an "out" parameter we need to pass it a
61  // tensor of zeros of the result type so here we construct that tensor.
62  auto inputType = op->getOperand(0).getType();
63  auto elementTy = cast<ShapedType>(inputType).getElementType();
64  auto loc = op->getLoc();
65 
66  const auto isTensorOp = isa<TensorType>(inputType);
67  Value input;
68  if (isTensorOp) {
69 
70  input = tensor::EmptyOp::create(rewriter, loc, newFilterShape, elementTy)
71  .getResult();
72  } else {
73  input = memref::AllocOp::create(rewriter, loc,
74  MemRefType::get(newFilterShape, elementTy))
75  .getResult();
76  }
77 
78  // We can then construct the transposition on our filter.
79  auto transpose =
80  linalg::TransposeOp::create(rewriter, loc, filter, input, filterPerm);
81 
82  Value newFilter;
83  if (isTensorOp) {
84  newFilter = transpose.getResult()[0];
85  } else {
86  newFilter = input;
87  }
88 
89  SmallVector<Value> newInputs{op.getInputs()};
90  // The filter is always the second input argument, the other inputs can be
91  // left as they are.
92  newInputs[1] = newFilter;
93  // It is possible the convolution doesn't define any results and its
94  // out argument is just used instead.
95  SmallVector<Type> resultTy;
96  if (op.getNumResults()) {
97  resultTy.push_back(op->getResult(0).getType());
98  }
99  auto newConv =
100  HWCFConvOp::create(rewriter, loc, resultTy, newInputs, op.getOutputs(),
101  op.getStrides(), op.getDilations());
102  rewriter.replaceOp(op, newConv);
103  return newConv.getOperation();
104 }
105 
106 template <typename FHWCConvOp, typename HWCFConvOp>
107 class ConvConverter : public OpRewritePattern<FHWCConvOp> {
108 public:
110  LogicalResult matchAndRewrite(FHWCConvOp op,
111  PatternRewriter &rewriter) const final {
112  if (failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) {
113  return failure();
114  }
115  return success();
116  }
117 };
118 } // namespace
119 
120 FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
121  linalg::Conv2DNhwcFhwcOp op) {
122 
123  return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp,
124  linalg::Conv2DNhwcHwcfOp>(rewriter, op);
125 }
126 
127 FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
128  linalg::Conv2DNhwcFhwcQOp op) {
129 
130  return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp,
131  linalg::Conv2DNhwcHwcfQOp>(rewriter, op);
132 }
133 
135  MLIRContext *context = patterns.getContext();
136  patterns.insert<
137  ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>,
138  ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>(
139  context);
140 }
141 } // namespace linalg
142 } // namespace mlir
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateTransposeConv2DPatterns(RewritePatternSet &patterns)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319