MLIR  19.0.0git
TransposeConv2D.cpp
Go to the documentation of this file.
1 //===- TransposeConv2D.cpp - Convolution transposition -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/ValueRange.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/ErrorHandling.h"
21 #include "llvm/Support/RWMutex.h"
22 #include <memory>
23 #include <numeric>
24 
25 namespace mlir {
26 namespace linalg {
27 namespace {
28 // clang-format off
29 /// Convolution converter that applies the following rewrite:
30 ///
31 /// Before:
32 ///
33 /// %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
34 /// strides = dense<2> : tensor<2xi64>}
35 /// ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
36 /// outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
37 ///
38 /// After:
39 ///
40 /// %cst = arith.constant 0.000000e+00 : f32
41 /// %0 = tensor.empty() : tensor<2x2x6x8xf32>
42 /// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
43 /// %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>)
44 /// permutation = [1, 2, 3, 0]
45 /// %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
46 /// ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>)
47 /// -> tensor<1x2x2x8xf32>
48 ///
49 /// with an analogous example for the quantized case.
50 // clang-format on
51 template <typename FHWCConvOp, typename HWCFConvOp>
52 FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter,
53  FHWCConvOp op) {
54  // Construct a permutation of the filter tensor dimensions. For a 2D
55  // convolution this will be known statically as [1, 2, 3, 0].
56  SmallVector<int64_t> filterPerm({1, 2, 3, 0});
57 
58  // Create the type for the transposed filter tensor.
59  auto filter = op->getOperand(1);
60  auto filterTy = cast<ShapedType>(filter.getType());
61  SmallVector<int64_t> newFilterShape(filterPerm.size());
62  std::generate(std::begin(newFilterShape), std::end(newFilterShape),
63  [dim = 0, &filterTy, &filterPerm]() mutable {
64  return filterTy.getShape()[filterPerm[dim++]];
65  });
66 
67  // Because linalg.transpose expects an "out" parameter we need to pass it a
68  // tensor of zeros of the result type so here we construct that tensor.
69  auto inputType = op->getOperand(0).getType();
70  auto elementTy = cast<ShapedType>(inputType).getElementType();
71  auto loc = op->getLoc();
72 
73  const auto isTensorOp = isa<TensorType>(inputType);
74  Value input;
75  if (isTensorOp) {
76 
77  input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy)
78  .getResult();
79  } else {
80  input = rewriter
81  .create<memref::AllocOp>(
82  loc, MemRefType::get(newFilterShape, elementTy))
83  .getResult();
84  }
85 
86  // We can then construct the transposition on our filter.
87  auto transpose =
88  rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm);
89 
90  Value newFilter;
91  if (isTensorOp) {
92  newFilter = transpose.getResult()[0];
93  } else {
94  newFilter = input;
95  }
96 
97  SmallVector<Value> newInputs{op.getInputs()};
98  // The filter is always the second input argument, the other inputs can be
99  // left as they are.
100  newInputs[1] = newFilter;
101  // It is possible the convolution doesn't define any results and its
102  // out argument is just used instead.
103  SmallVector<Type> resultTy;
104  if (op.getNumResults()) {
105  resultTy.push_back(op->getResult(0).getType());
106  }
107  auto newConv =
108  rewriter.create<HWCFConvOp>(loc, resultTy, newInputs, op.getOutputs(),
109  op.getStrides(), op.getDilations());
110  rewriter.replaceOp(op, newConv);
111  return newConv.getOperation();
112 }
113 
114 template <typename FHWCConvOp, typename HWCFConvOp>
115 class ConvConverter : public OpRewritePattern<FHWCConvOp> {
116 public:
118  LogicalResult matchAndRewrite(FHWCConvOp op,
119  PatternRewriter &rewriter) const final {
120  if (failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) {
121  return failure();
122  }
123  return success();
124  }
125 };
126 } // namespace
127 
129  linalg::Conv2DNhwcFhwcOp op) {
130 
131  return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp,
132  linalg::Conv2DNhwcHwcfOp>(rewriter, op);
133 }
134 
136  linalg::Conv2DNhwcFhwcQOp op) {
137 
138  return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp,
139  linalg::Conv2DNhwcHwcfQOp>(rewriter, op);
140 }
141 
143  MLIRContext *context = patterns.getContext();
144  patterns.insert<
145  ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>,
146  ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>(
147  context);
148 }
149 } // namespace linalg
150 } // namespace mlir
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:893
MLIRContext * getContext() const
Definition: PatternMatch.h:785
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
void populateTranposeConv2DPatterns(RewritePatternSet &patterns)
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361