MLIR 22.0.0git
TransposeMatmul.cpp
Go to the documentation of this file.
1//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This is intended to be a simple high-level (target-agnostic) matmul
9// transposition transformation.
10//===----------------------------------------------------------------------===//
11
14
15#define DEBUG_TYPE "linalg-transpose-matmul"
16
17using namespace mlir;
18using namespace mlir::linalg;
19
20/// Pattern to replace
21///
22/// linalg.matmul(a, b)
23///
24/// with
25///
26/// linalg.matmul_transpose_a(linalg.transpose(a), b)
27///
28/// By default the LHS is transposed. Set `transposeLHS=false` to
29/// transpose RHS instead.
30FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
31 linalg::MatmulOp matmulOp,
32 bool transposeLHS) {
33 // Check to not let go the matmul with extended semantic, through this
34 // transform.
35 if (matmulOp.hasUserDefinedMaps()) {
36 return rewriter.notifyMatchFailure(
37 matmulOp, "only matmul ops with non-extended semantics are supported");
38 }
39
40 if (!matmulOp.hasPureTensorSemantics())
41 return rewriter.notifyMatchFailure(
42 matmulOp, "only matmul ops with tensors are supported");
43
44 Location loc = matmulOp.getLoc();
45 Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
46 auto type = cast<ShapedType>(input.getType());
47
48 SmallVector<Value> dynamicDims;
49 if (type.isDynamicDim(1))
50 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 1));
51 if (type.isDynamicDim(0))
52 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
53
54 ArrayRef<int64_t> shape = type.getShape();
55 Value empty = tensor::EmptyOp::create(rewriter, loc,
57 type.getElementType(), dynamicDims);
58 auto transposeOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
59 ArrayRef<int64_t>{1, 0});
60 Operation *newMatmulOp;
61 if (transposeLHS) {
62 newMatmulOp = MatmulTransposeAOp::create(
63 rewriter, loc, matmulOp.getResultTypes(),
64 ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
65 matmulOp.getOutputs());
66 } else {
67 newMatmulOp = MatmulTransposeBOp::create(
68 rewriter, loc, matmulOp.getResultTypes(),
69 ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
70 matmulOp.getOutputs());
71 }
72 rewriter.replaceOp(matmulOp, newMatmulOp);
73 return newMatmulOp;
74}
75
76/// Pattern to replace
77///
78/// linalg.batch_matmul(a, b)
79///
80/// with
81///
82/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
83///
84/// Only the non-batch dimensions are transposed. By default the LHS is
85/// transposed. Set `transposeLHS=false` to transpose RHS instead.
86FailureOr<Operation *>
88 linalg::BatchMatmulOp batchMatmulOp,
89 bool transposeLHS) {
90 if (batchMatmulOp.hasUserDefinedMaps()) {
91 return rewriter.notifyMatchFailure(
92 batchMatmulOp, "ops with user-defined maps are not supported");
93 }
94
95 if (!batchMatmulOp.hasPureTensorSemantics())
96 return rewriter.notifyMatchFailure(
97 batchMatmulOp, "only matmul ops with tensors are supported");
98
99 Location loc = batchMatmulOp.getLoc();
100 Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
101 auto type = cast<ShapedType>(input.getType());
102
103 SmallVector<Value> dynamicDims;
104 if (type.isDynamicDim(0))
105 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
106 if (type.isDynamicDim(2))
107 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 2));
108 if (type.isDynamicDim(1))
109 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 1));
110
111 ArrayRef<int64_t> shape = type.getShape();
112 Value empty = tensor::EmptyOp::create(
113 rewriter, loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
114 type.getElementType(), dynamicDims);
115 auto transposeOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
116 ArrayRef<int64_t>{0, 2, 1});
117 Operation *newMatmulOp;
118 if (transposeLHS) {
120 rewriter, loc, batchMatmulOp.getResultTypes(),
121 ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
122 batchMatmulOp.getOutputs());
123 } else {
125 rewriter, loc, batchMatmulOp.getResultTypes(),
126 ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
127 batchMatmulOp.getOutputs());
128 }
129 rewriter.replaceOp(batchMatmulOp, newMatmulOp);
130 return newMatmulOp;
131}
132
133namespace {
134struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
135 TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
136 : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
137
138 LogicalResult matchAndRewrite(linalg::MatmulOp op,
139 PatternRewriter &rewriter) const override {
140 if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
141 return failure();
142 }
143 return success();
144 }
145
146private:
147 bool transposeLHS;
148};
149
150struct TransposeBatchMatmul final
151 : public OpRewritePattern<linalg::BatchMatmulOp> {
152 TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
153 : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
154
155 LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
156 PatternRewriter &rewriter) const override {
157 if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
158 return failure();
159 }
160 return success();
161 }
162
163private:
164 bool transposeLHS;
165};
166} // namespace
167
169 bool transposeLHS) {
170 patterns.add<TransposeMatmul, TransposeBatchMatmul>(patterns.getContext(),
171 transposeLHS);
172}
return success()
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
void populateTransposeMatmulPatterns(RewritePatternSet &patterns, bool transposeLHS=true)
Patterns to convert Linalg matmul ops to transposed variants.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...