16 #include "llvm/ADT/SmallVector.h"
44 template <
typename FHWCConvOp,
typename HWCFConvOp>
45 FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter,
49 SmallVector<int64_t> filterPerm = {1, 2, 3, 0};
52 auto filter = op->getOperand(1);
53 auto filterTy = cast<ShapedType>(filter.getType());
54 SmallVector<int64_t> newFilterShape(filterPerm.size());
55 std::generate(std::begin(newFilterShape), std::end(newFilterShape),
56 [dim = 0, &filterTy, &filterPerm]()
mutable {
57 return filterTy.getShape()[filterPerm[dim++]];
62 auto inputType = op->getOperand(0).getType();
63 auto elementTy = cast<ShapedType>(inputType).getElementType();
64 auto loc = op->getLoc();
66 const auto isTensorOp = isa<TensorType>(inputType);
70 input = tensor::EmptyOp::create(rewriter, loc, newFilterShape, elementTy)
73 input = memref::AllocOp::create(rewriter, loc,
80 linalg::TransposeOp::create(rewriter, loc, filter, input, filterPerm);
84 newFilter = transpose.getResult()[0];
89 SmallVector<Value> newInputs{op.getInputs()};
92 newInputs[1] = newFilter;
95 SmallVector<Type> resultTy;
96 if (op.getNumResults()) {
97 resultTy.push_back(op->getResult(0).getType());
100 HWCFConvOp::create(rewriter, loc, resultTy, newInputs, op.getOutputs(),
101 op.getStrides(), op.getDilations());
102 rewriter.replaceOp(op, newConv);
103 return newConv.getOperation();
106 template <
typename FHWCConvOp,
typename HWCFConvOp>
107 class ConvConverter :
public OpRewritePattern<FHWCConvOp> {
110 LogicalResult matchAndRewrite(FHWCConvOp op,
111 PatternRewriter &rewriter)
const final {
112 if (
failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) {
121 linalg::Conv2DNhwcFhwcOp op) {
123 return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp,
124 linalg::Conv2DNhwcHwcfOp>(rewriter, op);
128 linalg::Conv2DNhwcFhwcQOp op) {
130 return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp,
131 linalg::Conv2DNhwcHwcfQOp>(rewriter, op);
137 ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>,
138 ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>(
MLIRContext is the top-level object for a collection of MLIR operations.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateTransposeConv2DPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...