15#define DEBUG_TYPE "linalg-transpose-matmul"
31 linalg::MatmulOp matmulOp,
35 if (matmulOp.hasUserDefinedMaps()) {
37 matmulOp,
"only matmul ops with non-extended semantics are supported");
40 if (!matmulOp.hasPureTensorSemantics())
42 matmulOp,
"only matmul ops with tensors are supported");
45 Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
46 auto type = cast<ShapedType>(input.
getType());
49 if (type.isDynamicDim(1))
50 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 1));
51 if (type.isDynamicDim(0))
52 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
55 Value empty = tensor::EmptyOp::create(rewriter, loc,
57 type.getElementType(), dynamicDims);
58 auto transposeOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
63 rewriter, loc, matmulOp.getResultTypes(),
64 ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
65 matmulOp.getOutputs());
68 rewriter, loc, matmulOp.getResultTypes(),
69 ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
70 matmulOp.getOutputs());
72 rewriter.
replaceOp(matmulOp, newMatmulOp);
88 linalg::BatchMatmulOp batchMatmulOp,
90 if (batchMatmulOp.hasUserDefinedMaps()) {
92 batchMatmulOp,
"ops with user-defined maps are not supported");
95 if (!batchMatmulOp.hasPureTensorSemantics())
97 batchMatmulOp,
"only matmul ops with tensors are supported");
99 Location loc = batchMatmulOp.getLoc();
100 Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
101 auto type = cast<ShapedType>(input.
getType());
104 if (type.isDynamicDim(0))
105 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
106 if (type.isDynamicDim(2))
107 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 2));
108 if (type.isDynamicDim(1))
109 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 1));
112 Value empty = tensor::EmptyOp::create(
114 type.getElementType(), dynamicDims);
115 auto transposeOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
120 rewriter, loc, batchMatmulOp.getResultTypes(),
121 ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
122 batchMatmulOp.getOutputs());
125 rewriter, loc, batchMatmulOp.getResultTypes(),
126 ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
127 batchMatmulOp.getOutputs());
129 rewriter.
replaceOp(batchMatmulOp, newMatmulOp);
135 TransposeMatmul(
MLIRContext *ctx,
bool transposeLHS)
138 LogicalResult matchAndRewrite(linalg::MatmulOp op,
150struct TransposeBatchMatmul final
152 TransposeBatchMatmul(MLIRContext *ctx,
bool transposeLHS)
153 : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
155 LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
156 PatternRewriter &rewriter)
const override {
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
void populateTransposeMatmulPatterns(RewritePatternSet &patterns, bool transposeLHS=true)
Patterns to convert Linalg matmul ops to transposed variants.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...