MLIR  20.0.0git
TransposeMatmul.cpp
Go to the documentation of this file.
1 //===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This is intended to be a simple high-level (target-agnostic) matmul
9 // transposition transformation.
10 //===----------------------------------------------------------------------===//
11 
13 #include "mlir/IR/PatternMatch.h"
15 
16 #define DEBUG_TYPE "linalg-transpose-matmul"
17 
18 using namespace mlir;
19 using namespace mlir::linalg;
20 
21 /// Pattern to replace
22 ///
23 /// linalg.matmul(a, b)
24 ///
25 /// with
26 ///
27 /// linalg.matmul_transpose_a(linalg.transpose(a), b)
28 ///
29 /// By default the LHS is transposed. Set `transposeLHS=false` to
30 /// transpose RHS instead.
31 FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
32  linalg::MatmulOp matmulOp,
33  bool transposeLHS) {
34  if (!bufferization::hasTensorSemantics(matmulOp))
35  return rewriter.notifyMatchFailure(
36  matmulOp, "only matmul ops with tensors are supported");
37 
38  Location loc = matmulOp.getLoc();
39  Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
40  auto type = cast<ShapedType>(input.getType());
41 
42  SmallVector<Value> dynamicDims;
43  if (type.isDynamicDim(1))
44  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
45  if (type.isDynamicDim(0))
46  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
47 
48  ArrayRef<int64_t> shape = type.getShape();
49  Value empty = rewriter.create<tensor::EmptyOp>(
50  loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
51  dynamicDims);
52  auto transposeOp = rewriter.create<linalg::TransposeOp>(
53  loc, input, empty, ArrayRef<int64_t>{1, 0});
54  Operation *newMatmulOp;
55  if (transposeLHS) {
56  newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
57  loc, matmulOp.getResultTypes(),
58  ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
59  matmulOp.getOutputs());
60  } else {
61  newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
62  loc, matmulOp.getResultTypes(),
63  ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
64  matmulOp.getOutputs());
65  }
66  rewriter.replaceOp(matmulOp, newMatmulOp);
67  return newMatmulOp;
68 }
69 
70 /// Pattern to replace
71 ///
72 /// linalg.batch_matmul(a, b)
73 ///
74 /// with
75 ///
76 /// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
77 ///
78 /// Only the non-batch dimensions are transposed. By default the LHS is
79 /// transposed. Set `transposeLHS=false` to transpose RHS instead.
80 FailureOr<Operation *>
82  linalg::BatchMatmulOp batchMatmulOp,
83  bool transposeLHS) {
84  if (!bufferization::hasTensorSemantics(batchMatmulOp))
85  return rewriter.notifyMatchFailure(
86  batchMatmulOp, "only matmul ops with tensors are supported");
87 
88  Location loc = batchMatmulOp.getLoc();
89  Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
90  auto type = cast<ShapedType>(input.getType());
91 
92  SmallVector<Value> dynamicDims;
93  if (type.isDynamicDim(0))
94  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
95  if (type.isDynamicDim(2))
96  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
97  if (type.isDynamicDim(1))
98  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
99 
100  ArrayRef<int64_t> shape = type.getShape();
101  Value empty = rewriter.create<tensor::EmptyOp>(
102  loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
103  type.getElementType(), dynamicDims);
104  auto transposeOp = rewriter.create<linalg::TransposeOp>(
105  loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
106  Operation *newMatmulOp;
107  if (transposeLHS) {
108  newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
109  loc, batchMatmulOp.getResultTypes(),
110  ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
111  batchMatmulOp.getOutputs());
112  } else {
113  newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
114  loc, batchMatmulOp.getResultTypes(),
115  ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
116  batchMatmulOp.getOutputs());
117  }
118  rewriter.replaceOp(batchMatmulOp, newMatmulOp);
119  return newMatmulOp;
120 }
121 
122 namespace {
123 struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
124  TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
125  : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
126 
127  LogicalResult matchAndRewrite(linalg::MatmulOp op,
128  PatternRewriter &rewriter) const override {
129  if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
130  return failure();
131  }
132  return success();
133  }
134 
135 private:
136  bool transposeLHS;
137 };
138 
139 struct TransposeBatchMatmul final
140  : public OpRewritePattern<linalg::BatchMatmulOp> {
141  TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
142  : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
143 
144  LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
145  PatternRewriter &rewriter) const override {
146  if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
147  return failure();
148  }
149  return success();
150  }
151 
152 private:
153  bool transposeLHS;
154 };
155 } // namespace
156 
158  bool transposeLHS) {
159  patterns.add<TransposeMatmul, TransposeBatchMatmul>(patterns.getContext(),
160  transposeLHS);
161 }
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:488
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
void populateTransposeMatmulPatterns(RewritePatternSet &patterns, bool transposeLHS=true)
Patterns to convert Linalg matmul ops to transposed variants.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358