MLIR  22.0.0git
Specialize.cpp
Go to the documentation of this file.
1 //===- Specialize.cpp - linalg generic ops to named ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a method to specialize generic operations to named
10 // operations. Conceptually it is the opposite of generalize.cpp.
11 //
12 //===----------------------------------------------------------------------===//
13 
20 #include "mlir/IR/PatternMatch.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
25 #include "mlir/Dialect/Linalg/Passes.h.inc"
26 } // namespace mlir
27 
28 #define DEBUG_TYPE "linalg-specialization"
29 
30 #define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
31  (rewriter.replaceOpWithNewOp<NEWOP>( \
32  genericOp, \
33  ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
34  genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
35  ValueRange{genericOp.getDpsInits()[0]}))
36 
37 #define REPLACE_UNARY_OP(NEWOP) \
38  (rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
39  ValueRange{genericOp.getDpsInputs()[0]}, \
40  ValueRange{genericOp.getDpsInits()[0]}))
41 
42 using namespace mlir;
43 using namespace mlir::linalg;
44 
45 // Given a elementwise single binary linalg generic op, checks whether the
46 // binary op accesses operands as swapped. e.g.
47 // this differentiates between a linalg-generic body that contains:
48 // ^bb0(%a: f32, %b: f32, %c : f32):
49 // %0 = arith.subf %a, %b : f32
50 // linalg.yield %0: f32
51 // against:
52 // ^bb0(%a: f32, %b: f32, %c : f32):
53 // %0 = arith.subf %b, %a : f32
54 // linalg.yield %0: f32
55 // Former is linalg.sub(a,b), latter is linalg.sub(b,a).
56 static bool areBinOpsSwapped(GenericOp genericOp) {
57  Block *body = genericOp.getBody();
58  Operation *op = &body->front();
59  bool swapped = false;
60  if (op->getOpOperand(0).get() != body->getArgument(0)) {
61  swapped = true;
62  assert(op->getOpOperand(0).get() == body->getArgument(1) &&
63  op->getOpOperand(1).get() == body->getArgument(0) &&
64  "binary op uses just one block arg");
65  }
66  return swapped;
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // Specialize linalg generic to matmul variants.
71 //===----------------------------------------------------------------------===//
72 /// Identifies linalg.generic that is essentially named op of the form:
73 // ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
74 //
75 // It is possible that a linalg.generic may be implementing a matmul but not
76 // in a straight-forward way e.g. below is matrix multiply over some slice
77 // ```
78 // %0 = linalg.generic {
79 // indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
80 // affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
81 // affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
82 // iterator_types = ["parallel", "parallel", "parallel"]}
83 // ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>)
84 // outs(%C : tensor<20x20x20xf32>) {
85 // ^bb0(%a: f32, %b: f32, %c : f32):
86 // %mul = arith.mulf %a, %b : f32
87 // %add = arith.addf %mul, %c : f32
88 // linalg.yield %add : f32
89 // } -> tensor<20x20x20xf32>
90 // ```
91 // It is not possible to represent above as named op.
92 // e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is
93 // not the same as linalg.generic above.
94 namespace {
95 enum class IndexMatchResult {
96  Match = 0, // identity map.
97  Transposed, // transposed map.
98  Mismatch // none of the above.
99 };
100 
101 // Checks whether the input Affine `map` contains two consecutive dims that
102 // can be interpreted as accessing a 2D matrix. It is assumed that the row
103 // column dimension are adjacent axis (in this order) and start at
104 // `rowDimIdx` in the input map.
105 //
106 // e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
107 // whether the map of A is identity (match), transposed, or something
108 // completely different (mis-match). Similar for B and C.
109 static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
110  unsigned expectedPosOfRowDim,
111  unsigned expectedPosOfColDim) {
112  // Get the matrix multiply indices. They are past the batch indices.
113  auto exprOfRowDim = map.getResults()[rowDimIdx];
114  auto exprOfColDim = map.getResults()[rowDimIdx + 1];
115 
116  // They should be pure dimension ids.
117  if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
118  exprOfColDim.getKind() != AffineExprKind::DimId)
119  return IndexMatchResult::Mismatch;
120 
121  auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
122  auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
123 
124  if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
125  return IndexMatchResult::Match;
126 
127  if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
128  return IndexMatchResult::Transposed;
129 
130  return IndexMatchResult::Mismatch;
131 }
132 
133 // Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
134 // All the variants expressed as pseudo regular expression:
135 // `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
136 // have same number of ins/out, so its easy to stamp different versions.
137 template <typename NamedOpTy>
138 static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
139  LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
140  op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
141  ValueRange{op.getDpsInits()[0]});
142  return namedOp;
143 }
144 
145 // Converts linalg.generic to named linalg.*matmul* where possible.
146 static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
147  GenericOp genericOp) {
148  if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
149  return failure();
150 
151  // Early exit if not projected permutations.
152  auto mapRange = genericOp.getIndexingMapsArray();
153  if (llvm::any_of(mapRange,
154  [](AffineMap m) { return !m.isProjectedPermutation(); }))
155  return failure();
156 
157  // Linalg generic contraction can be across multiple axis e.g.
158  // ```
159  // linalg.generic
160  // {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>,
161  // affine_map<(m, n, k1, k2) -> (k2, k1, n)>,
162  // affine_map<(m, n, k1, k2) -> (m, n)>],
163  // iterator_types = ["parallel", "parallel",
164  // "reduction", "reduction"]}
165  // ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
166  // outs(%C : tensor<10x40xf32>) {
167  // ^bb0(%a: f32, %b: f32, %c: f32):
168  // %1 = arith.mulf %a, %b : f32
169  // %2 = arith.addf %c, %1 : f32
170  // linalg.yield %2 : f32
171  // } -> tensor<10x40xf32>
172  // ```
173  // In above contraction, there are two reduction dimensions {k1, k2}
174  // and although a valid linalg contraction, it is not a named-op
175  // matrix multiply kind. Therefore, reject multi-dim reduction.
176  auto res = inferContractionDims(genericOp);
177  if (!succeeded(res))
178  return failure();
179  auto dims = *res;
180  if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
181  return failure();
182 
184  *genericOp.getBlock(), [](Operation *first, Operation *second) {
185  if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
186  (isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
187  (isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
188  return true;
189  return false;
190  }))
191  return failure();
192 
193  // Check rank of operands
194  auto indexingMaps = genericOp.getIndexingMapsArray();
195  if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
196  return m.getResults().size() !=
197  dims.batch.size() + 2 /* any two of {m,n,k} */;
198  }))
199  return failure();
200 
201  auto numOfBatchDims = dims.batch.size();
202  if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
203  return failure();
204 
205  if (numOfBatchDims) {
206  // Each operand in a linalg generic contraction could express different
207  // permutations for its batch dimension. But for named op it must be
208  // identity since separate maps are not specified.
209  if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {
210  for (unsigned i = 0; i < numOfBatchDims; ++i) {
211  auto expr = m.getResults()[i];
212  if (expr.getKind() != AffineExprKind::DimId ||
213  cast<AffineDimExpr>(expr).getPosition() != i)
214  return true;
215  }
216  return false;
217  }))
218  return failure();
219  }
220 
221  auto a =
222  matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
223  auto b =
224  matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
225  auto c =
226  matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
227 
228  if (llvm::is_contained({a, b, c}, IndexMatchResult::Mismatch))
229  return failure();
230 
231  if (c != IndexMatchResult::Match ||
232  (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
233  return failure();
234 
235  /// Codegen the different matmul variants.
236  if (numOfBatchDims) {
237  return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
238  }
239  return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
240 }
241 
242 } // namespace
243 
244 //===----------------------------------------------------------------------===//
245 // Categorize linalg generic to named op where possible.
246 //===----------------------------------------------------------------------===//
247 FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
248  GenericOp genericOp) {
249  // Copy
250  if (isaCopyOpInterface(genericOp)) {
251  LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
252  genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
253  return namedOp;
254  }
255 
256  // Fill
257  if (std::optional<Value> fillValue = isaFillOpInterface(genericOp)) {
258  // Always use the detected fill value, regardless of pattern
259  LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
260  genericOp, *fillValue, genericOp.getDpsInits()[0]);
261  return namedOp;
262  }
263 
264  // Broadcast
265  std::optional<SmallVector<int64_t>> equivalentToBroadcast =
266  isaBroadcastOpInterface(genericOp);
267  if (equivalentToBroadcast) {
268  auto dims = *equivalentToBroadcast;
269  LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
270  genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
271  dims);
272  return namedOp;
273  }
274 
275  // Transpose
276  std::optional<SmallVector<int64_t>> equivalentToTranspose =
277  isaTransposeOpInterface(genericOp);
278  if (equivalentToTranspose) {
279  auto permutation = *equivalentToTranspose;
280  LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
281  genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
282  permutation);
283  return namedOp;
284  }
285 
286  // Elementwise Unary
287  if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
288  Operation *op = &genericOp.getBody()->front();
289  if (isa<math::ExpOp>(op)) {
290  LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
291  return namedOp;
292  }
293  }
294 
295  // Elementwise Binary
296  if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
297  bool swap = areBinOpsSwapped(genericOp);
298  Operation *op = &genericOp.getBody()->front();
299  if (isa<arith::AddFOp>(op)) {
300  LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
301  return namedOp;
302  }
303  if (isa<arith::SubFOp>(op)) {
304  LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
305  return namedOp;
306  }
307  if (isa<arith::MulFOp>(op)) {
308  LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
309  return namedOp;
310  }
311  if (isa<arith::DivFOp>(op)) {
312  LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
313  return namedOp;
314  }
315  }
316 
317  // Contraction - e.g. matmul
318  if (isaContractionOpInterface(genericOp)) {
319  return specializeLinalgContractions(rewriter, genericOp);
320  }
321  return failure();
322 }
323 
324 namespace {
325 struct LinalgSpecializeGenericOpsPass
326  : public impl::LinalgSpecializeGenericOpsPassBase<
327  LinalgSpecializeGenericOpsPass> {
328 
329  using impl::LinalgSpecializeGenericOpsPassBase<
330  LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
331  void runOnOperation() override;
332 };
333 } // namespace
334 
335 void LinalgSpecializeGenericOpsPass::runOnOperation() {
339 
340  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
341  signalPassFailure();
342 }
343 
346  patterns.add<LinalgSpecializationPattern>(patterns.getContext());
347 }
static MLIRContext * getContext(OpFoldResult val)
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)
Definition: Specialize.cpp:30
static bool areBinOpsSwapped(GenericOp genericOp)
Definition: Specialize.cpp:56
#define REPLACE_UNARY_OP(NEWOP)
Definition: Specialize.cpp:37
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation & front()
Definition: Block.h:153
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:388
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns)
Add patterns to make explicit broadcasts and transforms in the input operands of a genericOp.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
Definition: Specialize.cpp:247
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.broadcast.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
void populateLinalgGenericOpsSpecializationPatterns(RewritePatternSet &patterns)
Populates patterns with patterns to convert linalg.generic ops to named ops where possible.
Definition: Specialize.cpp:344
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
@ DimId
Dimensional identifier.
const FrozenRewritePatternSet & patterns