16 #define DEBUG_TYPE "linalg-transpose-matmul"
32 linalg::MatmulOp matmulOp,
36 if (matmulOp.hasUserDefinedMaps()) {
38 matmulOp,
"only matmul ops with non-extended semantics are supported");
43 matmulOp,
"only matmul ops with tensors are supported");
46 Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
47 auto type = cast<ShapedType>(input.
getType());
50 if (type.isDynamicDim(1))
51 dynamicDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, 1));
52 if (type.isDynamicDim(0))
53 dynamicDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, 0));
59 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
63 newMatmulOp = rewriter.
create<linalg::MatmulTransposeAOp>(
64 loc, matmulOp.getResultTypes(),
65 ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
66 matmulOp.getOutputs());
68 newMatmulOp = rewriter.
create<linalg::MatmulTransposeBOp>(
69 loc, matmulOp.getResultTypes(),
70 ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
71 matmulOp.getOutputs());
73 rewriter.
replaceOp(matmulOp, newMatmulOp);
87 FailureOr<Operation *>
89 linalg::BatchMatmulOp batchMatmulOp,
93 batchMatmulOp,
"only matmul ops with tensors are supported");
95 Location loc = batchMatmulOp.getLoc();
96 Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
97 auto type = cast<ShapedType>(input.
getType());
100 if (type.isDynamicDim(0))
101 dynamicDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, 0));
102 if (type.isDynamicDim(2))
103 dynamicDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, 2));
104 if (type.isDynamicDim(1))
105 dynamicDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, 1));
110 type.getElementType(), dynamicDims);
111 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
115 newMatmulOp = rewriter.
create<linalg::BatchMatmulTransposeAOp>(
116 loc, batchMatmulOp.getResultTypes(),
117 ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
118 batchMatmulOp.getOutputs());
120 newMatmulOp = rewriter.
create<linalg::BatchMatmulTransposeBOp>(
121 loc, batchMatmulOp.getResultTypes(),
122 ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
123 batchMatmulOp.getOutputs());
125 rewriter.
replaceOp(batchMatmulOp, newMatmulOp);
131 TransposeMatmul(
MLIRContext *ctx,
bool transposeLHS)
134 LogicalResult matchAndRewrite(linalg::MatmulOp op,
146 struct TransposeBatchMatmul final
148 TransposeBatchMatmul(
MLIRContext *ctx,
bool transposeLHS)
151 LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
166 patterns.
add<TransposeMatmul, TransposeBatchMatmul>(patterns.
getContext(),
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
void populateTransposeMatmulPatterns(RewritePatternSet &patterns, bool transposeLHS=true)
Patterns to convert Linalg matmul ops to transposed variants.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...