18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/Support/ErrorHandling.h"
20 #include "llvm/Support/RWMutex.h"
50 template <
typename FHWCConvOp,
typename HWCFConvOp>
51 FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter,
55 SmallVector<int64_t> filterPerm({1, 2, 3, 0});
58 auto filter = op->getOperand(1);
59 auto filterTy = cast<ShapedType>(filter.getType());
60 SmallVector<int64_t> newFilterShape(filterPerm.size());
61 std::generate(std::begin(newFilterShape), std::end(newFilterShape),
62 [dim = 0, &filterTy, &filterPerm]()
mutable {
63 return filterTy.getShape()[filterPerm[dim++]];
68 auto inputType = op->getOperand(0).getType();
69 auto elementTy = cast<ShapedType>(inputType).getElementType();
70 auto loc = op->getLoc();
72 const auto isTensorOp = isa<TensorType>(inputType);
76 input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy)
80 .create<memref::AllocOp>(
87 rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm);
96 SmallVector<Value> newInputs{op.getInputs()};
99 newInputs[1] = newFilter;
102 SmallVector<Type> resultTy;
103 if (op.getNumResults()) {
104 resultTy.push_back(op->getResult(0).getType());
107 rewriter.create<HWCFConvOp>(loc, resultTy, newInputs, op.getOutputs(),
108 op.getStrides(), op.getDilations());
109 rewriter.replaceOp(op, newConv);
110 return newConv.getOperation();
113 template <
typename FHWCConvOp,
typename HWCFConvOp>
114 class ConvConverter :
public OpRewritePattern<FHWCConvOp> {
117 LogicalResult matchAndRewrite(FHWCConvOp op,
118 PatternRewriter &rewriter)
const final {
119 if (failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) {
128 linalg::Conv2DNhwcFhwcOp op) {
130 return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp,
131 linalg::Conv2DNhwcHwcfOp>(rewriter, op);
135 linalg::Conv2DNhwcFhwcQOp op) {
137 return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp,
138 linalg::Conv2DNhwcHwcfQOp>(rewriter, op);
144 ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>,
145 ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>(
MLIRContext is the top-level object for a collection of MLIR operations.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void populateTranposeConv2DPatterns(RewritePatternSet &patterns)
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...