MLIR  20.0.0git
TransposeMatmul.cpp
Go to the documentation of this file.
1 //===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This is intended to be a simple high-level (target-agnostic) matmul
9 // transposition transformation.
10 //===----------------------------------------------------------------------===//
11 
13 #include "mlir/IR/PatternMatch.h"
15 
16 #define DEBUG_TYPE "linalg-transpose-matmul"
17 
18 using namespace mlir;
19 using namespace mlir::linalg;
20 
21 /// Pattern to replace
22 ///
23 /// linalg.matmul(a, b)
24 ///
25 /// with
26 ///
27 /// linalg.matmul_transpose_a(linalg.transpose(a), b)
28 ///
29 /// By default the LHS is transposed. Set `transposeLHS=false` to
30 /// transpose RHS instead.
31 FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
32  linalg::MatmulOp matmulOp,
33  bool transposeLHS) {
34  // Check to not let go the matmul with extended semantic, through this
35  // transform.
36  if (matmulOp.hasUserDefinedMaps()) {
37  return rewriter.notifyMatchFailure(
38  matmulOp, "only matmul ops with non-extended semantics are supported");
39  }
40 
41  if (!bufferization::hasTensorSemantics(matmulOp))
42  return rewriter.notifyMatchFailure(
43  matmulOp, "only matmul ops with tensors are supported");
44 
45  Location loc = matmulOp.getLoc();
46  Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
47  auto type = cast<ShapedType>(input.getType());
48 
49  SmallVector<Value> dynamicDims;
50  if (type.isDynamicDim(1))
51  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
52  if (type.isDynamicDim(0))
53  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
54 
55  ArrayRef<int64_t> shape = type.getShape();
56  Value empty = rewriter.create<tensor::EmptyOp>(
57  loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
58  dynamicDims);
59  auto transposeOp = rewriter.create<linalg::TransposeOp>(
60  loc, input, empty, ArrayRef<int64_t>{1, 0});
61  Operation *newMatmulOp;
62  if (transposeLHS) {
63  newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
64  loc, matmulOp.getResultTypes(),
65  ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
66  matmulOp.getOutputs());
67  } else {
68  newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
69  loc, matmulOp.getResultTypes(),
70  ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
71  matmulOp.getOutputs());
72  }
73  rewriter.replaceOp(matmulOp, newMatmulOp);
74  return newMatmulOp;
75 }
76 
77 /// Pattern to replace
78 ///
79 /// linalg.batch_matmul(a, b)
80 ///
81 /// with
82 ///
83 /// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
84 ///
85 /// Only the non-batch dimensions are transposed. By default the LHS is
86 /// transposed. Set `transposeLHS=false` to transpose RHS instead.
87 FailureOr<Operation *>
89  linalg::BatchMatmulOp batchMatmulOp,
90  bool transposeLHS) {
91  if (!bufferization::hasTensorSemantics(batchMatmulOp))
92  return rewriter.notifyMatchFailure(
93  batchMatmulOp, "only matmul ops with tensors are supported");
94 
95  Location loc = batchMatmulOp.getLoc();
96  Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
97  auto type = cast<ShapedType>(input.getType());
98 
99  SmallVector<Value> dynamicDims;
100  if (type.isDynamicDim(0))
101  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
102  if (type.isDynamicDim(2))
103  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
104  if (type.isDynamicDim(1))
105  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
106 
107  ArrayRef<int64_t> shape = type.getShape();
108  Value empty = rewriter.create<tensor::EmptyOp>(
109  loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
110  type.getElementType(), dynamicDims);
111  auto transposeOp = rewriter.create<linalg::TransposeOp>(
112  loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
113  Operation *newMatmulOp;
114  if (transposeLHS) {
115  newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
116  loc, batchMatmulOp.getResultTypes(),
117  ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
118  batchMatmulOp.getOutputs());
119  } else {
120  newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
121  loc, batchMatmulOp.getResultTypes(),
122  ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
123  batchMatmulOp.getOutputs());
124  }
125  rewriter.replaceOp(batchMatmulOp, newMatmulOp);
126  return newMatmulOp;
127 }
128 
129 namespace {
130 struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
131  TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
132  : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
133 
134  LogicalResult matchAndRewrite(linalg::MatmulOp op,
135  PatternRewriter &rewriter) const override {
136  if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
137  return failure();
138  }
139  return success();
140  }
141 
142 private:
143  bool transposeLHS;
144 };
145 
146 struct TransposeBatchMatmul final
147  : public OpRewritePattern<linalg::BatchMatmulOp> {
148  TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
149  : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
150 
151  LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
152  PatternRewriter &rewriter) const override {
153  if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
154  return failure();
155  }
156  return success();
157  }
158 
159 private:
160  bool transposeLHS;
161 };
162 } // namespace
163 
165  bool transposeLHS) {
166  patterns.add<TransposeMatmul, TransposeBatchMatmul>(patterns.getContext(),
167  transposeLHS);
168 }
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
void populateTransposeMatmulPatterns(RewritePatternSet &patterns, bool transposeLHS=true)
Patterns to convert Linalg matmul ops to transposed variants.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358