MLIR 22.0.0git
TransposeConv2D.cpp
Go to the documentation of this file.
1//===- TransposeConv2D.cpp - Convolution transposition -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16#include "llvm/ADT/SmallVector.h"
17
18namespace mlir {
19namespace linalg {
20namespace {
21// clang-format off
22/// Convolution converter that applies the following rewrite:
23///
24/// Before:
25///
26/// %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
27/// strides = dense<2> : tensor<2xi64>}
28/// ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
29/// outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
30///
31/// After:
32///
33/// %cst = arith.constant 0.000000e+00 : f32
34/// %0 = tensor.empty() : tensor<2x2x6x8xf32>
35/// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
36/// %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>)
37/// permutation = [1, 2, 3, 0]
38/// %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
39/// ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>)
40/// -> tensor<1x2x2x8xf32>
41///
42/// with an analogous example for the quantized case.
43// clang-format on
44template <typename FHWCConvOp, typename HWCFConvOp>
45FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter,
46 FHWCConvOp op) {
47 // Construct a permutation of the filter tensor dimensions. For a 2D
48 // convolution this will be known statically as [1, 2, 3, 0].
49 SmallVector<int64_t> filterPerm = {1, 2, 3, 0};
50
51 // Create the type for the transposed filter tensor.
52 auto filter = op->getOperand(1);
53 auto filterTy = cast<ShapedType>(filter.getType());
54 SmallVector<int64_t> newFilterShape(filterPerm.size());
55 std::generate(std::begin(newFilterShape), std::end(newFilterShape),
56 [dim = 0, &filterTy, &filterPerm]() mutable {
57 return filterTy.getShape()[filterPerm[dim++]];
58 });
59
60 // Because linalg.transpose expects an "out" parameter we need to pass it a
61 // tensor of zeros of the result type so here we construct that tensor.
62 auto inputType = op->getOperand(0).getType();
63 auto elementTy = cast<ShapedType>(inputType).getElementType();
64 auto loc = op->getLoc();
65
66 const auto isTensorOp = isa<TensorType>(inputType);
67 Value input;
68 if (isTensorOp) {
69
70 input = tensor::EmptyOp::create(rewriter, loc, newFilterShape, elementTy)
71 .getResult();
72 } else {
73 input = memref::AllocOp::create(rewriter, loc,
74 MemRefType::get(newFilterShape, elementTy))
75 .getResult();
76 }
77
78 // We can then construct the transposition on our filter.
79 auto transpose =
80 linalg::TransposeOp::create(rewriter, loc, filter, input, filterPerm);
81
82 Value newFilter;
83 if (isTensorOp) {
84 newFilter = transpose.getResult()[0];
85 } else {
86 newFilter = input;
87 }
88
89 SmallVector<Value> newInputs{op.getInputs()};
90 // The filter is always the second input argument, the other inputs can be
91 // left as they are.
92 newInputs[1] = newFilter;
93 // It is possible the convolution doesn't define any results and its
94 // out argument is just used instead.
95 SmallVector<Type> resultTy;
96 if (op.getNumResults()) {
97 resultTy.push_back(op->getResult(0).getType());
98 }
99 auto newConv =
100 HWCFConvOp::create(rewriter, loc, resultTy, newInputs, op.getOutputs(),
101 op.getStrides(), op.getDilations());
102 rewriter.replaceOp(op, newConv);
103 return newConv.getOperation();
104}
105
106template <typename FHWCConvOp, typename HWCFConvOp>
107class ConvConverter : public OpRewritePattern<FHWCConvOp> {
108public:
109 using OpRewritePattern<FHWCConvOp>::OpRewritePattern;
110 LogicalResult matchAndRewrite(FHWCConvOp op,
111 PatternRewriter &rewriter) const final {
112 if (failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) {
113 return failure();
114 }
115 return success();
116 }
117};
118} // namespace
119
120FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
121 linalg::Conv2DNhwcFhwcOp op) {
122
123 return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp,
124 linalg::Conv2DNhwcHwcfOp>(rewriter, op);
125}
126
127FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
128 linalg::Conv2DNhwcFhwcQOp op) {
129
130 return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp,
131 linalg::Conv2DNhwcHwcfQOp>(rewriter, op);
132}
133
135 MLIRContext *context = patterns.getContext();
136 patterns.insert<
137 ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>,
138 ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>(
139 context);
140}
141} // namespace linalg
142} // namespace mlir
return success()
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateTransposeConv2DPatterns(RewritePatternSet &patterns)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns