15 #define DEBUG_TYPE "linalg-transpose-matmul"
31 linalg::MatmulOp matmulOp,
35 if (matmulOp.hasUserDefinedMaps()) {
37 matmulOp,
"only matmul ops with non-extended semantics are supported");
40 if (!matmulOp.hasPureTensorSemantics())
42 matmulOp,
"only matmul ops with tensors are supported");
45 Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
46 auto type = cast<ShapedType>(input.
getType());
49 if (type.isDynamicDim(1))
50 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 1));
51 if (type.isDynamicDim(0))
52 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
55 Value empty = tensor::EmptyOp::create(rewriter, loc,
57 type.getElementType(), dynamicDims);
58 auto transposeOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
62 newMatmulOp = linalg::MatmulTransposeAOp::create(
63 rewriter, loc, matmulOp.getResultTypes(),
64 ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
65 matmulOp.getOutputs());
67 newMatmulOp = linalg::MatmulTransposeBOp::create(
68 rewriter, loc, matmulOp.getResultTypes(),
69 ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
70 matmulOp.getOutputs());
72 rewriter.
replaceOp(matmulOp, newMatmulOp);
86 FailureOr<Operation *>
88 linalg::BatchMatmulOp batchMatmulOp,
90 if (batchMatmulOp.hasUserDefinedMaps()) {
92 batchMatmulOp,
"ops with user-defined maps are not supported");
95 if (!batchMatmulOp.hasPureTensorSemantics())
97 batchMatmulOp,
"only matmul ops with tensors are supported");
99 Location loc = batchMatmulOp.getLoc();
100 Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
101 auto type = cast<ShapedType>(input.
getType());
104 if (type.isDynamicDim(0))
105 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
106 if (type.isDynamicDim(2))
107 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 2));
108 if (type.isDynamicDim(1))
109 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 1));
112 Value empty = tensor::EmptyOp::create(
114 type.getElementType(), dynamicDims);
115 auto transposeOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
119 newMatmulOp = linalg::BatchMatmulTransposeAOp::create(
120 rewriter, loc, batchMatmulOp.getResultTypes(),
121 ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
122 batchMatmulOp.getOutputs());
124 newMatmulOp = linalg::BatchMatmulTransposeBOp::create(
125 rewriter, loc, batchMatmulOp.getResultTypes(),
126 ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
127 batchMatmulOp.getOutputs());
129 rewriter.
replaceOp(batchMatmulOp, newMatmulOp);
135 TransposeMatmul(
MLIRContext *ctx,
bool transposeLHS)
138 LogicalResult matchAndRewrite(linalg::MatmulOp op,
150 struct TransposeBatchMatmul final
152 TransposeBatchMatmul(
MLIRContext *ctx,
bool transposeLHS)
155 LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateTransposeMatmulPatterns(RewritePatternSet &patterns, bool transposeLHS=true)
Patterns to convert Linalg matmul ops to transposed variants.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...