MLIR  21.0.0git
TransposeMatmul.cpp
Go to the documentation of this file.
1 //===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This is intended to be a simple high-level (target-agnostic) matmul
9 // transposition transformation.
10 //===----------------------------------------------------------------------===//
11 
13 #include "mlir/IR/PatternMatch.h"
15 
16 #define DEBUG_TYPE "linalg-transpose-matmul"
17 
18 using namespace mlir;
19 using namespace mlir::linalg;
20 
21 /// Pattern to replace
22 ///
23 /// linalg.matmul(a, b)
24 ///
25 /// with
26 ///
27 /// linalg.matmul_transpose_a(linalg.transpose(a), b)
28 ///
29 /// By default the LHS is transposed. Set `transposeLHS=false` to
30 /// transpose RHS instead.
31 FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
32  linalg::MatmulOp matmulOp,
33  bool transposeLHS) {
34  // Check to not let go the matmul with extended semantic, through this
35  // transform.
36  if (matmulOp.hasUserDefinedMaps()) {
37  return rewriter.notifyMatchFailure(
38  matmulOp, "only matmul ops with non-extended semantics are supported");
39  }
40 
41  if (!bufferization::hasTensorSemantics(matmulOp))
42  return rewriter.notifyMatchFailure(
43  matmulOp, "only matmul ops with tensors are supported");
44 
45  Location loc = matmulOp.getLoc();
46  Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
47  auto type = cast<ShapedType>(input.getType());
48 
49  SmallVector<Value> dynamicDims;
50  if (type.isDynamicDim(1))
51  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
52  if (type.isDynamicDim(0))
53  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
54 
55  ArrayRef<int64_t> shape = type.getShape();
56  Value empty = rewriter.create<tensor::EmptyOp>(
57  loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
58  dynamicDims);
59  auto transposeOp = rewriter.create<linalg::TransposeOp>(
60  loc, input, empty, ArrayRef<int64_t>{1, 0});
61  Operation *newMatmulOp;
62  if (transposeLHS) {
63  newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
64  loc, matmulOp.getResultTypes(),
65  ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
66  matmulOp.getOutputs());
67  } else {
68  newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
69  loc, matmulOp.getResultTypes(),
70  ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
71  matmulOp.getOutputs());
72  }
73  rewriter.replaceOp(matmulOp, newMatmulOp);
74  return newMatmulOp;
75 }
76 
77 /// Pattern to replace
78 ///
79 /// linalg.batch_matmul(a, b)
80 ///
81 /// with
82 ///
83 /// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
84 ///
85 /// Only the non-batch dimensions are transposed. By default the LHS is
86 /// transposed. Set `transposeLHS=false` to transpose RHS instead.
87 FailureOr<Operation *>
89  linalg::BatchMatmulOp batchMatmulOp,
90  bool transposeLHS) {
91  if (batchMatmulOp.hasUserDefinedMaps()) {
92  return rewriter.notifyMatchFailure(
93  batchMatmulOp, "ops with user-defined maps are not supported");
94  }
95 
96  if (!bufferization::hasTensorSemantics(batchMatmulOp))
97  return rewriter.notifyMatchFailure(
98  batchMatmulOp, "only matmul ops with tensors are supported");
99 
100  Location loc = batchMatmulOp.getLoc();
101  Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
102  auto type = cast<ShapedType>(input.getType());
103 
104  SmallVector<Value> dynamicDims;
105  if (type.isDynamicDim(0))
106  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
107  if (type.isDynamicDim(2))
108  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
109  if (type.isDynamicDim(1))
110  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
111 
112  ArrayRef<int64_t> shape = type.getShape();
113  Value empty = rewriter.create<tensor::EmptyOp>(
114  loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
115  type.getElementType(), dynamicDims);
116  auto transposeOp = rewriter.create<linalg::TransposeOp>(
117  loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
118  Operation *newMatmulOp;
119  if (transposeLHS) {
120  newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
121  loc, batchMatmulOp.getResultTypes(),
122  ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
123  batchMatmulOp.getOutputs());
124  } else {
125  newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
126  loc, batchMatmulOp.getResultTypes(),
127  ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
128  batchMatmulOp.getOutputs());
129  }
130  rewriter.replaceOp(batchMatmulOp, newMatmulOp);
131  return newMatmulOp;
132 }
133 
134 namespace {
135 struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
136  TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
137  : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
138 
139  LogicalResult matchAndRewrite(linalg::MatmulOp op,
140  PatternRewriter &rewriter) const override {
141  if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
142  return failure();
143  }
144  return success();
145  }
146 
147 private:
148  bool transposeLHS;
149 };
150 
151 struct TransposeBatchMatmul final
152  : public OpRewritePattern<linalg::BatchMatmulOp> {
153  TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
154  : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
155 
156  LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
157  PatternRewriter &rewriter) const override {
158  if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
159  return failure();
160  }
161  return success();
162  }
163 
164 private:
165  bool transposeLHS;
166 };
167 } // namespace
168 
170  bool transposeLHS) {
171  patterns.add<TransposeMatmul, TransposeBatchMatmul>(patterns.getContext(),
172  transposeLHS);
173 }
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
void populateTransposeMatmulPatterns(RewritePatternSet &patterns, bool transposeLHS=true)
Patterns to convert Linalg matmul ops to transposed variants.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358