MLIR  22.0.0git
TransposeMatmul.cpp
Go to the documentation of this file.
1 //===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This is intended to be a simple high-level (target-agnostic) matmul
9 // transposition transformation.
10 //===----------------------------------------------------------------------===//
11 
13 #include "mlir/IR/PatternMatch.h"
14 
15 #define DEBUG_TYPE "linalg-transpose-matmul"
16 
17 using namespace mlir;
18 using namespace mlir::linalg;
19 
20 /// Pattern to replace
21 ///
22 /// linalg.matmul(a, b)
23 ///
24 /// with
25 ///
26 /// linalg.matmul_transpose_a(linalg.transpose(a), b)
27 ///
28 /// By default the LHS is transposed. Set `transposeLHS=false` to
29 /// transpose RHS instead.
30 FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
31  linalg::MatmulOp matmulOp,
32  bool transposeLHS) {
33  // Check to not let go the matmul with extended semantic, through this
34  // transform.
35  if (matmulOp.hasUserDefinedMaps()) {
36  return rewriter.notifyMatchFailure(
37  matmulOp, "only matmul ops with non-extended semantics are supported");
38  }
39 
40  if (!matmulOp.hasPureTensorSemantics())
41  return rewriter.notifyMatchFailure(
42  matmulOp, "only matmul ops with tensors are supported");
43 
44  Location loc = matmulOp.getLoc();
45  Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
46  auto type = cast<ShapedType>(input.getType());
47 
48  SmallVector<Value> dynamicDims;
49  if (type.isDynamicDim(1))
50  dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 1));
51  if (type.isDynamicDim(0))
52  dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
53 
54  ArrayRef<int64_t> shape = type.getShape();
55  Value empty = tensor::EmptyOp::create(rewriter, loc,
56  ArrayRef<int64_t>{shape[1], shape[0]},
57  type.getElementType(), dynamicDims);
58  auto transposeOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
59  ArrayRef<int64_t>{1, 0});
60  Operation *newMatmulOp;
61  if (transposeLHS) {
62  newMatmulOp = linalg::MatmulTransposeAOp::create(
63  rewriter, loc, matmulOp.getResultTypes(),
64  ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
65  matmulOp.getOutputs());
66  } else {
67  newMatmulOp = linalg::MatmulTransposeBOp::create(
68  rewriter, loc, matmulOp.getResultTypes(),
69  ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
70  matmulOp.getOutputs());
71  }
72  rewriter.replaceOp(matmulOp, newMatmulOp);
73  return newMatmulOp;
74 }
75 
76 /// Pattern to replace
77 ///
78 /// linalg.batch_matmul(a, b)
79 ///
80 /// with
81 ///
82 /// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
83 ///
84 /// Only the non-batch dimensions are transposed. By default the LHS is
85 /// transposed. Set `transposeLHS=false` to transpose RHS instead.
86 FailureOr<Operation *>
88  linalg::BatchMatmulOp batchMatmulOp,
89  bool transposeLHS) {
90  if (batchMatmulOp.hasUserDefinedMaps()) {
91  return rewriter.notifyMatchFailure(
92  batchMatmulOp, "ops with user-defined maps are not supported");
93  }
94 
95  if (!batchMatmulOp.hasPureTensorSemantics())
96  return rewriter.notifyMatchFailure(
97  batchMatmulOp, "only matmul ops with tensors are supported");
98 
99  Location loc = batchMatmulOp.getLoc();
100  Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
101  auto type = cast<ShapedType>(input.getType());
102 
103  SmallVector<Value> dynamicDims;
104  if (type.isDynamicDim(0))
105  dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
106  if (type.isDynamicDim(2))
107  dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 2));
108  if (type.isDynamicDim(1))
109  dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 1));
110 
111  ArrayRef<int64_t> shape = type.getShape();
112  Value empty = tensor::EmptyOp::create(
113  rewriter, loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
114  type.getElementType(), dynamicDims);
115  auto transposeOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
116  ArrayRef<int64_t>{0, 2, 1});
117  Operation *newMatmulOp;
118  if (transposeLHS) {
119  newMatmulOp = linalg::BatchMatmulTransposeAOp::create(
120  rewriter, loc, batchMatmulOp.getResultTypes(),
121  ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
122  batchMatmulOp.getOutputs());
123  } else {
124  newMatmulOp = linalg::BatchMatmulTransposeBOp::create(
125  rewriter, loc, batchMatmulOp.getResultTypes(),
126  ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
127  batchMatmulOp.getOutputs());
128  }
129  rewriter.replaceOp(batchMatmulOp, newMatmulOp);
130  return newMatmulOp;
131 }
132 
133 namespace {
134 struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
135  TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
136  : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
137 
138  LogicalResult matchAndRewrite(linalg::MatmulOp op,
139  PatternRewriter &rewriter) const override {
140  if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
141  return failure();
142  }
143  return success();
144  }
145 
146 private:
147  bool transposeLHS;
148 };
149 
150 struct TransposeBatchMatmul final
151  : public OpRewritePattern<linalg::BatchMatmulOp> {
152  TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
153  : OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
154 
155  LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
156  PatternRewriter &rewriter) const override {
157  if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
158  return failure();
159  }
160  return success();
161  }
162 
163 private:
164  bool transposeLHS;
165 };
166 } // namespace
167 
169  bool transposeLHS) {
170  patterns.add<TransposeMatmul, TransposeBatchMatmul>(patterns.getContext(),
171  transposeLHS);
172 }
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void populateTransposeMatmulPatterns(RewritePatternSet &patterns, bool transposeLHS=true)
Patterns to convert Linalg matmul ops to transposed variants.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314