19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/ErrorHandling.h"
21 #include "llvm/Support/RWMutex.h"
51 template <
typename FHWCConvOp,
typename HWCFConvOp>
52 FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter,
56 SmallVector<int64_t> filterPerm({1, 2, 3, 0});
59 auto filter = op->getOperand(1);
60 auto filterTy = cast<ShapedType>(filter.getType());
61 SmallVector<int64_t> newFilterShape(filterPerm.size());
62 std::generate(std::begin(newFilterShape), std::end(newFilterShape),
63 [dim = 0, &filterTy, &filterPerm]()
mutable {
64 return filterTy.getShape()[filterPerm[dim++]];
69 auto inputType = op->getOperand(0).getType();
70 auto elementTy = cast<ShapedType>(inputType).getElementType();
71 auto loc = op->getLoc();
73 const auto isTensorOp = isa<TensorType>(inputType);
77 input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy)
81 .create<memref::AllocOp>(
88 rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm);
97 SmallVector<Value> newInputs{op.getInputs()};
100 newInputs[1] = newFilter;
103 SmallVector<Type> resultTy;
104 if (op.getNumResults()) {
105 resultTy.push_back(op->getResult(0).getType());
108 rewriter.create<HWCFConvOp>(loc, resultTy, newInputs, op.getOutputs(),
109 op.getStrides(), op.getDilations());
110 rewriter.replaceOp(op, newConv);
111 return newConv.getOperation();
114 template <
typename FHWCConvOp,
typename HWCFConvOp>
115 class ConvConverter :
public OpRewritePattern<FHWCConvOp> {
118 LogicalResult matchAndRewrite(FHWCConvOp op,
119 PatternRewriter &rewriter)
const final {
120 if (
failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) {
129 linalg::Conv2DNhwcFhwcOp op) {
131 return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp,
132 linalg::Conv2DNhwcHwcfOp>(rewriter, op);
136 linalg::Conv2DNhwcFhwcQOp op) {
138 return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp,
139 linalg::Conv2DNhwcHwcfQOp>(rewriter, op);
145 ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>,
146 ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>(
This class provides support for representing a failure result, or a valid value of type T.
MLIRContext is the top-level object for a collection of MLIR operations.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void populateTranposeConv2DPatterns(RewritePatternSet &patterns)
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...