16 #define DEBUG_TYPE "linalg-transpose-matmul"
32 linalg::MatmulOp matmulOp,
36 matmulOp,
"only matmul ops with tensors are supported");
39 Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
40 auto type = cast<ShapedType>(input.
getType());
43 if (type.isDynamicDim(1))
44 dynamicDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, 1));
45 if (type.isDynamicDim(0))
46 dynamicDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, 0));
52 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
56 newMatmulOp = rewriter.
create<linalg::MatmulTransposeAOp>(
57 loc, matmulOp.getResultTypes(),
58 ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
59 matmulOp.getOutputs());
61 newMatmulOp = rewriter.
create<linalg::MatmulTransposeBOp>(
62 loc, matmulOp.getResultTypes(),
63 ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
64 matmulOp.getOutputs());
66 rewriter.
replaceOp(matmulOp, newMatmulOp);
80 FailureOr<Operation *>
82 linalg::BatchMatmulOp batchMatmulOp,
86 batchMatmulOp,
"only matmul ops with tensors are supported");
88 Location loc = batchMatmulOp.getLoc();
89 Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
90 auto type = cast<ShapedType>(input.
getType());
93 if (type.isDynamicDim(0))
94 dynamicDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, 0));
95 if (type.isDynamicDim(2))
96 dynamicDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, 2));
97 if (type.isDynamicDim(1))
98 dynamicDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, 1));
103 type.getElementType(), dynamicDims);
104 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
108 newMatmulOp = rewriter.
create<linalg::BatchMatmulTransposeAOp>(
109 loc, batchMatmulOp.getResultTypes(),
110 ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
111 batchMatmulOp.getOutputs());
113 newMatmulOp = rewriter.
create<linalg::BatchMatmulTransposeBOp>(
114 loc, batchMatmulOp.getResultTypes(),
115 ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
116 batchMatmulOp.getOutputs());
118 rewriter.
replaceOp(batchMatmulOp, newMatmulOp);
124 TransposeMatmul(
MLIRContext *ctx,
bool transposeLHS)
127 LogicalResult matchAndRewrite(linalg::MatmulOp op,
139 struct TransposeBatchMatmul final
141 TransposeBatchMatmul(
MLIRContext *ctx,
bool transposeLHS)
144 LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
159 patterns.
add<TransposeMatmul, TransposeBatchMatmul>(patterns.
getContext(),
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
void populateTransposeMatmulPatterns(RewritePatternSet &patterns, bool transposeLHS=true)
Patterns to convert Linalg matmul ops to transposed variants.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...