MLIR  20.0.0git
Specialize.cpp
Go to the documentation of this file.
1 //===- Specialize.cpp - linalg generic ops to named ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a method to specialize generic operations to named
10 // operations. Conceptually it is the opposite of generalize.cpp.
11 //
12 //===----------------------------------------------------------------------===//
13 
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Support/TypeID.h"
23 #include "llvm/Support/Debug.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
27 #include "mlir/Dialect/Linalg/Passes.h.inc"
28 } // namespace mlir
29 
30 #define DEBUG_TYPE "linalg-specialization"
31 
32 #define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
33  (rewriter.replaceOpWithNewOp<NEWOP>( \
34  genericOp, \
35  ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
36  genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
37  ValueRange{genericOp.getDpsInits()[0]}))
38 
39 #define REPLACE_UNARY_OP(NEWOP) \
40  (rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
41  ValueRange{genericOp.getDpsInputs()[0]}, \
42  ValueRange{genericOp.getDpsInits()[0]}))
43 
44 using namespace mlir;
45 using namespace mlir::linalg;
46 
47 // Given a elementwise single binary linalg generic op, checks whether the
48 // binary op accesses operands as swapped. e.g.
49 // this differentiates between a linalg-generic body that contains:
50 // ^bb0(%a: f32, %b: f32, %c : f32):
51 // %0 = arith.subf %a, %b : f32
52 // linalg.yield %0: f32
53 // against:
54 // ^bb0(%a: f32, %b: f32, %c : f32):
55 // %0 = arith.subf %b, %a : f32
56 // linalg.yield %0: f32
57 // Former is linalg.sub(a,b), latter is linalg.sub(b,a).
58 static bool areBinOpsSwapped(GenericOp genericOp) {
59  Block *body = genericOp.getBody();
60  Operation *op = &body->front();
61  bool swapped = false;
62  if (op->getOpOperand(0).get() != body->getArgument(0)) {
63  swapped = true;
64  assert(op->getOpOperand(0).get() == body->getArgument(1) &&
65  op->getOpOperand(1).get() == body->getArgument(0) &&
66  "binary op uses just one block arg");
67  }
68  return swapped;
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // Specialize linalg generic to matmul variants.
73 //===----------------------------------------------------------------------===//
74 /// Identifies linalg.generic that is essentially named op of the form:
75 // ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
76 //
77 // It is possible that a linalg.generic may be implementing a matmul but not
78 // in a straight-forward way e.g. below is matrix multiply over some slice
79 // ```
80 // %0 = linalg.generic {
81 // indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
82 // affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
83 // affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
84 // iterator_types = ["parallel", "parallel", "parallel"]}
85 // ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>)
86 // outs(%C : tensor<20x20x20xf32>) {
87 // ^bb0(%a: f32, %b: f32, %c : f32):
88 // %mul = arith.mulf %a, %b : f32
89 // %add = arith.addf %mul, %c : f32
90 // linalg.yield %add : f32
91 // } -> tensor<20x20x20xf32>
92 // ```
93 // It is not possible to represent above as named op.
94 // e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is
95 // not the same as linalg.generic above.
96 namespace {
97 enum class IndexMatchResult {
98  Match = 0, // identity map.
99  Transposed, // transposed map.
100  Mismatch // none of the above.
101 };
102 
103 // Checks whether the input Affine `map` contains two consecutive dims that
104 // can be interpreted as accessing a 2D matrix. It is assumed that the row
105 // column dimension are adjacent axis (in this order) and start at
106 // `rowDimIdx` in the input map.
107 //
108 // e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
109 // whether the map of A is identity (match), transposed, or something
110 // completely different (mis-match). Similar for B and C.
111 static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
112  unsigned expectedPosOfRowDim,
113  unsigned expectedPosOfColDim) {
114  // Get the matrix multiply indices. They are past the batch indices.
115  auto exprOfRowDim = map.getResults()[rowDimIdx];
116  auto exprOfColDim = map.getResults()[rowDimIdx + 1];
117 
118  // They should be pure dimension ids.
119  if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
120  exprOfColDim.getKind() != AffineExprKind::DimId)
121  return IndexMatchResult::Mismatch;
122 
123  auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
124  auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
125 
126  if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
127  return IndexMatchResult::Match;
128 
129  if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
130  return IndexMatchResult::Transposed;
131 
132  return IndexMatchResult::Mismatch;
133 }
134 
135 // Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
136 // All the variants expressed as pseudo regular expression:
137 // `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
138 // have same number of ins/out, so its easy to stamp different versions.
139 template <typename NamedOpTy>
140 static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
141  LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
142  op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
143  ValueRange{op.getDpsInits()[0]});
144  return namedOp;
145 }
146 
147 // Converts linalg.generic to named linalg.*matmul* where possible.
148 static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
149  GenericOp genericOp) {
150  if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
151  return failure();
152 
153  // Early exit if not projected permutations.
154  auto mapRange = genericOp.getIndexingMapsArray();
155  if (llvm::any_of(mapRange,
156  [](AffineMap m) { return !m.isProjectedPermutation(); }))
157  return failure();
158 
159  // Linalg generic contraction can be across multiple axis e.g.
160  // ```
161  // linalg.generic
162  // {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>,
163  // affine_map<(m, n, k1, k2) -> (k2, k1, n)>,
164  // affine_map<(m, n, k1, k2) -> (m, n)>],
165  // iterator_types = ["parallel", "parallel",
166  // "reduction", "reduction"]}
167  // ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
168  // outs(%C : tensor<10x40xf32>) {
169  // ^bb0(%a: f32, %b: f32, %c: f32):
170  // %1 = arith.mulf %a, %b : f32
171  // %2 = arith.addf %c, %1 : f32
172  // linalg.yield %2 : f32
173  // } -> tensor<10x40xf32>
174  // ```
175  // In above contraction, there are two reduction dimensions {k1, k2}
176  // and although a valid linalg contraction, it is not a named-op
177  // matrix multiply kind. Therefore, reject multi-dim reduction.
178  auto res = inferContractionDims(genericOp);
179  if (!succeeded(res))
180  return failure();
181  auto dims = *res;
182  if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
183  return failure();
184 
186  *genericOp.getBlock(), [](Operation *first, Operation *second) {
187  if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
188  (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
189  (isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
190  return true;
191  return false;
192  }))
193  return failure();
194 
195  // Check rank of operands
196  auto indexingMaps = genericOp.getIndexingMapsArray();
197  if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
198  return m.getResults().size() !=
199  dims.batch.size() + 2 /* any two of {m,n,k} */;
200  }))
201  return failure();
202 
203  auto numOfBatchDims = dims.batch.size();
204  if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
205  return failure();
206 
207  if (numOfBatchDims) {
208  // Each operand in a linalg generic contraction could express different
209  // permutations for its batch dimension. But for named op it must be
210  // identity since separate maps are not specified.
211  if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {
212  for (unsigned i = 0; i < numOfBatchDims; ++i) {
213  auto expr = m.getResults()[i];
214  if (expr.getKind() != AffineExprKind::DimId ||
215  cast<AffineDimExpr>(expr).getPosition() != i)
216  return true;
217  }
218  return false;
219  }))
220  return failure();
221  }
222 
223  auto a =
224  matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
225  auto b =
226  matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
227  auto c =
228  matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
229 
230  if (llvm::is_contained({a, b, c}, IndexMatchResult::Mismatch))
231  return failure();
232 
233  if (c != IndexMatchResult::Match ||
234  (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
235  return failure();
236 
237  /// Codegen the different matmul variants.
238  if (numOfBatchDims) {
239  if (a == IndexMatchResult::Transposed)
240  return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
241  genericOp);
242  if (b == IndexMatchResult::Transposed)
243  return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
244  genericOp);
245  return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
246  }
247 
248  if (a == IndexMatchResult::Transposed)
249  return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
250  if (b == IndexMatchResult::Transposed)
251  return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
252  return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
253 }
254 
255 } // namespace
256 
257 //===----------------------------------------------------------------------===//
258 // Categorize linalg generic to named op where possible.
259 //===----------------------------------------------------------------------===//
260 FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
261  GenericOp genericOp) {
262  if (isaCopyOpInterface(genericOp)) {
263  LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
264  genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
265  return namedOp;
266  }
267 
268  if (isaFillOpInterface(genericOp)) {
269  LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
270  genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
271  return namedOp;
272  }
273 
274  if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
275  Operation *op = &genericOp.getBody()->front();
276  if (isa<math::ExpOp>(op)) {
277  LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
278  return namedOp;
279  }
280  }
281 
282  if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
283  bool swap = areBinOpsSwapped(genericOp);
284  Operation *op = &genericOp.getBody()->front();
285  if (isa<arith::AddFOp>(op)) {
286  LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
287  return namedOp;
288  }
289  if (isa<arith::SubFOp>(op)) {
290  LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
291  return namedOp;
292  }
293  if (isa<arith::MulFOp>(op)) {
294  LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
295  return namedOp;
296  }
297  if (isa<arith::DivFOp>(op)) {
298  LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
299  return namedOp;
300  }
301  }
302 
303  if (isaContractionOpInterface(genericOp)) {
304  return specializeLinalgContractions(rewriter, genericOp);
305  }
306  return failure();
307 }
308 
309 namespace {
310 struct LinalgSpecializeGenericOpsPass
311  : public impl::LinalgSpecializeGenericOpsPassBase<
312  LinalgSpecializeGenericOpsPass> {
313 
314  using impl::LinalgSpecializeGenericOpsPassBase<
315  LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
316  void runOnOperation() override;
317 };
318 } // namespace
319 
320 void LinalgSpecializeGenericOpsPass::runOnOperation() {
321  RewritePatternSet patterns(&getContext());
323 
324  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
325  signalPassFailure();
326 }
327 
329  RewritePatternSet &patterns) {
330  patterns.add<LinalgSpecializationPattern>(patterns.getContext());
331 }
static MLIRContext * getContext(OpFoldResult val)
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)
Definition: Specialize.cpp:32
static bool areBinOpsSwapped(GenericOp genericOp)
Definition: Specialize.cpp:58
#define REPLACE_UNARY_OP(NEWOP)
Definition: Specialize.cpp:39
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
Operation & front()
Definition: Block.h:151
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:383
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
Definition: Specialize.cpp:260
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
void populateLinalgGenericOpsSpecializationPatterns(RewritePatternSet &patterns)
Populates patterns with patterns to convert linalg.generic ops to named ops where possible.
Definition: Specialize.cpp:328
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
@ DimId
Dimensional identifier.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...